import argparse
import collections
import textwrap
import os
import hashlib
import struct
import io
from gen_param_defs import member_defs, ParamDef, IndentWriterBase
def _cname_to_fbname(cname):
return {
"uint32_t": "uint",
"uint64_t": "ulong",
"int32_t": "int",
"float": "float",
"double": "double",
"DTypeEnum": "DTypeEnum",
"bool": "bool",
}[cname]
def scramble_enum_member_name(name):
s = name.find('<<')
if s != -1:
name = name[0:name.find('=') + 1] + ' ' + name[s+2:]
if name in ("MIN", "MAX"):
return name + "_"
o_name = name.split(' ')[0].split('=')[0]
if o_name in ("MIN", "MAX"):
return name.replace(o_name, o_name + "_")
return name
class FlatBuffersWriter(IndentWriterBase):
_skip_current_param = False
_last_param = None
_enums = None
_used_enum = None
_cur_const_val = {}
def __call__(self, fout, defs):
param_io = io.StringIO()
super().__call__(param_io)
self._used_enum = set()
self._enums = {}
self._process(defs)
super().__call__(fout)
self._write("// %s", self._get_header())
self._write('include "dtype.fbs";')
self._write("namespace mgb.serialization.fbs.param;\n")
self._write_enums()
self._write(param_io.getvalue())
def _write_enums(self):
for (p, e) in sorted(self._used_enum):
name = p + e
e = self._enums[(p, e)]
self._write_doc(e.name)
attribute = "(bit_flags)" if e.combined else ""
self._write("enum %s%s : uint %s {", p, e.name, attribute, indent=1)
for idx, member in enumerate(e.members):
self._write_doc(member)
self._write("%s,", scramble_enum_member_name(str(member)))
self._write("}\n", indent=-1)
def _write_doc(self, doc):
if not isinstance(doc, member_defs.Doc) or not doc.doc: return
doc_lines = []
if doc.no_reformat:
doc_lines = doc.raw_lines
else:
doc = doc.doc.replace('\n', ' ')
text_width = 80 - len(self._cur_indent) - 4
doc_lines = textwrap.wrap(doc, text_width)
for line in doc_lines:
self._write("/// " + line)
def _on_param_begin(self, p):
self._last_param = p
self._cur_const_val = {}
self._write_doc(p.name)
self._write("table %s {", p.name, indent=1)
def _on_param_end(self, p):
if self._skip_current_param:
self._skip_current_param = False
return
self._write("}\n", indent=-1)
def _on_member_enum(self, e):
p = self._last_param
key = str(p.name), str(e.name)
self._enums[key] = e
if self._skip_current_param:
return
self._write_doc(e.name)
self._used_enum.add(key)
if e.combined:
default = e.compose_combined_enum(e.default)
else:
default = scramble_enum_member_name(
str(e.members[e.default]).split(' ')[0].split('=')[0])
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default)
def _resolve_const(self, v):
while v in self._cur_const_val:
v = self._cur_const_val[v]
return v
def _on_member_field(self, f):
if self._skip_current_param:
return
self._write_doc(f.name)
self._write("%s:%s = %s;", f.name, _cname_to_fbname(f.dtype.cname),
self._get_fb_default(self._resolve_const(f.default)))
def _on_const_field(self, f):
self._cur_const_val[str(f.name)] = str(f.default)
def _on_member_enum_alias(self, e):
if self._skip_current_param:
return
self._used_enum.add((e.src_class, e.src_name))
enum_name = e.src_class + e.src_name
s = e.src_enum
if s.combined:
default = s.compose_combined_enum(e.get_default())
else:
default = scramble_enum_member_name(
str(s.members[e.get_default()]).split(' ')[0].split('=')[0])
self._write("%s:%s = %s;", e.name_field, enum_name, default)
def _get_fb_default(self, cppdefault):
if not isinstance(cppdefault, str):
return cppdefault
d = cppdefault
if d.endswith('f'): return d[:-1]
if d.endswith('ull'):
return d[:-3]
if d.startswith("DTypeEnum::"):
return d[11:]
return d
def main():
parser = argparse.ArgumentParser(
'generate FlatBuffers schema of operator param from description file')
parser.add_argument('input')
parser.add_argument('output')
args = parser.parse_args()
with open(args.input) as fin:
inputs = fin.read()
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
input_hash = hashlib.sha256()
input_hash.update(inputs.encode(encoding='UTF-8'))
input_hash = input_hash.hexdigest()
writer = FlatBuffersWriter()
with open(args.output, 'w') as fout:
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
if __name__ == "__main__":
main()