from __future__ import absolute_import
import math
from . import types, is_power_of_two
from copy import copy
try:
from typing import Tuple, Union, Iterable, Any, Set, TYPE_CHECKING if TYPE_CHECKING:
from srcgen import Formatter Interval = Tuple[int, int]
BoolInterval = Union[bool, Interval]
SpecialSpec = Union[bool, Iterable[types.SpecialType]]
except ImportError:
pass
MAX_LANES = 256
MAX_BITS = 64
MAX_BITVEC = MAX_BITS * MAX_LANES
def int_log2(x):
return int(math.log(x, 2))
def intersect(a, b):
if a[0] is None or b[0] is None:
return (None, None)
lo = max(a[0], b[0])
assert lo is not None
hi = min(a[1], b[1])
assert hi is not None
if lo <= hi:
return (lo, hi)
else:
return (None, None)
def is_empty(intv):
return intv is None or intv is False or intv == (None, None)
def encode_bitset(vals, size):
res = 0
assert is_power_of_two(size) and size <= 64
for v in vals:
assert 0 <= v and v < size
res |= 1 << v
return res
def pp_set(s):
return '{' + ', '.join([repr(x) for x in sorted(s)]) + '}'
def decode_interval(intv, full_range, default=None):
if isinstance(intv, tuple):
lo, hi = intv
assert is_power_of_two(lo)
assert is_power_of_two(hi)
assert lo <= hi
assert lo >= full_range[0]
assert hi <= full_range[1]
return intv
if intv:
return full_range
else:
return (default, default)
def interval_to_set(intv):
if is_empty(intv):
return set()
(lo, hi) = intv
assert is_power_of_two(lo)
assert is_power_of_two(hi)
assert lo <= hi
return set([2**i for i in range(int_log2(lo), int_log2(hi)+1)])
def legal_bool(bits):
return bits == 1 or \
(bits >= 8 and bits <= MAX_BITS and is_power_of_two(bits))
class TypeSet(object):
def __init__(
self,
lanes=None, ints=None, floats=None, bools=None, bitvecs=None, specials=None ):
self.lanes = interval_to_set(decode_interval(lanes, (1, MAX_LANES), 1))
self.ints = interval_to_set(decode_interval(ints, (8, MAX_BITS)))
self.floats = interval_to_set(decode_interval(floats, (32, 64)))
self.bools = interval_to_set(decode_interval(bools, (1, MAX_BITS)))
self.bools = set(filter(legal_bool, self.bools))
self.bitvecs = interval_to_set(decode_interval(bitvecs,
(1, MAX_BITVEC)))
self.specials = set() if isinstance(specials, bool):
if specials:
self.specials = set(types.ValueType.all_special_types)
elif specials:
self.specials = set(specials)
def copy(self):
n = TypeSet()
n.lanes = copy(self.lanes)
n.ints = copy(self.ints)
n.floats = copy(self.floats)
n.bools = copy(self.bools)
n.bitvecs = copy(self.bitvecs)
n.specials = copy(self.specials)
return n
def typeset_key(self):
return (tuple(sorted(list(self.lanes))),
tuple(sorted(list(self.ints))),
tuple(sorted(list(self.floats))),
tuple(sorted(list(self.bools))),
tuple(sorted(list(self.bitvecs))),
tuple(sorted(s.name for s in self.specials)))
def __hash__(self):
h = hash(self.typeset_key())
assert h == getattr(self, 'prev_hash', h), "TypeSet changed"
self.prev_hash = h
return h
def __eq__(self, other):
if isinstance(other, TypeSet):
return self.typeset_key() == other.typeset_key()
else:
return False
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
s = 'TypeSet(lanes={}'.format(pp_set(self.lanes))
if len(self.ints) > 0:
s += ', ints={}'.format(pp_set(self.ints))
if len(self.floats) > 0:
s += ', floats={}'.format(pp_set(self.floats))
if len(self.bools) > 0:
s += ', bools={}'.format(pp_set(self.bools))
if len(self.bitvecs) > 0:
s += ', bitvecs={}'.format(pp_set(self.bitvecs))
if len(self.specials) > 0:
s += ', specials=[{}]'.format(pp_set(self.specials))
return s + ')'
def emit_fields(self, fmt):
assert len(self.bitvecs) == 0, "Bitvector types are not emitable."
fmt.comment(repr(self))
fields = (('lanes', 16),
('ints', 8),
('floats', 8),
('bools', 8))
for (field, bits) in fields:
vals = [int_log2(x) for x in getattr(self, field)]
fmt.line('{}: BitSet::<u{}>({}),'
.format(field, bits, encode_bitset(vals, bits)))
def __iand__(self, other):
self.lanes.intersection_update(other.lanes)
self.ints.intersection_update(other.ints)
self.floats.intersection_update(other.floats)
self.bools.intersection_update(other.bools)
self.bitvecs.intersection_update(other.bitvecs)
self.specials.intersection_update(other.specials)
return self
def issubset(self, other):
return self.lanes.issubset(other.lanes) and \
self.ints.issubset(other.ints) and \
self.floats.issubset(other.floats) and \
self.bools.issubset(other.bools) and \
self.bitvecs.issubset(other.bitvecs) and \
self.specials.issubset(other.specials)
def lane_of(self):
new = self.copy()
new.lanes = set([1])
new.bitvecs = set()
return new
def as_bool(self):
new = self.copy()
new.ints = set()
new.floats = set()
new.bitvecs = set()
if len(self.lanes.difference(set([1]))) > 0:
new.bools = self.ints.union(self.floats).union(self.bools)
if 1 in self.lanes:
new.bools.add(1)
return new
def half_width(self):
new = self.copy()
new.ints = set([x//2 for x in self.ints if x > 8])
new.floats = set([x//2 for x in self.floats if x > 32])
new.bools = set([x//2 for x in self.bools if x > 8])
new.bitvecs = set([x//2 for x in self.bitvecs if x > 1])
new.specials = set()
return new
def double_width(self):
new = self.copy()
new.ints = set([x*2 for x in self.ints if x < MAX_BITS])
new.floats = set([x*2 for x in self.floats if x < MAX_BITS])
new.bools = set(filter(legal_bool,
set([x*2 for x in self.bools if x < MAX_BITS])))
new.bitvecs = set([x*2 for x in self.bitvecs if x < MAX_BITVEC])
new.specials = set()
return new
def half_vector(self):
new = self.copy()
new.bitvecs = set()
new.lanes = set([x//2 for x in self.lanes if x > 1])
new.specials = set()
return new
def double_vector(self):
new = self.copy()
new.bitvecs = set()
new.lanes = set([x*2 for x in self.lanes if x < MAX_LANES])
new.specials = set()
return new
def to_bitvec(self):
assert len(self.bitvecs) == 0
all_scalars = self.ints.union(self.floats.union(self.bools))
new = self.copy()
new.lanes = set([1])
new.ints = set()
new.bools = set()
new.floats = set()
new.bitvecs = set([lane_w * nlanes for lane_w in all_scalars
for nlanes in self.lanes])
new.specials = set()
return new
def image(self, func):
if (func == TypeVar.LANEOF):
return self.lane_of()
elif (func == TypeVar.ASBOOL):
return self.as_bool()
elif (func == TypeVar.HALFWIDTH):
return self.half_width()
elif (func == TypeVar.DOUBLEWIDTH):
return self.double_width()
elif (func == TypeVar.HALFVECTOR):
return self.half_vector()
elif (func == TypeVar.DOUBLEVECTOR):
return self.double_vector()
elif (func == TypeVar.TOBITVEC):
return self.to_bitvec()
else:
assert False, "Unknown derived function: " + func
def preimage(self, func):
if (self.size() == 0):
return self
if (func == TypeVar.LANEOF):
new = self.copy()
new.bitvecs = set()
new.lanes = set([2**i for i in range(0, int_log2(MAX_LANES)+1)])
return new
elif (func == TypeVar.ASBOOL):
new = self.copy()
new.bitvecs = set()
if 1 not in self.bools:
new.ints = self.bools.difference(set([1]))
new.floats = self.bools.intersection(set([32, 64]))
new.lanes = self.lanes.difference(set([1]))
else:
new.ints = set([2**x for x in range(3, 7)])
new.floats = set([32, 64])
return new
elif (func == TypeVar.HALFWIDTH):
return self.double_width()
elif (func == TypeVar.DOUBLEWIDTH):
return self.half_width()
elif (func == TypeVar.HALFVECTOR):
return self.double_vector()
elif (func == TypeVar.DOUBLEVECTOR):
return self.half_vector()
elif (func == TypeVar.TOBITVEC):
new = TypeSet()
lanes = interval_to_set(decode_interval(True, (1, MAX_LANES), 1))
ints = interval_to_set(decode_interval(True, (8, MAX_BITS)))
floats = interval_to_set(decode_interval(True, (32, 64)))
bools = interval_to_set(decode_interval(True, (1, MAX_BITS)))
has_t = set() for l in lanes:
for i in ints:
if i * l in self.bitvecs:
has_t.add(('i', i, l))
for i in bools:
if i * l in self.bitvecs:
has_t.add(('b', i, l))
for i in floats:
if i * l in self.bitvecs:
has_t.add(('f', i, l))
for (t, width, lane) in has_t:
new.lanes.add(lane)
if (t == 'i'):
new.ints.add(width)
elif (t == 'b'):
new.bools.add(width)
else:
assert t == 'f'
new.floats.add(width)
return new
else:
assert False, "Unknown derived function: " + func
def size(self):
return (len(self.lanes) * (len(self.ints) + len(self.floats) +
len(self.bools) + len(self.bitvecs)) +
len(self.specials))
def concrete_types(self):
def by(scalar, lanes):
if (lanes == 1):
return scalar
else:
return scalar.by(lanes)
for nlanes in self.lanes:
for bits in self.ints:
yield by(types.IntType.with_bits(bits), nlanes)
for bits in self.floats:
yield by(types.FloatType.with_bits(bits), nlanes)
for bits in self.bools:
yield by(types.BoolType.with_bits(bits), nlanes)
for bits in self.bitvecs:
assert nlanes == 1
yield types.BVType.with_bits(bits)
for spec in self.specials:
yield spec
def get_singleton(self):
types = list(self.concrete_types())
assert len(types) == 1
return types[0]
def widths(self):
scalar_w = self.ints.union(self.floats.union(self.bools))
scalar_w = scalar_w.union(self.bitvecs)
return set(w * l for l in self.lanes for w in scalar_w)
class TypeVar(object):
def __init__(
self,
name, doc, ints=False, floats=False, bools=False, scalars=True, simd=False, bitvecs=False, base=None, derived_func=None, specials=None ):
self.name = name
self.__doc__ = doc
self.is_derived = isinstance(base, TypeVar)
if base:
assert self.is_derived
assert derived_func
self.base = base
self.derived_func = derived_func
self.name = '{}({})'.format(derived_func, base.name)
else:
min_lanes = 1 if scalars else 2
lanes = decode_interval(simd, (min_lanes, MAX_LANES), 1)
self.type_set = TypeSet(
lanes=lanes,
ints=ints,
floats=floats,
bools=bools,
bitvecs=bitvecs,
specials=specials)
@staticmethod
def singleton(typ):
scalar = None if isinstance(typ, types.VectorType):
scalar = typ.base
lanes = (typ.lanes, typ.lanes)
elif isinstance(typ, types.LaneType):
scalar = typ
lanes = (1, 1)
elif isinstance(typ, types.SpecialType):
return TypeVar(typ.name, typ.__doc__, specials=[typ])
else:
assert isinstance(typ, types.BVType)
scalar = typ
lanes = (1, 1)
ints = None
floats = None
bools = None
bitvecs = None
if isinstance(scalar, types.IntType):
ints = (scalar.bits, scalar.bits)
elif isinstance(scalar, types.FloatType):
floats = (scalar.bits, scalar.bits)
elif isinstance(scalar, types.BoolType):
bools = (scalar.bits, scalar.bits)
elif isinstance(scalar, types.BVType):
bitvecs = (scalar.bits, scalar.bits)
tv = TypeVar(
typ.name, typ.__doc__,
ints=ints, floats=floats, bools=bools,
bitvecs=bitvecs, simd=lanes)
return tv
def __str__(self):
return "`{}`".format(self.name)
def __repr__(self):
if self.is_derived:
return (
'TypeVar({}, base={}, derived_func={})'
.format(self.name, self.base, self.derived_func))
else:
return (
'TypeVar({}, {})'
.format(self.name, self.type_set))
def __hash__(self):
if (not self.is_derived):
return object.__hash__(self)
return hash((self.derived_func, self.base))
def __eq__(self, other):
if not isinstance(other, TypeVar):
return False
if self.is_derived and other.is_derived:
return (
self.derived_func == other.derived_func and
self.base == other.base)
else:
return self is other
def __ne__(self, other):
return not self.__eq__(other)
LANEOF = 'lane_of'
ASBOOL = 'as_bool'
HALFWIDTH = 'half_width'
DOUBLEWIDTH = 'double_width'
HALFVECTOR = 'half_vector'
DOUBLEVECTOR = 'double_vector'
TOBITVEC = 'to_bitvec'
@staticmethod
def is_bijection(func):
return func in [
TypeVar.HALFWIDTH,
TypeVar.DOUBLEWIDTH,
TypeVar.HALFVECTOR,
TypeVar.DOUBLEVECTOR]
@staticmethod
def inverse_func(func):
return {
TypeVar.HALFWIDTH: TypeVar.DOUBLEWIDTH,
TypeVar.DOUBLEWIDTH: TypeVar.HALFWIDTH,
TypeVar.HALFVECTOR: TypeVar.DOUBLEVECTOR,
TypeVar.DOUBLEVECTOR: TypeVar.HALFVECTOR
}[func]
@staticmethod
def derived(base, derived_func):
ts = base.get_typeset()
assert len(ts.specials) == 0, "Can't derive from special types"
if derived_func == TypeVar.HALFWIDTH:
if len(ts.ints) > 0:
assert min(ts.ints) > 8, "Can't halve all integer types"
if len(ts.floats) > 0:
assert min(ts.floats) > 32, "Can't halve all float types"
if len(ts.bools) > 0:
assert min(ts.bools) > 8, "Can't halve all boolean types"
elif derived_func == TypeVar.DOUBLEWIDTH:
if len(ts.ints) > 0:
assert max(ts.ints) < MAX_BITS,\
"Can't double all integer types."
if len(ts.floats) > 0:
assert max(ts.floats) < MAX_BITS,\
"Can't double all float types."
if len(ts.bools) > 0:
assert max(ts.bools) < MAX_BITS, "Can't double all bool types."
elif derived_func == TypeVar.HALFVECTOR:
assert min(ts.lanes) > 1, "Can't halve a scalar type"
elif derived_func == TypeVar.DOUBLEVECTOR:
assert max(ts.lanes) < MAX_LANES, "Can't double 256 lanes."
return TypeVar(None, None, base=base, derived_func=derived_func)
@staticmethod
def from_typeset(ts):
tv = TypeVar(None, None)
tv.type_set = ts
return tv
def lane_of(self):
return TypeVar.derived(self, self.LANEOF)
def as_bool(self):
return TypeVar.derived(self, self.ASBOOL)
def half_width(self):
return TypeVar.derived(self, self.HALFWIDTH)
def double_width(self):
return TypeVar.derived(self, self.DOUBLEWIDTH)
def half_vector(self):
return TypeVar.derived(self, self.HALFVECTOR)
def double_vector(self):
return TypeVar.derived(self, self.DOUBLEVECTOR)
def to_bitvec(self):
return TypeVar.derived(self, self.TOBITVEC)
def singleton_type(self):
ts = self.get_typeset()
if ts.size() != 1:
return None
return ts.get_singleton()
def free_typevar(self):
if self.is_derived:
return self.base.free_typevar()
elif self.singleton_type() is not None:
return None
else:
return self
def rust_expr(self):
if self.is_derived:
return '{}.{}()'.format(
self.base.rust_expr(), self.derived_func)
elif self.singleton_type():
return self.singleton_type().rust_name()
else:
return self.name
def constrain_types_by_ts(self, ts):
if not self.is_derived:
self.type_set &= ts
else:
self.base.constrain_types_by_ts(ts.preimage(self.derived_func))
def constrain_types(self, other):
if self is other:
return
self.constrain_types_by_ts(other.get_typeset())
def get_typeset(self):
if not self.is_derived:
return self.type_set
else:
return self.base.get_typeset().image(self.derived_func)
def get_fresh_copy(self, name):
assert not self.is_derived
tv = TypeVar.from_typeset(self.type_set.copy())
tv.name = name
return tv