megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file src/core/include/megbrain/ir/base.td
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#ifndef MGB_BASE
#define MGB_BASE

include "mlir/IR/OpBase.td"

def Mgb_Dialect : Dialect {
  let name = "mgb";
  let cppNamespace = "mgb::dialect";
}

// -- mgb Attr mixin
class MgbAttrWrapperBase<string className> {
  string underlyingType = className;
  int recursionDepth = 0;
}

class MgbHashableAttrMixin {
  string hashFunction = "mgb::hash($0)";
  // return 0 for eq, else for ne
  string cmpFunction = "$0 != $1";
  string reprFunction = "std::to_string($0)";
}

class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit combined, bit toString> {
  string parentNamespace = namespace;
  string enumName = name;
  list<string> enumMembers = members;
  bit enumCombined = combined;
  bit supportToString = toString;
}

class MgbAttrWrapper;
class MgbAliasAttrMixin<Attr base> {
  Attr aliasBase = base;
}

// -- mgb custom Attr
// TODO: CPred and description
class MgbAttrWrapper<string className>:
  Attr<CPred<"true">, "TODO">, MgbAttrWrapperBase<className> {
  let returnType = underlyingType;
}

class HashableAttr<string className>:
    MgbAttrWrapper<className>, MgbHashableAttrMixin;

// -- basic types
class MgbIntegerAttrBase<string CType> : HashableAttr<CType> {
  let storageType = "::mlir::IntegerAttr";
}

class MgbSignlessIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> {
  let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getInt())";
  let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4), $0)";
}

class MgbSignedIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> {
  let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getSInt())";
  let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4, true), $0)";
}

class MgbUnsignedIntegerAttrBase<string CType> : MgbIntegerAttrBase<CType> {
  let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getUInt())";
  let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4, false), $0)";
}

def MgbI8Attr: MgbSignlessIntegerAttrBase<"int8_t">;
def MgbI32Attr: MgbSignlessIntegerAttrBase<"int32_t">;
def MgbI64Attr: MgbSignlessIntegerAttrBase<"int64_t">;
def MgbUI32Attr: MgbUnsignedIntegerAttrBase<"uint32_t">;
def MgbUI64Attr: MgbUnsignedIntegerAttrBase<"uint64_t">;
def MgbSizeTAddr: MgbUnsignedIntegerAttrBase<"size_t">;

class MgbFloatAttrBase<string CType, string DType> : HashableAttr<CType> {
  let storageType = "::mlir::FloatAttr";
  let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getValueAsDouble())";
  let constBuilderCall = "$_builder.getFloatAttr($_builder.get" # DType # "Type(), $0)";
}

def MgbF32Attr : MgbFloatAttrBase<"float", "F32">;
def MgbF64Attr : MgbFloatAttrBase<"double", "F64">;

def MgbBoolAttr : HashableAttr<"bool"> {
  let storageType = "::mlir::BoolAttr";
  let constBuilderCall = "$_builder.getBoolAttr($0)";
}

def MgbStringAttr : HashableAttr<"std::string"> {
  let storageType = "::mlir::StringAttr";
  let convertFromStorage = "$_self.getValue().str()";
  let constBuilderCall = "$_builder.getStringAttr($0)"; // llvm::StringRef implicit ctor
  string reprFunction = "$0";
}

class MgbArrayAttr<MgbAttrWrapper elem>:
    HashableAttr<"std::vector<" # elem.underlyingType # ">"> {
  let storageType = "::mlir::ArrayAttr";
  let recursionDepth = !add(elem.recursionDepth, 1);
  let convertFromStorage =
    "[&] {\n"
    "    " # underlyingType # " ret" # recursionDepth # ";\n"
    "    std::for_each($_self.begin(), $_self.end(), [&](auto&& i" # recursionDepth # ") {\n"
    "        ret" # recursionDepth # ".push_back(\n"
    "            " # !subst("$_self", "i" # recursionDepth # ".template cast<" # elem.storageType # ">()", "" # elem.convertFromStorage) # "\n"
    "        );\n"
    "    });\n"
    "    return ret" # recursionDepth # ";}()";
  let constBuilderCall =
    "[&] {\n"
    "    std::vector<mlir::Attribute> ret" # recursionDepth # ";\n"
    "    std::for_each($0.begin(), $0.end(), [&](auto&& i" # recursionDepth # ") {\n"
    "        ret" # recursionDepth # ".push_back(\n"
    "            " # !subst("$0", "i" # recursionDepth, "" # elem.constBuilderCall) # "\n"
    "        );\n"
    "    });\n"
    "    return $_builder.getArrayAttr(ret" # recursionDepth # ");"
    "}()";
   let reprFunction = "\"{std::vector}\"";
}

defvar EmptyStrList = !listsplat("", 0);
class StrListAppend<list<string> l, string s> {
  list<string> r = !listconcat(l, !listsplat(s, 1));
}

