bitcode_lightyear_patch/encoding/
expect_normalized_float.rs

1use crate::code::{Decode, Encode};
2use crate::encoding::prelude::*;
3use crate::encoding::Fixed;
4
5#[derive(Copy, Clone)]
6pub struct ExpectNormalizedFloat;
7
8// Cannot currently be more than 12 because that would make f64 > 64 bits (requiring multiple reads/writes).
9const MAX_EXP_ZEROS: usize = 12;
10
11macro_rules! impl_float {
12    ($write:ident, $read:ident, $t:ty, $i: ty, $mantissa:literal, $exp_bias: literal) => {
13        #[inline(always)]
14        fn $write(self, writer: &mut impl Write, v: $t) {
15            let mantissa_bits = $mantissa as usize;
16            let exp_bias = $exp_bias as u32;
17            let sign_bit = 1 << (<$i>::BITS - 1);
18
19            let bits = v.to_bits();
20            let sign = bits & sign_bit;
21            let bits_without_sign = bits & !sign_bit;
22            let exp = (bits_without_sign >> mantissa_bits) as u32;
23            let exp_zeros = (exp_bias - 1).wrapping_sub(exp) as usize;
24
25            if (sign | exp_zeros as $i) < MAX_EXP_ZEROS as $i {
26                let mantissa = bits as $i & !(<$i>::MAX << mantissa_bits);
27                let v = (((mantissa as u64) << 1) | 1) << exp_zeros;
28                writer.write_bits(v, mantissa_bits + exp_zeros + 1);
29            } else {
30                #[cold]
31                fn cold(writer: &mut impl Write, v: $t) {
32                    writer.write_zeros(MAX_EXP_ZEROS);
33                    v.encode(Fixed, writer).unwrap()
34                }
35                cold(writer, v);
36            }
37        }
38
39        #[inline(always)]
40        fn $read(self, reader: &mut impl Read) -> Result<$t> {
41            let mantissa_bits = $mantissa as usize;
42            let exp_bias = $exp_bias as u32;
43
44            let v = reader.peek_bits()?;
45            let exp_zeros = v.trailing_zeros() as usize;
46
47            if exp_zeros < MAX_EXP_ZEROS {
48                let exp_bits = exp_zeros + 1;
49                reader.advance(mantissa_bits + exp_bits);
50
51                let mantissa = (v >> exp_bits) as $i & !(<$i>::MAX << mantissa_bits);
52                let exp = (exp_bias - 1) - exp_zeros as u32;
53                Ok(<$t>::from_bits(exp as $i << mantissa_bits | mantissa))
54            } else {
55                #[cold]
56                fn cold(reader: &mut impl Read) -> Result<$t> {
57                    reader.advance(MAX_EXP_ZEROS);
58                    <$t>::decode(Fixed, reader)
59                }
60                cold(reader)
61            }
62        }
63    }
64}
65
66impl Encoding for ExpectNormalizedFloat {
67    impl_float!(write_f32, read_f32, f32, u32, 23, 127);
68    impl_float!(write_f64, read_f64, f64, u64, 52, 1023);
69}
70
71#[cfg(all(test, not(miri)))]
72mod benches {
73    mod f32 {
74        use crate::encoding::bench_prelude::*;
75        bench_encoding!(crate::encoding::ExpectNormalizedFloat, dataset::<f32>);
76    }
77
78    mod f64 {
79        use crate::encoding::bench_prelude::*;
80        bench_encoding!(crate::encoding::ExpectNormalizedFloat, dataset::<f64>);
81    }
82}
83
84#[cfg(all(test, debug_assertions, not(miri)))]
85mod tests {
86    macro_rules! impl_test {
87        ($t:ty, $i:ty) => {
88            use crate::encoding::expect_normalized_float::*;
89            use crate::encoding::prelude::test_prelude::*;
90            use rand::{Rng, SeedableRng};
91
92            fn t(value: $t) {
93                #[derive(Copy, Clone, Debug, Encode, Decode)]
94                struct ExactBits(#[bitcode_hint(expected_range = "0.0..1.0")] $t);
95
96                impl PartialEq for ExactBits {
97                    fn eq(&self, other: &Self) -> bool {
98                        self.0.to_bits() == other.0.to_bits()
99                    }
100                }
101                test_encoding(ExpectNormalizedFloat, ExactBits(value));
102            }
103
104            #[test]
105            fn test_random() {
106                let mut rng = rand_chacha::ChaCha20Rng::from_seed(Default::default());
107                for _ in 0..100000 {
108                    let f = <$t>::from_bits(rng.gen::<$i>());
109                    t(f)
110                }
111            }
112
113            #[test]
114            fn test2() {
115                t(0.0);
116                t(0.5);
117                t(1.0);
118                t(-1.0);
119                t(<$t>::INFINITY);
120                t(<$t>::NEG_INFINITY);
121                t(<$t>::NAN);
122                t(0.0000000000001);
123
124                fn normalized_floats(n: usize) -> impl Iterator<Item = $t> {
125                    let scale = 1.0 / n as $t;
126                    (0..n).map(move |i| i as $t * scale)
127                }
128
129                fn normalized_float_bits(n: usize) -> $t {
130                    use crate::buffer::BufferTrait;
131                    use crate::word_buffer::WordBuffer;
132
133                    let mut buffer = WordBuffer::default();
134                    let mut writer = buffer.start_write();
135                    for v in normalized_floats(n) {
136                        v.encode(ExpectNormalizedFloat, &mut writer).unwrap();
137                    }
138                    let bytes = buffer.finish_write(writer).to_vec();
139
140                    let (mut reader, context) = buffer.start_read(&bytes);
141                    for v in normalized_floats(n) {
142                        let decoded = <$t>::decode(ExpectNormalizedFloat, &mut reader).unwrap();
143                        assert_eq!(decoded, v);
144                    }
145                    WordBuffer::finish_read(reader, context).unwrap();
146
147                    (bytes.len() * u8::BITS as usize) as $t / n as $t
148                }
149
150                if <$i>::BITS == 32 {
151                    assert!((25.0..25.5).contains(&normalized_float_bits(1 << 12)));
152                    // panic!("bits {}", normalized_float_bits(6000000)); // bits 25.013674
153                } else {
154                    assert!((54.0..54.5).contains(&normalized_float_bits(1 << 12)));
155                    // panic!("bits {}", normalized_float_bits(6000000)); // bits 54.019532
156                }
157            }
158        };
159    }
160
161    mod f32 {
162        impl_test!(f32, u32);
163    }
164
165    mod f64 {
166        impl_test!(f64, u64);
167    }
168}