Skip to main content

cjc_runtime/
f16.rs

1//! Half Precision (f16) — IEEE 754 binary16 with promotion to f64.
2//!
3//! # Design
4//!
5//! f16 values are promoted to f64 before entering the binned accumulation
6//! path. Subnormal handling is preserved by the bin 0 logic of the
7//! BinnedAccumulator. Arithmetic is performed in f64, then narrowed back
8//! to f16 on storage.
9//!
10//! # IEEE 754 binary16 Layout
11//!
12//! ```text
13//! Bit 15:     sign
14//! Bits 14-10: exponent (5 bits, bias = 15)
15//! Bits 9-0:   mantissa (10 bits)
16//! ```
17//!
18//! Range: ±65504 (max normal), ±6.1e-5 (min positive subnormal)
19
20use crate::accumulator::BinnedAccumulatorF64;
21
22// ---------------------------------------------------------------------------
23// F16 Type
24// ---------------------------------------------------------------------------
25
26/// IEEE 754 binary16 half-precision float.
27///
28/// Stored as u16. All arithmetic is performed by promoting to f64,
29/// computing, then narrowing back. This ensures deterministic behavior
30/// regardless of platform, since the f64 path is well-defined.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub struct F16(pub u16);
33
34impl F16 {
35    /// Positive zero.
36    pub const ZERO: F16 = F16(0x0000);
37    /// Negative zero.
38    pub const NEG_ZERO: F16 = F16(0x8000);
39    /// Positive infinity.
40    pub const INFINITY: F16 = F16(0x7C00);
41    /// Negative infinity.
42    pub const NEG_INFINITY: F16 = F16(0xFC00);
43    /// Canonical NaN.
44    pub const NAN: F16 = F16(0x7E00);
45    /// Maximum finite value: 65504.0.
46    pub const MAX: F16 = F16(0x7BFF);
47    /// Minimum positive subnormal.
48    pub const MIN_POSITIVE_SUBNORMAL: F16 = F16(0x0001);
49
50    /// Convert f16 to f64.
51    ///
52    /// Handles normals, subnormals, zeros, infinities, and NaNs.
53    pub fn to_f64(self) -> f64 {
54        let bits = self.0;
55        let sign = (bits >> 15) & 1;
56        let exp = (bits >> 10) & 0x1F;
57        let mant = bits & 0x03FF;
58
59        let sign_f = if sign == 1 { -1.0 } else { 1.0 };
60
61        if exp == 0 {
62            if mant == 0 {
63                // Signed zero.
64                if sign == 1 { -0.0 } else { 0.0 }
65            } else {
66                // Subnormal: value = sign * 2^(-14) * (mant / 1024)
67                sign_f * (mant as f64) * 2.0f64.powi(-24)
68            }
69        } else if exp == 0x1F {
70            if mant == 0 {
71                // Infinity.
72                if sign == 1 { f64::NEG_INFINITY } else { f64::INFINITY }
73            } else {
74                // NaN — canonicalize to a single NaN value.
75                f64::NAN
76            }
77        } else {
78            // Normal: value = sign * 2^(exp-15) * (1 + mant/1024)
79            sign_f * 2.0f64.powi(exp as i32 - 15) * (1.0 + mant as f64 / 1024.0)
80        }
81    }
82
83    /// Convert f64 to f16 (round-to-nearest-even).
84    ///
85    /// Handles overflow to infinity, underflow to zero, and subnormals.
86    pub fn from_f64(value: f64) -> Self {
87        if value.is_nan() {
88            return F16::NAN;
89        }
90
91        let sign: u16 = if value.is_sign_negative() { 0x8000 } else { 0 };
92        let abs_val = value.abs();
93
94        if abs_val == 0.0 {
95            return F16(sign); // Preserves sign of zero
96        }
97
98        if abs_val.is_infinite() {
99            return F16(sign | 0x7C00);
100        }
101
102        // Overflow to infinity.
103        if abs_val > 65504.0 {
104            return F16(sign | 0x7C00);
105        }
106
107        // Subnormal range: < 2^(-14) = 6.103515625e-5
108        if abs_val < 6.103515625e-5 {
109            // Subnormal: round to nearest subnormal representation.
110            let mant = (abs_val / 2.0f64.powi(-24)).round() as u16;
111            if mant == 0 {
112                return F16(sign); // Underflow to signed zero
113            }
114            return F16(sign | mant.min(0x03FF));
115        }
116
117        // Normal range.
118        let log2_val = abs_val.log2();
119        let exp = log2_val.floor() as i32;
120        let biased_exp = (exp + 15) as u16;
121
122        if biased_exp >= 31 {
123            return F16(sign | 0x7C00); // Overflow
124        }
125
126        let significand = abs_val / 2.0f64.powi(exp) - 1.0;
127        let mant = (significand * 1024.0).round() as u16;
128
129        // Handle rounding that pushes mantissa to 1024 (overflow to next exponent).
130        if mant >= 1024 {
131            let biased_exp = biased_exp + 1;
132            if biased_exp >= 31 {
133                return F16(sign | 0x7C00);
134            }
135            return F16(sign | (biased_exp << 10));
136        }
137
138        F16(sign | (biased_exp << 10) | mant)
139    }
140
141    /// Convert f32 to f16.
142    pub fn from_f32(value: f32) -> Self {
143        Self::from_f64(value as f64)
144    }
145
146    /// Convert f16 to f32.
147    pub fn to_f32(self) -> f32 {
148        self.to_f64() as f32
149    }
150
151    /// Check if this is NaN.
152    pub fn is_nan(self) -> bool {
153        let exp = (self.0 >> 10) & 0x1F;
154        let mant = self.0 & 0x03FF;
155        exp == 0x1F && mant != 0
156    }
157
158    /// Check if this is infinite.
159    pub fn is_infinite(self) -> bool {
160        let exp = (self.0 >> 10) & 0x1F;
161        let mant = self.0 & 0x03FF;
162        exp == 0x1F && mant == 0
163    }
164
165    /// Check if this is finite (not NaN or Inf).
166    pub fn is_finite(self) -> bool {
167        let exp = (self.0 >> 10) & 0x1F;
168        exp != 0x1F
169    }
170
171    /// Check if this is subnormal (exponent == 0, mantissa != 0).
172    pub fn is_subnormal(self) -> bool {
173        let exp = (self.0 >> 10) & 0x1F;
174        let mant = self.0 & 0x03FF;
175        exp == 0 && mant != 0
176    }
177
178    /// Add two f16 values. Promotes both to f64, adds, then narrows back.
179    pub fn add(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() + rhs.to_f64()) }
180    /// Subtract `rhs` from `self`. Promotes both to f64, subtracts, then narrows back.
181    pub fn sub(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() - rhs.to_f64()) }
182    /// Multiply two f16 values. Promotes both to f64, multiplies, then narrows back.
183    pub fn mul(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() * rhs.to_f64()) }
184    /// Divide `self` by `rhs`. Promotes both to f64, divides, then narrows back.
185    pub fn div(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() / rhs.to_f64()) }
186    /// Negate by toggling the sign bit. Does not promote to f64.
187    pub fn neg(self) -> Self { F16(self.0 ^ 0x8000) }
188}
189
190impl std::fmt::Display for F16 {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        write!(f, "{}", self.to_f64())
193    }
194}
195
196// ---------------------------------------------------------------------------
197// f16 Accumulation via BinnedAccumulator
198// ---------------------------------------------------------------------------
199
200/// Sum f16 values by promoting to f64 and using BinnedAccumulator.
201///
202/// This ensures order-invariant, deterministic results regardless of
203/// the f16 precision limitations.
204pub fn f16_binned_sum(values: &[F16]) -> f64 {
205    let mut acc = BinnedAccumulatorF64::new();
206    for &v in values {
207        acc.add(v.to_f64());
208    }
209    acc.finalize()
210}
211
212/// Dot product of two f16 slices, accumulated in f64 via BinnedAccumulator.
213pub fn f16_binned_dot(a: &[F16], b: &[F16]) -> f64 {
214    debug_assert_eq!(a.len(), b.len());
215    let mut acc = BinnedAccumulatorF64::new();
216    for i in 0..a.len() {
217        // Promote both operands to f64, multiply in f64, then accumulate.
218        acc.add(a[i].to_f64() * b[i].to_f64());
219    }
220    acc.finalize()
221}
222
223/// Matrix multiply for f16 arrays, computing in f64 via BinnedAccumulator.
224///
225/// Result is in f64 for full precision.
226pub fn f16_matmul(
227    a: &[F16], b: &[F16], out: &mut [f64],
228    m: usize, k: usize, n: usize,
229) {
230    debug_assert_eq!(a.len(), m * k);
231    debug_assert_eq!(b.len(), k * n);
232    debug_assert_eq!(out.len(), m * n);
233
234    for i in 0..m {
235        for j in 0..n {
236            let mut acc = BinnedAccumulatorF64::new();
237            for p in 0..k {
238                let av = a[i * k + p].to_f64();
239                let bv = b[p * n + j].to_f64();
240                acc.add(av * bv);
241            }
242            out[i * n + j] = acc.finalize();
243        }
244    }
245}
246
247// ---------------------------------------------------------------------------
248// Inline tests
249// ---------------------------------------------------------------------------
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_f16_zero() {
257        let z = F16::ZERO;
258        assert_eq!(z.to_f64(), 0.0);
259        assert!(z.to_f64().is_sign_positive());
260    }
261
262    #[test]
263    fn test_f16_neg_zero() {
264        let z = F16::NEG_ZERO;
265        assert_eq!(z.to_f64(), 0.0);
266        assert!(z.to_f64().is_sign_negative());
267    }
268
269    #[test]
270    fn test_f16_one() {
271        let one = F16::from_f64(1.0);
272        assert_eq!(one.to_f64(), 1.0);
273    }
274
275    #[test]
276    fn test_f16_max() {
277        let max = F16::MAX;
278        assert_eq!(max.to_f64(), 65504.0);
279    }
280
281    #[test]
282    fn test_f16_infinity() {
283        let inf = F16::INFINITY;
284        assert!(inf.to_f64().is_infinite());
285        assert!(inf.to_f64().is_sign_positive());
286    }
287
288    #[test]
289    fn test_f16_neg_infinity() {
290        let ninf = F16::NEG_INFINITY;
291        assert!(ninf.to_f64().is_infinite());
292        assert!(ninf.to_f64().is_sign_negative());
293    }
294
295    #[test]
296    fn test_f16_nan() {
297        let nan = F16::NAN;
298        assert!(nan.to_f64().is_nan());
299        assert!(nan.is_nan());
300    }
301
302    #[test]
303    fn test_f16_subnormal() {
304        let sub = F16::MIN_POSITIVE_SUBNORMAL;
305        let val = sub.to_f64();
306        assert!(val > 0.0);
307        assert!(sub.is_subnormal());
308        // Smallest f16 subnormal: 2^(-24) ≈ 5.96e-8
309        assert!((val - 5.960464477539063e-8).abs() < 1e-15);
310    }
311
312    #[test]
313    fn test_f16_roundtrip() {
314        let values = [0.0, 1.0, -1.0, 0.5, 2.0, 100.0, -0.25, 65504.0];
315        for &v in &values {
316            let f16 = F16::from_f64(v);
317            let back = f16.to_f64();
318            assert_eq!(back, v, "Roundtrip failed for {v}");
319        }
320    }
321
322    #[test]
323    fn test_f16_overflow_to_inf() {
324        let f16 = F16::from_f64(100000.0);
325        assert!(f16.is_infinite());
326    }
327
328    #[test]
329    fn test_f16_underflow_to_zero() {
330        let f16 = F16::from_f64(1e-10);
331        assert_eq!(f16.to_f64(), 0.0);
332    }
333
334    #[test]
335    fn test_f16_arithmetic() {
336        let a = F16::from_f64(2.0);
337        let b = F16::from_f64(3.0);
338        assert_eq!(a.add(b).to_f64(), 5.0);
339        assert_eq!(a.mul(b).to_f64(), 6.0);
340        assert_eq!(b.sub(a).to_f64(), 1.0);
341    }
342
343    #[test]
344    fn test_f16_neg_preserves_bits() {
345        let a = F16::from_f64(3.5);
346        let neg = a.neg();
347        assert_eq!(neg.to_f64(), -3.5);
348        // Double negation round-trips.
349        assert_eq!(neg.neg().0, a.0);
350    }
351
352    #[test]
353    fn test_f16_binned_sum_basic() {
354        let values: Vec<F16> = (0..10).map(|i| F16::from_f64(i as f64)).collect();
355        let result = f16_binned_sum(&values);
356        assert_eq!(result, 45.0);
357    }
358
359    #[test]
360    fn test_f16_binned_sum_order_invariant() {
361        let values: Vec<F16> = (0..200).map(|i| F16::from_f64(i as f64 * 0.5 - 50.0)).collect();
362        let mut reversed = values.clone();
363        reversed.reverse();
364
365        let r1 = f16_binned_sum(&values);
366        let r2 = f16_binned_sum(&reversed);
367        assert_eq!(r1.to_bits(), r2.to_bits(), "f16 sum must be order-invariant");
368    }
369
370    #[test]
371    fn test_f16_dot_basic() {
372        let a = vec![F16::from_f64(1.0), F16::from_f64(2.0), F16::from_f64(3.0)];
373        let b = vec![F16::from_f64(4.0), F16::from_f64(5.0), F16::from_f64(6.0)];
374        let result = f16_binned_dot(&a, &b);
375        assert_eq!(result, 32.0);
376    }
377
378    #[test]
379    fn test_f16_matmul_identity() {
380        let identity = vec![
381            F16::from_f64(1.0), F16::from_f64(0.0),
382            F16::from_f64(0.0), F16::from_f64(1.0),
383        ];
384        let b = vec![
385            F16::from_f64(3.0), F16::from_f64(4.0),
386            F16::from_f64(5.0), F16::from_f64(6.0),
387        ];
388        let mut out = vec![0.0f64; 4];
389        f16_matmul(&identity, &b, &mut out, 2, 2, 2);
390        assert_eq!(out, vec![3.0, 4.0, 5.0, 6.0]);
391    }
392
393    #[test]
394    fn test_f16_subnormal_accumulation() {
395        // Test that subnormals are correctly accumulated via the binned path.
396        let sub = F16::MIN_POSITIVE_SUBNORMAL;
397        let values = vec![sub; 1000];
398        let result = f16_binned_sum(&values);
399        let expected = sub.to_f64() * 1000.0;
400        assert!((result - expected).abs() < 1e-12,
401            "Subnormal accumulation: {result} vs expected {expected}");
402    }
403
404    #[test]
405    fn test_f16_signed_zero_preserved() {
406        let pz = F16::ZERO;
407        let nz = F16::NEG_ZERO;
408        assert!(pz.to_f64().is_sign_positive());
409        assert!(nz.to_f64().is_sign_negative());
410        // From f64 preserves sign.
411        assert_eq!(F16::from_f64(0.0).0, F16::ZERO.0);
412        assert_eq!(F16::from_f64(-0.0).0, F16::NEG_ZERO.0);
413    }
414}