class TupleConvertFromStorage<MgbAttrWrapper attr, int idx> {
  string r = !subst(
    "$_self",
    "$_self[" # !cast<string>(idx) # "].template cast<"# attr.storageType #">()",
    "" # attr.convertFromStorage);
}

class TupleConstBuilderCall<MgbAttrWrapper attr, int idx> {
  string r = !subst(
    "$0",
    "std::get<" # !cast<string>(idx) # ">($0)",
    "" # attr.constBuilderCall);
}

class ApplyTupleConvertFromStorage<list<MgbAttrWrapper> args> {
  list<string> r = !foldl(
    EmptyStrList, args, l, arg, StrListAppend<l, TupleConvertFromStorage<arg, !size(l)>.r>.r);
}

class ApplyTupleConstBuilderCall<list<MgbAttrWrapper> args> {
  list<string> r = !foldl(
    EmptyStrList, args, l, arg, StrListAppend<l, TupleConstBuilderCall<arg, !size(l)>.r>.r);
}

class MgbTupleAttr<list<MgbAttrWrapper> args>:
    HashableAttr<"std::tuple<" # StrJoin<!foreach(i, args, i.underlyingType)>.result # ">"> {
  let storageType = "::mlir::ArrayAttr";
  let convertFromStorage = "std::make_tuple(" # StrJoin<ApplyTupleConvertFromStorage<args>.r>.result # ")";
  let constBuilderCall = "$_builder.getArrayAttr({" # StrJoin<ApplyTupleConstBuilderCall<args>.r>.result # "})";
}

// -- enum types
class MgbEnumAttr<string namespace, string enumName, list<string> members, bit combined, bit toString=0>:
    HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, combined, toString> {
  let storageType = "::mlir::IntegerAttr";
  let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
  let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
  let hashFunction = "mgb::enumhash()($0)";
  string reprFunction = "std::to_string((int)$0)";
}

class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>:
    MgbEnumAttr<namespace, enumName, base.enumMembers, 0>, MgbAliasAttrMixin<base>;

// -- other types
def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> {
  let storageType = "::mlir::IntegerAttr";
  let convertFromStorage = underlyingType # "::from_enum(static_cast<::megdnn::DTypeEnum>($_self.getInt()))";
  let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0.enumv()))";
  let hashFunction = "mgb::hash($0.handle())";
  let reprFunction = "$0.name()";
}

def MgbCompNodeAttr: HashableAttr<"::mgb::CompNode"> {
  let storageType = "::mlir::StringAttr";
  let convertFromStorage = underlyingType # "::load($_self.getValue().str())";
  let constBuilderCall = "$_builder.getStringAttr($0.to_string_logical())";
  string reprFunction = "$0.to_string()";
}

def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> {
  let storageType = "::mlir::ArrayAttr";
  let hashFunction = "mgb::PODHash<size_t>::perform($0.shape, $0.ndim)";
  let cmpFunction = "!$0.eq_shape($1)";
  defvar elemInst = MgbSizeTAddr;
  let convertFromStorage =
    "[&] {\n"
    "    " # underlyingType # " ret;\n"
    "    std::for_each($_self.begin(), $_self.end(), [&ret](auto&& i) {\n"
    "        ret[ret.ndim ++] = " # !subst("$_self", "i.template cast<"# elemInst.storageType #">()", "" # elemInst.convertFromStorage) # ";\n"
    "    });\n"
    "    return ret;}()";
  let constBuilderCall =
    "[&] {\n"
    "    std::vector<mlir::Attribute> ret;\n"
    "    for (size_t i = 0; i < $0.ndim; ++ i) {\n"
    "        ret.push_back(\n"
    "            " # !subst("$0", "$0[i]", "" # elemInst.constBuilderCall) # "\n"
    "        );\n"
    "    }\n"
    "    return $_builder.getArrayAttr(ret);"
    "}()";
    let reprFunction = "$0.to_string()";
}

class MgbDefaultValuedAttr<MgbAttrWrapper attr, string value>:
    DefaultValuedAttr<attr, value>, MgbAttrWrapperBase<attr.underlyingType> {
  // Note: this class is similar to DefaultValuedAttr but with extra
  // meta informations which are used by mgb dialect tblgen, so this
  // has to be kept up to date with class MgbAttrWrapperMixin
  let recursionDepth = attr.recursionDepth;
}

// -- dnn params
class MgbParamBase<string className> {
  string paramType = className;
  string fullName = "::megdnn::param::" # paramType;
  dag fields = ?;
}

class MgbPackedParamBase<string className, string accessor>:
  MgbParamBase<className> {
  string paramAccessor = accessor;
}

// -- mgb ops
class MgbHashableOpMixin {
  string hashFunction = ?;
  string cmpFunction = ?;
}

class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>:
    Op<Mgb_Dialect, mnemonic, traits> {
  dag inputs = (ins);
  dag extraArguments = (ins);
  // TODO: remove it
  code extraOpdefDecl = ?;
  code nameFunction = ?;

  let arguments = !con(
    !foldl(inputs, params, args, param, !con(args, param.fields)),
    extraArguments);

  list<MgbParamBase> dnnParams = params;
}

class MgbHashableOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>:
  MgbOp<mnemonic, params, traits>, MgbHashableOpMixin;

#endif // MGB_BASE