fast_posit/posit/convert/
float.rs

1use super::*;
2
3use crate::underlying::const_as;
4
5/// Extract the mantissa and exponent fields of an [`f64`], and represent them as a
6/// [`Decoded`], plus any sticky bits that have been lost.
7fn decode_finite_f64<
8  const N: u32,
9  const ES: u32,
10  Int: crate::Int,
11>(num: f64) -> (Decoded<N, ES, Int>, Int) {  // TODO type for `(Decoded, sticky)`
12  debug_assert!(num.is_finite());
13  const MANTISSA_BITS: u32 = f64::MANTISSA_DIGITS - 1;
14  const EXP_BIAS: i64 = f64::MIN_EXP as i64 - 1;
15  const HIDDEN_BIT: i64 = (i64::MIN as u64 >> 1) as i64;
16
17  // Extract sign, mantissa, and exponent.
18  use crate::underlying::Sealed;
19  let sign = num.is_sign_positive();
20  let bits = num.abs().to_bits() as i64;
21  let mantissa = bits.mask_lsb(MANTISSA_BITS);
22  let mut exponent = bits >> MANTISSA_BITS;
23
24  // An exponent field of 0 marks a subnormal number. Normals have implicit unit (`1.xxx`) and -1
25  // bias in the exponent; subnormals don't.
26  let is_normal = exponent != 0;
27  exponent -= i64::from(is_normal);
28
29  // Represent the mantissa as a `frac` in the target type `Int`.
30  //
31  // First, the float `mantissa` field is (1) unsigned, and (2) does not contain the hidden bit, so
32  // we need to correct that. Note that, if `frac` is 1.000… (i.e. `mantissa` = 0), it's negation
33  // is not -1.000…, but rather -2.000… with -1 in the `exp`!
34  let frac: i64 = {
35    const SHIFT_LEFT: u32 = 64 - MANTISSA_BITS - 2;
36    let unsigned_frac = (mantissa << SHIFT_LEFT) | HIDDEN_BIT;
37    if sign {
38      unsigned_frac
39    } else if mantissa != 0 {
40      -unsigned_frac
41    } else {
42      exponent -= 1;
43      i64::MIN
44    }
45  };
46  // Then, the bits have to be moved, from a width of `i64` to a width of `Int`, which may be
47  // either narrower or wider than an `i64`. Bits lost, if any, have to be accumulated onto
48  // `sticky`, to be returned.
49  let (mut frac, sticky): (Int, Int) = {
50    let shift_left = Int::BITS as i64 - 64;
51    if shift_left >= 0 {
52      // The mantissa has to be shifted left: there are no bits lost.
53      let shift_left = shift_left as u32;
54      let frac = const_as::<i64, Int>(frac) << shift_left;
55      (frac, Int::ZERO)
56    } else {
57      // The mantissa has to be shifted right: that amount of bits are lost.
58      let shift_right = -shift_left as u32;
59      let sticky = Int::from(frac.mask_lsb(shift_right) != 0);
60      let frac = const_as::<i64, Int>(frac.lshr(shift_right));
61      (frac, sticky)
62    }
63  };
64
65  // If it's a subnormal, then `frac` is "underflowing". We have to find the first 1 after the 0s,
66  // or the first 0 after the 1, and shift it to the correct place.
67  //
68  // Examples:
69  //
70  //   subnormal frac: 0000001101
71  //          becomes: 0110100000 and adjust exponent by -5
72  //
73  //   subnormal frac: 1111011011
74  //          becomes: 1011011000 and adjust exponent by -3
75  //
76  // Beware also that, if `frac` is exactly 0 (e.g. if some lowest bits have been lost) then we
77  // need to floor at `Posit::MIN`.
78  if !is_normal {
79    if frac == Int::ZERO {
80      return (Decoded { frac: Int::ONE, exp: Int::MIN >> 1 }, Int::ZERO)
81    }
82    // SAFETY: Just early returned if `frac == 0`.
83    let underflow = unsafe { frac.leading_run_minus_one() };
84    frac = frac << underflow;
85    exponent = exponent.wrapping_sub(underflow as i64);
86  }
87
88  // Represent the exponent as an `exp` in the target type `Int`.
89  //
90  // Beware to clamp it to the range representable in a `Decoded::exp` of type `Int`, otherwise
91  // there may be overflow in more extreme conversions (like f64 → p8).
92  let exponent = exponent.wrapping_add(EXP_BIAS);
93  let exp =
94    if const { Int::BITS < 64 } && exponent > const_as::<Int, i64>(Int::MAX >> 1) {
95      Int::MAX >> 1
96    } else if const { Int::BITS < 64 } && exponent < const_as::<Int, i64>(Int::MIN >> 1) {
97      Int::MIN >> 1
98    } else {
99      const_as::<_, Int>(exponent)
100    };
101
102  (Decoded { exp, frac }, sticky)
103}
104
105/// Extract the mantissa and exponent fields of an [`f64`], and represent them as a
106/// [`Decoded`], plus any sticky bits that have been lost.
107fn decode_finite_f32<
108  const N: u32,
109  const ES: u32,
110  Int: crate::Int,
111>(num: f32) -> (Decoded<N, ES, Int>, Int) {
112  debug_assert!(num.is_finite());
113  // TODO I'm lazy so for I'm just gonna call into [`decode_finite_f64`], since `f32` → `f64` is
114  // lossless; write standalone impl at some point
115  decode_finite_f64(num.into())
116}
117
118impl<
119  const N: u32,
120  const ES: u32,
121  Int: crate::Int,
122> RoundFrom<f32> for Posit<N, ES, Int> {
123  /// Convert an `f32` into a `Posit`, [rounding according to the standard]:
124  ///
125  /// - If the value is any infinity or any NaN, it converts to [NaR](Posit::NAR).
126  /// - Otherwise, the float value is rounded (if necessary).
127  ///
128  /// [rounding according to the standard]: https://posithub.org/docs/posit_standard-2.pdf#subsection.6.5
129  fn round_from(value: f32) -> Self {
130    use core::num::FpCategory;
131    match value.classify() {
132      FpCategory::Nan | FpCategory::Infinite => Self::NAR,
133      FpCategory::Zero => Self::ZERO,
134      FpCategory::Normal | FpCategory::Subnormal => {
135        let (decoded, sticky) = decode_finite_f32(value);
136        unsafe { decoded.encode_regular_round(sticky) }
137      }
138    }
139  }
140}
141
142impl<
143  const N: u32,
144  const ES: u32,
145  Int: crate::Int,
146> RoundFrom<f64> for Posit<N, ES, Int> {
147  /// Convert an `f64` into a `Posit`, [rounding according to the standard]:
148  ///
149  /// - If the value is any infinity or any NaN, it converts to [NaR](Posit::NAR).
150  /// - Otherwise, the float value is rounded (if necessary).
151  ///
152  /// [rounding according to the standard]: https://posithub.org/docs/posit_standard-2.pdf#subsection.6.5
153  fn round_from(value: f64) -> Self {
154    use core::num::FpCategory;
155    match value.classify() {
156      FpCategory::Nan | FpCategory::Infinite => Self::NAR,
157      FpCategory::Zero => Self::ZERO,
158      FpCategory::Normal | FpCategory::Subnormal => {
159        let (decoded, sticky) = decode_finite_f64(value);
160        unsafe { decoded.encode_regular_round(sticky) }
161      }
162    }
163  }
164}
165
166/// Take a [`Decoded`] and encode into an [`f64`] using IEEE 754 rounding rules.
167fn encode_finite_f64<
168  const N: u32,
169  const ES: u32,
170  Int: crate::Int,
171>(decoded: Decoded<N, ES, Int>) -> f64 {
172  // Rust assumes that the "default" IEEE rounding mode "roundTiesToEven" is always in effect
173  // (anything else is UB). This considerably simplifies this implementation.
174  const MANTISSA_BITS: u32 = f64::MANTISSA_DIGITS - 1;
175  const EXPONENT_BITS: u32 = 64 - MANTISSA_BITS - 1;
176
177  // Split `frac` into sign and absolute value (sans hidden bit).
178  let sign = decoded.frac.is_positive();
179  let (frac_abs, exp) =
180    // Small detail: a `frac` of `0b10_000…` (= -2.0) is translated to a float mantissa with
181    // absolute value 1.0, compensated by adding +1 to the exponent.
182    if decoded.frac != Int::MIN {
183      (decoded.frac.wrapping_abs().mask_lsb(Decoded::<N, ES, Int>::FRAC_WIDTH), decoded.exp)
184    } else {
185      (Int::ZERO, decoded.exp + Int::ONE)
186    };
187
188  // There are only `EXPONENT_BITS` bits for the exponent, if we overflow this then we have to
189  // return ±∞.
190  //
191  // The range of *normal* f64 numbers has -m < exponent ≤ +m, where `m` is 2^EXPONENT_BITS - 1.
192  //
193  // However, exponents ≤ -m may still be representable as *subnormal* numbers.
194  //
195  // We can also short circuit this at compile time by simply checking if Posit::MAX_EXP < m, in
196  // which case there's no need to check for overflows or subnormals at all! This is the case,
197  // e.g. when converting p8,p16,p32 to f32, or p8,p16,p32,p64 to f64 (so, a common and important
198  // case to specialise).
199  let max_exponent: i64 = (1 << (EXPONENT_BITS - 1)) - 1;
200  let exponent =
201    // No overflow possible
202    if Int::BITS < EXPONENT_BITS || Posit::<N, ES, Int>::MAX_EXP < const_as(max_exponent) {
203      const_as::<Int, i64>(exp)
204    }
205    // Can overflow
206    else {
207      // Overflow case, go to infinity
208      if exp > const_as(max_exponent) {
209        return if sign {f64::INFINITY} else {f64::NEG_INFINITY}
210      }
211      // Subnormal case, TODO
212      else if exp <= const_as(-max_exponent) {
213        todo!()
214      }
215      // Normal case
216      else {
217        const_as::<Int, i64>(exp)
218      }
219    };
220
221  // There are only `MANTISSA_BITS` bits for the mantissa, any less than that and we have to do
222  // some rounding.
223  let shift_left  = MANTISSA_BITS.saturating_sub(Decoded::<N, ES, Int>::FRAC_WIDTH);
224  let shift_right = Decoded::<N, ES, Int>::FRAC_WIDTH.saturating_sub(MANTISSA_BITS);
225  let mantissa = const_as::<Int, i64>(frac_abs >> shift_right) << shift_left;
226  // Compute also the bits lost due to right shift, and compile them into `round` and `sticky`
227  // bits.
228  //
229  // The formula for whether to round up in "round to nearest, ties to even" is the usual one; for
230  // more information, look at the comments in [`encode_regular_round`]).
231  let lost_bits = if shift_right == 0 {Int::ZERO} else {frac_abs << (Int::BITS - shift_right)};
232  let round = lost_bits < Int::ZERO;
233  let sticky = lost_bits << 1 != Int::ZERO;
234  let odd = mantissa & 1 == 1;
235  let round_up = round & (odd | sticky);
236
237  // One detail: if the mantissa overflows (i.e. we rounded up and the mantissa is all 0s), then we
238  // need to bump the exponent by 1.
239  let mantissa = mantissa + i64::from(round_up);
240  let exponent = if round_up & (mantissa == 0) {exponent + 1} else {exponent};
241
242  // Assemble the three fields of the final result: sign, (biased) exponent, and mantissa.
243  let bits =
244    (u64::from(!sign) << (u64::BITS - 1))
245    | (((exponent + max_exponent) as u64) << MANTISSA_BITS)
246    | (mantissa as u64);
247  f64::from_bits(bits)
248}
249
250/// Take a [`Decoded`] and encode into an [`f32`] using IEEE 754 rounding rules.
251fn encode_finite_f32<
252  const N: u32,
253  const ES: u32,
254  Int: crate::Int,
255>(decoded: Decoded<N, ES, Int>) -> f32 {
256  // Again, I'm lazy so shortcut for now.
257  encode_finite_f64(decoded) as f32
258}
259
260impl<
261  const N: u32,
262  const ES: u32,
263  Int: crate::Int,
264> RoundFrom<Posit<N, ES, Int>> for f32 {
265  /// Convert a `Posit` into an `f32`, [rounding according to the standard]:
266  ///
267  /// - If the value is [0](Posit::ZERO), the result is `+0.0`.
268  /// - If the value is [NaR](Posit::NAR), the result is a [quiet NaN](f32::NAN).
269  /// - Otherwise, the posit value is rounded to a float (if necessary) using the "roundTiesToEven"
270  ///   rule in the IEEE 754 standard (in short: underflow to ±0, overflow to ±∞, otherwise round
271  ///   to nearest, in case of a tie round to nearest even bit pattern).
272  ///
273  /// [rounding according to the standard]: https://posithub.org/docs/posit_standard-2.pdf#subsection.6.5
274  fn round_from(value: Posit<N, ES, Int>) -> Self {
275    if value == Posit::ZERO {
276      0.
277    } else if value == Posit::NAR {
278      f32::NAN
279    } else {
280      // SAFETY: `value` is not 0 nor NaR
281      let decoded = unsafe { value.decode_regular() };
282      encode_finite_f32(decoded)
283    }
284  }
285}
286
287impl<
288  const N: u32,
289  const ES: u32,
290  Int: crate::Int,
291> RoundFrom<Posit<N, ES, Int>> for f64 {
292  /// Convert a `Posit` into an `f64`, [rounding according to the standard]:
293  ///
294  /// - If the value is [0](Posit::ZERO), the result is `+0.0`.
295  /// - If the value is [NaR](Posit::NAR), the result is a [quiet NaN](f64::NAN).
296  /// - Otherwise, the posit value is rounded to a float (if necessary) using the "roundTiesToEven"
297  ///   rule in the IEEE 754 standard (in short: underflow to ±0, overflow to ±∞, otherwise round
298  ///   to nearest, in case of a tie round to nearest even bit pattern).
299  ///
300  /// [rounding according to the standard]: https://posithub.org/docs/posit_standard-2.pdf#subsection.6.5
301  fn round_from(value: Posit<N, ES, Int>) -> Self {
302    if value == Posit::ZERO {
303      0.
304    } else if value == Posit::NAR {
305      f64::NAN
306    } else {
307      // SAFETY: `value` is not 0 nor NaR
308      let decoded = unsafe { value.decode_regular() };
309      encode_finite_f64(decoded)
310    }
311  }
312}
313
314#[cfg(test)]
315mod tests {
316  use super::*;
317  use malachite::rational::Rational;
318  use proptest::prelude::*;
319
320  mod float_to_posit {
321    use super::*;
322
323    /// Instantiate a suite of tests
324    macro_rules! make_tests {
325      ($float:ty, $posit:ty) => {
326        use super::*;
327
328        #[test]
329        fn zero() {
330          assert_eq!(<$posit>::round_from(0.0 as $float), <$posit>::ZERO)
331        }
332
333        #[test]
334        fn one() {
335          assert_eq!(<$posit>::round_from(1.0 as $float), <$posit>::ONE)
336        }
337
338        #[test]
339        fn minus_one() {
340          assert_eq!(<$posit>::round_from(-1.0 as $float), <$posit>::MINUS_ONE)
341        }
342
343        #[test]
344        fn nan() {
345          assert_eq!(<$posit>::round_from(<$float>::NAN), <$posit>::NAR)
346        }
347
348        #[test]
349        fn min() {
350          if const { <$posit>::MAX_EXP as i64 <= 127 } {
351            assert_eq!(<$posit>::round_from(<$float>::MIN), <$posit>::MIN)
352          }
353        }
354
355        #[test]
356        fn max() {
357          if const { <$posit>::MAX_EXP as i64 <= 127 } {
358            assert_eq!(<$posit>::round_from(<$float>::MAX), <$posit>::MAX)
359          }
360        }
361
362        #[test]
363        fn min_positive() {
364          if const { <$posit>::MAX_EXP as i64 <= 127 } {
365            assert_eq!(<$posit>::round_from(<$float>::MIN_POSITIVE), <$posit>::MIN_POSITIVE)
366          }
367        }
368
369        #[test]
370        fn max_negative() {
371          if const { <$posit>::MAX_EXP as i64 <= 127 } {
372            assert_eq!(<$posit>::round_from(-<$float>::MIN_POSITIVE), <$posit>::MAX_NEGATIVE)
373          }
374        }
375
376        #[test]
377        fn subnormal_positive() {
378          if const { <$posit>::MAX_EXP as i64 <= 127 } {
379            assert_eq!(<$posit>::round_from(<$float>::from_bits(1)), <$posit>::MIN_POSITIVE)
380          }
381        }
382
383        #[test]
384        fn subnormal_negative() {
385          if const { <$posit>::MAX_EXP as i64 <= 127 } {
386            assert_eq!(<$posit>::round_from(-<$float>::from_bits(1)), <$posit>::MAX_NEGATIVE)
387          }
388        }
389
390        proptest!{
391          #![proptest_config(ProptestConfig::with_cases(crate::PROPTEST_CASES))]
392          #[test]
393          fn proptest(float: $float) {
394            let posit = <$posit>::round_from(float);
395            match Rational::try_from(float) {
396              Ok(exact) => assert!(super::rational::is_correct_rounded(exact, posit)),
397              Err(_) => assert!(posit == <$posit>::NAR),
398            }
399          }
400        }
401      };
402    }
403
404    mod f64 {
405      use super::*;
406
407      mod p8 { make_tests!{f64, crate::p8} }
408      mod p16 { make_tests!{f64, crate::p16} }
409      mod p32 { make_tests!{f64, crate::p32} }
410      mod p64 { make_tests!{f64, crate::p64} }
411
412      mod posit_8_0 { make_tests!{f64, Posit::<8, 0, i8>} }
413      mod posit_10_0 { make_tests!{f64, Posit::<10, 0, i16>} }
414      mod posit_10_1 { make_tests!{f64, Posit::<10, 1, i16>} }
415      mod posit_10_2 { make_tests!{f64, Posit::<10, 2, i16>} }
416      mod posit_10_3 { make_tests!{f64, Posit::<10, 3, i16>} }
417      mod posit_20_4 { make_tests!{f64, Posit::<20, 4, i32>} }
418
419      mod posit_3_0 { make_tests!{f64, Posit::<3, 0, i8>} }
420      mod posit_4_0 { make_tests!{f64, Posit::<4, 0, i8>} }
421      mod posit_4_1 { make_tests!{f64, Posit::<4, 1, i8>} }
422    }
423
424    mod f32 {
425      use super::*;
426
427      mod p8 { make_tests!{f32, crate::p8} }
428      mod p16 { make_tests!{f32, crate::p16} }
429      mod p32 { make_tests!{f32, crate::p32} }
430      mod p64 { make_tests!{f32, crate::p64} }
431
432      mod posit_8_0 { make_tests!{f32, Posit::<8, 0, i8>} }
433      mod posit_10_0 { make_tests!{f32, Posit::<10, 0, i16>} }
434      mod posit_10_1 { make_tests!{f32, Posit::<10, 1, i16>} }
435      mod posit_10_2 { make_tests!{f32, Posit::<10, 2, i16>} }
436      mod posit_10_3 { make_tests!{f32, Posit::<10, 3, i16>} }
437      mod posit_20_4 { make_tests!{f32, Posit::<20, 4, i32>} }
438
439      mod posit_3_0 { make_tests!{f32, Posit::<3, 0, i8>} }
440      mod posit_4_0 { make_tests!{f32, Posit::<4, 0, i8>} }
441      mod posit_4_1 { make_tests!{f32, Posit::<4, 1, i8>} }
442    }
443  }
444
445  mod posit_to_float {
446    use super::*;
447
448    // TODO Tests are very incomplete! It's limited to testing float ↔ posit roundtrips, which is
449    // not even exact!
450
451    /// Instantiate a suite of tests
452    macro_rules! test_exhaustive {
453      ($float:ty, $posit:ty) => {
454        use super::*;
455
456        #[test]
457        fn posit_roundtrip_exhaustive() {
458          for posit in <$posit>::cases_exhaustive_all() {
459            let float = <$float>::round_from(posit);
460            let reposit = <$posit>::round_from(float);
461            assert_eq!(posit, reposit)
462          }
463        }
464
465        /*#[test]
466        fn float_roundtrip_exhaustive(float: $float) {
467          let posit = <$posit>::round_from(float);
468          let refloat = <$float>::round_from(posit);
469          assert_eq!(float, refloat)
470        }*/
471      };
472    }
473
474    /// Instantiate a suite of tests
475    macro_rules! test_proptest {
476      ($float:ty, $posit:ty) => {
477        use super::*;
478
479        proptest!{
480          #![proptest_config(ProptestConfig::with_cases(crate::PROPTEST_CASES))]
481
482          #[test]
483          fn posit_roundtrip_proptest(posit in <$posit>::cases_proptest_all()) {
484            let float = <$float>::round_from(posit);
485            let reposit = <$posit>::round_from(float);
486            assert_eq!(posit, reposit)
487          }
488
489          /*#[test]
490          fn float_roundtrip_proptest(float: $float) {
491            let posit = <$posit>::round_from(float);
492            let refloat = <$float>::round_from(posit);
493            assert_eq!(float, refloat)
494          }*/
495        }
496      };
497    }
498
499    mod f64 {
500      use super::*;
501
502      mod p8 { test_exhaustive!{f64, crate::p8} }
503      mod p16 { test_exhaustive!{f64, crate::p16} }
504      mod p32 { test_proptest!{f64, crate::p32} }
505      // mod p64 { test_proptest!{f64, crate::p64} }
506
507      mod posit_8_0 { test_exhaustive!{f64, Posit::<8, 0, i8>} }
508      mod posit_10_0 { test_exhaustive!{f64, Posit::<10, 0, i16>} }
509      mod posit_10_1 { test_exhaustive!{f64, Posit::<10, 1, i16>} }
510      mod posit_10_2 { test_exhaustive!{f64, Posit::<10, 2, i16>} }
511      mod posit_10_3 { test_exhaustive!{f64, Posit::<10, 3, i16>} }
512      mod posit_20_4 { test_proptest!{f64, Posit::<20, 4, i32>} }
513
514      mod posit_3_0 { test_exhaustive!{f64, Posit::<3, 0, i8>} }
515      mod posit_4_0 { test_exhaustive!{f64, Posit::<4, 0, i8>} }
516      mod posit_4_1 { test_exhaustive!{f64, Posit::<4, 1, i8>} }
517    }
518
519    mod f32 {
520      use super::*;
521
522      mod p8 { test_exhaustive!{f32, crate::p8} }
523      mod p16 { test_exhaustive!{f32, crate::p16} }
524      // mod p32 { test_proptest!{f32, crate::p32} }
525      // mod p64 { test_proptest!{f32, crate::p64} }
526
527      mod posit_8_0 { test_exhaustive!{f32, Posit::<8, 0, i8>} }
528      mod posit_10_0 { test_exhaustive!{f32, Posit::<10, 0, i16>} }
529      mod posit_10_1 { test_exhaustive!{f32, Posit::<10, 1, i16>} }
530      mod posit_10_2 { test_exhaustive!{f32, Posit::<10, 2, i16>} }
531      mod posit_10_3 { test_exhaustive!{f32, Posit::<10, 3, i16>} }
532      // mod posit_20_4 { test_proptest!{f32, Posit::<20, 4, i32>} }
533
534      mod posit_3_0 { test_exhaustive!{f32, Posit::<3, 0, i8>} }
535      mod posit_4_0 { test_exhaustive!{f32, Posit::<4, 0, i8>} }
536      mod posit_4_1 { test_exhaustive!{f32, Posit::<4, 1, i8>} }
537    }
538  }
539}