num-ord 0.1.0

Numerically ordered wrapper type for cross-type comparisons
Documentation
import functools
import sys
import numpy as np
import decimal
from decimal import Decimal
decimal.getcontext().prec = 1000

@functools.total_ordering
class NumType:
    def __init__(self, bits, signed, isfloat):
        self.bits = bits
        self.signed = signed
        self.isfloat = isfloat

        if self.isfloat:
            assert signed
            self.name = f"f{bits}"
            if bits == 32:
                self.int_min = -2**24
                self.int_max = 2**24
                self.min = -2**127
                self.max = 2**127
            else:
                assert bits == 64
                self.int_min = -2**53
                self.int_max = 2**53
                self.min = -2**1023
                self.max = 2**1023
        else:
            if signed:
                self.name = f"i{bits}"
                self.int_min = -2**(bits-1)
                self.int_max = 2**(bits-1) - 1
                self.min = self.int_min
                self.max = self.int_max
            else:
                self.name = f"u{bits}"
                self.int_min = 0
                self.int_max = 2**bits - 1
                self.min = self.int_min
                self.max = self.int_max

    def can_exactly_represent(self, x):
        if x < self.min or x > self.max:
            return False

        if self.isfloat:
            if self.bits == 32:
                return Decimal(float(np.float32(x))) == x
            else:
                return Decimal(float(np.float64(x))) == x

        else:
            return x == int(x)
        
    def interesting_values(self):
        vals = {self.int_min, self.int_max, self.min, self.max}
        for val in vals.copy():
            vals.add(val + Decimal("0.5"))
            vals.add(val - Decimal("0.5"))
            vals.add(val + 1)
            vals.add(val - 1)
        return vals

    def is_subset_eq(self, other):
        return other.int_min <= self.int_min <= self.int_max <= other.int_max

    def __lt__(self, other):
        return (self.isfloat, self.bits, self.signed) < (other.isfloat, other.bits, other.signed)

    def __repr__(self):
        return self.name
                

def implies(a, b):
    return not a or b

def common_type(a, b, types):
    return min((t for t in types
                if (a.is_subset_eq(t) and b.is_subset_eq(t))
                and implies(a.isfloat or b.isfloat, t.isfloat)), default=None)


if __name__ == "__main__":
    types = [
        NumType(bits, signed, False)
        for bits in [8, 16, 32, 64, 128]
        for signed in [False, True]]
    types += [NumType(32, True, True), NumType(64, True, True)]

    if sys.argv[1] == "common-types":
        unhandled = []
        for a in types:
            for b in types:
                if a < b:
                    c = common_type(a, b, types)
                    if c is not None:
                        print(f"{str(a):>4}, {str(b):>4} => {str(c):>4};")
                    else:
                        unhandled.append((a, b))


        for a, b in unhandled:
            print("NO COMMON TYPE", a, b)


    elif sys.argv[1] == "all-types":
        for a in types:
            for b in types:
                print(f"{a}, {b};")

    elif sys.argv[1] == "tests":
        print("""// Automatically generated by tools/gen.py.
use core::cmp::Ordering;
use num_ord::NumOrd;

""")
        for t1 in types:
            for t2 in types:
                if t1 == t2:
                    continue
                print(f"#[test] fn test_{t1}_{t2}() {{")
                interesting_values = sorted(t1.interesting_values() | t2.interesting_values())
                for v1 in interesting_values:
                    for v2 in interesting_values:
                        if t1.can_exactly_represent(v1) and t2.can_exactly_represent(v2):
                            print(f"    assert_eq!(NumOrd({v1}{t1}) < NumOrd({v2}{t2}), {str(v1 < v2).lower()});")
                            print(f"    assert_eq!(NumOrd({v1}{t1}) <= NumOrd({v2}{t2}), {str(v1 <= v2).lower()});")
                            print(f"    assert_eq!(NumOrd({v1}{t1}) > NumOrd({v2}{t2}), {str(v1 > v2).lower()});")
                            print(f"    assert_eq!(NumOrd({v1}{t1}) >= NumOrd({v2}{t2}), {str(v1 >= v2).lower()});")
                            if v1 > v2:
                                ordering = "Greater"
                            elif v1 < v2:
                                ordering = "Less"
                            else:
                                ordering = "Equal"
                            print(f"    assert_eq!(NumOrd({v1}{t1}).partial_cmp(&NumOrd({v2}{t2})), Some(Ordering::{ordering}));")
                print("}\n")