megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import collections
import textwrap
import os
import hashlib
import struct
import io

from gen_param_defs import member_defs, ParamDef, IndentWriterBase

# FIXME: move supportToString flag definition into the param def source file
ENUM_TO_STRING_SPECIAL_RULES = [
    ("Elemwise", "Mode"),
    ("ElemwiseMultiType", "Mode")
]

class ConverterWriter(IndentWriterBase):
    _skip_current_param = False
    _last_param = None
    _current_tparams = None
    _packed = None
    _const = None

    def __call__(self, fout, defs):
        super().__call__(fout)
        self._write("// %s", self._get_header())
        self._write("#ifndef MGB_PARAM")
        self._write("#define MGB_PARAM")
        self._process(defs)
        self._write("#endif // MGB_PARAM")

    def _ctype2attr(self, ctype, value):
        if ctype == 'uint32_t':
            return 'MgbUI32Attr', value
        if ctype == 'uint64_t':
            return 'MgbUI64Attr', value
        if ctype == 'int32_t':
            return 'MgbI32Attr', value
        if ctype == 'float':
            return 'MgbF32Attr', value
        if ctype == 'double':
            return 'MgbF64Attr', value
        if ctype == 'bool':
            return 'MgbBoolAttr', value
        if ctype == 'DTypeEnum':
            self._packed = False
            return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value)
        raise RuntimeError("unknown ctype")

    def _on_param_begin(self, p):
        self._last_param = p
        self._packed = True
        self._current_tparams = []
        self._const = set()

    def _on_param_end(self, p):
        if self._skip_current_param:
            self._skip_current_param = False
            return
        if self._packed:
            self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1)
        else:
            self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1)
        self._write("let fields = (ins", indent=1)
        self._write(",\n{}".format(self._cur_indent).join(self._current_tparams))
        self._write(");", indent=-1)
        self._write("}\n", indent=-1)
        if self._packed:
            self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name))
        self._current_tparams = None
        self._packed = None
        self._const = None

    def _wrapped_with_default_value(self, attr, default):
        return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default)

    def _on_member_enum(self, e):
        p = self._last_param

        # Note: always generate llvm Record def for enum attribute even it was not
        # directly used by any operator, or other enum couldn't alias to this enum
        td_class = "{}{}".format(p.name, e.name)
        fullname = "::megdnn::param::{}".format(p.name)
        enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name)
        def format(v):
            return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0])
        enum_def += ','.join(format(i) for i in e.members)

        if e.combined:
            enum_def += "], 1"
        else:
            enum_def += "], 0"

        if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
            enum_def += ", 1" # whether generate ToStringTrait
        enum_def += ">"

        self._write("def {} : {};".format(td_class, enum_def))
        if self._skip_current_param:
            return

        # wrapped with default value
        if e.combined:
            default_val = "static_cast<{}::{}>({})".format(
                    fullname, e.name, e.compose_combined_enum(e.default))
        else:
            default_val = "{}::{}::{}".format(
                fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0])

        wrapped = self._wrapped_with_default_value(td_class, default_val)

        self._current_tparams.append("{}:${}".format(wrapped, e.name_field))

    def _on_member_enum_alias(self, e):
        p = self._last_param
        if self._skip_current_param:
            return

        # write enum attr def
        td_class = "{}{}".format(p.name, e.name)
        fullname = "::megdnn::param::{}".format(p.name)
        base_td_class = "{}{}".format(e.src_class, e.src_name)
        enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class)
        self._write("def {} : {};".format(td_class, enum_def))

        # wrapped with default value
        s = e.src_enum
        if s.combined:
            default_val = "static_cast<{}::{}>({})".format(
                    fullname, e.name, s.compose_combined_enum(e.get_default()))
        else:
            default_val = "{}::{}::{}".format(fullname, e.name, str(
                s.members[e.get_default()]).split(' ')[0].split('=')[0])

        wrapped = self._wrapped_with_default_value(td_class, default_val)

        self._current_tparams.append("{}:${}".format(wrapped, e.name_field))


    def _on_member_field(self, f):
        if self._skip_current_param:
            return
        attr, value = self._ctype2attr(f.dtype.cname, str(f.default))
        if str(value) in self._const:
            value = '::megdnn::param::{}::{}'.format(self._last_param.name, value)
        wrapped = self._wrapped_with_default_value(attr, value)
        self._current_tparams.append("{}:${}".format(wrapped, f.name))

    def _on_const_field(self, f):
        self._const.add(str(f.name))

def main():
    parser = argparse.ArgumentParser('generate op param tablegen file')
    parser.add_argument('input')
    parser.add_argument('output')
    args = parser.parse_args()

    with open(args.input) as fin:
        inputs = fin.read()
        exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
        input_hash = hashlib.sha256()
        input_hash.update(inputs.encode(encoding='UTF-8'))
        input_hash = input_hash.hexdigest()

    writer = ConverterWriter()
    with open(args.output, 'w') as fout:
        writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)

if __name__ == "__main__":
    main()