Skip to main content

ternlang_ml/
tritfloat.rs

1// SPDX-License-Identifier: LicenseRef-Ternlang-Commercial
2// Copyright (C) 2026 RFI-IRFOS. All rights reserved.
3//
4// TritFloat — a floating-point number built on balanced ternary trits.
5//
6// FORMAT (14 trits, stored as u32 in base-3):
7//
8//   [ phase: 1t ][ exponent: 5t ][ mantissa: 6t ][ confidence: 2t ]
9//      {-,0,+}    bal. ternary    base-3 frac.     bal. ternary
10//                 [-121, +121]    [0, 728]          [-4, +4] (9 states)
11//
12//   Total: 14 trits ≈ 22.2 bits of information capacity
13//   Storage: u32, value = sum(digit_i * 3^i), digits in {0,1,2}
14//
15// VALUE SEMANTICS:
16//   - Phase=0 (digit=1) → zero; exponent and mantissa are irrelevant
17//   - value = phase × 3^exponent × (1 + mantissa/364.5)   [covers [1,3) normalized range]
18//   - Exponent range ±121 covers f32 range comfortably (max f32 ≈ 3^80)
19//
20// CONFIDENCE FIELD:
21//   The 2 confidence trits encode certainty about the value on a 9-state scale.
22//   This is the key innovation: confidence is a first-class field in the number,
23//   not a separate tensor. It propagates through arithmetic automatically.
24//
25//   c_digit = (c1_trit + 1) * 3 + (c0_trit + 1), range [0, 8]
26//   Normalized to [0.0, 1.0] as c_digit / 8.0
27//
28//   0/8 = completely unknown    (both trits -1)
29//   4/8 = neutral / unset       (both trits 0)
30//   8/8 = maximally certain     (both trits +1)
31//
32// CONFIDENCE PROPAGATION RULES:
33//   mul(a, b): c = min(conf_a, conf_b)     — chain weakest link
34//   add(a, b): c = (conf_a + conf_b) / 2  — average the evidence
35
36use serde::{Deserialize, Serialize};
37
38// ─── Constants ────────────────────────────────────────────────────────────────
39
40const TRIT_BASE: u32 = 3;
41
42// Position offsets in the base-3 u32 encoding
43const PHASE_POS: u32    = 0;   // trit 0
44const EXP_POS: u32      = 1;   // trits 1-5
45const MANT_POS: u32     = 6;   // trits 6-11
46const CONF_POS: u32     = 12;  // trits 12-13
47
48// Field widths (number of trits)
49const EXP_TRITS: u32    = 5;
50const MANT_TRITS: u32   = 6;
51const CONF_TRITS: u32   = 2;
52
53// Field maxima
54const EXP_MAX: i32      = 121;   // (3^5 - 1) / 2
55const MANT_MAX: u32     = 728;   // 3^6 - 1
56const CONF_MAX: i32     = 4;     // (3^2 - 1) / 2
57
58// Mantissa divisor: MANT_MAX/2 = 364.5 so (1 + M/MANT_DIV) ∈ [1, 3)
59const MANT_DIV: f32     = 364.5;
60
61// Total number of digits in the encoding
62const TOTAL_TRITS: u32  = 14;
63
64// The maximum u32 value representable: 3^14 - 1 = 4782968
65const MAX_RAW: u32      = 4_782_968;
66
67// ─── Core type ────────────────────────────────────────────────────────────────
68
69/// A floating-point number encoded in balanced ternary with a native confidence field.
70///
71/// The confidence field propagates automatically through arithmetic, giving any
72/// computation a live uncertainty estimate without a separate Bayesian layer.
73///
74/// Use `TritFloat::from_f32` to construct, `.to_f32()` to read the value,
75/// and `.confidence()` to read the certainty in [0, 1].
76#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
77pub struct TritFloat(u32);
78
79// ─── Internal helpers ─────────────────────────────────────────────────────────
80
81/// Extract one base-3 digit at position `pos` from a u32 base-3 encoding.
82#[inline]
83fn get_digit(raw: u32, pos: u32) -> u32 {
84    let divisor = TRIT_BASE.pow(pos);
85    (raw / divisor) % TRIT_BASE
86}
87
88/// Set base-3 digit at position `pos` in a u32 base-3 encoding.
89#[inline]
90fn set_digit(raw: u32, pos: u32, digit: u32) -> u32 {
91    debug_assert!(digit < 3, "digit must be in {{0,1,2}}");
92    let place = TRIT_BASE.pow(pos);
93    let cleared = raw - (raw / place % TRIT_BASE) * place;
94    cleared + digit * place
95}
96
97/// Encode a balanced trit {-1, 0, +1} as a digit {0, 1, 2}.
98#[inline]
99fn balanced_to_digit(t: i8) -> u32 {
100    (t + 1) as u32
101}
102
103/// Decode a digit {0, 1, 2} to a balanced trit {-1, 0, +1}.
104#[inline]
105fn digit_to_balanced(d: u32) -> i8 {
106    d as i8 - 1
107}
108
109/// Decode a multi-trit balanced ternary integer from a packed u32 region.
110/// `start_pos`: position of the least-significant trit in the packed u32.
111/// `n_trits`: number of trits to read.
112/// Returns the balanced integer value.
113fn decode_balanced_int(raw: u32, start_pos: u32, n_trits: u32) -> i32 {
114    let mut value = 0i32;
115    let mut place = 1i32;
116    for i in 0..n_trits {
117        let digit = get_digit(raw, start_pos + i);
118        let trit = digit_to_balanced(digit) as i32;
119        value += trit * place;
120        place *= 3;
121    }
122    value
123}
124
125/// Encode a balanced integer into n_trits starting at start_pos in a u32.
126fn encode_balanced_int(mut raw: u32, start_pos: u32, n_trits: u32, mut value: i32) -> u32 {
127    value = value.clamp(-((TRIT_BASE.pow(n_trits) as i32 - 1) / 2),
128                         (TRIT_BASE.pow(n_trits) as i32 - 1) / 2);
129    // Convert to balanced ternary digits (least significant first)
130    let mut remaining = value;
131    for i in 0..n_trits {
132        // Find the trit that minimises |remaining|: try 0, then +1 or -1
133        let low = (remaining % 3 + 3) % 3; // non-negative remainder mod 3
134        let trit = if low <= 1 { low as i8 } else { (low as i8) - 3 }; // balanced: pick closest
135        let digit = balanced_to_digit(trit);
136        raw = set_digit(raw, start_pos + i, digit);
137        remaining -= trit as i32;
138        remaining /= 3;
139    }
140    raw
141}
142
143/// log base 3 of x, as integer floor. Returns 0 for x <= 0.
144fn log3_floor(x: f32) -> i32 {
145    if x <= 0.0 { return 0; }
146    (x.ln() / 3f32.ln()).floor() as i32
147}
148
149// ─── TritFloat implementation ─────────────────────────────────────────────────
150
151impl TritFloat {
152    // ── Constructors ──────────────────────────────────────────────────────────
153
154    /// The canonical zero, with neutral confidence.
155    pub fn zero() -> Self {
156        // phase digit = 1 (balanced 0), all others = 1 (balanced 0), conf = neutral (digit 1,1)
157        // This gives raw = 0 for all-zero balanced = digit 1 everywhere...
158        // Actually: digit 1 = balanced 0 everywhere = all trits zero
159        // raw = 1*3^0 + 1*3^1 + ... + 1*3^13 = (3^14 - 1)/2 = 2391484
160        // But simpler: just build it procedurally.
161        let mut raw = 0u32;
162        for i in 0..TOTAL_TRITS {
163            raw = set_digit(raw, i, 1); // digit 1 = balanced trit 0
164        }
165        // Set confidence to neutral (both trits 0 = digit 1 each)
166        Self(raw)
167    }
168
169    /// Convert an `f32` to TritFloat with maximum confidence (certainty=1.0).
170    pub fn from_f32(x: f32) -> Self {
171        Self::from_f32_with_confidence(x, 1.0)
172    }
173
174    /// Convert an `f32` to TritFloat with a specified confidence in [0, 1].
175    pub fn from_f32_with_confidence(x: f32, confidence: f32) -> Self {
176        let mut raw = Self::zero().0;
177
178        // ── Phase ─────────────────────────────────────────────────────────────
179        if x == 0.0 || x.is_nan() {
180            // phase = 0 (digit 1) — zero case; exponent/mantissa don't matter
181            // confidence still applies to zero (we know it's zero)
182            raw = set_digit(raw, PHASE_POS, 1);
183            raw = Self::encode_confidence_into(raw, confidence);
184            return Self(raw);
185        }
186
187        let phase: i8 = if x > 0.0 { 1 } else { -1 };
188        raw = set_digit(raw, PHASE_POS, balanced_to_digit(phase));
189
190        let x_abs = x.abs();
191
192        // ── Exponent ──────────────────────────────────────────────────────────
193        // E = floor(log3(x_abs)), clamped to [-121, +121]
194        // After this, x_abs / 3^E ∈ [1, 3)
195        let exp = log3_floor(x_abs).clamp(-EXP_MAX, EXP_MAX);
196        raw = encode_balanced_int(raw, EXP_POS, EXP_TRITS, exp);
197
198        // ── Mantissa ──────────────────────────────────────────────────────────
199        // mantissa_f = x_abs / 3^exp - 1, in [0, 2)
200        // M = round(mantissa_f * 729), clamped to [0, 728]
201        let scale = (3f32).powi(exp);
202        let mantissa_f = (x_abs / scale - 1.0).clamp(0.0, 1.9999);
203        let m = (mantissa_f * MANT_DIV).round().clamp(0.0, MANT_MAX as f32) as u32;
204
205        // Encode as 6 base-3 digits {0,1,2} (unbalanced mantissa — pure magnitude)
206        let mut m_remaining = m;
207        for i in 0..MANT_TRITS {
208            let digit = m_remaining % 3;
209            raw = set_digit(raw, MANT_POS + i, digit);
210            m_remaining /= 3;
211        }
212
213        // ── Confidence ────────────────────────────────────────────────────────
214        raw = Self::encode_confidence_into(raw, confidence);
215
216        Self(raw)
217    }
218
219    /// Encode a [0,1] confidence float into the confidence trit field of a raw value.
220    fn encode_confidence_into(raw: u32, confidence: f32) -> u32 {
221        // Map [0,1] → [0, 8] → two balanced trits
222        let c_int = (confidence.clamp(0.0, 1.0) * (CONF_MAX * 2) as f32).round() as i32;
223        // c_int in [0, 8]: decode as c1*3 + c0 = c_int, balanced trits c0, c1 ∈ {-1,0,+1}
224        let c_int_shifted = c_int - CONF_MAX; // shift to [-4, +4]
225        encode_balanced_int(raw, CONF_POS, CONF_TRITS, c_int_shifted)
226    }
227
228    // ── Value extraction ──────────────────────────────────────────────────────
229
230    /// Convert to f32. Confidence is discarded; use `.confidence()` separately.
231    pub fn to_f32(self) -> f32 {
232        let phase = digit_to_balanced(get_digit(self.0, PHASE_POS));
233        if phase == 0 {
234            return 0.0;
235        }
236
237        let exp = decode_balanced_int(self.0, EXP_POS, EXP_TRITS);
238
239        // Decode mantissa (unbalanced base-3 digits {0,1,2})
240        let mut m = 0u32;
241        let mut place = 1u32;
242        for i in 0..MANT_TRITS {
243            m += get_digit(self.0, MANT_POS + i) * place;
244            place *= 3;
245        }
246        let mantissa_f = m as f32 / MANT_DIV;
247
248        let scale = (3f32).powi(exp);
249        (phase as f32) * scale * (1.0 + mantissa_f)
250    }
251
252    /// The phase trit: -1 (negative), 0 (zero), or +1 (positive).
253    pub fn phase(self) -> i8 {
254        digit_to_balanced(get_digit(self.0, PHASE_POS))
255    }
256
257    /// The exponent as a signed integer in [-121, +121].
258    pub fn exponent(self) -> i32 {
259        decode_balanced_int(self.0, EXP_POS, EXP_TRITS)
260    }
261
262    /// The mantissa as a u32 in [0, 728]. Represents fractional part as M/729.
263    pub fn mantissa(self) -> u32 {
264        let mut m = 0u32;
265        let mut place = 1u32;
266        for i in 0..MANT_TRITS {
267            m += get_digit(self.0, MANT_POS + i) * place;
268            place *= 3;
269        }
270        m
271    }
272
273    /// Confidence as a float in [0.0, 1.0].
274    ///
275    /// 0.0 = completely unknown, 0.5 = neutral/unset, 1.0 = maximally certain.
276    pub fn confidence(self) -> f32 {
277        let c_balanced = decode_balanced_int(self.0, CONF_POS, CONF_TRITS);
278        // c_balanced in [-4, +4] → shift to [0, 8] → divide by 8
279        (c_balanced + CONF_MAX) as f32 / (CONF_MAX * 2) as f32
280    }
281
282    /// True if this value is zero (phase trit = 0).
283    pub fn is_zero(self) -> bool {
284        digit_to_balanced(get_digit(self.0, PHASE_POS)) == 0
285    }
286
287    /// True if confidence is below 0.5 (both confidence trits ≤ 0).
288    pub fn is_uncertain(self) -> bool {
289        self.confidence() < 0.5
290    }
291
292    /// The raw u32 backing value (for serialization and hardware interop).
293    pub fn raw(self) -> u32 {
294        self.0
295    }
296
297    /// Reconstruct from a raw u32 (as returned by `.raw()`).
298    pub fn from_raw(raw: u32) -> Self {
299        debug_assert!(raw <= MAX_RAW, "raw value exceeds 14-trit maximum");
300        Self(raw.min(MAX_RAW))
301    }
302
303    // ── Confidence propagation ─────────────────────────────────────────────────
304
305    /// Propagation rule for multiplication: weakest link.
306    /// The result is only as confident as the less certain operand.
307    pub fn mul_confidence(a: Self, b: Self) -> f32 {
308        a.confidence().min(b.confidence())
309    }
310
311    /// Propagation rule for addition: average the evidence.
312    pub fn add_confidence(a: Self, b: Self) -> f32 {
313        (a.confidence() + b.confidence()) * 0.5
314    }
315
316    // ── Arithmetic ─────────────────────────────────────────────────────────────
317    //
318    // Software path: converts to f32, operates, converts back with propagated
319    // confidence. Hardware-native trit arithmetic is a future optimization.
320
321    /// Negate: flip phase, preserve all other fields including confidence.
322    pub fn neg(self) -> Self {
323        let new_phase = -self.phase();
324        let new_digit = balanced_to_digit(new_phase);
325        let raw = set_digit(self.0, PHASE_POS, new_digit);
326        Self(raw)
327    }
328
329    /// Absolute value: force phase to +1 (or 0 if zero).
330    pub fn abs(self) -> Self {
331        if self.is_zero() { return self; }
332        let raw = set_digit(self.0, PHASE_POS, balanced_to_digit(1));
333        Self(raw)
334    }
335
336    /// Addition with confidence propagation (average rule).
337    pub fn add(self, rhs: Self) -> Self {
338        let value = self.to_f32() + rhs.to_f32();
339        let conf = Self::add_confidence(self, rhs);
340        Self::from_f32_with_confidence(value, conf)
341    }
342
343    /// Subtraction with confidence propagation (average rule).
344    pub fn sub(self, rhs: Self) -> Self {
345        self.add(rhs.neg())
346    }
347
348    /// Multiplication with confidence propagation (weakest-link rule).
349    pub fn mul(self, rhs: Self) -> Self {
350        // Short-circuit: if either operand is phase-zero, result is zero.
351        // Confidence of zero = min(conf_a, conf_b) — we know it's zero, but
352        // only as confidently as our least-certain input.
353        if self.is_zero() || rhs.is_zero() {
354            let conf = Self::mul_confidence(self, rhs);
355            return Self::from_f32_with_confidence(0.0, conf);
356        }
357        let value = self.to_f32() * rhs.to_f32();
358        let conf = Self::mul_confidence(self, rhs);
359        Self::from_f32_with_confidence(value, conf)
360    }
361
362    /// Dot product of two slices of TritFloats.
363    ///
364    /// Confidence of the result = min confidence across all terms.
365    /// Zero-phase terms are skipped entirely (@sparseskip at activation level).
366    pub fn dot(a: &[Self], b: &[Self]) -> Self {
367        assert_eq!(a.len(), b.len(), "dot product requires equal-length slices");
368
369        let mut acc_value = 0.0f32;
370        let mut min_conf = 1.0f32;
371        let mut skipped = 0usize;
372
373        for (&ai, &bi) in a.iter().zip(b.iter()) {
374            // @sparseskip: neutral phase on either operand → contributes zero, skip MAC
375            if ai.is_zero() || bi.is_zero() {
376                // Still track the minimum confidence across skipped terms
377                let term_conf = Self::mul_confidence(ai, bi);
378                min_conf = min_conf.min(term_conf);
379                skipped += 1;
380                continue;
381            }
382            acc_value += ai.to_f32() * bi.to_f32();
383            min_conf = min_conf.min(Self::mul_confidence(ai, bi));
384        }
385
386        let _ = skipped; // available for instrumentation if needed
387
388        Self::from_f32_with_confidence(acc_value, min_conf)
389    }
390
391    /// Dot product returning (result, skip_count) for sparsity instrumentation.
392    pub fn dot_with_skips(a: &[Self], b: &[Self]) -> (Self, usize) {
393        assert_eq!(a.len(), b.len(), "dot product requires equal-length slices");
394
395        let mut acc_value = 0.0f32;
396        let mut min_conf = 1.0f32;
397        let mut skipped = 0usize;
398
399        for (&ai, &bi) in a.iter().zip(b.iter()) {
400            if ai.is_zero() || bi.is_zero() {
401                let term_conf = Self::mul_confidence(ai, bi);
402                min_conf = min_conf.min(term_conf);
403                skipped += 1;
404                continue;
405            }
406            acc_value += ai.to_f32() * bi.to_f32();
407            min_conf = min_conf.min(Self::mul_confidence(ai, bi));
408        }
409
410        (Self::from_f32_with_confidence(acc_value, min_conf), skipped)
411    }
412
413    // ── Routing hint ──────────────────────────────────────────────────────────
414
415    /// Returns true if this activation should be routed to an expert.
416    ///
417    /// Uncertain activations (confidence < threshold) can skip expensive expert
418    /// layers entirely — the confidence field directly gates MoE routing.
419    ///
420    /// `threshold` = minimum confidence to route (suggested: 0.3–0.5)
421    pub fn should_route(self, threshold: f32) -> bool {
422        !self.is_zero() && self.confidence() >= threshold
423    }
424
425    // ── Extended arithmetic ───────────────────────────────────────────────────
426
427    /// Division with weakest-link confidence. Division by zero returns zero
428    /// with 0 confidence — the caller can detect this via `is_uncertain`.
429    pub fn div(self, rhs: Self) -> Self {
430        if rhs.is_zero() {
431            return Self::from_f32_with_confidence(0.0, 0.0);
432        }
433        let conf = Self::mul_confidence(self, rhs);
434        Self::from_f32_with_confidence(self.to_f32() / rhs.to_f32(), conf)
435    }
436
437    /// Reciprocal: 1/x. Confidence preserved; zero input returns 0-confidence zero.
438    pub fn recip(self) -> Self {
439        if self.is_zero() {
440            return Self::from_f32_with_confidence(0.0, 0.0);
441        }
442        Self::from_f32_with_confidence(1.0 / self.to_f32(), self.confidence())
443    }
444
445    /// Integer power. Confidence preserved — single-operand chain.
446    pub fn powi(self, n: i32) -> Self {
447        Self::from_f32_with_confidence(self.to_f32().powi(n), self.confidence())
448    }
449
450    /// Square root. Negative input returns 0-confidence zero (not a real number).
451    pub fn sqrt(self) -> Self {
452        if self.is_zero() { return self; }
453        if self.phase() < 0 {
454            return Self::from_f32_with_confidence(0.0, 0.0);
455        }
456        Self::from_f32_with_confidence(self.to_f32().sqrt(), self.confidence())
457    }
458
459    /// Clamp the value to [lo, hi]. Confidence is preserved unchanged.
460    pub fn clamp(self, lo: f32, hi: f32) -> Self {
461        Self::from_f32_with_confidence(self.to_f32().clamp(lo, hi), self.confidence())
462    }
463
464    /// Ternary comparison: returns +1 if self > rhs, −1 if self < rhs, 0 if equal.
465    /// Confidence = min(conf_self, conf_rhs) — comparison is only as reliable as inputs.
466    pub fn cmp_trit(self, rhs: Self) -> Self {
467        let (va, vb) = (self.to_f32(), rhs.to_f32());
468        let r = if va > vb { 1.0f32 } else if va < vb { -1.0 } else { 0.0 };
469        Self::from_f32_with_confidence(r, Self::mul_confidence(self, rhs))
470    }
471
472    // ── Slice operations ──────────────────────────────────────────────────────
473
474    /// Numerically stable softmax over a slice of TritFloats.
475    ///
476    /// Values are computed in f32; each output element carries the minimum
477    /// confidence of all inputs (softmax mixes every element, so the whole
478    /// slice's certainty bounds the result).
479    pub fn softmax(slice: &[Self]) -> Vec<Self> {
480        if slice.is_empty() { return vec![]; }
481        let vals: Vec<f32> = slice.iter().map(|x| x.to_f32()).collect();
482        let max_v = vals.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
483        let exps: Vec<f32> = vals.iter().map(|&v| (v - max_v).exp()).collect();
484        let sum: f32 = exps.iter().sum::<f32>().max(f32::EPSILON);
485        let min_c = slice.iter().map(|x| x.confidence()).fold(1.0f32, f32::min);
486        exps.iter()
487            .map(|&e| Self::from_f32_with_confidence(e / sum, min_c))
488            .collect()
489    }
490
491    // ── Phase packing — SIMD-lite ─────────────────────────────────────────────
492
493    /// Extract phase digits (0=neg, 1=zero, 2=pos) for a slice into a `Vec<u8>`.
494    ///
495    /// The pre-scan buffer: a single contiguous pass over raw u32 values (% 3)
496    /// before the arithmetic loop. Separating phase-check from f32 math eliminates
497    /// branch misprediction in the hot loop at high sparsity (≥50% zeros).
498    #[inline]
499    pub fn phase_digits(slice: &[Self]) -> Vec<u8> {
500        slice.iter().map(|x| (x.0 % 3) as u8).collect()
501    }
502
503    /// Pack zero-phase flags for up to 64 TritFloats into a u64 bitmask.
504    ///
505    /// Bit i = 1 if slice[i].is_zero(), else 0. `mask.count_ones()` instantly
506    /// gives the skip count for a 64-element chunk. `mask == 0` means all
507    /// elements are active — no branch needed in the arithmetic loop.
508    /// This is the preparation layer for AVX2 vectorization of the dot product.
509    pub fn pack_phases_u64(slice: &[Self]) -> u64 {
510        debug_assert!(slice.len() <= 64, "pack_phases_u64: slice too long (max 64)");
511        let mut mask = 0u64;
512        for (i, x) in slice.iter().take(64).enumerate() {
513            if x.0 % 3 == 1 {
514                mask |= 1u64 << i;
515            }
516        }
517        mask
518    }
519
520    /// Dot product with two-pass pre-scan for reduced branch misprediction.
521    ///
522    /// Pass 1: extract all phase flags into u8 arrays (cache-hot, no branching).
523    /// Pass 2: arithmetic only for active (non-zero-phase) pairs.
524    ///
525    /// Outperforms `dot_with_skips` at ≥50% sparsity where misprediction of the
526    /// inline zero-check dominates. At low sparsity the extra allocation cost
527    /// makes it slightly slower — profile before choosing.
528    pub fn dot_prescan(a: &[Self], b: &[Self]) -> (Self, usize) {
529        assert_eq!(a.len(), b.len(), "dot_prescan requires equal-length slices");
530        let pa = Self::phase_digits(a);
531        let pb = Self::phase_digits(b);
532
533        let mut acc = 0.0f32;
534        let mut min_conf = 1.0f32;
535        let mut skipped = 0usize;
536
537        for i in 0..a.len() {
538            let c = Self::mul_confidence(a[i], b[i]);
539            if c < min_conf { min_conf = c; }
540            if pa[i] == 1 || pb[i] == 1 {
541                skipped += 1;
542            } else {
543                acc += a[i].to_f32() * b[i].to_f32();
544            }
545        }
546
547        (Self::from_f32_with_confidence(acc, min_conf), skipped)
548    }
549}
550
551// ─── Display ─────────────────────────────────────────────────────────────────
552
553impl std::fmt::Debug for TritFloat {
554    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555        write!(f, "TritFloat({:.6} conf={:.2} exp={} mant={})",
556            self.to_f32(),
557            self.confidence(),
558            self.exponent(),
559            self.mantissa(),
560        )
561    }
562}
563
564impl std::fmt::Display for TritFloat {
565    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566        write!(f, "{:.6}±{:.0}%", self.to_f32(), self.confidence() * 100.0)
567    }
568}
569
570// ─── Tests ───────────────────────────────────────────────────────────────────
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575
576    const TOL: f32 = 0.01; // ~1% relative tolerance for roundtrip
577
578    fn approx(a: f32, b: f32, tol: f32) -> bool {
579        if b == 0.0 { return a.abs() < tol; }
580        ((a - b) / b).abs() < tol
581    }
582
583    #[test]
584    fn zero_roundtrip() {
585        let z = TritFloat::from_f32(0.0);
586        assert!(z.is_zero());
587        assert_eq!(z.to_f32(), 0.0);
588        assert_eq!(z.phase(), 0);
589    }
590
591    #[test]
592    fn positive_roundtrip() {
593        for &x in &[0.001f32, 0.1, 0.5, 1.0, 3.0, 9.0, 100.0, 12345.678, 1e10, 1e-10] {
594            let tf = TritFloat::from_f32(x);
595            let back = tf.to_f32();
596            assert!(approx(back, x, TOL),
597                "roundtrip failed for x={}: got {} ({})", x, back, tf);
598            assert_eq!(tf.phase(), 1);
599        }
600    }
601
602    #[test]
603    fn negative_roundtrip() {
604        for &x in &[-0.5f32, -1.0, -3.14, -999.9] {
605            let tf = TritFloat::from_f32(x);
606            let back = tf.to_f32();
607            assert!(approx(back.abs(), x.abs(), TOL),
608                "negative roundtrip failed for x={}: got {}", x, back);
609            assert_eq!(tf.phase(), -1);
610        }
611    }
612
613    #[test]
614    fn confidence_from_f32_is_max() {
615        let tf = TritFloat::from_f32(1.0);
616        assert!((tf.confidence() - 1.0).abs() < 0.15,
617            "from_f32 should give near-max confidence, got {}", tf.confidence());
618    }
619
620    #[test]
621    fn confidence_custom() {
622        let tf = TritFloat::from_f32_with_confidence(1.0, 0.0);
623        assert!(tf.confidence() < 0.2, "expected low confidence, got {}", tf.confidence());
624
625        let tf = TritFloat::from_f32_with_confidence(1.0, 0.5);
626        assert!((tf.confidence() - 0.5).abs() < 0.2, "expected mid confidence, got {}", tf.confidence());
627    }
628
629    #[test]
630    fn zero_confidence_neutral() {
631        let z = TritFloat::zero();
632        assert!(z.is_zero());
633        assert!((z.confidence() - 0.5).abs() < 0.2, "zero should have neutral confidence");
634    }
635
636    #[test]
637    fn neg_flips_phase() {
638        let pos = TritFloat::from_f32(2.5);
639        let neg = pos.neg();
640        assert_eq!(pos.phase(), 1);
641        assert_eq!(neg.phase(), -1);
642        assert!(approx(pos.to_f32(), -neg.to_f32(), TOL));
643        // confidence is preserved
644        assert!((pos.confidence() - neg.confidence()).abs() < 0.15);
645    }
646
647    #[test]
648    fn abs_always_positive() {
649        let neg = TritFloat::from_f32(-7.0);
650        let a = neg.abs();
651        assert_eq!(a.phase(), 1);
652        assert!(a.to_f32() > 0.0);
653    }
654
655    #[test]
656    fn mul_confidence_weakest_link() {
657        let certain = TritFloat::from_f32_with_confidence(2.0, 1.0);
658        let uncertain = TritFloat::from_f32_with_confidence(3.0, 0.0);
659        let product = certain.mul(uncertain);
660        assert!(product.confidence() < 0.2,
661            "mul confidence should be dominated by uncertain operand");
662    }
663
664    #[test]
665    fn mul_zero_propagates_uncertainty() {
666        let zero = TritFloat::from_f32_with_confidence(0.0, 0.0);
667        let certain = TritFloat::from_f32_with_confidence(5.0, 1.0);
668        let product = certain.mul(zero);
669        assert!(product.is_zero());
670        // confidence = min(1.0, 0.0) = 0.0
671        assert!(product.confidence() < 0.2);
672    }
673
674    #[test]
675    fn add_confidence_averages() {
676        let a = TritFloat::from_f32_with_confidence(1.0, 1.0);
677        let b = TritFloat::from_f32_with_confidence(1.0, 0.0);
678        let sum = a.add(b);
679        assert!((sum.confidence() - 0.5).abs() < 0.2,
680            "add confidence should average, got {}", sum.confidence());
681    }
682
683    #[test]
684    fn add_value_correct() {
685        let a = TritFloat::from_f32(1.5);
686        let b = TritFloat::from_f32(2.5);
687        let sum = a.add(b);
688        assert!(approx(sum.to_f32(), 4.0, TOL), "1.5 + 2.5 should ≈ 4.0, got {}", sum.to_f32());
689    }
690
691    #[test]
692    fn mul_value_correct() {
693        let a = TritFloat::from_f32(3.0);
694        let b = TritFloat::from_f32(4.0);
695        let p = a.mul(b);
696        assert!(approx(p.to_f32(), 12.0, 0.02), "3 × 4 should ≈ 12, got {}", p.to_f32());
697    }
698
699    #[test]
700    fn dot_basic() {
701        let a: Vec<TritFloat> = [1.0f32, 2.0, 3.0].iter().map(|&x| TritFloat::from_f32(x)).collect();
702        let b: Vec<TritFloat> = [4.0f32, 5.0, 6.0].iter().map(|&x| TritFloat::from_f32(x)).collect();
703        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
704        let result = TritFloat::dot(&a, &b);
705        assert!(approx(result.to_f32(), 32.0, 0.02),
706            "dot([1,2,3],[4,5,6]) should ≈ 32, got {}", result.to_f32());
707    }
708
709    #[test]
710    fn dot_skips_zeros() {
711        // Two zeros in a = 2/3 skipped
712        let a: Vec<TritFloat> = vec![
713            TritFloat::from_f32(0.0),
714            TritFloat::from_f32(2.0),
715            TritFloat::from_f32(0.0),
716        ];
717        let b: Vec<TritFloat> = vec![
718            TritFloat::from_f32(1.0),
719            TritFloat::from_f32(3.0),
720            TritFloat::from_f32(1.0),
721        ];
722        let (result, skips) = TritFloat::dot_with_skips(&a, &b);
723        assert_eq!(skips, 2, "two zero phases should produce 2 skips");
724        assert!(approx(result.to_f32(), 6.0, 0.02),
725            "0*1 + 2*3 + 0*1 = 6, got {}", result.to_f32());
726    }
727
728    #[test]
729    fn should_route_confidence_gate() {
730        let certain = TritFloat::from_f32_with_confidence(1.0, 0.9);
731        let uncertain = TritFloat::from_f32_with_confidence(1.0, 0.1);
732        let zero = TritFloat::from_f32(0.0);
733
734        assert!(certain.should_route(0.5),   "certain should route");
735        assert!(!uncertain.should_route(0.5), "uncertain should not route");
736        assert!(!zero.should_route(0.0),      "zero phase never routes");
737    }
738
739    #[test]
740    fn raw_roundtrip() {
741        let tf = TritFloat::from_f32(42.0);
742        let raw = tf.raw();
743        let restored = TritFloat::from_raw(raw);
744        assert_eq!(tf, restored);
745    }
746
747    #[test]
748    fn display_shows_confidence() {
749        let tf = TritFloat::from_f32(3.14);
750        let s = format!("{tf}");
751        assert!(s.contains('%'), "display should show confidence %: got '{}'", s);
752    }
753
754    #[test]
755    fn exponent_range_covered() {
756        let large = TritFloat::from_f32(1e30f32);
757        let small = TritFloat::from_f32(1e-30f32);
758        assert!(large.exponent().abs() <= EXP_MAX as i32);
759        assert!(small.exponent().abs() <= EXP_MAX as i32);
760        assert!(approx(large.to_f32(), 1e30, 0.05));
761        assert!(approx(small.to_f32(), 1e-30, 0.05));
762    }
763
764    // ── Extended arithmetic tests ─────────────────────────────────────────────
765
766    #[test]
767    fn div_basic() {
768        let a = TritFloat::from_f32(6.0);
769        let b = TritFloat::from_f32(2.0);
770        let r = a.div(b);
771        assert!(approx(r.to_f32(), 3.0, TOL), "6/2 should be 3, got {}", r.to_f32());
772    }
773
774    #[test]
775    fn div_by_zero_returns_zero_confidence() {
776        let a = TritFloat::from_f32(5.0);
777        let z = TritFloat::from_f32(0.0);
778        let r = a.div(z);
779        assert!(r.is_zero());
780        assert!(r.confidence() < 0.15, "div-by-zero should have 0 confidence");
781    }
782
783    #[test]
784    fn recip_basic() {
785        let r = TritFloat::from_f32(4.0).recip();
786        assert!(approx(r.to_f32(), 0.25, TOL), "recip(4) should be 0.25, got {}", r.to_f32());
787    }
788
789    #[test]
790    fn recip_zero_returns_zero_confidence() {
791        let r = TritFloat::zero().recip();
792        assert!(r.is_zero());
793        assert!(r.confidence() < 0.15);
794    }
795
796    #[test]
797    fn powi_basic() {
798        let r = TritFloat::from_f32(2.0).powi(3);
799        assert!(approx(r.to_f32(), 8.0, TOL), "2^3 should be 8, got {}", r.to_f32());
800    }
801
802    #[test]
803    fn powi_confidence_preserved() {
804        let a = TritFloat::from_f32_with_confidence(2.0, 0.75);
805        let r = a.powi(2);
806        assert!((r.confidence() - 0.75).abs() < 0.15);
807    }
808
809    #[test]
810    fn sqrt_basic() {
811        let r = TritFloat::from_f32(9.0).sqrt();
812        assert!(approx(r.to_f32(), 3.0, TOL), "sqrt(9) should be 3, got {}", r.to_f32());
813    }
814
815    #[test]
816    fn sqrt_negative_returns_zero_confidence() {
817        let r = TritFloat::from_f32(-4.0).sqrt();
818        assert!(r.is_zero());
819        assert!(r.confidence() < 0.15, "sqrt of negative should have 0 confidence");
820    }
821
822    #[test]
823    fn clamp_caps_value() {
824        let hi = TritFloat::from_f32(5.0).clamp(0.0, 3.0);
825        assert!(approx(hi.to_f32(), 3.0, TOL), "clamp(5, 0, 3) should be 3, got {}", hi.to_f32());
826        let lo = TritFloat::from_f32(-2.0).clamp(0.0, 3.0);
827        assert!(approx(lo.to_f32(), 0.0, 0.01), "clamp(-2, 0, 3) should be 0");
828    }
829
830    #[test]
831    fn clamp_preserves_confidence() {
832        let a = TritFloat::from_f32_with_confidence(10.0, 0.625);
833        let r = a.clamp(0.0, 1.0);
834        assert!((r.confidence() - 0.625).abs() < 0.15);
835    }
836
837    #[test]
838    fn cmp_trit_ordering() {
839        let big = TritFloat::from_f32(3.0);
840        let small = TritFloat::from_f32(2.0);
841        assert_eq!(big.cmp_trit(small).phase(),  1,  "3 > 2 should give +1");
842        assert_eq!(small.cmp_trit(big).phase(), -1,  "2 < 3 should give -1");
843        assert_eq!(big.cmp_trit(big).phase(),    0,  "x == x should give 0");
844    }
845
846    #[test]
847    fn cmp_trit_confidence_is_min() {
848        let a = TritFloat::from_f32_with_confidence(3.0, 1.0);
849        let b = TritFloat::from_f32_with_confidence(2.0, 0.125);
850        let r = a.cmp_trit(b);
851        assert!(r.confidence() < 0.2, "cmp confidence should be min of inputs");
852    }
853
854    // ── Slice / SIMD-lite tests ───────────────────────────────────────────────
855
856    #[test]
857    fn softmax_sums_to_one() {
858        let vals: Vec<TritFloat> = [1.0f32, 2.0, 3.0, 0.5]
859            .iter().map(|&x| TritFloat::from_f32(x)).collect();
860        let sm = TritFloat::softmax(&vals);
861        let sum: f32 = sm.iter().map(|x| x.to_f32()).sum();
862        assert!((sum - 1.0).abs() < 1e-4, "softmax should sum to 1.0, got {sum}");
863    }
864
865    #[test]
866    fn softmax_confidence_is_min_of_inputs() {
867        let vals = vec![
868            TritFloat::from_f32_with_confidence(1.0, 1.0),
869            TritFloat::from_f32_with_confidence(2.0, 0.125),
870            TritFloat::from_f32_with_confidence(3.0, 1.0),
871        ];
872        let sm = TritFloat::softmax(&vals);
873        for s in &sm {
874            assert!(s.confidence() < 0.2,
875                "softmax conf should be min of inputs (0.125), got {}", s.confidence());
876        }
877    }
878
879    #[test]
880    fn softmax_empty_slice() {
881        assert_eq!(TritFloat::softmax(&[]).len(), 0);
882    }
883
884    #[test]
885    fn pack_phases_u64_correctness() {
886        let vals: Vec<TritFloat> = [1.0f32, 0.0, -1.0, 0.0, 2.0]
887            .iter().map(|&x| TritFloat::from_f32(x)).collect();
888        let mask = TritFloat::pack_phases_u64(&vals);
889        // bits 1 and 3 should be set (zero-phase elements at indices 1 and 3)
890        assert_eq!(mask & 1,  0, "index 0 (1.0) should not be zero-phase");
891        assert_eq!(mask & 2,  2, "index 1 (0.0) should be zero-phase");
892        assert_eq!(mask & 4,  0, "index 2 (-1.0) should not be zero-phase");
893        assert_eq!(mask & 8,  8, "index 3 (0.0) should be zero-phase");
894        assert_eq!(mask & 16, 0, "index 4 (2.0) should not be zero-phase");
895        assert_eq!(mask.count_ones(), 2);
896    }
897
898    #[test]
899    fn dot_prescan_matches_dot_with_skips() {
900        let a: Vec<TritFloat> = [1.0f32, 0.0, 2.0, 0.0, 3.0]
901            .iter().map(|&x| TritFloat::from_f32(x)).collect();
902        let b: Vec<TritFloat> = [4.0f32, 5.0, 0.0, 6.0, 7.0]
903            .iter().map(|&x| TritFloat::from_f32(x)).collect();
904
905        let (r1, s1) = TritFloat::dot_with_skips(&a, &b);
906        let (r2, s2) = TritFloat::dot_prescan(&a, &b);
907
908        assert!(approx(r1.to_f32(), r2.to_f32(), 0.001),
909            "prescan and dot_with_skips should match: {} vs {}", r1.to_f32(), r2.to_f32());
910        assert_eq!(s1, s2, "skip counts should match: {s1} vs {s2}");
911    }
912
913    #[test]
914    fn phase_digits_correct() {
915        let vals: Vec<TritFloat> = [-1.0f32, 0.0, 1.0]
916            .iter().map(|&x| TritFloat::from_f32(x)).collect();
917        let pd = TritFloat::phase_digits(&vals);
918        assert_eq!(pd[0], 0, "neg phase → digit 0");
919        assert_eq!(pd[1], 1, "zero phase → digit 1");
920        assert_eq!(pd[2], 2, "pos phase → digit 2");
921    }
922}