1use crate::accumulator::BinnedAccumulatorF64;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub struct F16(pub u16);
33
34impl F16 {
35 pub const ZERO: F16 = F16(0x0000);
37 pub const NEG_ZERO: F16 = F16(0x8000);
39 pub const INFINITY: F16 = F16(0x7C00);
41 pub const NEG_INFINITY: F16 = F16(0xFC00);
43 pub const NAN: F16 = F16(0x7E00);
45 pub const MAX: F16 = F16(0x7BFF);
47 pub const MIN_POSITIVE_SUBNORMAL: F16 = F16(0x0001);
49
50 pub fn to_f64(self) -> f64 {
54 let bits = self.0;
55 let sign = (bits >> 15) & 1;
56 let exp = (bits >> 10) & 0x1F;
57 let mant = bits & 0x03FF;
58
59 let sign_f = if sign == 1 { -1.0 } else { 1.0 };
60
61 if exp == 0 {
62 if mant == 0 {
63 if sign == 1 { -0.0 } else { 0.0 }
65 } else {
66 sign_f * (mant as f64) * 2.0f64.powi(-24)
68 }
69 } else if exp == 0x1F {
70 if mant == 0 {
71 if sign == 1 { f64::NEG_INFINITY } else { f64::INFINITY }
73 } else {
74 f64::NAN
76 }
77 } else {
78 sign_f * 2.0f64.powi(exp as i32 - 15) * (1.0 + mant as f64 / 1024.0)
80 }
81 }
82
83 pub fn from_f64(value: f64) -> Self {
87 if value.is_nan() {
88 return F16::NAN;
89 }
90
91 let sign: u16 = if value.is_sign_negative() { 0x8000 } else { 0 };
92 let abs_val = value.abs();
93
94 if abs_val == 0.0 {
95 return F16(sign); }
97
98 if abs_val.is_infinite() {
99 return F16(sign | 0x7C00);
100 }
101
102 if abs_val > 65504.0 {
104 return F16(sign | 0x7C00);
105 }
106
107 if abs_val < 6.103515625e-5 {
109 let mant = (abs_val / 2.0f64.powi(-24)).round() as u16;
111 if mant == 0 {
112 return F16(sign); }
114 return F16(sign | mant.min(0x03FF));
115 }
116
117 let log2_val = abs_val.log2();
119 let exp = log2_val.floor() as i32;
120 let biased_exp = (exp + 15) as u16;
121
122 if biased_exp >= 31 {
123 return F16(sign | 0x7C00); }
125
126 let significand = abs_val / 2.0f64.powi(exp) - 1.0;
127 let mant = (significand * 1024.0).round() as u16;
128
129 if mant >= 1024 {
131 let biased_exp = biased_exp + 1;
132 if biased_exp >= 31 {
133 return F16(sign | 0x7C00);
134 }
135 return F16(sign | (biased_exp << 10));
136 }
137
138 F16(sign | (biased_exp << 10) | mant)
139 }
140
141 pub fn from_f32(value: f32) -> Self {
143 Self::from_f64(value as f64)
144 }
145
146 pub fn to_f32(self) -> f32 {
148 self.to_f64() as f32
149 }
150
151 pub fn is_nan(self) -> bool {
153 let exp = (self.0 >> 10) & 0x1F;
154 let mant = self.0 & 0x03FF;
155 exp == 0x1F && mant != 0
156 }
157
158 pub fn is_infinite(self) -> bool {
160 let exp = (self.0 >> 10) & 0x1F;
161 let mant = self.0 & 0x03FF;
162 exp == 0x1F && mant == 0
163 }
164
165 pub fn is_finite(self) -> bool {
167 let exp = (self.0 >> 10) & 0x1F;
168 exp != 0x1F
169 }
170
171 pub fn is_subnormal(self) -> bool {
173 let exp = (self.0 >> 10) & 0x1F;
174 let mant = self.0 & 0x03FF;
175 exp == 0 && mant != 0
176 }
177
178 pub fn add(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() + rhs.to_f64()) }
180 pub fn sub(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() - rhs.to_f64()) }
182 pub fn mul(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() * rhs.to_f64()) }
184 pub fn div(self, rhs: Self) -> Self { Self::from_f64(self.to_f64() / rhs.to_f64()) }
186 pub fn neg(self) -> Self { F16(self.0 ^ 0x8000) }
188}
189
190impl std::fmt::Display for F16 {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 write!(f, "{}", self.to_f64())
193 }
194}
195
196pub fn f16_binned_sum(values: &[F16]) -> f64 {
205 let mut acc = BinnedAccumulatorF64::new();
206 for &v in values {
207 acc.add(v.to_f64());
208 }
209 acc.finalize()
210}
211
212pub fn f16_binned_dot(a: &[F16], b: &[F16]) -> f64 {
214 debug_assert_eq!(a.len(), b.len());
215 let mut acc = BinnedAccumulatorF64::new();
216 for i in 0..a.len() {
217 acc.add(a[i].to_f64() * b[i].to_f64());
219 }
220 acc.finalize()
221}
222
223pub fn f16_matmul(
227 a: &[F16], b: &[F16], out: &mut [f64],
228 m: usize, k: usize, n: usize,
229) {
230 debug_assert_eq!(a.len(), m * k);
231 debug_assert_eq!(b.len(), k * n);
232 debug_assert_eq!(out.len(), m * n);
233
234 for i in 0..m {
235 for j in 0..n {
236 let mut acc = BinnedAccumulatorF64::new();
237 for p in 0..k {
238 let av = a[i * k + p].to_f64();
239 let bv = b[p * n + j].to_f64();
240 acc.add(av * bv);
241 }
242 out[i * n + j] = acc.finalize();
243 }
244 }
245}
246
247#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_f16_zero() {
257 let z = F16::ZERO;
258 assert_eq!(z.to_f64(), 0.0);
259 assert!(z.to_f64().is_sign_positive());
260 }
261
262 #[test]
263 fn test_f16_neg_zero() {
264 let z = F16::NEG_ZERO;
265 assert_eq!(z.to_f64(), 0.0);
266 assert!(z.to_f64().is_sign_negative());
267 }
268
269 #[test]
270 fn test_f16_one() {
271 let one = F16::from_f64(1.0);
272 assert_eq!(one.to_f64(), 1.0);
273 }
274
275 #[test]
276 fn test_f16_max() {
277 let max = F16::MAX;
278 assert_eq!(max.to_f64(), 65504.0);
279 }
280
281 #[test]
282 fn test_f16_infinity() {
283 let inf = F16::INFINITY;
284 assert!(inf.to_f64().is_infinite());
285 assert!(inf.to_f64().is_sign_positive());
286 }
287
288 #[test]
289 fn test_f16_neg_infinity() {
290 let ninf = F16::NEG_INFINITY;
291 assert!(ninf.to_f64().is_infinite());
292 assert!(ninf.to_f64().is_sign_negative());
293 }
294
295 #[test]
296 fn test_f16_nan() {
297 let nan = F16::NAN;
298 assert!(nan.to_f64().is_nan());
299 assert!(nan.is_nan());
300 }
301
302 #[test]
303 fn test_f16_subnormal() {
304 let sub = F16::MIN_POSITIVE_SUBNORMAL;
305 let val = sub.to_f64();
306 assert!(val > 0.0);
307 assert!(sub.is_subnormal());
308 assert!((val - 5.960464477539063e-8).abs() < 1e-15);
310 }
311
312 #[test]
313 fn test_f16_roundtrip() {
314 let values = [0.0, 1.0, -1.0, 0.5, 2.0, 100.0, -0.25, 65504.0];
315 for &v in &values {
316 let f16 = F16::from_f64(v);
317 let back = f16.to_f64();
318 assert_eq!(back, v, "Roundtrip failed for {v}");
319 }
320 }
321
322 #[test]
323 fn test_f16_overflow_to_inf() {
324 let f16 = F16::from_f64(100000.0);
325 assert!(f16.is_infinite());
326 }
327
328 #[test]
329 fn test_f16_underflow_to_zero() {
330 let f16 = F16::from_f64(1e-10);
331 assert_eq!(f16.to_f64(), 0.0);
332 }
333
334 #[test]
335 fn test_f16_arithmetic() {
336 let a = F16::from_f64(2.0);
337 let b = F16::from_f64(3.0);
338 assert_eq!(a.add(b).to_f64(), 5.0);
339 assert_eq!(a.mul(b).to_f64(), 6.0);
340 assert_eq!(b.sub(a).to_f64(), 1.0);
341 }
342
343 #[test]
344 fn test_f16_neg_preserves_bits() {
345 let a = F16::from_f64(3.5);
346 let neg = a.neg();
347 assert_eq!(neg.to_f64(), -3.5);
348 assert_eq!(neg.neg().0, a.0);
350 }
351
352 #[test]
353 fn test_f16_binned_sum_basic() {
354 let values: Vec<F16> = (0..10).map(|i| F16::from_f64(i as f64)).collect();
355 let result = f16_binned_sum(&values);
356 assert_eq!(result, 45.0);
357 }
358
359 #[test]
360 fn test_f16_binned_sum_order_invariant() {
361 let values: Vec<F16> = (0..200).map(|i| F16::from_f64(i as f64 * 0.5 - 50.0)).collect();
362 let mut reversed = values.clone();
363 reversed.reverse();
364
365 let r1 = f16_binned_sum(&values);
366 let r2 = f16_binned_sum(&reversed);
367 assert_eq!(r1.to_bits(), r2.to_bits(), "f16 sum must be order-invariant");
368 }
369
370 #[test]
371 fn test_f16_dot_basic() {
372 let a = vec![F16::from_f64(1.0), F16::from_f64(2.0), F16::from_f64(3.0)];
373 let b = vec![F16::from_f64(4.0), F16::from_f64(5.0), F16::from_f64(6.0)];
374 let result = f16_binned_dot(&a, &b);
375 assert_eq!(result, 32.0);
376 }
377
378 #[test]
379 fn test_f16_matmul_identity() {
380 let identity = vec![
381 F16::from_f64(1.0), F16::from_f64(0.0),
382 F16::from_f64(0.0), F16::from_f64(1.0),
383 ];
384 let b = vec![
385 F16::from_f64(3.0), F16::from_f64(4.0),
386 F16::from_f64(5.0), F16::from_f64(6.0),
387 ];
388 let mut out = vec![0.0f64; 4];
389 f16_matmul(&identity, &b, &mut out, 2, 2, 2);
390 assert_eq!(out, vec![3.0, 4.0, 5.0, 6.0]);
391 }
392
393 #[test]
394 fn test_f16_subnormal_accumulation() {
395 let sub = F16::MIN_POSITIVE_SUBNORMAL;
397 let values = vec![sub; 1000];
398 let result = f16_binned_sum(&values);
399 let expected = sub.to_f64() * 1000.0;
400 assert!((result - expected).abs() < 1e-12,
401 "Subnormal accumulation: {result} vs expected {expected}");
402 }
403
404 #[test]
405 fn test_f16_signed_zero_preserved() {
406 let pz = F16::ZERO;
407 let nz = F16::NEG_ZERO;
408 assert!(pz.to_f64().is_sign_positive());
409 assert!(nz.to_f64().is_sign_negative());
410 assert_eq!(F16::from_f64(0.0).0, F16::ZERO.0);
412 assert_eq!(F16::from_f64(-0.0).0, F16::NEG_ZERO.0);
413 }
414}