use serde::{Deserialize, Serialize};
const TRIT_BASE: u32 = 3;
const PHASE_POS: u32 = 0; const EXP_POS: u32 = 1; const MANT_POS: u32 = 6; const CONF_POS: u32 = 12;
const EXP_TRITS: u32 = 5;
const MANT_TRITS: u32 = 6;
const CONF_TRITS: u32 = 2;
const EXP_MAX: i32 = 121; const MANT_MAX: u32 = 728; const CONF_MAX: i32 = 4;
const MANT_DIV: f32 = 364.5;
const TOTAL_TRITS: u32 = 14;
const MAX_RAW: u32 = 4_782_968;
#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TritFloat(u32);
#[inline]
fn get_digit(raw: u32, pos: u32) -> u32 {
let divisor = TRIT_BASE.pow(pos);
(raw / divisor) % TRIT_BASE
}
#[inline]
fn set_digit(raw: u32, pos: u32, digit: u32) -> u32 {
debug_assert!(digit < 3, "digit must be in {{0,1,2}}");
let place = TRIT_BASE.pow(pos);
let cleared = raw - (raw / place % TRIT_BASE) * place;
cleared + digit * place
}
#[inline]
fn balanced_to_digit(t: i8) -> u32 {
(t + 1) as u32
}
#[inline]
fn digit_to_balanced(d: u32) -> i8 {
d as i8 - 1
}
fn decode_balanced_int(raw: u32, start_pos: u32, n_trits: u32) -> i32 {
let mut value = 0i32;
let mut place = 1i32;
for i in 0..n_trits {
let digit = get_digit(raw, start_pos + i);
let trit = digit_to_balanced(digit) as i32;
value += trit * place;
place *= 3;
}
value
}
fn encode_balanced_int(mut raw: u32, start_pos: u32, n_trits: u32, mut value: i32) -> u32 {
value = value.clamp(-((TRIT_BASE.pow(n_trits) as i32 - 1) / 2),
(TRIT_BASE.pow(n_trits) as i32 - 1) / 2);
let mut remaining = value;
for i in 0..n_trits {
let low = (remaining % 3 + 3) % 3; let trit = if low <= 1 { low as i8 } else { (low as i8) - 3 }; let digit = balanced_to_digit(trit);
raw = set_digit(raw, start_pos + i, digit);
remaining -= trit as i32;
remaining /= 3;
}
raw
}
fn log3_floor(x: f32) -> i32 {
if x <= 0.0 { return 0; }
(x.ln() / 3f32.ln()).floor() as i32
}
impl TritFloat {
pub fn zero() -> Self {
let mut raw = 0u32;
for i in 0..TOTAL_TRITS {
raw = set_digit(raw, i, 1); }
Self(raw)
}
pub fn from_f32(x: f32) -> Self {
Self::from_f32_with_confidence(x, 1.0)
}
pub fn from_f32_with_confidence(x: f32, confidence: f32) -> Self {
let mut raw = Self::zero().0;
if x == 0.0 || x.is_nan() {
raw = set_digit(raw, PHASE_POS, 1);
raw = Self::encode_confidence_into(raw, confidence);
return Self(raw);
}
let phase: i8 = if x > 0.0 { 1 } else { -1 };
raw = set_digit(raw, PHASE_POS, balanced_to_digit(phase));
let x_abs = x.abs();
let exp = log3_floor(x_abs).clamp(-EXP_MAX, EXP_MAX);
raw = encode_balanced_int(raw, EXP_POS, EXP_TRITS, exp);
let scale = (3f32).powi(exp);
let mantissa_f = (x_abs / scale - 1.0).clamp(0.0, 1.9999);
let m = (mantissa_f * MANT_DIV).round().clamp(0.0, MANT_MAX as f32) as u32;
let mut m_remaining = m;
for i in 0..MANT_TRITS {
let digit = m_remaining % 3;
raw = set_digit(raw, MANT_POS + i, digit);
m_remaining /= 3;
}
raw = Self::encode_confidence_into(raw, confidence);
Self(raw)
}
fn encode_confidence_into(raw: u32, confidence: f32) -> u32 {
let c_int = (confidence.clamp(0.0, 1.0) * (CONF_MAX * 2) as f32).round() as i32;
let c_int_shifted = c_int - CONF_MAX; encode_balanced_int(raw, CONF_POS, CONF_TRITS, c_int_shifted)
}
pub fn to_f32(self) -> f32 {
let phase = digit_to_balanced(get_digit(self.0, PHASE_POS));
if phase == 0 {
return 0.0;
}
let exp = decode_balanced_int(self.0, EXP_POS, EXP_TRITS);
let mut m = 0u32;
let mut place = 1u32;
for i in 0..MANT_TRITS {
m += get_digit(self.0, MANT_POS + i) * place;
place *= 3;
}
let mantissa_f = m as f32 / MANT_DIV;
let scale = (3f32).powi(exp);
(phase as f32) * scale * (1.0 + mantissa_f)
}
pub fn phase(self) -> i8 {
digit_to_balanced(get_digit(self.0, PHASE_POS))
}
pub fn exponent(self) -> i32 {
decode_balanced_int(self.0, EXP_POS, EXP_TRITS)
}
pub fn mantissa(self) -> u32 {
let mut m = 0u32;
let mut place = 1u32;
for i in 0..MANT_TRITS {
m += get_digit(self.0, MANT_POS + i) * place;
place *= 3;
}
m
}
pub fn confidence(self) -> f32 {
let c_balanced = decode_balanced_int(self.0, CONF_POS, CONF_TRITS);
(c_balanced + CONF_MAX) as f32 / (CONF_MAX * 2) as f32
}
pub fn is_zero(self) -> bool {
digit_to_balanced(get_digit(self.0, PHASE_POS)) == 0
}
pub fn is_uncertain(self) -> bool {
self.confidence() < 0.5
}
pub fn raw(self) -> u32 {
self.0
}
pub fn from_raw(raw: u32) -> Self {
debug_assert!(raw <= MAX_RAW, "raw value exceeds 14-trit maximum");
Self(raw.min(MAX_RAW))
}
pub fn mul_confidence(a: Self, b: Self) -> f32 {
a.confidence().min(b.confidence())
}
pub fn add_confidence(a: Self, b: Self) -> f32 {
(a.confidence() + b.confidence()) * 0.5
}
pub fn neg(self) -> Self {
let new_phase = -self.phase();
let new_digit = balanced_to_digit(new_phase);
let raw = set_digit(self.0, PHASE_POS, new_digit);
Self(raw)
}
pub fn abs(self) -> Self {
if self.is_zero() { return self; }
let raw = set_digit(self.0, PHASE_POS, balanced_to_digit(1));
Self(raw)
}
pub fn add(self, rhs: Self) -> Self {
let value = self.to_f32() + rhs.to_f32();
let conf = Self::add_confidence(self, rhs);
Self::from_f32_with_confidence(value, conf)
}
pub fn sub(self, rhs: Self) -> Self {
self.add(rhs.neg())
}
pub fn mul(self, rhs: Self) -> Self {
if self.is_zero() || rhs.is_zero() {
let conf = Self::mul_confidence(self, rhs);
return Self::from_f32_with_confidence(0.0, conf);
}
let value = self.to_f32() * rhs.to_f32();
let conf = Self::mul_confidence(self, rhs);
Self::from_f32_with_confidence(value, conf)
}
pub fn dot(a: &[Self], b: &[Self]) -> Self {
assert_eq!(a.len(), b.len(), "dot product requires equal-length slices");
let mut acc_value = 0.0f32;
let mut min_conf = 1.0f32;
let mut skipped = 0usize;
for (&ai, &bi) in a.iter().zip(b.iter()) {
if ai.is_zero() || bi.is_zero() {
let term_conf = Self::mul_confidence(ai, bi);
min_conf = min_conf.min(term_conf);
skipped += 1;
continue;
}
acc_value += ai.to_f32() * bi.to_f32();
min_conf = min_conf.min(Self::mul_confidence(ai, bi));
}
let _ = skipped;
Self::from_f32_with_confidence(acc_value, min_conf)
}
pub fn dot_with_skips(a: &[Self], b: &[Self]) -> (Self, usize) {
assert_eq!(a.len(), b.len(), "dot product requires equal-length slices");
let mut acc_value = 0.0f32;
let mut min_conf = 1.0f32;
let mut skipped = 0usize;
for (&ai, &bi) in a.iter().zip(b.iter()) {
if ai.is_zero() || bi.is_zero() {
let term_conf = Self::mul_confidence(ai, bi);
min_conf = min_conf.min(term_conf);
skipped += 1;
continue;
}
acc_value += ai.to_f32() * bi.to_f32();
min_conf = min_conf.min(Self::mul_confidence(ai, bi));
}
(Self::from_f32_with_confidence(acc_value, min_conf), skipped)
}
pub fn should_route(self, threshold: f32) -> bool {
!self.is_zero() && self.confidence() >= threshold
}
pub fn div(self, rhs: Self) -> Self {
if rhs.is_zero() {
return Self::from_f32_with_confidence(0.0, 0.0);
}
let conf = Self::mul_confidence(self, rhs);
Self::from_f32_with_confidence(self.to_f32() / rhs.to_f32(), conf)
}
pub fn recip(self) -> Self {
if self.is_zero() {
return Self::from_f32_with_confidence(0.0, 0.0);
}
Self::from_f32_with_confidence(1.0 / self.to_f32(), self.confidence())
}
pub fn powi(self, n: i32) -> Self {
Self::from_f32_with_confidence(self.to_f32().powi(n), self.confidence())
}
pub fn sqrt(self) -> Self {
if self.is_zero() { return self; }
if self.phase() < 0 {
return Self::from_f32_with_confidence(0.0, 0.0);
}
Self::from_f32_with_confidence(self.to_f32().sqrt(), self.confidence())
}
pub fn clamp(self, lo: f32, hi: f32) -> Self {
Self::from_f32_with_confidence(self.to_f32().clamp(lo, hi), self.confidence())
}
pub fn cmp_trit(self, rhs: Self) -> Self {
let (va, vb) = (self.to_f32(), rhs.to_f32());
let r = if va > vb { 1.0f32 } else if va < vb { -1.0 } else { 0.0 };
Self::from_f32_with_confidence(r, Self::mul_confidence(self, rhs))
}
pub fn softmax(slice: &[Self]) -> Vec<Self> {
if slice.is_empty() { return vec![]; }
let vals: Vec<f32> = slice.iter().map(|x| x.to_f32()).collect();
let max_v = vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = vals.iter().map(|&v| (v - max_v).exp()).collect();
let sum: f32 = exps.iter().sum::<f32>().max(f32::EPSILON);
let min_c = slice.iter().map(|x| x.confidence()).fold(1.0f32, f32::min);
exps.iter()
.map(|&e| Self::from_f32_with_confidence(e / sum, min_c))
.collect()
}
#[inline]
pub fn phase_digits(slice: &[Self]) -> Vec<u8> {
slice.iter().map(|x| (x.0 % 3) as u8).collect()
}
pub fn pack_phases_u64(slice: &[Self]) -> u64 {
debug_assert!(slice.len() <= 64, "pack_phases_u64: slice too long (max 64)");
let mut mask = 0u64;
for (i, x) in slice.iter().take(64).enumerate() {
if x.0 % 3 == 1 {
mask |= 1u64 << i;
}
}
mask
}
pub fn dot_prescan(a: &[Self], b: &[Self]) -> (Self, usize) {
assert_eq!(a.len(), b.len(), "dot_prescan requires equal-length slices");
let pa = Self::phase_digits(a);
let pb = Self::phase_digits(b);
let mut acc = 0.0f32;
let mut min_conf = 1.0f32;
let mut skipped = 0usize;
for i in 0..a.len() {
let c = Self::mul_confidence(a[i], b[i]);
if c < min_conf { min_conf = c; }
if pa[i] == 1 || pb[i] == 1 {
skipped += 1;
} else {
acc += a[i].to_f32() * b[i].to_f32();
}
}
(Self::from_f32_with_confidence(acc, min_conf), skipped)
}
}
impl std::fmt::Debug for TritFloat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "TritFloat({:.6} conf={:.2} exp={} mant={})",
self.to_f32(),
self.confidence(),
self.exponent(),
self.mantissa(),
)
}
}
impl std::fmt::Display for TritFloat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:.6}±{:.0}%", self.to_f32(), self.confidence() * 100.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f32 = 0.01;
fn approx(a: f32, b: f32, tol: f32) -> bool {
if b == 0.0 { return a.abs() < tol; }
((a - b) / b).abs() < tol
}
#[test]
fn zero_roundtrip() {
let z = TritFloat::from_f32(0.0);
assert!(z.is_zero());
assert_eq!(z.to_f32(), 0.0);
assert_eq!(z.phase(), 0);
}
#[test]
fn positive_roundtrip() {
for &x in &[0.001f32, 0.1, 0.5, 1.0, 3.0, 9.0, 100.0, 12345.678, 1e10, 1e-10] {
let tf = TritFloat::from_f32(x);
let back = tf.to_f32();
assert!(approx(back, x, TOL),
"roundtrip failed for x={}: got {} ({})", x, back, tf);
assert_eq!(tf.phase(), 1);
}
}
#[test]
fn negative_roundtrip() {
for &x in &[-0.5f32, -1.0, -3.14, -999.9] {
let tf = TritFloat::from_f32(x);
let back = tf.to_f32();
assert!(approx(back.abs(), x.abs(), TOL),
"negative roundtrip failed for x={}: got {}", x, back);
assert_eq!(tf.phase(), -1);
}
}
#[test]
fn confidence_from_f32_is_max() {
let tf = TritFloat::from_f32(1.0);
assert!((tf.confidence() - 1.0).abs() < 0.15,
"from_f32 should give near-max confidence, got {}", tf.confidence());
}
#[test]
fn confidence_custom() {
let tf = TritFloat::from_f32_with_confidence(1.0, 0.0);
assert!(tf.confidence() < 0.2, "expected low confidence, got {}", tf.confidence());
let tf = TritFloat::from_f32_with_confidence(1.0, 0.5);
assert!((tf.confidence() - 0.5).abs() < 0.2, "expected mid confidence, got {}", tf.confidence());
}
#[test]
fn zero_confidence_neutral() {
let z = TritFloat::zero();
assert!(z.is_zero());
assert!((z.confidence() - 0.5).abs() < 0.2, "zero should have neutral confidence");
}
#[test]
fn neg_flips_phase() {
let pos = TritFloat::from_f32(2.5);
let neg = pos.neg();
assert_eq!(pos.phase(), 1);
assert_eq!(neg.phase(), -1);
assert!(approx(pos.to_f32(), -neg.to_f32(), TOL));
assert!((pos.confidence() - neg.confidence()).abs() < 0.15);
}
#[test]
fn abs_always_positive() {
let neg = TritFloat::from_f32(-7.0);
let a = neg.abs();
assert_eq!(a.phase(), 1);
assert!(a.to_f32() > 0.0);
}
#[test]
fn mul_confidence_weakest_link() {
let certain = TritFloat::from_f32_with_confidence(2.0, 1.0);
let uncertain = TritFloat::from_f32_with_confidence(3.0, 0.0);
let product = certain.mul(uncertain);
assert!(product.confidence() < 0.2,
"mul confidence should be dominated by uncertain operand");
}
#[test]
fn mul_zero_propagates_uncertainty() {
let zero = TritFloat::from_f32_with_confidence(0.0, 0.0);
let certain = TritFloat::from_f32_with_confidence(5.0, 1.0);
let product = certain.mul(zero);
assert!(product.is_zero());
assert!(product.confidence() < 0.2);
}
#[test]
fn add_confidence_averages() {
let a = TritFloat::from_f32_with_confidence(1.0, 1.0);
let b = TritFloat::from_f32_with_confidence(1.0, 0.0);
let sum = a.add(b);
assert!((sum.confidence() - 0.5).abs() < 0.2,
"add confidence should average, got {}", sum.confidence());
}
#[test]
fn add_value_correct() {
let a = TritFloat::from_f32(1.5);
let b = TritFloat::from_f32(2.5);
let sum = a.add(b);
assert!(approx(sum.to_f32(), 4.0, TOL), "1.5 + 2.5 should ≈ 4.0, got {}", sum.to_f32());
}
#[test]
fn mul_value_correct() {
let a = TritFloat::from_f32(3.0);
let b = TritFloat::from_f32(4.0);
let p = a.mul(b);
assert!(approx(p.to_f32(), 12.0, 0.02), "3 × 4 should ≈ 12, got {}", p.to_f32());
}
#[test]
fn dot_basic() {
let a: Vec<TritFloat> = [1.0f32, 2.0, 3.0].iter().map(|&x| TritFloat::from_f32(x)).collect();
let b: Vec<TritFloat> = [4.0f32, 5.0, 6.0].iter().map(|&x| TritFloat::from_f32(x)).collect();
let result = TritFloat::dot(&a, &b);
assert!(approx(result.to_f32(), 32.0, 0.02),
"dot([1,2,3],[4,5,6]) should ≈ 32, got {}", result.to_f32());
}
#[test]
fn dot_skips_zeros() {
let a: Vec<TritFloat> = vec![
TritFloat::from_f32(0.0),
TritFloat::from_f32(2.0),
TritFloat::from_f32(0.0),
];
let b: Vec<TritFloat> = vec![
TritFloat::from_f32(1.0),
TritFloat::from_f32(3.0),
TritFloat::from_f32(1.0),
];
let (result, skips) = TritFloat::dot_with_skips(&a, &b);
assert_eq!(skips, 2, "two zero phases should produce 2 skips");
assert!(approx(result.to_f32(), 6.0, 0.02),
"0*1 + 2*3 + 0*1 = 6, got {}", result.to_f32());
}
#[test]
fn should_route_confidence_gate() {
let certain = TritFloat::from_f32_with_confidence(1.0, 0.9);
let uncertain = TritFloat::from_f32_with_confidence(1.0, 0.1);
let zero = TritFloat::from_f32(0.0);
assert!(certain.should_route(0.5), "certain should route");
assert!(!uncertain.should_route(0.5), "uncertain should not route");
assert!(!zero.should_route(0.0), "zero phase never routes");
}
#[test]
fn raw_roundtrip() {
let tf = TritFloat::from_f32(42.0);
let raw = tf.raw();
let restored = TritFloat::from_raw(raw);
assert_eq!(tf, restored);
}
#[test]
fn display_shows_confidence() {
let tf = TritFloat::from_f32(3.14);
let s = format!("{tf}");
assert!(s.contains('%'), "display should show confidence %: got '{}'", s);
}
#[test]
fn exponent_range_covered() {
let large = TritFloat::from_f32(1e30f32);
let small = TritFloat::from_f32(1e-30f32);
assert!(large.exponent().abs() <= EXP_MAX as i32);
assert!(small.exponent().abs() <= EXP_MAX as i32);
assert!(approx(large.to_f32(), 1e30, 0.05));
assert!(approx(small.to_f32(), 1e-30, 0.05));
}
#[test]
fn div_basic() {
let a = TritFloat::from_f32(6.0);
let b = TritFloat::from_f32(2.0);
let r = a.div(b);
assert!(approx(r.to_f32(), 3.0, TOL), "6/2 should be 3, got {}", r.to_f32());
}
#[test]
fn div_by_zero_returns_zero_confidence() {
let a = TritFloat::from_f32(5.0);
let z = TritFloat::from_f32(0.0);
let r = a.div(z);
assert!(r.is_zero());
assert!(r.confidence() < 0.15, "div-by-zero should have 0 confidence");
}
#[test]
fn recip_basic() {
let r = TritFloat::from_f32(4.0).recip();
assert!(approx(r.to_f32(), 0.25, TOL), "recip(4) should be 0.25, got {}", r.to_f32());
}
#[test]
fn recip_zero_returns_zero_confidence() {
let r = TritFloat::zero().recip();
assert!(r.is_zero());
assert!(r.confidence() < 0.15);
}
#[test]
fn powi_basic() {
let r = TritFloat::from_f32(2.0).powi(3);
assert!(approx(r.to_f32(), 8.0, TOL), "2^3 should be 8, got {}", r.to_f32());
}
#[test]
fn powi_confidence_preserved() {
let a = TritFloat::from_f32_with_confidence(2.0, 0.75);
let r = a.powi(2);
assert!((r.confidence() - 0.75).abs() < 0.15);
}
#[test]
fn sqrt_basic() {
let r = TritFloat::from_f32(9.0).sqrt();
assert!(approx(r.to_f32(), 3.0, TOL), "sqrt(9) should be 3, got {}", r.to_f32());
}
#[test]
fn sqrt_negative_returns_zero_confidence() {
let r = TritFloat::from_f32(-4.0).sqrt();
assert!(r.is_zero());
assert!(r.confidence() < 0.15, "sqrt of negative should have 0 confidence");
}
#[test]
fn clamp_caps_value() {
let hi = TritFloat::from_f32(5.0).clamp(0.0, 3.0);
assert!(approx(hi.to_f32(), 3.0, TOL), "clamp(5, 0, 3) should be 3, got {}", hi.to_f32());
let lo = TritFloat::from_f32(-2.0).clamp(0.0, 3.0);
assert!(approx(lo.to_f32(), 0.0, 0.01), "clamp(-2, 0, 3) should be 0");
}
#[test]
fn clamp_preserves_confidence() {
let a = TritFloat::from_f32_with_confidence(10.0, 0.625);
let r = a.clamp(0.0, 1.0);
assert!((r.confidence() - 0.625).abs() < 0.15);
}
#[test]
fn cmp_trit_ordering() {
let big = TritFloat::from_f32(3.0);
let small = TritFloat::from_f32(2.0);
assert_eq!(big.cmp_trit(small).phase(), 1, "3 > 2 should give +1");
assert_eq!(small.cmp_trit(big).phase(), -1, "2 < 3 should give -1");
assert_eq!(big.cmp_trit(big).phase(), 0, "x == x should give 0");
}
#[test]
fn cmp_trit_confidence_is_min() {
let a = TritFloat::from_f32_with_confidence(3.0, 1.0);
let b = TritFloat::from_f32_with_confidence(2.0, 0.125);
let r = a.cmp_trit(b);
assert!(r.confidence() < 0.2, "cmp confidence should be min of inputs");
}
#[test]
fn softmax_sums_to_one() {
let vals: Vec<TritFloat> = [1.0f32, 2.0, 3.0, 0.5]
.iter().map(|&x| TritFloat::from_f32(x)).collect();
let sm = TritFloat::softmax(&vals);
let sum: f32 = sm.iter().map(|x| x.to_f32()).sum();
assert!((sum - 1.0).abs() < 1e-4, "softmax should sum to 1.0, got {sum}");
}
#[test]
fn softmax_confidence_is_min_of_inputs() {
let vals = vec![
TritFloat::from_f32_with_confidence(1.0, 1.0),
TritFloat::from_f32_with_confidence(2.0, 0.125),
TritFloat::from_f32_with_confidence(3.0, 1.0),
];
let sm = TritFloat::softmax(&vals);
for s in &sm {
assert!(s.confidence() < 0.2,
"softmax conf should be min of inputs (0.125), got {}", s.confidence());
}
}
#[test]
fn softmax_empty_slice() {
assert_eq!(TritFloat::softmax(&[]).len(), 0);
}
#[test]
fn pack_phases_u64_correctness() {
let vals: Vec<TritFloat> = [1.0f32, 0.0, -1.0, 0.0, 2.0]
.iter().map(|&x| TritFloat::from_f32(x)).collect();
let mask = TritFloat::pack_phases_u64(&vals);
assert_eq!(mask & 1, 0, "index 0 (1.0) should not be zero-phase");
assert_eq!(mask & 2, 2, "index 1 (0.0) should be zero-phase");
assert_eq!(mask & 4, 0, "index 2 (-1.0) should not be zero-phase");
assert_eq!(mask & 8, 8, "index 3 (0.0) should be zero-phase");
assert_eq!(mask & 16, 0, "index 4 (2.0) should not be zero-phase");
assert_eq!(mask.count_ones(), 2);
}
#[test]
fn dot_prescan_matches_dot_with_skips() {
let a: Vec<TritFloat> = [1.0f32, 0.0, 2.0, 0.0, 3.0]
.iter().map(|&x| TritFloat::from_f32(x)).collect();
let b: Vec<TritFloat> = [4.0f32, 5.0, 0.0, 6.0, 7.0]
.iter().map(|&x| TritFloat::from_f32(x)).collect();
let (r1, s1) = TritFloat::dot_with_skips(&a, &b);
let (r2, s2) = TritFloat::dot_prescan(&a, &b);
assert!(approx(r1.to_f32(), r2.to_f32(), 0.001),
"prescan and dot_with_skips should match: {} vs {}", r1.to_f32(), r2.to_f32());
assert_eq!(s1, s2, "skip counts should match: {s1} vs {s2}");
}
#[test]
fn phase_digits_correct() {
let vals: Vec<TritFloat> = [-1.0f32, 0.0, 1.0]
.iter().map(|&x| TritFloat::from_f32(x)).collect();
let pd = TritFloat::phase_digits(&vals);
assert_eq!(pd[0], 0, "neg phase → digit 0");
assert_eq!(pd[1], 1, "zero phase → digit 1");
assert_eq!(pd[2], 2, "pos phase → digit 2");
}
}