1use serde::{Deserialize, Serialize};
37
38const TRIT_BASE: u32 = 3;
41
42const PHASE_POS: u32 = 0; const EXP_POS: u32 = 1; const MANT_POS: u32 = 6; const CONF_POS: u32 = 12; const EXP_TRITS: u32 = 5;
50const MANT_TRITS: u32 = 6;
51const CONF_TRITS: u32 = 2;
52
53const EXP_MAX: i32 = 121; const MANT_MAX: u32 = 728; const CONF_MAX: i32 = 4; const MANT_DIV: f32 = 364.5;
60
61const TOTAL_TRITS: u32 = 14;
63
64const MAX_RAW: u32 = 4_782_968;
66
67#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
77pub struct TritFloat(u32);
78
79#[inline]
83fn get_digit(raw: u32, pos: u32) -> u32 {
84 let divisor = TRIT_BASE.pow(pos);
85 (raw / divisor) % TRIT_BASE
86}
87
88#[inline]
90fn set_digit(raw: u32, pos: u32, digit: u32) -> u32 {
91 debug_assert!(digit < 3, "digit must be in {{0,1,2}}");
92 let place = TRIT_BASE.pow(pos);
93 let cleared = raw - (raw / place % TRIT_BASE) * place;
94 cleared + digit * place
95}
96
97#[inline]
99fn balanced_to_digit(t: i8) -> u32 {
100 (t + 1) as u32
101}
102
103#[inline]
105fn digit_to_balanced(d: u32) -> i8 {
106 d as i8 - 1
107}
108
109fn decode_balanced_int(raw: u32, start_pos: u32, n_trits: u32) -> i32 {
114 let mut value = 0i32;
115 let mut place = 1i32;
116 for i in 0..n_trits {
117 let digit = get_digit(raw, start_pos + i);
118 let trit = digit_to_balanced(digit) as i32;
119 value += trit * place;
120 place *= 3;
121 }
122 value
123}
124
125fn encode_balanced_int(mut raw: u32, start_pos: u32, n_trits: u32, mut value: i32) -> u32 {
127 value = value.clamp(-((TRIT_BASE.pow(n_trits) as i32 - 1) / 2),
128 (TRIT_BASE.pow(n_trits) as i32 - 1) / 2);
129 let mut remaining = value;
131 for i in 0..n_trits {
132 let low = (remaining % 3 + 3) % 3; let trit = if low <= 1 { low as i8 } else { (low as i8) - 3 }; let digit = balanced_to_digit(trit);
136 raw = set_digit(raw, start_pos + i, digit);
137 remaining -= trit as i32;
138 remaining /= 3;
139 }
140 raw
141}
142
143fn log3_floor(x: f32) -> i32 {
145 if x <= 0.0 { return 0; }
146 (x.ln() / 3f32.ln()).floor() as i32
147}
148
149impl TritFloat {
152 pub fn zero() -> Self {
156 let mut raw = 0u32;
162 for i in 0..TOTAL_TRITS {
163 raw = set_digit(raw, i, 1); }
165 Self(raw)
167 }
168
169 pub fn from_f32(x: f32) -> Self {
171 Self::from_f32_with_confidence(x, 1.0)
172 }
173
174 pub fn from_f32_with_confidence(x: f32, confidence: f32) -> Self {
176 let mut raw = Self::zero().0;
177
178 if x == 0.0 || x.is_nan() {
180 raw = set_digit(raw, PHASE_POS, 1);
183 raw = Self::encode_confidence_into(raw, confidence);
184 return Self(raw);
185 }
186
187 let phase: i8 = if x > 0.0 { 1 } else { -1 };
188 raw = set_digit(raw, PHASE_POS, balanced_to_digit(phase));
189
190 let x_abs = x.abs();
191
192 let exp = log3_floor(x_abs).clamp(-EXP_MAX, EXP_MAX);
196 raw = encode_balanced_int(raw, EXP_POS, EXP_TRITS, exp);
197
198 let scale = (3f32).powi(exp);
202 let mantissa_f = (x_abs / scale - 1.0).clamp(0.0, 1.9999);
203 let m = (mantissa_f * MANT_DIV).round().clamp(0.0, MANT_MAX as f32) as u32;
204
205 let mut m_remaining = m;
207 for i in 0..MANT_TRITS {
208 let digit = m_remaining % 3;
209 raw = set_digit(raw, MANT_POS + i, digit);
210 m_remaining /= 3;
211 }
212
213 raw = Self::encode_confidence_into(raw, confidence);
215
216 Self(raw)
217 }
218
219 fn encode_confidence_into(raw: u32, confidence: f32) -> u32 {
221 let c_int = (confidence.clamp(0.0, 1.0) * (CONF_MAX * 2) as f32).round() as i32;
223 let c_int_shifted = c_int - CONF_MAX; encode_balanced_int(raw, CONF_POS, CONF_TRITS, c_int_shifted)
226 }
227
228 pub fn to_f32(self) -> f32 {
232 let phase = digit_to_balanced(get_digit(self.0, PHASE_POS));
233 if phase == 0 {
234 return 0.0;
235 }
236
237 let exp = decode_balanced_int(self.0, EXP_POS, EXP_TRITS);
238
239 let mut m = 0u32;
241 let mut place = 1u32;
242 for i in 0..MANT_TRITS {
243 m += get_digit(self.0, MANT_POS + i) * place;
244 place *= 3;
245 }
246 let mantissa_f = m as f32 / MANT_DIV;
247
248 let scale = (3f32).powi(exp);
249 (phase as f32) * scale * (1.0 + mantissa_f)
250 }
251
252 pub fn phase(self) -> i8 {
254 digit_to_balanced(get_digit(self.0, PHASE_POS))
255 }
256
257 pub fn exponent(self) -> i32 {
259 decode_balanced_int(self.0, EXP_POS, EXP_TRITS)
260 }
261
262 pub fn mantissa(self) -> u32 {
264 let mut m = 0u32;
265 let mut place = 1u32;
266 for i in 0..MANT_TRITS {
267 m += get_digit(self.0, MANT_POS + i) * place;
268 place *= 3;
269 }
270 m
271 }
272
273 pub fn confidence(self) -> f32 {
277 let c_balanced = decode_balanced_int(self.0, CONF_POS, CONF_TRITS);
278 (c_balanced + CONF_MAX) as f32 / (CONF_MAX * 2) as f32
280 }
281
282 pub fn is_zero(self) -> bool {
284 digit_to_balanced(get_digit(self.0, PHASE_POS)) == 0
285 }
286
287 pub fn is_uncertain(self) -> bool {
289 self.confidence() < 0.5
290 }
291
292 pub fn raw(self) -> u32 {
294 self.0
295 }
296
297 pub fn from_raw(raw: u32) -> Self {
299 debug_assert!(raw <= MAX_RAW, "raw value exceeds 14-trit maximum");
300 Self(raw.min(MAX_RAW))
301 }
302
303 pub fn mul_confidence(a: Self, b: Self) -> f32 {
308 a.confidence().min(b.confidence())
309 }
310
311 pub fn add_confidence(a: Self, b: Self) -> f32 {
313 (a.confidence() + b.confidence()) * 0.5
314 }
315
316 pub fn neg(self) -> Self {
323 let new_phase = -self.phase();
324 let new_digit = balanced_to_digit(new_phase);
325 let raw = set_digit(self.0, PHASE_POS, new_digit);
326 Self(raw)
327 }
328
329 pub fn abs(self) -> Self {
331 if self.is_zero() { return self; }
332 let raw = set_digit(self.0, PHASE_POS, balanced_to_digit(1));
333 Self(raw)
334 }
335
336 pub fn add(self, rhs: Self) -> Self {
338 let value = self.to_f32() + rhs.to_f32();
339 let conf = Self::add_confidence(self, rhs);
340 Self::from_f32_with_confidence(value, conf)
341 }
342
343 pub fn sub(self, rhs: Self) -> Self {
345 self.add(rhs.neg())
346 }
347
348 pub fn mul(self, rhs: Self) -> Self {
350 if self.is_zero() || rhs.is_zero() {
354 let conf = Self::mul_confidence(self, rhs);
355 return Self::from_f32_with_confidence(0.0, conf);
356 }
357 let value = self.to_f32() * rhs.to_f32();
358 let conf = Self::mul_confidence(self, rhs);
359 Self::from_f32_with_confidence(value, conf)
360 }
361
362 pub fn dot(a: &[Self], b: &[Self]) -> Self {
367 assert_eq!(a.len(), b.len(), "dot product requires equal-length slices");
368
369 let mut acc_value = 0.0f32;
370 let mut min_conf = 1.0f32;
371 let mut skipped = 0usize;
372
373 for (&ai, &bi) in a.iter().zip(b.iter()) {
374 if ai.is_zero() || bi.is_zero() {
376 let term_conf = Self::mul_confidence(ai, bi);
378 min_conf = min_conf.min(term_conf);
379 skipped += 1;
380 continue;
381 }
382 acc_value += ai.to_f32() * bi.to_f32();
383 min_conf = min_conf.min(Self::mul_confidence(ai, bi));
384 }
385
386 let _ = skipped; Self::from_f32_with_confidence(acc_value, min_conf)
389 }
390
391 pub fn dot_with_skips(a: &[Self], b: &[Self]) -> (Self, usize) {
393 assert_eq!(a.len(), b.len(), "dot product requires equal-length slices");
394
395 let mut acc_value = 0.0f32;
396 let mut min_conf = 1.0f32;
397 let mut skipped = 0usize;
398
399 for (&ai, &bi) in a.iter().zip(b.iter()) {
400 if ai.is_zero() || bi.is_zero() {
401 let term_conf = Self::mul_confidence(ai, bi);
402 min_conf = min_conf.min(term_conf);
403 skipped += 1;
404 continue;
405 }
406 acc_value += ai.to_f32() * bi.to_f32();
407 min_conf = min_conf.min(Self::mul_confidence(ai, bi));
408 }
409
410 (Self::from_f32_with_confidence(acc_value, min_conf), skipped)
411 }
412
413 pub fn should_route(self, threshold: f32) -> bool {
422 !self.is_zero() && self.confidence() >= threshold
423 }
424
425 pub fn div(self, rhs: Self) -> Self {
430 if rhs.is_zero() {
431 return Self::from_f32_with_confidence(0.0, 0.0);
432 }
433 let conf = Self::mul_confidence(self, rhs);
434 Self::from_f32_with_confidence(self.to_f32() / rhs.to_f32(), conf)
435 }
436
437 pub fn recip(self) -> Self {
439 if self.is_zero() {
440 return Self::from_f32_with_confidence(0.0, 0.0);
441 }
442 Self::from_f32_with_confidence(1.0 / self.to_f32(), self.confidence())
443 }
444
445 pub fn powi(self, n: i32) -> Self {
447 Self::from_f32_with_confidence(self.to_f32().powi(n), self.confidence())
448 }
449
450 pub fn sqrt(self) -> Self {
452 if self.is_zero() { return self; }
453 if self.phase() < 0 {
454 return Self::from_f32_with_confidence(0.0, 0.0);
455 }
456 Self::from_f32_with_confidence(self.to_f32().sqrt(), self.confidence())
457 }
458
459 pub fn clamp(self, lo: f32, hi: f32) -> Self {
461 Self::from_f32_with_confidence(self.to_f32().clamp(lo, hi), self.confidence())
462 }
463
464 pub fn cmp_trit(self, rhs: Self) -> Self {
467 let (va, vb) = (self.to_f32(), rhs.to_f32());
468 let r = if va > vb { 1.0f32 } else if va < vb { -1.0 } else { 0.0 };
469 Self::from_f32_with_confidence(r, Self::mul_confidence(self, rhs))
470 }
471
472 pub fn softmax(slice: &[Self]) -> Vec<Self> {
480 if slice.is_empty() { return vec![]; }
481 let vals: Vec<f32> = slice.iter().map(|x| x.to_f32()).collect();
482 let max_v = vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
483 let exps: Vec<f32> = vals.iter().map(|&v| (v - max_v).exp()).collect();
484 let sum: f32 = exps.iter().sum::<f32>().max(f32::EPSILON);
485 let min_c = slice.iter().map(|x| x.confidence()).fold(1.0f32, f32::min);
486 exps.iter()
487 .map(|&e| Self::from_f32_with_confidence(e / sum, min_c))
488 .collect()
489 }
490
491 #[inline]
499 pub fn phase_digits(slice: &[Self]) -> Vec<u8> {
500 slice.iter().map(|x| (x.0 % 3) as u8).collect()
501 }
502
503 pub fn pack_phases_u64(slice: &[Self]) -> u64 {
510 debug_assert!(slice.len() <= 64, "pack_phases_u64: slice too long (max 64)");
511 let mut mask = 0u64;
512 for (i, x) in slice.iter().take(64).enumerate() {
513 if x.0 % 3 == 1 {
514 mask |= 1u64 << i;
515 }
516 }
517 mask
518 }
519
520 pub fn dot_prescan(a: &[Self], b: &[Self]) -> (Self, usize) {
529 assert_eq!(a.len(), b.len(), "dot_prescan requires equal-length slices");
530 let pa = Self::phase_digits(a);
531 let pb = Self::phase_digits(b);
532
533 let mut acc = 0.0f32;
534 let mut min_conf = 1.0f32;
535 let mut skipped = 0usize;
536
537 for i in 0..a.len() {
538 let c = Self::mul_confidence(a[i], b[i]);
539 if c < min_conf { min_conf = c; }
540 if pa[i] == 1 || pb[i] == 1 {
541 skipped += 1;
542 } else {
543 acc += a[i].to_f32() * b[i].to_f32();
544 }
545 }
546
547 (Self::from_f32_with_confidence(acc, min_conf), skipped)
548 }
549}
550
551impl std::fmt::Debug for TritFloat {
554 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 write!(f, "TritFloat({:.6} conf={:.2} exp={} mant={})",
556 self.to_f32(),
557 self.confidence(),
558 self.exponent(),
559 self.mantissa(),
560 )
561 }
562}
563
564impl std::fmt::Display for TritFloat {
565 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566 write!(f, "{:.6}±{:.0}%", self.to_f32(), self.confidence() * 100.0)
567 }
568}
569
570#[cfg(test)]
573mod tests {
574 use super::*;
575
576 const TOL: f32 = 0.01; fn approx(a: f32, b: f32, tol: f32) -> bool {
579 if b == 0.0 { return a.abs() < tol; }
580 ((a - b) / b).abs() < tol
581 }
582
583 #[test]
584 fn zero_roundtrip() {
585 let z = TritFloat::from_f32(0.0);
586 assert!(z.is_zero());
587 assert_eq!(z.to_f32(), 0.0);
588 assert_eq!(z.phase(), 0);
589 }
590
591 #[test]
592 fn positive_roundtrip() {
593 for &x in &[0.001f32, 0.1, 0.5, 1.0, 3.0, 9.0, 100.0, 12345.678, 1e10, 1e-10] {
594 let tf = TritFloat::from_f32(x);
595 let back = tf.to_f32();
596 assert!(approx(back, x, TOL),
597 "roundtrip failed for x={}: got {} ({})", x, back, tf);
598 assert_eq!(tf.phase(), 1);
599 }
600 }
601
602 #[test]
603 fn negative_roundtrip() {
604 for &x in &[-0.5f32, -1.0, -3.14, -999.9] {
605 let tf = TritFloat::from_f32(x);
606 let back = tf.to_f32();
607 assert!(approx(back.abs(), x.abs(), TOL),
608 "negative roundtrip failed for x={}: got {}", x, back);
609 assert_eq!(tf.phase(), -1);
610 }
611 }
612
613 #[test]
614 fn confidence_from_f32_is_max() {
615 let tf = TritFloat::from_f32(1.0);
616 assert!((tf.confidence() - 1.0).abs() < 0.15,
617 "from_f32 should give near-max confidence, got {}", tf.confidence());
618 }
619
620 #[test]
621 fn confidence_custom() {
622 let tf = TritFloat::from_f32_with_confidence(1.0, 0.0);
623 assert!(tf.confidence() < 0.2, "expected low confidence, got {}", tf.confidence());
624
625 let tf = TritFloat::from_f32_with_confidence(1.0, 0.5);
626 assert!((tf.confidence() - 0.5).abs() < 0.2, "expected mid confidence, got {}", tf.confidence());
627 }
628
629 #[test]
630 fn zero_confidence_neutral() {
631 let z = TritFloat::zero();
632 assert!(z.is_zero());
633 assert!((z.confidence() - 0.5).abs() < 0.2, "zero should have neutral confidence");
634 }
635
636 #[test]
637 fn neg_flips_phase() {
638 let pos = TritFloat::from_f32(2.5);
639 let neg = pos.neg();
640 assert_eq!(pos.phase(), 1);
641 assert_eq!(neg.phase(), -1);
642 assert!(approx(pos.to_f32(), -neg.to_f32(), TOL));
643 assert!((pos.confidence() - neg.confidence()).abs() < 0.15);
645 }
646
647 #[test]
648 fn abs_always_positive() {
649 let neg = TritFloat::from_f32(-7.0);
650 let a = neg.abs();
651 assert_eq!(a.phase(), 1);
652 assert!(a.to_f32() > 0.0);
653 }
654
655 #[test]
656 fn mul_confidence_weakest_link() {
657 let certain = TritFloat::from_f32_with_confidence(2.0, 1.0);
658 let uncertain = TritFloat::from_f32_with_confidence(3.0, 0.0);
659 let product = certain.mul(uncertain);
660 assert!(product.confidence() < 0.2,
661 "mul confidence should be dominated by uncertain operand");
662 }
663
664 #[test]
665 fn mul_zero_propagates_uncertainty() {
666 let zero = TritFloat::from_f32_with_confidence(0.0, 0.0);
667 let certain = TritFloat::from_f32_with_confidence(5.0, 1.0);
668 let product = certain.mul(zero);
669 assert!(product.is_zero());
670 assert!(product.confidence() < 0.2);
672 }
673
674 #[test]
675 fn add_confidence_averages() {
676 let a = TritFloat::from_f32_with_confidence(1.0, 1.0);
677 let b = TritFloat::from_f32_with_confidence(1.0, 0.0);
678 let sum = a.add(b);
679 assert!((sum.confidence() - 0.5).abs() < 0.2,
680 "add confidence should average, got {}", sum.confidence());
681 }
682
683 #[test]
684 fn add_value_correct() {
685 let a = TritFloat::from_f32(1.5);
686 let b = TritFloat::from_f32(2.5);
687 let sum = a.add(b);
688 assert!(approx(sum.to_f32(), 4.0, TOL), "1.5 + 2.5 should ≈ 4.0, got {}", sum.to_f32());
689 }
690
691 #[test]
692 fn mul_value_correct() {
693 let a = TritFloat::from_f32(3.0);
694 let b = TritFloat::from_f32(4.0);
695 let p = a.mul(b);
696 assert!(approx(p.to_f32(), 12.0, 0.02), "3 × 4 should ≈ 12, got {}", p.to_f32());
697 }
698
699 #[test]
700 fn dot_basic() {
701 let a: Vec<TritFloat> = [1.0f32, 2.0, 3.0].iter().map(|&x| TritFloat::from_f32(x)).collect();
702 let b: Vec<TritFloat> = [4.0f32, 5.0, 6.0].iter().map(|&x| TritFloat::from_f32(x)).collect();
703 let result = TritFloat::dot(&a, &b);
705 assert!(approx(result.to_f32(), 32.0, 0.02),
706 "dot([1,2,3],[4,5,6]) should ≈ 32, got {}", result.to_f32());
707 }
708
709 #[test]
710 fn dot_skips_zeros() {
711 let a: Vec<TritFloat> = vec![
713 TritFloat::from_f32(0.0),
714 TritFloat::from_f32(2.0),
715 TritFloat::from_f32(0.0),
716 ];
717 let b: Vec<TritFloat> = vec![
718 TritFloat::from_f32(1.0),
719 TritFloat::from_f32(3.0),
720 TritFloat::from_f32(1.0),
721 ];
722 let (result, skips) = TritFloat::dot_with_skips(&a, &b);
723 assert_eq!(skips, 2, "two zero phases should produce 2 skips");
724 assert!(approx(result.to_f32(), 6.0, 0.02),
725 "0*1 + 2*3 + 0*1 = 6, got {}", result.to_f32());
726 }
727
728 #[test]
729 fn should_route_confidence_gate() {
730 let certain = TritFloat::from_f32_with_confidence(1.0, 0.9);
731 let uncertain = TritFloat::from_f32_with_confidence(1.0, 0.1);
732 let zero = TritFloat::from_f32(0.0);
733
734 assert!(certain.should_route(0.5), "certain should route");
735 assert!(!uncertain.should_route(0.5), "uncertain should not route");
736 assert!(!zero.should_route(0.0), "zero phase never routes");
737 }
738
739 #[test]
740 fn raw_roundtrip() {
741 let tf = TritFloat::from_f32(42.0);
742 let raw = tf.raw();
743 let restored = TritFloat::from_raw(raw);
744 assert_eq!(tf, restored);
745 }
746
747 #[test]
748 fn display_shows_confidence() {
749 let tf = TritFloat::from_f32(3.14);
750 let s = format!("{tf}");
751 assert!(s.contains('%'), "display should show confidence %: got '{}'", s);
752 }
753
754 #[test]
755 fn exponent_range_covered() {
756 let large = TritFloat::from_f32(1e30f32);
757 let small = TritFloat::from_f32(1e-30f32);
758 assert!(large.exponent().abs() <= EXP_MAX as i32);
759 assert!(small.exponent().abs() <= EXP_MAX as i32);
760 assert!(approx(large.to_f32(), 1e30, 0.05));
761 assert!(approx(small.to_f32(), 1e-30, 0.05));
762 }
763
764 #[test]
767 fn div_basic() {
768 let a = TritFloat::from_f32(6.0);
769 let b = TritFloat::from_f32(2.0);
770 let r = a.div(b);
771 assert!(approx(r.to_f32(), 3.0, TOL), "6/2 should be 3, got {}", r.to_f32());
772 }
773
774 #[test]
775 fn div_by_zero_returns_zero_confidence() {
776 let a = TritFloat::from_f32(5.0);
777 let z = TritFloat::from_f32(0.0);
778 let r = a.div(z);
779 assert!(r.is_zero());
780 assert!(r.confidence() < 0.15, "div-by-zero should have 0 confidence");
781 }
782
783 #[test]
784 fn recip_basic() {
785 let r = TritFloat::from_f32(4.0).recip();
786 assert!(approx(r.to_f32(), 0.25, TOL), "recip(4) should be 0.25, got {}", r.to_f32());
787 }
788
789 #[test]
790 fn recip_zero_returns_zero_confidence() {
791 let r = TritFloat::zero().recip();
792 assert!(r.is_zero());
793 assert!(r.confidence() < 0.15);
794 }
795
796 #[test]
797 fn powi_basic() {
798 let r = TritFloat::from_f32(2.0).powi(3);
799 assert!(approx(r.to_f32(), 8.0, TOL), "2^3 should be 8, got {}", r.to_f32());
800 }
801
802 #[test]
803 fn powi_confidence_preserved() {
804 let a = TritFloat::from_f32_with_confidence(2.0, 0.75);
805 let r = a.powi(2);
806 assert!((r.confidence() - 0.75).abs() < 0.15);
807 }
808
809 #[test]
810 fn sqrt_basic() {
811 let r = TritFloat::from_f32(9.0).sqrt();
812 assert!(approx(r.to_f32(), 3.0, TOL), "sqrt(9) should be 3, got {}", r.to_f32());
813 }
814
815 #[test]
816 fn sqrt_negative_returns_zero_confidence() {
817 let r = TritFloat::from_f32(-4.0).sqrt();
818 assert!(r.is_zero());
819 assert!(r.confidence() < 0.15, "sqrt of negative should have 0 confidence");
820 }
821
822 #[test]
823 fn clamp_caps_value() {
824 let hi = TritFloat::from_f32(5.0).clamp(0.0, 3.0);
825 assert!(approx(hi.to_f32(), 3.0, TOL), "clamp(5, 0, 3) should be 3, got {}", hi.to_f32());
826 let lo = TritFloat::from_f32(-2.0).clamp(0.0, 3.0);
827 assert!(approx(lo.to_f32(), 0.0, 0.01), "clamp(-2, 0, 3) should be 0");
828 }
829
830 #[test]
831 fn clamp_preserves_confidence() {
832 let a = TritFloat::from_f32_with_confidence(10.0, 0.625);
833 let r = a.clamp(0.0, 1.0);
834 assert!((r.confidence() - 0.625).abs() < 0.15);
835 }
836
837 #[test]
838 fn cmp_trit_ordering() {
839 let big = TritFloat::from_f32(3.0);
840 let small = TritFloat::from_f32(2.0);
841 assert_eq!(big.cmp_trit(small).phase(), 1, "3 > 2 should give +1");
842 assert_eq!(small.cmp_trit(big).phase(), -1, "2 < 3 should give -1");
843 assert_eq!(big.cmp_trit(big).phase(), 0, "x == x should give 0");
844 }
845
846 #[test]
847 fn cmp_trit_confidence_is_min() {
848 let a = TritFloat::from_f32_with_confidence(3.0, 1.0);
849 let b = TritFloat::from_f32_with_confidence(2.0, 0.125);
850 let r = a.cmp_trit(b);
851 assert!(r.confidence() < 0.2, "cmp confidence should be min of inputs");
852 }
853
854 #[test]
857 fn softmax_sums_to_one() {
858 let vals: Vec<TritFloat> = [1.0f32, 2.0, 3.0, 0.5]
859 .iter().map(|&x| TritFloat::from_f32(x)).collect();
860 let sm = TritFloat::softmax(&vals);
861 let sum: f32 = sm.iter().map(|x| x.to_f32()).sum();
862 assert!((sum - 1.0).abs() < 1e-4, "softmax should sum to 1.0, got {sum}");
863 }
864
865 #[test]
866 fn softmax_confidence_is_min_of_inputs() {
867 let vals = vec![
868 TritFloat::from_f32_with_confidence(1.0, 1.0),
869 TritFloat::from_f32_with_confidence(2.0, 0.125),
870 TritFloat::from_f32_with_confidence(3.0, 1.0),
871 ];
872 let sm = TritFloat::softmax(&vals);
873 for s in &sm {
874 assert!(s.confidence() < 0.2,
875 "softmax conf should be min of inputs (0.125), got {}", s.confidence());
876 }
877 }
878
879 #[test]
880 fn softmax_empty_slice() {
881 assert_eq!(TritFloat::softmax(&[]).len(), 0);
882 }
883
884 #[test]
885 fn pack_phases_u64_correctness() {
886 let vals: Vec<TritFloat> = [1.0f32, 0.0, -1.0, 0.0, 2.0]
887 .iter().map(|&x| TritFloat::from_f32(x)).collect();
888 let mask = TritFloat::pack_phases_u64(&vals);
889 assert_eq!(mask & 1, 0, "index 0 (1.0) should not be zero-phase");
891 assert_eq!(mask & 2, 2, "index 1 (0.0) should be zero-phase");
892 assert_eq!(mask & 4, 0, "index 2 (-1.0) should not be zero-phase");
893 assert_eq!(mask & 8, 8, "index 3 (0.0) should be zero-phase");
894 assert_eq!(mask & 16, 0, "index 4 (2.0) should not be zero-phase");
895 assert_eq!(mask.count_ones(), 2);
896 }
897
898 #[test]
899 fn dot_prescan_matches_dot_with_skips() {
900 let a: Vec<TritFloat> = [1.0f32, 0.0, 2.0, 0.0, 3.0]
901 .iter().map(|&x| TritFloat::from_f32(x)).collect();
902 let b: Vec<TritFloat> = [4.0f32, 5.0, 0.0, 6.0, 7.0]
903 .iter().map(|&x| TritFloat::from_f32(x)).collect();
904
905 let (r1, s1) = TritFloat::dot_with_skips(&a, &b);
906 let (r2, s2) = TritFloat::dot_prescan(&a, &b);
907
908 assert!(approx(r1.to_f32(), r2.to_f32(), 0.001),
909 "prescan and dot_with_skips should match: {} vs {}", r1.to_f32(), r2.to_f32());
910 assert_eq!(s1, s2, "skip counts should match: {s1} vs {s2}");
911 }
912
913 #[test]
914 fn phase_digits_correct() {
915 let vals: Vec<TritFloat> = [-1.0f32, 0.0, 1.0]
916 .iter().map(|&x| TritFloat::from_f32(x)).collect();
917 let pd = TritFloat::phase_digits(&vals);
918 assert_eq!(pd[0], 0, "neg phase → digit 0");
919 assert_eq!(pd[1], 1, "zero phase → digit 1");
920 assert_eq!(pd[2], 2, "pos phase → digit 2");
921 }
922}