Skip to main content

dsfb_gpu_debug_core/
fixed.rs

1//! Q16.16 signed fixed-point arithmetic.
2//!
3//! Why fixed-point: the CPU reference path and the CUDA kernels must produce
4//! byte-identical outputs cell-for-cell so the hash chain in the case file is
5//! the same regardless of whether evidence was produced on the host or the
6//! device. Floating point makes byte-equivalence fragile across compilers,
7//! drivers, fused-multiply-add behavior, and reduction order. Q16.16 with
8//! explicit rounding sidesteps all of that.
9//!
10//! Why Q16.16 specifically: 16.16 strikes a balance between dynamic range
11//! (i32 covers roughly ±32_767 in the integer half) and precision (1/65_536
12//! in the fractional half) that fits this workload — latencies in
13//! milliseconds, error-rate fractions, and smoothed-residual magnitudes all
14//! sit comfortably inside that band once inputs are clamped at the boundary.
15//!
16//! Why these specific operations: `sat_add`, `sat_sub`, `sat_mul`, `sat_div`,
17//! and `abs` are the only operations the downstream pipeline needs. No FMA,
18//! no SIMD intrinsics. The same scalar code must run on the CPU and inside a
19//! CUDA kernel, so the implementation is restricted to plain integer math
20//! that both backends can express identically.
21//!
22//! Rounding rule: multiplication widens to i64, then applies round-half-to-
23//! even (banker's rounding) at bit 15 before shifting right by 16. This rule
24//! is symmetric, deterministic across architectures, and reduces the bias
25//! that round-half-up would inject into long EWMA recurrences.
26
27#![allow(clippy::module_name_repetitions)]
28
29/// Signed Q16.16 fixed-point value.
30///
31/// Layout: a single `i32` where the high 16 bits are the integer part and
32/// the low 16 bits are the fractional part. `#[repr(transparent)]` so the
33/// type can cross the FFI boundary as a plain `int32_t` without conversion.
34#[repr(transparent)]
35#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Default, Debug)]
36pub struct Q16(pub i32);
37
38impl Q16 {
39    /// Q16.16 representation of zero. Used as an EWMA seed and as the
40    /// "no evidence" axis value before any detector has fired.
41    pub const ZERO: Q16 = Q16(0);
42
43    /// Q16.16 representation of one. The fractional half is zero, the
44    /// integer half is one. `1 << 16 == 65_536`.
45    pub const ONE: Q16 = Q16(1 << 16);
46
47    /// The smallest representable value. Saturating arithmetic clamps here
48    /// rather than wrapping.
49    pub const MIN: Q16 = Q16(i32::MIN);
50
51    /// The largest representable value. Saturating arithmetic clamps here
52    /// rather than wrapping.
53    pub const MAX: Q16 = Q16(i32::MAX);
54
55    /// Construct a Q16.16 value from an integer. Sign-extended into the
56    /// high 16 bits. Caller is responsible for keeping `|x| ≤ 32_767`.
57    #[must_use]
58    pub const fn from_int(x: i16) -> Q16 {
59        Q16((x as i32) << 16)
60    }
61
62    /// Construct a Q16.16 value from its raw `i32` bit pattern. Used by the
63    /// contract loader (`ewma_alpha_q16_raw`) and by tests that need to pin
64    /// a specific bit value.
65    #[must_use]
66    pub const fn from_raw(raw: i32) -> Q16 {
67        Q16(raw)
68    }
69
70    /// Return the raw `i32` representation. The hash-chain serializer
71    /// writes raw words as zero-padded big-endian hex so the canonical bytes
72    /// are stable regardless of host endianness.
73    #[must_use]
74    pub const fn raw(self) -> i32 {
75        self.0
76    }
77
78    /// Saturating addition. Both CPU and CUDA paths must use the same rule
79    /// because per-cell results would otherwise disagree at the boundaries.
80    #[must_use]
81    pub const fn sat_add(self, b: Q16) -> Q16 {
82        Q16(self.0.saturating_add(b.0))
83    }
84
85    /// Saturating subtraction. Mirrors `sat_add`.
86    #[must_use]
87    pub const fn sat_sub(self, b: Q16) -> Q16 {
88        Q16(self.0.saturating_sub(b.0))
89    }
90
91    /// Saturating multiplication with round-half-to-even.
92    ///
93    /// The product is widened to i64 to avoid overflow before the rounding
94    /// shift. Banker's rounding is applied at bit 15: ties round toward the
95    /// even result. After rounding the value is shifted right by 16 (the
96    /// fractional width) and saturated back to i32.
97    ///
98    /// This function is the load-bearing primitive for CPU↔GPU bit equality.
99    /// Any divergence in its definition between the two backends would break
100    /// stage-by-stage hash equivalence. The CUDA implementation in
101    /// `cuda/common.cuh` is intentionally a line-for-line transliteration.
102    #[must_use]
103    pub const fn sat_mul(self, b: Q16) -> Q16 {
104        let prod: i64 = (self.0 as i64) * (b.0 as i64);
105        let rounded: i64 = round_half_even_i64_shift(prod, 16);
106        Q16(saturate_i64_to_i32(rounded))
107    }
108
109    /// Saturating division. Returns `Q16::ZERO` when the divisor is zero.
110    /// The pipeline never divides on the hot path; this exists for the
111    /// occasional baseline computation and is included for completeness.
112    #[must_use]
113    pub const fn sat_div(self, b: Q16) -> Q16 {
114        if b.0 == 0 {
115            return Q16::ZERO;
116        }
117        // Shift numerator left by the fractional width first, then divide.
118        // Performed in i64 to keep the pre-divide shift safe.
119        let num: i64 = (self.0 as i64) << 16;
120        Q16(saturate_i64_to_i32(num / (b.0 as i64)))
121    }
122
123    /// Absolute value, saturating. `Q16::MIN.abs()` returns `Q16::MAX`
124    /// rather than panicking or wrapping.
125    #[must_use]
126    pub const fn abs(self) -> Q16 {
127        Q16(self.0.saturating_abs())
128    }
129
130    /// Test for the additive identity. Used by axis gates that need to know
131    /// "no evidence" without comparing two `Q16` values.
132    #[must_use]
133    pub const fn is_zero(self) -> bool {
134        self.0 == 0
135    }
136
137    /// Linear interpolation: `self + alpha * (other - self)`, computed in
138    /// saturating Q16.16. Convenience wrapper used by the EWMA recurrence.
139    /// `alpha` is expected to live in `[0, ONE]`; values outside that band
140    /// still produce a deterministic result but the EWMA interpretation
141    /// stops being meaningful.
142    #[must_use]
143    pub const fn lerp(self, other: Q16, alpha: Q16) -> Q16 {
144        let diff = other.sat_sub(self);
145        let step = alpha.sat_mul(diff);
146        self.sat_add(step)
147    }
148}
149
150/// Saturate an `i64` to `i32`. Standalone so both the CPU and CUDA paths use
151/// the same clamp rule.
152#[must_use]
153const fn saturate_i64_to_i32(x: i64) -> i32 {
154    if x > i32::MAX as i64 {
155        i32::MAX
156    } else if x < i32::MIN as i64 {
157        i32::MIN
158    } else {
159        x as i32
160    }
161}
162
163/// Right shift with round-half-to-even at the cut.
164///
165/// `bits` is expected to be in `1..=32`. For `bits == 0` the input is
166/// returned unchanged. The rule: take the bits about to be discarded; if
167/// the top discarded bit is 0 the result rounds down (truncation); if the
168/// top discarded bit is 1 and any of the remaining discarded bits is 1 the
169/// result rounds away from zero (round-up for positive, round-down for
170/// negative magnitudes); if the top discarded bit is 1 and all remaining
171/// discarded bits are 0 (exact midpoint), the result is rounded toward the
172/// even neighbor.
173#[must_use]
174const fn round_half_even_i64_shift(value: i64, bits: u32) -> i64 {
175    if bits == 0 {
176        return value;
177    }
178    let truncated: i64 = value >> bits;
179    let halfway_bit: i64 = 1i64 << (bits - 1);
180    let discarded: i64 = value & ((1i64 << bits) - 1);
181    let remainder_below_half: i64 = discarded & (halfway_bit - 1);
182    let at_half: bool = discarded == halfway_bit && remainder_below_half == 0;
183    if at_half {
184        // Exact midpoint: round toward even.
185        if truncated & 1 == 0 {
186            truncated
187        } else {
188            truncated.saturating_add(1)
189        }
190    } else if discarded > halfway_bit {
191        truncated.saturating_add(1)
192    } else {
193        truncated
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn constants_have_expected_raw_bits() {
203        assert_eq!(Q16::ZERO.raw(), 0);
204        assert_eq!(Q16::ONE.raw(), 1 << 16);
205        assert_eq!(Q16::MIN.raw(), i32::MIN);
206        assert_eq!(Q16::MAX.raw(), i32::MAX);
207    }
208
209    #[test]
210    fn from_int_round_trips_through_integer_part() {
211        for x in [-32_767_i16, -1, 0, 1, 32_767] {
212            let q = Q16::from_int(x);
213            assert_eq!(q.raw() >> 16, i32::from(x));
214            assert_eq!(q.raw() & 0xFFFF, 0);
215        }
216    }
217
218    #[test]
219    fn sat_add_clamps_at_max() {
220        assert_eq!(Q16::MAX.sat_add(Q16::ONE), Q16::MAX);
221        assert_eq!(Q16::MIN.sat_add(Q16::from_int(-1)), Q16::MIN);
222        assert_eq!(Q16::from_int(2).sat_add(Q16::from_int(3)), Q16::from_int(5));
223    }
224
225    #[test]
226    fn sat_sub_clamps_at_min() {
227        assert_eq!(Q16::MIN.sat_sub(Q16::ONE), Q16::MIN);
228        assert_eq!(Q16::MAX.sat_sub(Q16::from_int(-1)), Q16::MAX);
229        assert_eq!(Q16::from_int(5).sat_sub(Q16::from_int(3)), Q16::from_int(2));
230    }
231
232    #[test]
233    fn sat_mul_is_one_identity() {
234        for x in [-100, -1, 0, 1, 100, 32_767] {
235            let q = Q16::from_int(x);
236            assert_eq!(q.sat_mul(Q16::ONE), q);
237            assert_eq!(Q16::ONE.sat_mul(q), q);
238        }
239    }
240
241    #[test]
242    fn sat_mul_half_times_half_is_quarter() {
243        // 0.5 raw == 0x8000 == 32_768. 0.5 * 0.5 = 0.25 raw == 0x4000.
244        let half = Q16::from_raw(0x8000);
245        let quarter = Q16::from_raw(0x4000);
246        assert_eq!(half.sat_mul(half), quarter);
247    }
248
249    #[test]
250    fn sat_mul_handles_negative_signs() {
251        let a = Q16::from_int(-3);
252        let b = Q16::from_int(4);
253        assert_eq!(a.sat_mul(b), Q16::from_int(-12));
254        assert_eq!(b.sat_mul(a), Q16::from_int(-12));
255        assert_eq!(a.sat_mul(a), Q16::from_int(9));
256    }
257
258    #[test]
259    fn sat_mul_saturates_on_overflow() {
260        // 30_000 * 30_000 = 9e8 — far beyond i32::MAX / 65_536 ≈ 32_768.
261        let big = Q16::from_int(30_000);
262        assert_eq!(big.sat_mul(big), Q16::MAX);
263        assert_eq!(big.sat_mul(Q16::from_int(-30_000)), Q16::MIN);
264    }
265
266    #[test]
267    fn sat_div_one_is_identity() {
268        for x in [-100, -1, 0, 1, 100] {
269            let q = Q16::from_int(x);
270            assert_eq!(q.sat_div(Q16::ONE), q);
271        }
272    }
273
274    #[test]
275    fn sat_div_by_zero_returns_zero() {
276        assert_eq!(Q16::from_int(5).sat_div(Q16::ZERO), Q16::ZERO);
277        assert_eq!(Q16::from_int(-5).sat_div(Q16::ZERO), Q16::ZERO);
278    }
279
280    #[test]
281    fn abs_saturates_at_min() {
282        assert_eq!(Q16::MIN.abs(), Q16::MAX);
283        assert_eq!(Q16::from_int(-7).abs(), Q16::from_int(7));
284        assert_eq!(Q16::from_int(7).abs(), Q16::from_int(7));
285    }
286
287    #[test]
288    fn lerp_at_zero_returns_self() {
289        let a = Q16::from_int(10);
290        let b = Q16::from_int(20);
291        assert_eq!(a.lerp(b, Q16::ZERO), a);
292    }
293
294    #[test]
295    fn lerp_at_one_returns_other() {
296        let a = Q16::from_int(10);
297        let b = Q16::from_int(20);
298        assert_eq!(a.lerp(b, Q16::ONE), b);
299    }
300
301    #[test]
302    fn lerp_at_half_returns_midpoint() {
303        let a = Q16::from_int(10);
304        let b = Q16::from_int(20);
305        let half = Q16::from_raw(0x8000);
306        assert_eq!(a.lerp(b, half), Q16::from_int(15));
307    }
308
309    #[test]
310    fn ewma_recurrence_with_locked_alpha_is_stable() {
311        // alpha = 0.125 (0x2000 raw). Constant input x. The EWMA should
312        // monotonically approach x without ever overshooting it.
313        //
314        // With banker's rounding and alpha = 1/8, convergence stalls when
315        // the per-step delta drops to 4 raw Q16 units (4 * 8192 = 32_768,
316        // which is exactly halfway and rounds to the even neighbor 0). So
317        // the asymptotic gap is 4 raw units rather than 0 — a known
318        // property of the discrete recurrence we are intentionally pinning.
319        let alpha = Q16::from_raw(0x2000);
320        let x = Q16::from_int(100);
321        let mut ewma = Q16::ZERO;
322        let mut last = Q16::MIN;
323        for _ in 0..200 {
324            ewma = ewma.lerp(x, alpha);
325            // Each step is closer to x than the previous one (monotone non-decreasing).
326            assert!(ewma.raw() >= last.raw());
327            assert!(ewma.raw() <= x.raw());
328            last = ewma;
329        }
330        // After convergence, the EWMA sits within the rounding floor of x.
331        let final_gap = x.sat_sub(ewma).raw();
332        assert!(
333            final_gap <= 4,
334            "expected gap ≤ 4 raw units, got {final_gap}"
335        );
336    }
337
338    #[test]
339    fn round_half_even_breaks_ties_to_even() {
340        // 0.5 rounds to 0 (even); 1.5 rounds to 2 (even); 2.5 rounds to 2; 3.5 rounds to 4.
341        // Test by constructing values whose discarded bits are exactly the halfway pattern.
342        let cases: [(i64, i64); 4] = [
343            (0b0_1000, 0),  // 0.5 ↦ 0
344            (0b1_1000, 2),  // 1.5 ↦ 2
345            (0b10_1000, 2), // 2.5 ↦ 2
346            (0b11_1000, 4), // 3.5 ↦ 4
347        ];
348        for (input, expected) in cases {
349            assert_eq!(
350                round_half_even_i64_shift(input, 4),
351                expected,
352                "input={input:#b}"
353            );
354        }
355    }
356
357    #[test]
358    fn round_half_even_handles_negative_midpoints() {
359        // -0.5 ↦ 0 (even); -1.5 ↦ -2; -2.5 ↦ -2; -3.5 ↦ -4.
360        let cases: [(i64, i64); 4] = [
361            (-(0b0_1000_i64), 0),
362            (-(0b1_1000_i64), -2),
363            (-(0b10_1000_i64), -2),
364            (-(0b11_1000_i64), -4),
365        ];
366        for (input, expected) in cases {
367            assert_eq!(
368                round_half_even_i64_shift(input, 4),
369                expected,
370                "input={input}"
371            );
372        }
373    }
374}