import tvm
from tvm import relay
from calyx.py_ast import CompVar, Stdlib, CompInst, Cell, Invoke, CompPort
from calyx.utils import bits_needed
from typing import List
from dataclasses import dataclass
NumDimsToCell = {
0: Stdlib().register,
1: Stdlib().seq_mem_d1,
2: Stdlib().seq_mem_d2,
3: Stdlib().seq_mem_d3,
4: Stdlib().seq_mem_d4,
}
@dataclass
class DahliaFuncDef:
function_id: str
component_name: str
dest: CompVar
args: List[CompVar]
attributes: tvm.ir.Attrs
data_type: str
component: CompInst
def get_dims(c: CompInst):
id = c.id
id2dimensions = {
"std_reg": 0,
"seq_mem_d1": 1,
"seq_mem_d2": 2,
"seq_mem_d3": 3,
"seq_mem_d4": 4,
}
assert id in id2dimensions, f"{id} not supported."
return id2dimensions[id]
def get_dimension_sizes(c: CompInst) -> List[int]:
dims = get_dims(c)
return [c.args[i] for i in range(1, dims + 1)]
def get_addr_ports(c: CompInst):
dims = get_dims(c)
addresses = range(0, dims)
indices = range(dims + 1, dims << 1 + 1)
return [(f"addr{i}", c.args[n]) for (i, n) in zip(addresses, indices)]
def emit_invoke_control(
decl: CompVar, dest: Cell, args: List[Cell], old_args=[], old_dest=None
) -> Invoke:
ref_cells = []
inputs = []
def add_arg(cell):
comp = cell.comp
param = f"{cell.id.name}"
arg = CompVar(cell.id.name)
if any(p in comp.id for p in ["reg", "const"]):
inputs.append((f"{param}", CompPort(arg, "out")))
else:
ref_cells.append((param, arg))
def add_arg2(arg_cell, param_cell):
assert (
arg_cell.comp == param_cell.comp
), "arg cell and param cell must be same component"
comp = arg_cell.comp
param = f"{param_cell.id.name}"
arg = CompVar(arg_cell.id.name)
if any(p in comp.id for p in ["reg", "const"]):
inputs.append((f"{param}", CompPort(arg, "out")))
else:
ref_cells.append((param, arg))
if len(old_args) == 0:
for cell in args:
add_arg(cell)
add_arg(dest)
else:
assert len(old_args) == len(
args
), "we are reusing a dahlia function but the args are different lengths"
assert old_dest is not None, "if using old_args must provide an old_dest too"
for cell1, cell2 in zip(args, old_args):
add_arg2(cell1, cell2)
add_arg2(dest, old_dest)
return Invoke(decl, inputs, [], ref_cells)
def get_dahlia_data_type(relay_type) -> str:
width = get_bitwidth(relay_type)
if "int" in relay_type.dtype:
return f"bit<{width}>"
if "float" in relay_type.dtype:
return f"fix<{width}, {width // 2}>"
assert 0, f"{relay_type} is not supported."
def get_bitwidth(relay_type) -> int:
dtype = relay_type.dtype
assert "int" in dtype or "float" in dtype, f"{relay_type} not supported."
return int("".join(filter(str.isdigit, dtype)))
def get_memory(name: str, type: tvm.ir.Type) -> Cell:
dims = type.concrete_shape
args = [get_bitwidth(type)] + [d for d in dims] + [bits_needed(d) for d in dims]
num_dims = len(dims)
assert num_dims in NumDimsToCell, f"Memory of size {num_dims} not supported."
return Cell(CompVar(name), NumDimsToCell[num_dims](*args), is_external=True)
def python2relay(func) -> str:
seq = tvm.transform.Sequential(
[
relay.transform.SimplifyExpr(),
relay.transform.SimplifyInference(),
relay.transform.ToANormalForm(),
]
)
mod_opt = tvm.IRModule.from_expr(func)
mod_opt = seq(mod_opt)
return mod_opt["main"]