lilliput_float/
truncate.rs

1use crate::bits::{FpFromBits, FpToBits};
2use crate::floats::{F16, F24, F32, F40, F48, F56, F64, F8};
3use crate::repr::FpRepr;
4use crate::sealed::Sealed;
5
6pub trait FpTruncate<T>: Sized + Sealed {
7    fn truncate(self) -> (Self, T);
8}
9
10// Source: https://github.com/rust-lang/compiler-builtins/blob/3dea633a80d32da75e923a940d16ce98cce74822/src/float/trunc.rs#L4
11macro_rules! impl_float_truncate {
12    ($src:ty => [$($dst:ty),* $(,)?]) => {
13        $(
14            impl_float_truncate!($src => $dst);
15        )*
16    };
17    (F64 => F32) => {
18        impl FpTruncate<F32> for F64 {
19            fn truncate(self) -> (F64, F32) {
20                let value: f64 = self.into();
21
22                let dst_val = value as f32;
23                let src_val = dst_val as f64;
24
25                (F64::from(src_val), F32::from(dst_val))
26            }
27        }
28    };
29    ($src:ty => $dst:ty) => {
30        impl FpTruncate<$dst> for $src {
31            fn truncate(self) -> ($src, $dst) {
32                type Src = $src;
33                type Dst = $dst;
34
35                type SrcBits = <Src as FpRepr>::Bits;
36                type DstBits = <Dst as FpRepr>::Bits;
37
38                let src = self;
39
40                let src_bits: u32 = Src::BITS;
41
42                let src_exp_bias: SrcBits = Src::EXPONENT_BIAS;
43                let dst_exp_bias: DstBits = Dst::EXPONENT_BIAS;
44
45                let src_infinity: SrcBits = Src::EXPONENT_MASK;
46
47                let src_abs_mask: SrcBits = Src::SIGN_MASK - 1;
48                let round_mask: SrcBits = (1 << (Src::SIGNIFICAND_BITS - Dst::SIGNIFICAND_BITS)) - 1;
49                let halfway: SrcBits = 1 << (Src::SIGNIFICAND_BITS - Dst::SIGNIFICAND_BITS - 1);
50                let src_qnan: SrcBits = 1 << (Src::SIGNIFICAND_BITS - 1);
51                let src_nan_code: SrcBits = src_qnan - 1;
52
53                let src_inf_exp: SrcBits = Src::EXPONENT_MAX;
54                let dst_inf_exp: DstBits = Dst::EXPONENT_MAX;
55
56                let dst_qnan: DstBits = 1 << (Dst::SIGNIFICAND_BITS - 1);
57                let dst_nan_code: DstBits = dst_qnan - 1;
58
59                let underflow_exponent: SrcBits = src_exp_bias + 1 - (dst_exp_bias as SrcBits);
60                let overflow_exponent: SrcBits = src_exp_bias + (dst_inf_exp - dst_exp_bias) as SrcBits;
61
62                let underflow: SrcBits = underflow_exponent << Src::SIGNIFICAND_BITS;
63                let overflow: SrcBits = overflow_exponent << Src::SIGNIFICAND_BITS;
64
65                let bits_delta: u32 = Src::BITS - Dst::BITS;
66                let significand_bits_delta: u32 = Src::SIGNIFICAND_BITS - Dst::SIGNIFICAND_BITS;
67
68                let bits: SrcBits = src.to_bits();
69                let src_abs: SrcBits = bits & src_abs_mask;
70
71                let src_sign: SrcBits = bits & Src::SIGN_MASK;
72                let mut src_exponent: SrcBits = bits & Src::EXPONENT_MASK;
73                let mut src_significand: SrcBits = bits & Src::SIGNIFICAND_MASK;
74
75                let dst_sign: DstBits = (src_sign >> bits_delta) as DstBits;
76                let dst_exponent: DstBits;
77                let mut dst_significand: DstBits;
78
79                let exp_bias_delta: SrcBits = src_exp_bias.wrapping_sub(dst_exp_bias as SrcBits);
80                let shifted_exp_bias_delta: SrcBits = exp_bias_delta << Src::SIGNIFICAND_BITS;
81
82                if src_abs.wrapping_sub(underflow) < src_abs.wrapping_sub(overflow) {
83                    // The value remains normal.
84
85                    // The exponent is within the range of normal numbers in the
86                    // destination format.  We can convert by simply right-shifting with
87                    // rounding and adjusting the exponent.
88
89                    dst_exponent = (src_exponent.wrapping_sub(shifted_exp_bias_delta) >> significand_bits_delta) as DstBits;
90                    dst_significand = (src_significand >> significand_bits_delta) as DstBits;
91
92                    let round_bits = src_significand & round_mask;
93
94                    if round_bits > halfway {
95                        // Round significand to nearest.
96                        dst_significand += 1;
97                    } else if round_bits == halfway {
98                        // Tie significand to even.
99                        dst_significand += dst_significand & 1;
100                    }
101
102                    src_significand = ((dst_significand as SrcBits) << significand_bits_delta) & Src::SIGNIFICAND_MASK;
103                } else if src_abs > src_infinity {
104                    // The value is NaN.
105
106                    // Conjure the result by beginning with infinity, setting the qNaN
107                    // bit and inserting the (truncated) trailing NaN field.
108
109                    dst_exponent = dst_inf_exp << Dst::SIGNIFICAND_BITS;
110
111                    dst_significand = dst_qnan | dst_nan_code & ((src_significand & src_nan_code) >> significand_bits_delta) as DstBits;
112                } else if src_abs >= overflow {
113                    // Value overflows to infinity.
114
115                    dst_exponent = dst_inf_exp << Dst::SIGNIFICAND_BITS;
116                    src_exponent = src_inf_exp << Src::SIGNIFICAND_BITS;
117
118                    dst_significand = 0;
119                    src_significand = 0;
120                } else {
121                    // Value underflows on conversion to the destination type
122                    // or is an exact zero. The result may be a denormal or zero.
123
124                    // Extract the exponent to get the shift amount for the denormalization.
125
126                    let src_exp = src_abs >> Src::SIGNIFICAND_BITS;
127                    let shift: u32 = (src_exp_bias - dst_exp_bias as SrcBits + 1 - src_exp) as u32;
128
129                    let significand: SrcBits = (bits & Src::SIGNIFICAND_MASK) | Src::IMPLICIT_BIT;
130
131                    if shift >= Src::SIGNIFICAND_BITS {
132                        // Value underflows to zero.
133
134                        dst_exponent = 0;
135                        src_exponent = 0;
136
137                        dst_significand = 0;
138                        src_significand = 0;
139                    } else {
140                        // Value underflows to denormal.
141
142                        dst_exponent = 0;
143
144                        let sticky: SrcBits = if (significand << (src_bits - shift)) != 0 {
145                            1
146                        } else {
147                            0
148                        };
149
150                        // Right shift by the denormalization amount with sticky.
151                        let denormalized: SrcBits = (significand >> shift) | sticky;
152                        dst_significand = (denormalized >> significand_bits_delta) as DstBits;
153
154                        let round_bits = denormalized & round_mask;
155                        let round_bit: DstBits = 1;
156
157                        if round_bits > halfway {
158                            // Round to nearest
159                            dst_significand += round_bit;
160                        } else if round_bits == halfway {
161                            // Ties to even
162                            dst_significand += dst_significand & round_bit;
163                        };
164
165                        dst_significand &= Dst::SIGNIFICAND_MASK;
166
167                        if dst_significand == 0 {
168                            src_exponent = 0;
169                            src_significand = 0;
170                        } else {
171                            let scale = dst_significand.leading_zeros() - Dst::IMPLICIT_BIT.leading_zeros();
172
173                            src_exponent = (exp_bias_delta - (scale as SrcBits) + 1) << Src::SIGNIFICAND_BITS;
174                            src_significand = (dst_significand as SrcBits).wrapping_shl(significand_bits_delta + scale);
175
176                            src_exponent &= Src::EXPONENT_MASK;
177                            src_significand &= Src::SIGNIFICAND_MASK;
178                        }
179                    }
180                }
181
182                // src_exponent &= Src::EXPONENT_MASK;
183
184                let src_result_bits: SrcBits = src_sign | src_exponent | src_significand;
185                let dst_result_bits: DstBits = dst_sign | dst_exponent | dst_significand;
186
187                let src_val = Src::from_bits(src_result_bits);
188                let dst_val = Dst::from_bits(dst_result_bits);
189
190                (src_val, dst_val)
191            }
192        }
193    };
194}
195
196#[cfg(feature = "full")]
197impl_float_truncate!(F8 => []);
198#[cfg(feature = "full")]
199impl_float_truncate!(F16 => [F8]);
200#[cfg(feature = "full")]
201impl_float_truncate!(F24 => [F8, F16]);
202
203impl_float_truncate!(F32 => [F8, F16, F24]);
204#[cfg(feature = "full")]
205impl_float_truncate!(F40 => [F8, F16, F24, F32]);
206#[cfg(feature = "full")]
207impl_float_truncate!(F48 => [F8, F16, F24, F32, F40]);
208#[cfg(feature = "full")]
209impl_float_truncate!(F56 => [F8, F16, F24, F32, F40, F48]);
210
211impl_float_truncate!(F64 => [F8, F16, F24, F32, F40, F48, F56]);
212
213#[cfg(test)]
214mod tests {
215    use std::num::FpCategory;
216
217    use proptest::prelude::*;
218
219    use crate::FpClassify as _;
220
221    use super::*;
222
223    fn assert_valid_category(before: FpCategory, after: FpCategory) -> Result<(), TestCaseError> {
224        match before {
225            FpCategory::Nan => {
226                prop_assert_eq!(after, FpCategory::Nan);
227            }
228            FpCategory::Infinite => {
229                prop_assert_eq!(after, FpCategory::Infinite);
230            }
231            FpCategory::Zero => {
232                prop_assert_eq!(after, FpCategory::Zero);
233            }
234            FpCategory::Subnormal => {
235                prop_assert!(matches!(
236                    after,
237                    FpCategory::Zero | FpCategory::Subnormal | FpCategory::Infinite
238                ));
239            }
240            FpCategory::Normal => {
241                prop_assert!(matches!(
242                    after,
243                    FpCategory::Zero
244                        | FpCategory::Subnormal
245                        | FpCategory::Normal
246                        | FpCategory::Infinite
247                ));
248            }
249        }
250
251        Ok(())
252    }
253
254    proptest! {
255        // MARK: - F32
256
257        #[test]
258        fn truncate_f32_to_f8(native in f32::arbitrary()) {
259            let subject = F32::from(native);
260            let (src_actual, dst_actual): (F32, F8) = subject.truncate();
261
262            let category_before = subject.classify();
263            let src_category_after = src_actual.classify();
264            let dst_category_after = dst_actual.classify();
265
266            assert_valid_category(category_before, src_category_after)?;
267            assert_valid_category(category_before, dst_category_after)?;
268        }
269
270        #[test]
271        fn truncate_f32_to_f16(native in f32::arbitrary()) {
272            let subject = F32::from(native);
273            let (src_actual, dst_actual): (F32, F16) = subject.truncate();
274
275            let category_before = subject.classify();
276            let src_category_after = src_actual.classify();
277            let dst_category_after = dst_actual.classify();
278
279            assert_valid_category(category_before, src_category_after)?;
280            assert_valid_category(category_before, dst_category_after)?;
281
282            let _ = (src_actual, dst_actual);
283        }
284
285        #[test]
286        fn truncate_f32_to_f24(native in f32::arbitrary()) {
287            let subject = F32::from(native);
288            let (src_actual, dst_actual): (F32, F24) = subject.truncate();
289
290            let category_before = subject.classify();
291            let src_category_after = src_actual.classify();
292            let dst_category_after = dst_actual.classify();
293
294            assert_valid_category(category_before, src_category_after)?;
295            assert_valid_category(category_before, dst_category_after)?;
296        }
297
298        // MARK: - F64
299
300        #[test]
301        fn truncate_f64_to_f8(native in f64::arbitrary()) {
302            let subject = F64::from(native);
303            let (src_actual, dst_actual): (F64, F8) = subject.truncate();
304
305            let category_before = subject.classify();
306            let src_category_after = src_actual.classify();
307            let dst_category_after = dst_actual.classify();
308
309            assert_valid_category(category_before, src_category_after)?;
310            assert_valid_category(category_before, dst_category_after)?;
311        }
312
313        #[test]
314        fn truncate_f64_to_f16(native in f64::arbitrary()) {
315            let subject = F64::from(native);
316            let (src_actual, dst_actual): (F64, F16) = subject.truncate();
317
318            let category_before = subject.classify();
319            let src_category_after = src_actual.classify();
320            let dst_category_after = dst_actual.classify();
321
322            assert_valid_category(category_before, src_category_after)?;
323            assert_valid_category(category_before, dst_category_after)?;
324        }
325
326        #[test]
327        fn truncate_f64_to_f24(native in f64::arbitrary()) {
328            let subject = F64::from(native);
329            let (src_actual, dst_actual): (F64, F24) = subject.truncate();
330
331            let category_before = subject.classify();
332            let src_category_after = src_actual.classify();
333            let dst_category_after = dst_actual.classify();
334
335            assert_valid_category(category_before, src_category_after)?;
336            assert_valid_category(category_before, dst_category_after)?;
337        }
338
339        #[test]
340        fn truncate_f64_to_f32(native in f64::arbitrary()) {
341            let subject = F64::from(native);
342            let (src_actual, dst_actual): (F64, F32) = subject.truncate();
343
344            let category_before = subject.classify();
345            let src_category_after = src_actual.classify();
346            let dst_category_after = dst_actual.classify();
347
348            assert_valid_category(category_before, src_category_after)?;
349            assert_valid_category(category_before, dst_category_after)?;
350
351            let dst_native = native as f32;
352            let src_native = dst_native as f64;
353
354            let dst_expected = F32::from(dst_native);
355            let src_expected = F64::from(src_native);
356
357            prop_assert_eq!(dst_actual, dst_expected);
358            prop_assert_eq!(src_actual, src_expected);
359        }
360
361        #[test]
362        fn truncate_f64_to_f40(native in f64::arbitrary()) {
363            let subject = F64::from(native);
364            let (src_actual, dst_actual): (F64, F40) = subject.truncate();
365
366            let category_before = subject.classify();
367            let src_category_after = src_actual.classify();
368            let dst_category_after = dst_actual.classify();
369
370            assert_valid_category(category_before, src_category_after)?;
371            assert_valid_category(category_before, dst_category_after)?;
372        }
373
374        #[test]
375        fn truncate_f64_to_f48(native in f64::arbitrary()) {
376            let subject = F64::from(native);
377            let (src_actual, dst_actual): (F64, F48) = subject.truncate();
378
379            let category_before = subject.classify();
380            let src_category_after = src_actual.classify();
381            let dst_category_after = dst_actual.classify();
382
383            assert_valid_category(category_before, src_category_after)?;
384            assert_valid_category(category_before, dst_category_after)?;
385        }
386
387        #[test]
388        fn truncate_f64_to_f56(native in f64::arbitrary()) {
389            let subject = F64::from(native);
390            let (src_actual, dst_actual): (F64, F56) = subject.truncate();
391
392            let category_before = subject.classify();
393            let src_category_after = src_actual.classify();
394            let dst_category_after = dst_actual.classify();
395
396            assert_valid_category(category_before, src_category_after)?;
397            assert_valid_category(category_before, dst_category_after)?;
398        }
399    }
400}