from .typevar import TypeVar
from .ast import Def, Var
from copy import copy
from itertools import product
try:
from typing import Dict, TYPE_CHECKING, Union, Tuple, Optional, Set from typing import Iterable, List, Any, TypeVar as MTypeVar from typing import cast
from .xform import Rtl, XForm from .ast import Expr from .typevar import TypeSet if TYPE_CHECKING:
T = MTypeVar('T')
TypeMap = Dict[TypeVar, TypeVar]
VarTyping = Dict[Var, TypeVar]
except ImportError:
TYPE_CHECKING = False
pass
class TypeConstraint(object):
def __init__(self, tv, tc):
assert False, "Abstract"
def translate(self, m):
def translate_one(a):
if (isinstance(a, TypeVar)):
return m[a] if isinstance(m, TypeEnv) else subst(a, m)
return a
res = None res = self.__class__(*tuple(map(translate_one, self._args())))
return res
def __eq__(self, other):
if (not isinstance(other, self.__class__)):
return False
assert isinstance(other, TypeConstraint) return self._args() == other._args()
def is_concrete(self):
return [] == list(filter(lambda x: x.singleton_type() is None,
self.tvs()))
def __hash__(self):
return hash(self._args())
def _args(self):
assert False, "Abstract"
def tvs(self):
return list(filter(lambda x: isinstance(x, TypeVar), self._args()))
def is_trivial(self):
assert False, "Abstract"
def eval(self):
assert False, "Abstract"
def __repr__(self):
return (self.__class__.__name__ + '(' +
', '.join(map(str, self._args())) + ')')
class TypesEqual(TypeConstraint):
def __init__(self, tv1, tv2):
(self.tv1, self.tv2) = sorted([tv1, tv2], key=repr)
def _args(self):
return (self.tv1, self.tv2)
def is_trivial(self):
return self.tv1 == self.tv2 or self.is_concrete()
def eval(self):
assert self.is_concrete()
return self.tv1.singleton_type() == self.tv2.singleton_type()
class InTypeset(TypeConstraint):
def __init__(self, tv, ts):
assert not tv.is_derived and tv.name.startswith("typeof_")
self.tv = tv
self.ts = ts
def _args(self):
return (self.tv, self.ts)
def is_trivial(self):
tv_ts = self.tv.get_typeset().copy()
if (tv_ts.issubset(self.ts)):
return True
tv_ts &= self.ts
if (tv_ts.size() == 0):
return True
return self.is_concrete()
def eval(self):
assert self.is_concrete()
return self.tv.get_typeset().issubset(self.ts)
class WiderOrEq(TypeConstraint):
def __init__(self, tv1, tv2):
self.tv1 = tv1
self.tv2 = tv2
def _args(self):
return (self.tv1, self.tv2)
def is_trivial(self):
if (self.tv1 == self.tv2):
return True
ts1 = self.tv1.get_typeset()
ts2 = self.tv2.get_typeset()
def set_wider_or_equal(s1, s2):
return len(s1) > 0 and len(s2) > 0 and min(s1) >= max(s2)
if set_wider_or_equal(ts1.ints, ts2.ints) and\
set_wider_or_equal(ts1.floats, ts2.floats) and\
set_wider_or_equal(ts1.bools, ts2.bools):
return True
def set_narrower(s1, s2):
return len(s1) > 0 and len(s2) > 0 and min(s1) < max(s2)
if set_narrower(ts1.ints, ts2.ints) and\
set_narrower(ts1.floats, ts2.floats) and\
set_narrower(ts1.bools, ts2.bools):
return True
if len(ts1.lanes.intersection(ts2.lanes)) == 0:
return True
return self.is_concrete()
def eval(self):
assert self.is_concrete()
typ1 = self.tv1.singleton_type()
typ2 = self.tv2.singleton_type()
return typ1.wider_or_equal(typ2)
class SameWidth(TypeConstraint):
def __init__(self, tv1, tv2):
self.tv1 = tv1
self.tv2 = tv2
def _args(self):
return (self.tv1, self.tv2)
def is_trivial(self):
if (self.tv1 == self.tv2):
return True
ts1 = self.tv1.get_typeset()
ts2 = self.tv2.get_typeset()
if len(ts1.widths().intersection(ts2.widths())) == 0:
return True
return self.is_concrete()
def eval(self):
assert self.is_concrete()
typ1 = self.tv1.singleton_type()
typ2 = self.tv2.singleton_type()
return (typ1.width() == typ2.width())
class TypeEnv(object):
RANK_SINGLETON = 5
RANK_INPUT = 4
RANK_INTERMEDIATE = 3
RANK_OUTPUT = 2
RANK_TEMP = 1
RANK_INTERNAL = 0
def __init__(self, arg=None):
self.ranks = {} self.vars = set()
if arg is None:
self.type_map = {} self.constraints = [] else:
self.type_map, self.constraints = arg
self.idx = 0
def __getitem__(self, arg):
if (isinstance(arg, Var)):
assert arg in self.vars
tv = arg.get_typevar()
else:
assert (isinstance(arg, TypeVar))
tv = arg
while tv in self.type_map:
tv = self.type_map[tv]
if tv.is_derived:
tv = TypeVar.derived(self[tv.base], tv.derived_func)
return tv
def equivalent(self, tv1, tv2):
assert not tv1.is_derived
assert self[tv1] == tv1
if tv2.is_derived:
assert self[tv2.base] != tv1
self.type_map[tv1] = tv2
def add_constraint(self, constr):
if (constr in self.constraints):
return
if (isinstance(constr, InTypeset)):
self[constr.tv].constrain_types_by_ts(constr.ts)
return
self.constraints.append(constr)
def get_uid(self):
r = str(self.idx)
self.idx += 1
return r
def __repr__(self):
return self.dot()
def rank(self, tv):
default_rank = TypeEnv.RANK_INTERNAL if tv.singleton_type() is None \
else TypeEnv.RANK_SINGLETON
if tv.is_derived:
tv = tv.free_typevar()
return self.ranks.get(tv, default_rank)
def register(self, v):
self.vars.add(v)
if v.is_input():
r = TypeEnv.RANK_INPUT
elif v.is_intermediate():
r = TypeEnv.RANK_INTERMEDIATE
elif v.is_output():
r = TypeEnv.RANK_OUTPUT
else:
assert(v.is_temp())
r = TypeEnv.RANK_TEMP
self.ranks[v.get_typevar()] = r
def free_typevars(self):
tvs = set([self[tv].free_typevar() for tv in self.type_map.keys()])
tvs = tvs.union(set([self[v].free_typevar() for v in self.vars]))
return sorted(filter(lambda x: x is not None, tvs),
key=lambda x: x.name)
def normalize(self):
source_tvs = set([v.get_typevar() for v in self.vars])
children = {} for v in self.type_map.values():
if not v.is_derived:
continue
t = v.free_typevar()
s = children.get(t, set())
s.add(v)
children[t] = s
for (a, b) in self.type_map.items():
s = children.get(b, set())
s.add(a)
children[b] = s
for r in self.free_typevars():
while (r not in source_tvs and r in children and
len(children[r]) == 1):
child = list(children[r])[0]
if child in self.type_map:
assert self.type_map[child] == r
del self.type_map[child]
r = child
def extract(self):
vars_tvs = set([v.get_typevar() for v in self.vars])
new_type_map = {tv: self[tv] for tv in vars_tvs if tv != self[tv]}
new_constraints = [] for constr in self.constraints:
constr = constr.translate(self)
if constr.is_trivial() or constr in new_constraints:
continue
for arg in constr._args():
if (not isinstance(arg, TypeVar)):
continue
arg_free_tv = arg.free_typevar()
assert arg_free_tv is None or arg_free_tv in vars_tvs
new_constraints.append(constr)
for (k, v) in new_type_map.items():
assert k in vars_tvs
assert v.free_typevar() is None or v.free_typevar() in vars_tvs
t = TypeEnv()
t.type_map = new_type_map
t.constraints = new_constraints
t.ranks = copy(self.ranks)
t.vars = copy(self.vars)
return t
def concrete_typings(self):
free_tvs = self.free_typevars()
free_tv_iters = [tv.get_typeset().concrete_types() for tv in free_tvs]
for concrete_types in product(*free_tv_iters):
m = {tv: TypeVar.singleton(typ)
for (tv, typ) in zip(free_tvs, concrete_types)}
concrete_var_map = {v: subst(self[v.get_typevar()], m)
for v in self.vars}
failed = None
for constr in self.constraints:
concrete_constr = constr.translate(m)
if not concrete_constr.eval():
failed = concrete_constr
break
if (failed is not None):
continue
yield concrete_var_map
def permits(self, concrete_typing):
for (v, typ) in concrete_typing.items():
assert typ.singleton_type() is not None
if not typ.get_typeset().issubset(self[v].get_typeset()):
return False
m = {self[v]: typ for (v, typ) in concrete_typing.items()}
for constr in self.constraints:
try:
constr = constr.translate(m)
if not constr.eval():
return False
except KeyError:
pass
return True
def dot(self):
def label(s):
return "\"" + str(s) + "\""
nodes = set() edges = set()
def add_nodes(*args):
for tv in args:
nodes.add(tv)
while (tv.is_derived):
nodes.add(tv.base)
edges.add((tv, tv.base, "solid", "forward",
tv.derived_func))
tv = tv.base
for v in self.vars:
add_nodes(v.get_typevar())
for (tv1, tv2) in self.type_map.items():
add_nodes(tv1, tv2)
edges.add((tv1, tv2, "dotted", "forward", None))
for constr in self.constraints:
if isinstance(constr, TypesEqual):
add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "none", "equal"))
elif isinstance(constr, WiderOrEq):
add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "forward", ">="))
elif isinstance(constr, SameWidth):
add_nodes(constr.tv1, constr.tv2)
edges.add((constr.tv1, constr.tv2, "dashed", "none",
"same_width"))
else:
assert False, "Can't display constraint {}".format(constr)
root_nodes = set([x for x in nodes
if x not in self.type_map and not x.is_derived])
r = "digraph {\n"
for n in nodes:
r += label(n)
if n in root_nodes:
r += "[xlabel=\"{}\"]".format(self[n].get_typeset())
r += ";\n"
for (n1, n2, style, direction, elabel) in edges:
e = label(n1) + "->" + label(n2)
e += "[style={},dir={}".format(style, direction)
if elabel is not None:
e += ",label=\"{}\"".format(elabel)
e += "];\n"
r += e
r += "}"
return r
if TYPE_CHECKING:
TypingError = str
TypingOrError = Union[TypeEnv, TypingError]
def get_error(typing_or_err):
if isinstance(typing_or_err, str):
if (TYPE_CHECKING):
return cast(TypingError, typing_or_err)
else:
return typing_or_err
else:
return None
def get_type_env(typing_or_err):
assert isinstance(typing_or_err, TypeEnv), \
"Unexpected error: {}".format(typing_or_err)
if (TYPE_CHECKING):
return cast(TypeEnv, typing_or_err)
else:
return typing_or_err
def subst(tv, tv_map):
if tv in tv_map:
return tv_map[tv]
if tv.is_derived:
return TypeVar.derived(subst(tv.base, tv_map), tv.derived_func)
return tv
def normalize_tv(tv):
vector_derives = [TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR]
width_derives = [TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH]
if not tv.is_derived:
return tv
df = tv.derived_func
if (tv.base.is_derived):
base_df = tv.base.derived_func
if df in vector_derives and base_df in width_derives:
return normalize_tv(
TypeVar.derived(
TypeVar.derived(tv.base.base, df), base_df))
if (df, base_df) in \
[(TypeVar.HALFVECTOR, TypeVar.DOUBLEVECTOR),
(TypeVar.DOUBLEVECTOR, TypeVar.HALFVECTOR),
(TypeVar.HALFWIDTH, TypeVar.DOUBLEWIDTH),
(TypeVar.DOUBLEWIDTH, TypeVar.HALFWIDTH)]:
return normalize_tv(tv.base.base)
return TypeVar.derived(normalize_tv(tv.base), df)
def constrain_fixpoint(tv1, tv2):
while True:
old_tv1_ts = tv1.get_typeset().copy()
tv2.constrain_types(tv1)
if tv1.get_typeset() == old_tv1_ts:
break
old_tv2_ts = tv2.get_typeset().copy()
tv1.constrain_types(tv2)
assert old_tv2_ts == tv2.get_typeset()
def unify(tv1, tv2, typ):
tv1 = normalize_tv(typ[tv1])
tv2 = normalize_tv(typ[tv2])
if tv1 == tv2:
return typ
if typ.rank(tv2) < typ.rank(tv1):
return unify(tv2, tv1, typ)
constrain_fixpoint(tv1, tv2)
if (tv1.get_typeset().size() == 0 or tv2.get_typeset().size() == 0):
return "Error: empty type created when unifying {} and {}"\
.format(tv1, tv2)
if not tv1.is_derived:
typ.equivalent(tv1, tv2)
return typ
if (tv1.is_derived and TypeVar.is_bijection(tv1.derived_func)):
inv_f = TypeVar.inverse_func(tv1.derived_func)
return unify(tv1.base, normalize_tv(TypeVar.derived(tv2, inv_f)), typ)
typ.add_constraint(TypesEqual(tv1, tv2))
return typ
def move_first(l, i):
return [l[i]] + l[:i] + l[i+1:]
def ti_def(definition, typ):
expr = definition.expr
inst = expr.inst
free_formal_tvs = inst.all_typevars()
m = {tv: tv.get_fresh_copy(str(typ.get_uid())) for tv in free_formal_tvs}
for (idx, bound_typ) in enumerate(expr.typevars):
m[free_formal_tvs[idx]] = TypeVar.singleton(bound_typ)
fresh_formal_tvs = \
[subst(inst.outs[i].typevar, m) for i in inst.value_results] +\
[subst(inst.ins[i].typevar, m) for i in inst.value_opnums]
actual_vars = [] actual_vars += [definition.defs[i] for i in inst.value_results]
actual_vars += [expr.args[i] for i in inst.value_opnums]
actual_tvs = []
for v in actual_vars:
assert(isinstance(v, Var))
typ.register(v)
actual_tvs.append(v.get_typevar())
if inst.is_polymorphic:
idx = fresh_formal_tvs.index(m[inst.ctrl_typevar])
fresh_formal_tvs = move_first(fresh_formal_tvs, idx)
actual_tvs = move_first(actual_tvs, idx)
for (actual_tv, formal_tv) in zip(actual_tvs, fresh_formal_tvs):
typ_or_err = unify(actual_tv, formal_tv, typ)
err = get_error(typ_or_err)
if (err):
return "fail ti on {} <: {}: ".format(actual_tv, formal_tv) + err
typ = get_type_env(typ_or_err)
for constr in inst.constraints:
typ.add_constraint(constr.translate(m))
return typ
def ti_rtl(rtl, typ):
for (i, d) in enumerate(rtl.rtl):
assert (isinstance(d, Def))
typ_or_err = ti_def(d, typ)
err = get_error(typ_or_err) if (err):
return "On line {}: ".format(i) + err
typ = get_type_env(typ_or_err)
return typ
def ti_xform(xform, typ):
typ_or_err = ti_rtl(xform.src, typ)
err = get_error(typ_or_err) if (err):
return "In src pattern: " + err
typ = get_type_env(typ_or_err)
typ_or_err = ti_rtl(xform.dst, typ)
err = get_error(typ_or_err)
if (err):
return "In dst pattern: " + err
typ = get_type_env(typ_or_err)
return get_type_env(typ_or_err)