1use crate::{
10 DataType, Error, Result,
11 codec::{Argument, Head, Major},
12 view::ValueView,
13};
14
15const fn f16_to_f64(bits: u16) -> f64 {
23 let bits = bits as u64;
24 let sign = (bits >> 15) & 1;
25 let exp = (bits >> 10) & 0x1f;
26 let sig = bits & 0x03ff;
27
28 let bits64 = if exp == 0 {
29 if sig == 0 {
30 sign << 63
31 } else {
32 let shift = sig.leading_zeros() - (64 - 10);
33 let sig = (sig << (shift + 1)) & 0x03ff;
34 let exp64 = 1023 - 15 - shift as u64;
35 sign << 63 | exp64 << 52 | sig << 42
36 }
37 } else if exp == 0x1f {
38 sign << 63 | 0x7ff0_0000_0000_0000 | sig << 42
39 } else {
40 let exp64 = exp + (1023 - 15);
41 sign << 63 | exp64 << 52 | sig << 42
42 };
43
44 f64::from_bits(bits64)
45}
46
47const fn f16_to_f32(bits: u16) -> f32 {
49 let bits = bits as u32;
50 let sign = (bits >> 15) & 1;
51 let exp = (bits >> 10) & 0x1f;
52 let sig = bits & 0x03ff;
53
54 let bits32 = if exp == 0 {
55 if sig == 0 {
56 sign << 31
57 } else {
58 let shift = sig.leading_zeros() - (32 - 10);
59 let sig = (sig << (shift + 1)) & 0x03ff;
60 let exp32 = 127 - 15 - shift;
61 (sign << 31) | (exp32 << 23) | (sig << 13)
62 }
63 } else if exp == 0x1f {
64 (sign << 31) | 0x7f80_0000 | (sig << 13)
65 } else {
66 let exp32 = exp + (127 - 15);
67 (sign << 31) | (exp32 << 23) | (sig << 13)
68 };
69
70 f32::from_bits(bits32)
71}
72
73const fn f64_to_f16(value: f64) -> u16 {
79 let bits = value.to_bits();
80 let sign_bit = ((bits >> 48) & 0x8000) as u16; let exp = ((bits >> 52) & 0x7ff) as i32; let sig = bits & 0x000f_ffff_ffff_ffff; match exp {
85 0 => return sign_bit,
86
87 0x7ff => {
88 if sig == 0 {
89 return sign_bit | 0x7c00;
90 } else {
91 let sig16 = (sig >> 42) as u16;
92 return sign_bit | 0x7c00 | if sig16 == 0 { 1 } else { sig16 }; }
94 }
95
96 _ => (),
97 }
98
99 let exp16 = exp - 1008;
100
101 if exp16 >= 0x1f {
102 return sign_bit | 0x7c00;
103 }
104
105 if exp16 <= 0 {
106 let full_sig = sig | 0x0010_0000_0000_0000;
107 let shift = (1 - exp16) as u64 + 42;
108
109 if shift >= 64 {
110 if shift == 64 && full_sig > (1_u64 << 52) {
111 return sign_bit | 1;
112 } else {
113 return sign_bit;
114 }
115 } else {
116 let shifted = full_sig >> shift;
117 let remainder = full_sig & ((1_u64 << shift) - 1);
118 let halfway = 1_u64 << (shift - 1);
119 let round_up = remainder > halfway || (remainder == halfway && (shifted & 1) != 0);
120 let sig16 = (shifted as u16) + round_up as u16;
121 return sign_bit | sig16;
122 }
123 }
124
125 let sig10 = (sig >> 42) as u16;
126 let remainder = sig & 0x3ff_ffff_ffff;
127 let halfway = 0x200_0000_0000_u64;
128 let round_up = remainder > halfway || (remainder == halfway && (sig10 & 1) != 0);
129 let sig16 = sig10 + round_up as u16;
130
131 if sig16 >= 0x0400 {
132 sign_bit | (((exp16 as u16) + 1) << 10)
133 } else {
134 sign_bit | ((exp16 as u16) << 10) | sig16
135 }
136}
137
138const fn f32_nan_to_f64(bits: u32) -> f64 {
145 let sign_bit = ((bits & 0x8000_0000) as u64) << 32;
146 let payload = ((bits & 0x007f_ffff) as u64) << 29;
147 f64::from_bits(sign_bit | 0x7ff0_0000_0000_0000 | payload)
148}
149
150#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
156pub(crate) enum Inner {
157 F16(u16),
158 F32(u32),
159 F64(u64),
160}
161
162impl Inner {
163 const fn new(x: f64) -> Self {
170 if x.is_finite() {
171 let bits16 = f64_to_f16(x);
172
173 if f16_to_f64(bits16).to_bits() == x.to_bits() {
174 Inner::F16(bits16)
175 } else if ((x as f32) as f64).to_bits() == x.to_bits() {
176 Inner::F32((x as f32).to_bits())
177 } else {
178 Inner::F64(x.to_bits())
179 }
180 } else {
181 let bits64 = x.to_bits();
182 let sign_bit = bits64 & 0x8000_0000_0000_0000;
183
184 if (bits64 & 0x3ff_ffff_ffff) == 0 {
185 let bits = (bits64 >> 42) & 0x7fff | (sign_bit >> 48);
186 Self::F16(bits as u16)
187 } else if (bits64 & 0x1fff_ffff) == 0 {
188 let bits = (bits64 >> 29) & 0x7fff_ffff | (sign_bit >> 32);
189 Self::F32(bits as u32)
190 } else {
191 Self::F64(bits64)
192 }
193 }
194 }
195}
196
197#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
231pub struct Float(pub(crate) Inner);
232
233impl ValueView for Float {
234 fn head(&self) -> Head {
235 match self.0 {
236 Inner::F16(bits) => Head::new(Major::SimpleOrFloat, Argument::U16(bits)),
237 Inner::F32(bits) => Head::new(Major::SimpleOrFloat, Argument::U32(bits)),
238 Inner::F64(bits) => Head::new(Major::SimpleOrFloat, Argument::U64(bits)),
239 }
240 }
241
242 fn payload(&self) -> crate::view::Payload<'_> {
243 crate::view::Payload::None
244 }
245}
246
247impl Float {
248 #[must_use]
269 pub fn new(value: impl Into<Self>) -> Self {
270 value.into()
271 }
272
273 #[must_use]
307 pub const fn with_payload(payload: u64) -> Self {
308 let sign_bit = payload & 0x10_0000_0000_0000; let lower52 = payload ^ sign_bit; if lower52 <= 0x3ff {
312 let sig = ((lower52 as u16) << 6).reverse_bits();
313 let sign_bit = (sign_bit >> 37) as u16;
314 Self(Inner::F16(sign_bit | 0x7c00 | sig))
315 } else if lower52 <= 0x7f_ffff {
316 let sig = ((lower52 as u32) << 9).reverse_bits();
317 let sign_bit = (sign_bit >> 21) as u32;
318 Self(Inner::F32(sign_bit | 0x7f80_0000 | sig))
319 } else if lower52 <= 0x0f_ffff_ffff_ffff {
320 let sig = (lower52 << 12).reverse_bits();
321 let sign_bit = sign_bit << 11;
322 Self(Inner::F64(sign_bit | 0x7ff0_0000_0000_0000 | sig))
323 } else {
324 panic!("payload exceeds maximum allowed value")
325 }
326 }
327
328 #[must_use]
342 pub const fn from_f64(value: f64) -> Self {
343 Self(Inner::new(value))
344 }
345
346 #[must_use]
360 pub const fn from_f32(value: f32) -> Self {
361 if value.is_nan() {
362 Self(Inner::new(f32_nan_to_f64(value.to_bits())))
364 } else {
365 Self(Inner::new(value as f64))
366 }
367 }
368
369 #[must_use]
379 pub const fn data_type(&self) -> DataType {
380 match self.0 {
381 Inner::F16(_) => DataType::Float16,
382 Inner::F32(_) => DataType::Float32,
383 Inner::F64(_) => DataType::Float64,
384 }
385 }
386
387 #[must_use]
388 pub(crate) const fn from_bits_u16(bits: u16) -> Self {
389 Self(Inner::F16(bits))
390 }
391
392 pub(crate) const fn from_bits_u32(bits: u32) -> Self {
393 Self(Inner::F32(bits))
394 }
395
396 pub(crate) const fn from_bits_u64(bits: u64) -> Self {
397 Self(Inner::F64(bits))
398 }
399
400 #[must_use]
406 pub(crate) const fn is_deterministic(self) -> bool {
407 matches!(
408 (self.0, Inner::new(self.to_f64())),
409 (Inner::F16(_), Inner::F16(_)) | (Inner::F32(_), Inner::F32(_)) | (Inner::F64(_), Inner::F64(_))
410 )
411 }
412
413 pub(crate) const fn shortest(self) -> Self {
416 Self(Inner::new(self.to_f64()))
417 }
418
419 #[must_use]
424 pub const fn to_f64(self) -> f64 {
425 match self.0 {
426 Inner::F16(bits) => f16_to_f64(bits),
427 Inner::F32(bits) => {
428 let f = f32::from_bits(bits);
429 if f.is_nan() { f32_nan_to_f64(bits) } else { f as f64 }
430 }
431 Inner::F64(bits) => f64::from_bits(bits),
432 }
433 }
434
435 pub const fn to_f32(self) -> Result<f32> {
441 match self.0 {
442 Inner::F16(bits) => Ok(f16_to_f32(bits)),
443 Inner::F32(bits) => Ok(f32::from_bits(bits)),
444 Inner::F64(_) => Err(Error::Precision),
445 }
446 }
447
448 pub const fn to_payload(self) -> Result<u64> {
464 if self.is_finite() {
465 Err(Error::InvalidValue)
466 } else {
467 let sign_bit;
468 let sig;
469
470 match self.0 {
471 Inner::F16(bits) => {
472 sign_bit = ((bits & 0x8000) as u64) << 37;
473 sig = (bits.reverse_bits() >> 6) as u64;
474 }
475 Inner::F32(bits) => {
476 sign_bit = ((bits & 0x8000_0000) as u64) << 21;
477 sig = (bits.reverse_bits() >> 9) as u64;
478 }
479 Inner::F64(bits) => {
480 sign_bit = (bits & 0x8000_0000_0000_0000) >> 11;
481 sig = bits.reverse_bits() >> 12;
482 }
483 }
484
485 Ok(sign_bit | sig)
486 }
487 }
488
489 #[must_use]
496 pub const fn is_finite(self) -> bool {
497 match self.0 {
498 Inner::F16(bits) => bits & 0x7c00 != 0x7c00,
499 Inner::F32(bits) => bits & 0x7f80_0000 != 0x7f80_0000,
500 Inner::F64(bits) => bits & 0x7ff0_0000_0000_0000 != 0x7ff0_0000_0000_0000,
501 }
502 }
503}
504
505impl From<f64> for Float {
508 fn from(value: f64) -> Self {
509 Self::from_f64(value)
510 }
511}
512
513impl From<f32> for Float {
514 fn from(value: f32) -> Self {
515 Self::from_f32(value)
516 }
517}
518
519macro_rules! try_from {
522 ($type:ty) => {
523 impl From<$type> for Float {
524 fn from(value: $type) -> Self {
525 Self::from(value as f64)
526 }
527 }
528 };
529}
530
531try_from!(u8);
532try_from!(u16);
533try_from!(u32);
534
535try_from!(i8);
536try_from!(i16);
537try_from!(i32);
538
539impl From<bool> for Float {
540 fn from(value: bool) -> Self {
541 Self(if value { Inner::new(1.0) } else { Inner::new(0.0) })
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 fn f16_is_nan(bits: u16) -> bool {
550 (bits & 0x7fff) > 0x7c00
551 }
552
553 #[test]
558 fn to_f64_zero() {
559 assert_eq!(f16_to_f64(0x0000), 0.0);
560 assert!(f16_to_f64(0x0000).is_sign_positive());
561 }
562
563 #[test]
564 fn to_f64_neg_zero() {
565 let v = f16_to_f64(0x8000);
566 assert_eq!(v.to_bits(), (-0.0_f64).to_bits());
567 }
568
569 #[test]
570 fn to_f64_one() {
571 assert_eq!(f16_to_f64(0x3c00), 1.0);
572 }
573
574 #[test]
575 fn to_f64_neg_one() {
576 assert_eq!(f16_to_f64(0xbc00), -1.0);
577 }
578
579 #[test]
580 fn to_f64_max_normal() {
581 assert_eq!(f16_to_f64(0x7bff), 65504.0);
582 }
583
584 #[test]
585 fn to_f64_min_positive_normal() {
586 assert_eq!(f16_to_f64(0x0400), 0.00006103515625);
587 }
588
589 #[test]
590 fn to_f64_min_positive_subnormal() {
591 assert_eq!(f16_to_f64(0x0001), 5.960464477539063e-8);
592 }
593
594 #[test]
595 fn to_f64_max_subnormal() {
596 assert_eq!(f16_to_f64(0x03ff), 0.00006097555160522461);
597 }
598
599 #[test]
600 fn to_f64_infinity() {
601 assert_eq!(f16_to_f64(0x7c00), f64::INFINITY);
602 }
603
604 #[test]
605 fn to_f64_neg_infinity() {
606 assert_eq!(f16_to_f64(0xfc00), f64::NEG_INFINITY);
607 }
608
609 #[test]
610 fn to_f64_nan() {
611 assert!(f16_to_f64(0x7e00).is_nan());
612 }
613
614 #[test]
615 fn to_f64_nan_preserves_payload() {
616 let bits = f16_to_f64(0x7c01).to_bits();
617 assert_eq!(bits, 0x7ff0_0400_0000_0000);
618 }
619
620 #[test]
621 fn to_f64_two() {
622 assert_eq!(f16_to_f64(0x4000), 2.0);
623 }
624
625 #[test]
626 fn to_f64_one_point_five() {
627 assert_eq!(f16_to_f64(0x3e00), 1.5);
628 }
629
630 #[test]
635 fn to_f32_zero() {
636 assert_eq!(f16_to_f32(0x0000), 0.0_f32);
637 assert!(f16_to_f32(0x0000).is_sign_positive());
638 }
639
640 #[test]
641 fn to_f32_neg_zero() {
642 assert_eq!(f16_to_f32(0x8000).to_bits(), (-0.0_f32).to_bits());
643 }
644
645 #[test]
646 fn to_f32_one() {
647 assert_eq!(f16_to_f32(0x3c00), 1.0_f32);
648 }
649
650 #[test]
651 fn to_f32_neg_one() {
652 assert_eq!(f16_to_f32(0xbc00), -1.0_f32);
653 }
654
655 #[test]
656 fn to_f32_two() {
657 assert_eq!(f16_to_f32(0x4000), 2.0_f32);
658 }
659
660 #[test]
661 fn to_f32_one_point_five() {
662 assert_eq!(f16_to_f32(0x3e00), 1.5_f32);
663 }
664
665 #[test]
666 fn to_f32_max_normal() {
667 assert_eq!(f16_to_f32(0x7bff), 65504.0_f32);
668 }
669
670 #[test]
671 fn to_f32_min_positive_normal() {
672 assert_eq!(f16_to_f32(0x0400), 0.000061035156_f32);
673 }
674
675 #[test]
676 fn to_f32_min_positive_subnormal() {
677 assert_eq!(f16_to_f32(0x0001), 5.9604645e-8_f32);
678 }
679
680 #[test]
681 fn to_f32_max_subnormal() {
682 assert_eq!(f16_to_f32(0x03ff), 0.00006097555_f32);
683 }
684
685 #[test]
686 fn to_f32_infinity() {
687 assert_eq!(f16_to_f32(0x7c00), f32::INFINITY);
688 }
689
690 #[test]
691 fn to_f32_neg_infinity() {
692 assert_eq!(f16_to_f32(0xfc00), f32::NEG_INFINITY);
693 }
694
695 #[test]
696 fn to_f32_nan() {
697 assert!(f16_to_f32(0x7e00).is_nan());
698 }
699
700 #[test]
701 fn to_f32_nan_preserves_payload() {
702 let bits = f16_to_f32(0x7c01).to_bits();
703 assert_eq!(bits, 0x7f80_2000);
705 }
706
707 #[test]
708 fn to_f32_agrees_with_f16_to_f64() {
709 for bits in 0..=0x7fff_u16 {
711 if f16_is_nan(bits) {
712 continue;
713 }
714 let via_f32 = f16_to_f32(bits);
715 let via_f64 = f16_to_f64(bits) as f32;
716 assert_eq!(via_f32.to_bits(), via_f64.to_bits(), "mismatch for bits 0x{bits:04x}");
717
718 let neg = bits | 0x8000;
719 let via_f32n = f16_to_f32(neg);
720 let via_f64n = f16_to_f64(neg) as f32;
721 assert_eq!(via_f32n.to_bits(), via_f64n.to_bits(), "mismatch for bits 0x{neg:04x}");
722 }
723 }
724
725 #[test]
730 fn from_f64_zero() {
731 assert_eq!(f64_to_f16(0.0), 0x0000);
732 }
733
734 #[test]
735 fn from_f64_neg_zero() {
736 assert_eq!(f64_to_f16(-0.0), 0x8000);
737 }
738
739 #[test]
740 fn from_f64_one() {
741 assert_eq!(f64_to_f16(1.0), 0x3c00);
742 }
743
744 #[test]
745 fn from_f64_neg_one() {
746 assert_eq!(f64_to_f16(-1.0), 0xbc00);
747 }
748
749 #[test]
750 fn from_f64_max_normal() {
751 assert_eq!(f64_to_f16(65504.0), 0x7bff);
752 }
753
754 #[test]
755 fn from_f64_overflow_to_infinity() {
756 assert_eq!(f64_to_f16(65520.0), 0x7c00);
757 }
758
759 #[test]
760 fn from_f64_infinity() {
761 assert_eq!(f64_to_f16(f64::INFINITY), 0x7c00);
762 }
763
764 #[test]
765 fn from_f64_neg_infinity() {
766 assert_eq!(f64_to_f16(f64::NEG_INFINITY), 0xfc00);
767 }
768
769 #[test]
770 fn from_f64_nan() {
771 assert!(f16_is_nan(f64_to_f16(f64::NAN)));
772 }
773
774 #[test]
775 fn from_f64_min_positive_subnormal() {
776 assert_eq!(f64_to_f16(5.960464477539063e-8), 0x0001);
777 }
778
779 #[test]
780 fn from_f64_min_positive_normal() {
781 assert_eq!(f64_to_f16(0.00006103515625), 0x0400);
782 }
783
784 #[test]
789 fn rounding_exactly_halfway_rounds_to_even_down() {
790 let halfway = f64::from_bits(0x3FF0_0200_0000_0000);
791 assert_eq!(f64_to_f16(halfway), 0x3c00);
792 }
793
794 #[test]
795 fn rounding_exactly_halfway_rounds_to_even_up() {
796 let halfway = f64::from_bits(0x3FF0_0600_0000_0000);
797 assert_eq!(f64_to_f16(halfway), 0x3c02);
798 }
799
800 #[test]
801 fn rounding_just_below_halfway_rounds_down() {
802 let below = f64::from_bits(0x3FF0_01FF_FFFF_FFFF);
803 assert_eq!(f64_to_f16(below), 0x3c00);
804 }
805
806 #[test]
807 fn rounding_just_above_halfway_rounds_up() {
808 let above = f64::from_bits(0x3FF0_0200_0000_0001);
809 assert_eq!(f64_to_f16(above), 0x3c01);
810 }
811
812 #[test]
813 fn rounding_subnormal_halfway_rounds_to_even() {
814 let val = 1.5 * 5.960464477539063e-8;
815 assert_eq!(f64_to_f16(val), 0x0002);
816 }
817
818 #[test]
819 fn rounding_subnormal_halfway_even_down() {
820 let val = 2.5 * 5.960464477539063e-8;
821 assert_eq!(f64_to_f16(val), 0x0002);
822 }
823
824 #[test]
825 fn rounding_normal_to_subnormal_boundary() {
826 let min_normal = 0.00006103515625_f64;
827 assert_eq!(f64_to_f16(min_normal), 0x0400);
828
829 let below = f64::from_bits(min_normal.to_bits() - 1);
830 assert_eq!(f64_to_f16(below), 0x0400);
831 }
832
833 #[test]
834 fn rounding_overflow_at_max() {
835 assert_eq!(f64_to_f16(65504.0), 0x7bff);
836 assert_eq!(f64_to_f16(65519.99), 0x7bff);
837 assert_eq!(f64_to_f16(65520.0), 0x7c00);
838 }
839
840 #[test]
841 fn rounding_tiny_to_zero() {
842 assert_eq!(f64_to_f16(1e-30), 0x0000);
843 assert_eq!(f64_to_f16(-1e-30), 0x8000);
844 }
845
846 #[test]
847 fn rounding_tiny_to_min_subnormal() {
848 let half_min: f64 = 0.5 * 5.960464477539063e-8;
849 assert_eq!(f64_to_f16(half_min), 0x0000);
850
851 let above = f64::from_bits(half_min.to_bits() + 1);
852 assert_eq!(f64_to_f16(above), 0x0001);
853 }
854
855 #[test]
860 fn roundtrip_all_exact_f16_values() {
861 for bits in 0..=0x7fff_u16 {
862 if f16_is_nan(bits) {
863 continue;
864 }
865 let f = f16_to_f64(bits);
866 let h2 = f64_to_f16(f);
867 assert_eq!(bits, h2, "roundtrip failed for bits 0x{bits:04x}");
868
869 let neg_bits = bits | 0x8000;
871 let fn_ = f16_to_f64(neg_bits);
872 let hn2 = f64_to_f16(fn_);
873 assert_eq!(neg_bits, hn2, "roundtrip failed for bits 0x{neg_bits:04x}");
874 }
875 }
876}