import functools
import sys
import numpy as np
import decimal
from decimal import Decimal
decimal.getcontext().prec = 1000
@functools.total_ordering
class NumType:
def __init__(self, bits, signed, isfloat):
self.bits = bits
self.signed = signed
self.isfloat = isfloat
if self.isfloat:
assert signed
self.name = f"f{bits}"
if bits == 32:
self.int_min = -2**24
self.int_max = 2**24
self.min = -2**127
self.max = 2**127
else:
assert bits == 64
self.int_min = -2**53
self.int_max = 2**53
self.min = -2**1023
self.max = 2**1023
else:
if signed:
self.name = f"i{bits}"
self.int_min = -2**(bits-1)
self.int_max = 2**(bits-1) - 1
self.min = self.int_min
self.max = self.int_max
else:
self.name = f"u{bits}"
self.int_min = 0
self.int_max = 2**bits - 1
self.min = self.int_min
self.max = self.int_max
def can_exactly_represent(self, x):
if x < self.min or x > self.max:
return False
if self.isfloat:
if self.bits == 32:
return Decimal(float(np.float32(x))) == x
else:
return Decimal(float(np.float64(x))) == x
else:
return x == int(x)
def interesting_values(self):
vals = {self.int_min, self.int_max, self.min, self.max}
for val in vals.copy():
vals.add(val + Decimal("0.5"))
vals.add(val - Decimal("0.5"))
vals.add(val + 1)
vals.add(val - 1)
return vals
def is_subset_eq(self, other):
return other.int_min <= self.int_min <= self.int_max <= other.int_max
def __lt__(self, other):
return (self.isfloat, self.bits, self.signed) < (other.isfloat, other.bits, other.signed)
def __repr__(self):
return self.name
def implies(a, b):
return not a or b
def common_type(a, b, types):
return min((t for t in types
if (a.is_subset_eq(t) and b.is_subset_eq(t))
and implies(a.isfloat or b.isfloat, t.isfloat)), default=None)
if __name__ == "__main__":
types = [
NumType(bits, signed, False)
for bits in [8, 16, 32, 64, 128]
for signed in [False, True]]
types += [NumType(32, True, True), NumType(64, True, True)]
if sys.argv[1] == "common-types":
unhandled = []
for a in types:
for b in types:
if a < b:
c = common_type(a, b, types)
if c is not None:
print(f"{str(a):>4}, {str(b):>4} => {str(c):>4};")
else:
unhandled.append((a, b))
for a, b in unhandled:
print("NO COMMON TYPE", a, b)
elif sys.argv[1] == "all-types":
for a in types:
for b in types:
print(f"{a}, {b};")
elif sys.argv[1] == "tests":
print("""// Automatically generated by tools/gen.py.
use core::cmp::Ordering;
use num_ord::NumOrd;
""")
for t1 in types:
for t2 in types:
if t1 == t2:
continue
print(f"#[test] fn test_{t1}_{t2}() {{")
interesting_values = sorted(t1.interesting_values() | t2.interesting_values())
for v1 in interesting_values:
for v2 in interesting_values:
if t1.can_exactly_represent(v1) and t2.can_exactly_represent(v2):
print(f" assert_eq!(NumOrd({v1}{t1}) < NumOrd({v2}{t2}), {str(v1 < v2).lower()});")
print(f" assert_eq!(NumOrd({v1}{t1}) <= NumOrd({v2}{t2}), {str(v1 <= v2).lower()});")
print(f" assert_eq!(NumOrd({v1}{t1}) > NumOrd({v2}{t2}), {str(v1 > v2).lower()});")
print(f" assert_eq!(NumOrd({v1}{t1}) >= NumOrd({v2}{t2}), {str(v1 >= v2).lower()});")
if v1 > v2:
ordering = "Greater"
elif v1 < v2:
ordering = "Less"
else:
ordering = "Equal"
print(f" assert_eq!(NumOrd({v1}{t1}).partial_cmp(&NumOrd({v2}{t2})), Some(Ordering::{ordering}));")
print("}\n")