1use crate::accumulator::BinnedAccumulatorF64;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub struct F16(pub u16);
33
34impl F16 {
35 pub const ZERO: F16 = F16(0x0000);
37 pub const NEG_ZERO: F16 = F16(0x8000);
39 pub const INFINITY: F16 = F16(0x7C00);
41 pub const NEG_INFINITY: F16 = F16(0xFC00);
43 pub const NAN: F16 = F16(0x7E00);
45 pub const MAX: F16 = F16(0x7BFF);
47 pub const MIN_POSITIVE_SUBNORMAL: F16 = F16(0x0001);
49
50 pub fn to_f64(self) -> f64 {
54 let bits = self.0;
55 let sign = (bits >> 15) & 1;
56 let exp = (bits >> 10) & 0x1F;
57 let mant = bits & 0x03FF;
58
59 let sign_f = if sign == 1 { -1.0 } else { 1.0 };
60
61 if exp == 0 {
62 if mant == 0 {
63 if sign == 1 { -0.0 } else { 0.0 }
65 } else {
66 sign_f * (mant as f64) * 2.0f64.powi(-24)
68 }
69 } else if exp == 0x1F {
70 if mant == 0 {
71 if sign == 1 { f64::NEG_INFINITY } else { f64::INFINITY }
73 } else {
74 f64::NAN
76 }
77 } else {
78 sign_f * 2.0f64.powi(exp as i32 - 15) * (1.0 + mant as f64 / 1024.0)
80 }
81 }
82
83 pub fn from_f64(value: f64) -> Self {
87 if value.is_nan() {
88 return F16::NAN;
89 }
90
91 let sign: u16 = if value.is_sign_negative() { 0x8000 } else { 0 };
92 let abs_val = value.abs();
93
94 if abs_val == 0.0 {
95 return F16(sign); }
97
98 if abs_val.is_infinite() {
99 return F16(sign | 0x7C00);
100 }
101
102 if abs_val > 65504.0 {
104 return F16(sign | 0x7C00);
105 }
106
107 if abs_val < 6.103515625e-5 {
109 let mant = (abs_val / 2.0f64.powi(-24)).round() as u16;
111 if mant == 0 {
112 return F16(sign); }
114 return F16(sign | mant.min(0x03FF));
115 }
116
117 let log2_val = abs_val.log2();
119 let exp = log2_val.floor() as i32;
120 let biased_exp = (exp + 15) as u16;
121
122 if biased_exp >= 31 {
123 return F16(sign | 0x7C00); }
125
126 let significand = abs_val / 2.0f64.powi(exp) - 1.0;
127 let mant = (significand * 1024.0).round() as u16;
128
129 if mant >= 1024 {
131 let biased_exp = biased_exp + 1;
132 if biased_exp >= 31 {
133 return F16(sign | 0x7C00);
134 }
135 return F16(sign | (biased_exp << 10));
136 }
137
138 F16(sign | (biased_exp << 10) | mant)
139 }
140
141 pub fn from_f32(value: f32) -> Self {
143 Self::from_f64(value as f64)
144 }
145
146 pub fn to_f32(self) -> f32 {
148 self.to_f64() as f32
149 }
150
151 pub fn is_nan(self) -> bool {
153 let exp = (self.0 >> 10) & 0x1F;
154 let mant = self.0 & 0x03FF;
155 exp == 0x1F && mant != 0
156 }
157
158 pub fn is_infinite(self) -> bool {
160 let exp = (self.0 >> 10) & 0x1F;
161 let mant = self.0 & 0x03FF;
162 exp == 0x1F && mant == 0
163 }
164
165 pub fn is_finite(self) -> bool {
167 let exp = (self.0 >> 10) & 0x1F;
168 exp != 0x1F
169 }
170
171 pub fn is_subnormal(self) -> bool {
173 let exp = (self.0 >> 10) & 0x1F;
174 let mant = self.0 & 0x03FF;
175 exp == 0 && mant != 0
176 }
177
178 pub fn add(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() + rhs.to_f64()) }
180 pub fn sub(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() - rhs.to_f64()) }
181 pub fn mul(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() * rhs.to_f64()) }
182 pub fn div(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() / rhs.to_f64()) }
183 pub fn neg(self) -> Self { F16(self.0 ^ 0x8000) }
184}
185
186impl std::fmt::Display for F16 {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 write!(f, "{}", self.to_f64())
189 }
190}
191
192pub fn f16_binned_sum(values: &[F16]) -> f64 {
201 let mut acc = BinnedAccumulatorF64::new();
202 for &v in values {
203 acc.add(v.to_f64());
204 }
205 acc.finalize()
206}
207
208pub fn f16_binned_dot(a: &[F16], b: &[F16]) -> f64 {
210 debug_assert_eq!(a.len(), b.len());
211 let mut acc = BinnedAccumulatorF64::new();
212 for i in 0..a.len() {
213 acc.add(a[i].to_f64() * b[i].to_f64());
215 }
216 acc.finalize()
217}
218
219pub fn f16_matmul(
223 a: &[F16], b: &[F16], out: &mut [f64],
224 m: usize, k: usize, n: usize,
225) {
226 debug_assert_eq!(a.len(), m * k);
227 debug_assert_eq!(b.len(), k * n);
228 debug_assert_eq!(out.len(), m * n);
229
230 for i in 0..m {
231 for j in 0..n {
232 let mut acc = BinnedAccumulatorF64::new();
233 for p in 0..k {
234 let av = a[i * k + p].to_f64();
235 let bv = b[p * n + j].to_f64();
236 acc.add(av * bv);
237 }
238 out[i * n + j] = acc.finalize();
239 }
240 }
241}
242
243#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_f16_zero() {
253 let z = F16::ZERO;
254 assert_eq!(z.to_f64(), 0.0);
255 assert!(z.to_f64().is_sign_positive());
256 }
257
258 #[test]
259 fn test_f16_neg_zero() {
260 let z = F16::NEG_ZERO;
261 assert_eq!(z.to_f64(), 0.0);
262 assert!(z.to_f64().is_sign_negative());
263 }
264
265 #[test]
266 fn test_f16_one() {
267 let one = F16::from_f64(1.0);
268 assert_eq!(one.to_f64(), 1.0);
269 }
270
271 #[test]
272 fn test_f16_max() {
273 let max = F16::MAX;
274 assert_eq!(max.to_f64(), 65504.0);
275 }
276
277 #[test]
278 fn test_f16_infinity() {
279 let inf = F16::INFINITY;
280 assert!(inf.to_f64().is_infinite());
281 assert!(inf.to_f64().is_sign_positive());
282 }
283
284 #[test]
285 fn test_f16_neg_infinity() {
286 let ninf = F16::NEG_INFINITY;
287 assert!(ninf.to_f64().is_infinite());
288 assert!(ninf.to_f64().is_sign_negative());
289 }
290
291 #[test]
292 fn test_f16_nan() {
293 let nan = F16::NAN;
294 assert!(nan.to_f64().is_nan());
295 assert!(nan.is_nan());
296 }
297
298 #[test]
299 fn test_f16_subnormal() {
300 let sub = F16::MIN_POSITIVE_SUBNORMAL;
301 let val = sub.to_f64();
302 assert!(val > 0.0);
303 assert!(sub.is_subnormal());
304 assert!((val - 5.960464477539063e-8).abs() < 1e-15);
306 }
307
308 #[test]
309 fn test_f16_roundtrip() {
310 let values = [0.0, 1.0, -1.0, 0.5, 2.0, 100.0, -0.25, 65504.0];
311 for &v in &values {
312 let f16 = F16::from_f64(v);
313 let back = f16.to_f64();
314 assert_eq!(back, v, "Roundtrip failed for {v}");
315 }
316 }
317
318 #[test]
319 fn test_f16_overflow_to_inf() {
320 let f16 = F16::from_f64(100000.0);
321 assert!(f16.is_infinite());
322 }
323
324 #[test]
325 fn test_f16_underflow_to_zero() {
326 let f16 = F16::from_f64(1e-10);
327 assert_eq!(f16.to_f64(), 0.0);
328 }
329
330 #[test]
331 fn test_f16_arithmetic() {
332 let a = F16::from_f64(2.0);
333 let b = F16::from_f64(3.0);
334 assert_eq!(a.add(b).to_f64(), 5.0);
335 assert_eq!(a.mul(b).to_f64(), 6.0);
336 assert_eq!(b.sub(a).to_f64(), 1.0);
337 }
338
339 #[test]
340 fn test_f16_neg_preserves_bits() {
341 let a = F16::from_f64(3.5);
342 let neg = a.neg();
343 assert_eq!(neg.to_f64(), -3.5);
344 assert_eq!(neg.neg().0, a.0);
346 }
347
348 #[test]
349 fn test_f16_binned_sum_basic() {
350 let values: Vec<F16> = (0..10).map(|i| F16::from_f64(i as f64)).collect();
351 let result = f16_binned_sum(&values);
352 assert_eq!(result, 45.0);
353 }
354
355 #[test]
356 fn test_f16_binned_sum_order_invariant() {
357 let values: Vec<F16> = (0..200).map(|i| F16::from_f64(i as f64 * 0.5 - 50.0)).collect();
358 let mut reversed = values.clone();
359 reversed.reverse();
360
361 let r1 = f16_binned_sum(&values);
362 let r2 = f16_binned_sum(&reversed);
363 assert_eq!(r1.to_bits(), r2.to_bits(), "f16 sum must be order-invariant");
364 }
365
366 #[test]
367 fn test_f16_dot_basic() {
368 let a = vec![F16::from_f64(1.0), F16::from_f64(2.0), F16::from_f64(3.0)];
369 let b = vec![F16::from_f64(4.0), F16::from_f64(5.0), F16::from_f64(6.0)];
370 let result = f16_binned_dot(&a, &b);
371 assert_eq!(result, 32.0);
372 }
373
374 #[test]
375 fn test_f16_matmul_identity() {
376 let identity = vec![
377 F16::from_f64(1.0), F16::from_f64(0.0),
378 F16::from_f64(0.0), F16::from_f64(1.0),
379 ];
380 let b = vec![
381 F16::from_f64(3.0), F16::from_f64(4.0),
382 F16::from_f64(5.0), F16::from_f64(6.0),
383 ];
384 let mut out = vec![0.0f64; 4];
385 f16_matmul(&identity, &b, &mut out, 2, 2, 2);
386 assert_eq!(out, vec![3.0, 4.0, 5.0, 6.0]);
387 }
388
389 #[test]
390 fn test_f16_subnormal_accumulation() {
391 let sub = F16::MIN_POSITIVE_SUBNORMAL;
393 let values = vec![sub; 1000];
394 let result = f16_binned_sum(&values);
395 let expected = sub.to_f64() * 1000.0;
396 assert!((result - expected).abs() < 1e-12,
397 "Subnormal accumulation: {result} vs expected {expected}");
398 }
399
400 #[test]
401 fn test_f16_signed_zero_preserved() {
402 let pz = F16::ZERO;
403 let nz = F16::NEG_ZERO;
404 assert!(pz.to_f64().is_sign_positive());
405 assert!(nz.to_f64().is_sign_negative());
406 assert_eq!(F16::from_f64(0.0).0, F16::ZERO.0);
408 assert_eq!(F16::from_f64(-0.0).0, F16::NEG_ZERO.0);
409 }
410}