1use crate::{ArgLength, DataType, Error, Result};
2
3const fn f16_to_f64(bits: u16) -> f64 {
6 let bits = bits as u64;
7 let sign = (bits >> 15) & 1;
8 let exp = (bits >> 10) & 0x1f;
9 let sig = bits & 0x03ff;
10
11 let bits64 = if exp == 0 {
12 if sig == 0 {
13 sign << 63
14 } else {
15 let shift = sig.leading_zeros() - (64 - 10);
16 let sig = (sig << (shift + 1)) & 0x03ff;
17 let exp64 = 1023 - 15 - shift as u64;
18 sign << 63 | exp64 << 52 | sig << 42
19 }
20 } else if exp == 0x1f {
21 sign << 63 | 0x7ff0_0000_0000_0000 | sig << 42
22 } else {
23 let exp64 = exp + (1023 - 15);
24 sign << 63 | exp64 << 52 | sig << 42
25 };
26
27 f64::from_bits(bits64)
28}
29
30const fn f16_to_f32(bits: u16) -> f32 {
31 let bits = bits as u32;
32 let sign = (bits >> 15) & 1;
33 let exp = (bits >> 10) & 0x1f;
34 let sig = bits & 0x03ff;
35
36 let bits32 = if exp == 0 {
37 if sig == 0 {
38 sign << 31
39 } else {
40 let shift = sig.leading_zeros() - (32 - 10);
41 let sig = (sig << (shift + 1)) & 0x03ff;
42 let exp32 = 127 - 15 - shift;
43 (sign << 31) | (exp32 << 23) | (sig << 13)
44 }
45 } else if exp == 0x1f {
46 (sign << 31) | 0x7f80_0000 | (sig << 13)
47 } else {
48 let exp32 = exp + (127 - 15);
49 (sign << 31) | (exp32 << 23) | (sig << 13)
50 };
51
52 f32::from_bits(bits32)
53}
54
55const fn f64_to_f16(value: f64) -> u16 {
57 let bits = value.to_bits();
58 let sign_bit = ((bits >> 48) & 0x8000) as u16; let exp = ((bits >> 52) & 0x7ff) as i32; let sig = bits & 0x000f_ffff_ffff_ffff; match exp {
63 0 => return sign_bit,
64
65 0x7ff => {
66 if sig == 0 {
67 return sign_bit | 0x7c00;
68 } else {
69 let sig16 = (sig >> 42) as u16;
70 return sign_bit | 0x7c00 | if sig16 == 0 { 1 } else { sig16 }; }
72 }
73
74 _ => (),
75 }
76
77 let exp16 = exp - 1008;
78
79 if exp16 >= 0x1f {
80 return sign_bit | 0x7c00;
81 }
82
83 if exp16 <= 0 {
84 let full_sig = sig | 0x0010_0000_0000_0000;
85 let shift = (1 - exp16) as u64 + 42;
86
87 if shift >= 64 {
88 if shift == 64 && full_sig > (1_u64 << 52) {
89 return sign_bit | 1;
90 } else {
91 return sign_bit;
92 }
93 } else {
94 let shifted = full_sig >> shift;
95 let remainder = full_sig & ((1_u64 << shift) - 1);
96 let halfway = 1_u64 << (shift - 1);
97 let round_up = remainder > halfway || (remainder == halfway && (shifted & 1) != 0);
98 let sig16 = (shifted as u16) + round_up as u16;
99 return sign_bit | sig16;
100 }
101 }
102
103 let sig10 = (sig >> 42) as u16;
104 let remainder = sig & 0x3ff_ffff_ffff;
105 let halfway = 0x200_0000_0000_u64;
106 let round_up = remainder > halfway || (remainder == halfway && (sig10 & 1) != 0);
107 let sig16 = sig10 + round_up as u16;
108
109 if sig16 >= 0x0400 {
110 sign_bit | (((exp16 as u16) + 1) << 10)
111 } else {
112 sign_bit | ((exp16 as u16) << 10) | sig16
113 }
114}
115
116const fn f32_nan_to_f64(bits: u32) -> f64 {
118 let sign_bit = ((bits & 0x8000_0000) as u64) << 32;
119 let payload = ((bits & 0x007f_ffff) as u64) << 29;
120 f64::from_bits(sign_bit | 0x7ff0_0000_0000_0000 | payload)
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
125enum Inner {
126 F16(u16),
127 F32(u32),
128 F64(u64),
129}
130
131impl Inner {
132 const fn new(x: f64) -> Self {
133 if x.is_finite() {
134 let bits16 = f64_to_f16(x);
135
136 if f16_to_f64(bits16).to_bits() == x.to_bits() {
137 Inner::F16(bits16)
138 } else if ((x as f32) as f64).to_bits() == x.to_bits() {
139 Inner::F32((x as f32).to_bits())
140 } else {
141 Inner::F64(x.to_bits())
142 }
143 } else {
144 let bits64 = x.to_bits();
145 let sign_bit = bits64 & 0x8000_0000_0000_0000;
146
147 if (bits64 & 0x3ff_ffff_ffff) == 0 {
148 let bits = (bits64 >> 42) & 0x7fff | (sign_bit >> 48);
149 Self::F16(bits as u16)
150 } else if (bits64 & 0x1fff_ffff) == 0 {
151 let bits = (bits64 >> 29) & 0x7fff_ffff | (sign_bit >> 32);
152 Self::F32(bits as u32)
153 } else {
154 Self::F64(bits64)
155 }
156 }
157 }
158}
159
160#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
166pub struct Float(Inner);
167
168impl Float {
169 #[must_use]
171 pub const fn data_type(&self) -> DataType {
172 match self.0 {
173 Inner::F16(_) => DataType::Float16,
174 Inner::F32(_) => DataType::Float32,
175 Inner::F64(_) => DataType::Float64,
176 }
177 }
178
179 pub(crate) const fn cbor_argument(&self) -> (u8, u64) {
180 match self.0 {
181 Inner::F16(arg) => (ArgLength::U16, arg as u64),
182 Inner::F32(arg) => (ArgLength::U32, arg as u64),
183 Inner::F64(arg) => (ArgLength::U64, arg),
184 }
185 }
186
187 pub(crate) const fn from_u16(bits: u16) -> Self {
188 Self(Inner::F16(bits))
189 }
190
191 pub(crate) const fn from_u32(bits: u32) -> Result<Self> {
192 let float = Self(Inner::F32(bits));
193 if matches!(Inner::new(float.to_f64()), Inner::F32(_)) {
194 Ok(float)
195 } else {
196 Err(Error::Precision)
197 }
198 }
199
200 pub(crate) const fn from_u64(bits: u64) -> Result<Self> {
201 let float = Self(Inner::F64(bits));
202 if matches!(Inner::new(float.to_f64()), Inner::F64(_)) {
203 Ok(float)
204 } else {
205 Err(Error::Precision)
206 }
207 }
208
209 #[must_use]
211 pub const fn to_f64(self) -> f64 {
212 match self.0 {
213 Inner::F16(bits) => f16_to_f64(bits),
214 Inner::F32(bits) => {
215 let f = f32::from_bits(bits);
216 if f.is_nan() { f32_nan_to_f64(bits) } else { f as f64 }
217 }
218 Inner::F64(bits) => f64::from_bits(bits),
219 }
220 }
221
222 pub const fn to_f32(self) -> Result<f32> {
224 match self.0 {
225 Inner::F16(bits) => Ok(f16_to_f32(bits)),
226 Inner::F32(bits) => Ok(f32::from_bits(bits)),
227 Inner::F64(_) => Err(Error::Precision),
228 }
229 }
230}
231
232impl From<f64> for Float {
235 fn from(value: f64) -> Self {
236 Self(Inner::new(value))
237 }
238}
239
240impl From<f32> for Float {
241 fn from(value: f32) -> Self {
242 if value.is_nan() {
243 Self(Inner::new(f32_nan_to_f64(value.to_bits())))
245 } else {
246 Self(Inner::new(value as f64))
247 }
248 }
249}
250
251impl From<u8> for Float {
254 fn from(value: u8) -> Self {
255 Self::from(value as f64)
256 }
257}
258
259impl From<u16> for Float {
260 fn from(value: u16) -> Self {
261 Self::from(value as f64)
262 }
263}
264
265impl From<u32> for Float {
266 fn from(value: u32) -> Self {
267 Self::from(value as f64)
268 }
269}
270
271impl From<i8> for Float {
272 fn from(value: i8) -> Self {
273 Self::from(value as f64)
274 }
275}
276
277impl From<i16> for Float {
278 fn from(value: i16) -> Self {
279 Self::from(value as f64)
280 }
281}
282
283impl From<i32> for Float {
284 fn from(value: i32) -> Self {
285 Self::from(value as f64)
286 }
287}
288
289impl From<bool> for Float {
290 fn from(value: bool) -> Self {
291 Self(if value { Inner::new(1.0) } else { Inner::new(0.0) })
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 fn f16_is_nan(bits: u16) -> bool {
300 (bits & 0x7fff) > 0x7c00
301 }
302
303 #[test]
308 fn to_f64_zero() {
309 assert_eq!(f16_to_f64(0x0000), 0.0);
310 assert!(f16_to_f64(0x0000).is_sign_positive());
311 }
312
313 #[test]
314 fn to_f64_neg_zero() {
315 let v = f16_to_f64(0x8000);
316 assert_eq!(v.to_bits(), (-0.0_f64).to_bits());
317 }
318
319 #[test]
320 fn to_f64_one() {
321 assert_eq!(f16_to_f64(0x3c00), 1.0);
322 }
323
324 #[test]
325 fn to_f64_neg_one() {
326 assert_eq!(f16_to_f64(0xbc00), -1.0);
327 }
328
329 #[test]
330 fn to_f64_max_normal() {
331 assert_eq!(f16_to_f64(0x7bff), 65504.0);
332 }
333
334 #[test]
335 fn to_f64_min_positive_normal() {
336 assert_eq!(f16_to_f64(0x0400), 0.00006103515625);
337 }
338
339 #[test]
340 fn to_f64_min_positive_subnormal() {
341 assert_eq!(f16_to_f64(0x0001), 5.960464477539063e-8);
342 }
343
344 #[test]
345 fn to_f64_max_subnormal() {
346 assert_eq!(f16_to_f64(0x03ff), 0.00006097555160522461);
347 }
348
349 #[test]
350 fn to_f64_infinity() {
351 assert_eq!(f16_to_f64(0x7c00), f64::INFINITY);
352 }
353
354 #[test]
355 fn to_f64_neg_infinity() {
356 assert_eq!(f16_to_f64(0xfc00), f64::NEG_INFINITY);
357 }
358
359 #[test]
360 fn to_f64_nan() {
361 assert!(f16_to_f64(0x7e00).is_nan());
362 }
363
364 #[test]
365 fn to_f64_nan_preserves_payload() {
366 let bits = f16_to_f64(0x7c01).to_bits();
367 assert_eq!(bits, 0x7ff0_0400_0000_0000);
368 }
369
370 #[test]
371 fn to_f64_two() {
372 assert_eq!(f16_to_f64(0x4000), 2.0);
373 }
374
375 #[test]
376 fn to_f64_one_point_five() {
377 assert_eq!(f16_to_f64(0x3e00), 1.5);
378 }
379
380 #[test]
385 fn to_f32_zero() {
386 assert_eq!(f16_to_f32(0x0000), 0.0_f32);
387 assert!(f16_to_f32(0x0000).is_sign_positive());
388 }
389
390 #[test]
391 fn to_f32_neg_zero() {
392 assert_eq!(f16_to_f32(0x8000).to_bits(), (-0.0_f32).to_bits());
393 }
394
395 #[test]
396 fn to_f32_one() {
397 assert_eq!(f16_to_f32(0x3c00), 1.0_f32);
398 }
399
400 #[test]
401 fn to_f32_neg_one() {
402 assert_eq!(f16_to_f32(0xbc00), -1.0_f32);
403 }
404
405 #[test]
406 fn to_f32_two() {
407 assert_eq!(f16_to_f32(0x4000), 2.0_f32);
408 }
409
410 #[test]
411 fn to_f32_one_point_five() {
412 assert_eq!(f16_to_f32(0x3e00), 1.5_f32);
413 }
414
415 #[test]
416 fn to_f32_max_normal() {
417 assert_eq!(f16_to_f32(0x7bff), 65504.0_f32);
418 }
419
420 #[test]
421 fn to_f32_min_positive_normal() {
422 assert_eq!(f16_to_f32(0x0400), 0.000061035156_f32);
423 }
424
425 #[test]
426 fn to_f32_min_positive_subnormal() {
427 assert_eq!(f16_to_f32(0x0001), 5.9604645e-8_f32);
428 }
429
430 #[test]
431 fn to_f32_max_subnormal() {
432 assert_eq!(f16_to_f32(0x03ff), 0.00006097555_f32);
433 }
434
435 #[test]
436 fn to_f32_infinity() {
437 assert_eq!(f16_to_f32(0x7c00), f32::INFINITY);
438 }
439
440 #[test]
441 fn to_f32_neg_infinity() {
442 assert_eq!(f16_to_f32(0xfc00), f32::NEG_INFINITY);
443 }
444
445 #[test]
446 fn to_f32_nan() {
447 assert!(f16_to_f32(0x7e00).is_nan());
448 }
449
450 #[test]
451 fn to_f32_nan_preserves_payload() {
452 let bits = f16_to_f32(0x7c01).to_bits();
453 assert_eq!(bits, 0x7f80_2000);
455 }
456
457 #[test]
458 fn to_f32_agrees_with_f16_to_f64() {
459 for bits in 0..=0x7fff_u16 {
461 if f16_is_nan(bits) {
462 continue;
463 }
464 let via_f32 = f16_to_f32(bits);
465 let via_f64 = f16_to_f64(bits) as f32;
466 assert_eq!(via_f32.to_bits(), via_f64.to_bits(), "mismatch for bits 0x{bits:04x}");
467
468 let neg = bits | 0x8000;
469 let via_f32n = f16_to_f32(neg);
470 let via_f64n = f16_to_f64(neg) as f32;
471 assert_eq!(via_f32n.to_bits(), via_f64n.to_bits(), "mismatch for bits 0x{neg:04x}");
472 }
473 }
474
475 #[test]
480 fn from_f64_zero() {
481 assert_eq!(f64_to_f16(0.0), 0x0000);
482 }
483
484 #[test]
485 fn from_f64_neg_zero() {
486 assert_eq!(f64_to_f16(-0.0), 0x8000);
487 }
488
489 #[test]
490 fn from_f64_one() {
491 assert_eq!(f64_to_f16(1.0), 0x3c00);
492 }
493
494 #[test]
495 fn from_f64_neg_one() {
496 assert_eq!(f64_to_f16(-1.0), 0xbc00);
497 }
498
499 #[test]
500 fn from_f64_max_normal() {
501 assert_eq!(f64_to_f16(65504.0), 0x7bff);
502 }
503
504 #[test]
505 fn from_f64_overflow_to_infinity() {
506 assert_eq!(f64_to_f16(65520.0), 0x7c00);
507 }
508
509 #[test]
510 fn from_f64_infinity() {
511 assert_eq!(f64_to_f16(f64::INFINITY), 0x7c00);
512 }
513
514 #[test]
515 fn from_f64_neg_infinity() {
516 assert_eq!(f64_to_f16(f64::NEG_INFINITY), 0xfc00);
517 }
518
519 #[test]
520 fn from_f64_nan() {
521 assert!(f16_is_nan(f64_to_f16(f64::NAN)));
522 }
523
524 #[test]
525 fn from_f64_min_positive_subnormal() {
526 assert_eq!(f64_to_f16(5.960464477539063e-8), 0x0001);
527 }
528
529 #[test]
530 fn from_f64_min_positive_normal() {
531 assert_eq!(f64_to_f16(0.00006103515625), 0x0400);
532 }
533
534 #[test]
539 fn rounding_exactly_halfway_rounds_to_even_down() {
540 let halfway = f64::from_bits(0x3FF0_0200_0000_0000);
541 assert_eq!(f64_to_f16(halfway), 0x3c00);
542 }
543
544 #[test]
545 fn rounding_exactly_halfway_rounds_to_even_up() {
546 let halfway = f64::from_bits(0x3FF0_0600_0000_0000);
547 assert_eq!(f64_to_f16(halfway), 0x3c02);
548 }
549
550 #[test]
551 fn rounding_just_below_halfway_rounds_down() {
552 let below = f64::from_bits(0x3FF0_01FF_FFFF_FFFF);
553 assert_eq!(f64_to_f16(below), 0x3c00);
554 }
555
556 #[test]
557 fn rounding_just_above_halfway_rounds_up() {
558 let above = f64::from_bits(0x3FF0_0200_0000_0001);
559 assert_eq!(f64_to_f16(above), 0x3c01);
560 }
561
562 #[test]
563 fn rounding_subnormal_halfway_rounds_to_even() {
564 let val = 1.5 * 5.960464477539063e-8;
565 assert_eq!(f64_to_f16(val), 0x0002);
566 }
567
568 #[test]
569 fn rounding_subnormal_halfway_even_down() {
570 let val = 2.5 * 5.960464477539063e-8;
571 assert_eq!(f64_to_f16(val), 0x0002);
572 }
573
574 #[test]
575 fn rounding_normal_to_subnormal_boundary() {
576 let min_normal = 0.00006103515625_f64;
577 assert_eq!(f64_to_f16(min_normal), 0x0400);
578
579 let below = f64::from_bits(min_normal.to_bits() - 1);
580 assert_eq!(f64_to_f16(below), 0x0400);
581 }
582
583 #[test]
584 fn rounding_overflow_at_max() {
585 assert_eq!(f64_to_f16(65504.0), 0x7bff);
586 assert_eq!(f64_to_f16(65519.99), 0x7bff);
587 assert_eq!(f64_to_f16(65520.0), 0x7c00);
588 }
589
590 #[test]
591 fn rounding_tiny_to_zero() {
592 assert_eq!(f64_to_f16(1e-30), 0x0000);
593 assert_eq!(f64_to_f16(-1e-30), 0x8000);
594 }
595
596 #[test]
597 fn rounding_tiny_to_min_subnormal() {
598 let half_min: f64 = 0.5 * 5.960464477539063e-8;
599 assert_eq!(f64_to_f16(half_min), 0x0000);
600
601 let above = f64::from_bits(half_min.to_bits() + 1);
602 assert_eq!(f64_to_f16(above), 0x0001);
603 }
604
605 #[test]
610 fn roundtrip_all_exact_f16_values() {
611 for bits in 0..=0x7fff_u16 {
612 if f16_is_nan(bits) {
613 continue;
614 }
615 let f = f16_to_f64(bits);
616 let h2 = f64_to_f16(f);
617 assert_eq!(bits, h2, "roundtrip failed for bits 0x{bits:04x}");
618
619 let neg_bits = bits | 0x8000;
621 let fn_ = f16_to_f64(neg_bits);
622 let hn2 = f64_to_f16(fn_);
623 assert_eq!(neg_bits, hn2, "roundtrip failed for bits 0x{neg_bits:04x}");
624 }
625 }
626}