Skip to main content

fast_posit/posit/math/
sqrt.rs

1use super::*;
2
3impl<
4  const N: u32,
5  const ES: u32,
6  Int: crate::Int,
7  const RS: u32,
8> Posit<N, ES, Int, RS> {
9  /// Return a [normalised](Decoded::is_normalised) `Decoded` that is the result of √x, if `x` is
10  /// non-negative.
11  ///
12  /// # Safety
13  ///
14  /// `x` must to be [normalised](Decoded::is_normalised) and `x.frac` must be positive, or calling
15  /// this function is *undefined behaviour*.
16  #[inline]
17  pub(crate) unsafe fn sqrt_kernel(x: Decoded<N, ES, RS, Int>) -> (Decoded<N, ES, RS, Int>, Int) {
18    // Taking the square root of a number in the form `frac × 2^exp` has two steps.
19    //
20    // First, ensure that `exp` is an even number. If it's odd, add 1 to exp and compensate `frac`
21    // accordingly. That is:
22    //
23    //   frac', exp' = frac     , exp        if exp is even
24    //               = frac << 1, exp - 1    if exp is odd
25    //
26    // This is fine: `x.frac` is positive, meaning we have exactly 1 bit available to do the shift
27    // if we cast to `Int::Unsigned`.
28    //
29    // Then, the square root is easy.
30    //
31    //   √(frac' / FRAC_DENOM * 2^exp')
32    //   = √(frac') / √(FRAC_DENOM) * 2^(exp' / 2)
33    //   = √(frac' * FRAC_DENOM) / FRAC_DENOM * 2^(exp' / 2)
34    //
35    // In other words: the resulting `exp` is `exp >> 1` (not forgetting to accumulate this lost
36    // bit onto `sticky`), and the `frac` is the integer square root of `frac << FRAC_WIDTH`.
37
38    use crate::underlying::Unsigned;
39    let exp_odd = x.exp & Int::ONE;
40
41    let frac_adjusted = (x.frac).as_unsigned() << exp_odd.as_u32();
42    let exp_adjusted = x.exp - exp_odd;
43
44    let (result, _) = frac_adjusted.shift_sqrt(Decoded::<N, ES, RS, Int>::FRAC_WIDTH);
45    let frac = Int::of_unsigned(result);
46    let exp = exp_adjusted >> 1;
47    let sticky = Int::ONE;
48
49    (Decoded{frac, exp}, sticky)
50  }
51
52  /// Returns the square root of `self`, rounded. If `self` is negative or [NaR](Self::NAR),
53  /// returns NaR.
54  ///
55  /// Standard: "[**sqrt**](https://posithub.org/docs/posit_standard-2.pdf#subsection.5.5)".
56  ///
57  /// # Example
58  ///
59  /// ```
60  /// # use fast_posit::*;
61  /// # use core::f64::consts::PI;
62  /// assert_eq!(p16::sqrt((4. * PI).round_into()), p16::round_from(3.5449));
63  /// assert_eq!(p16::MINUS_ONE.sqrt(), p16::NAR);
64  /// ```
65  pub fn sqrt(self) -> Self {
66    if self < Self::ZERO {
67      Self::NAR
68    } else if self == Self::ZERO {
69      Self::ZERO
70    } else {
71      // SAFETY: `self` is not 0 or NaR
72      let x = unsafe { self.decode_regular() };
73      // SAFETY: `self` is non-negative
74      let (result, sticky) = unsafe { Self::sqrt_kernel(x) };
75      // SAFETY: `result.is_normalised()` holds
76      unsafe { result.encode_regular_round(sticky) }
77    }
78  }
79}
80
81#[cfg(test)]
82mod tests {
83  use crate::Posit;
84  use malachite::{rational::Rational, Natural};
85  use proptest::prelude::*;
86
87  /// Aux function: check that `x.sqrt()` is rounded correctly.
88  fn is_correct_rounded<const N: u32, const ES: u32, Int: crate::Int, const RS: u32>(
89    x: Posit<N, ES, Int, RS>,
90  ) -> bool
91  where
92    Rational: TryFrom<Posit<N, ES, Int, RS>, Error = super::rational::IsNaR>, 
93  {
94    let posit = x.sqrt();
95    if let Ok(rational) = Rational::try_from(x)
96    && rational >= Rational::from(0) {
97      use malachite::base::num::arithmetic::traits::{PowerOf2, FloorSqrt};
98      let factor =  Rational::power_of_2((N as u64) << ES << 1);
99      let natural = Natural::try_from(rational * &factor * &factor).unwrap();
100      let exact = Rational::from_naturals(natural.floor_sqrt(), factor.into_numerator());
101      super::rational::is_correct_rounded(exact, posit)
102    } else {
103      posit == Posit::NAR
104    }
105  }
106
107  macro_rules! test_exhaustive {
108    ($name:ident, $posit:ty) => {
109      #[test]
110      fn $name() {
111        for p in <$posit>::cases_exhaustive_all() {
112          assert!(is_correct_rounded(p), "{p:?}")
113        }
114      }
115    };
116  }
117
118  macro_rules! test_proptest {
119    ($name:ident, $posit:ty) => {
120      proptest!{
121        #![proptest_config(ProptestConfig::with_cases(crate::PROPTEST_CASES))]
122        #[test]
123        fn $name(p in <$posit>::cases_proptest_all()) {
124          assert!(is_correct_rounded(p), "{p:?}")
125        }
126      }
127    };
128  }
129
130  test_exhaustive!{posit_10_0_exhaustive, Posit::<10, 0, i16>}
131  test_exhaustive!{posit_10_1_exhaustive, Posit::<10, 1, i16>}
132  test_exhaustive!{posit_10_2_exhaustive, Posit::<10, 2, i16>}
133  test_exhaustive!{posit_10_3_exhaustive, Posit::<10, 3, i16>}
134
135  test_exhaustive!{posit_8_0_exhaustive, Posit::<8, 0, i8>}
136
137  test_exhaustive!{p8_exhaustive, crate::p8}
138  test_exhaustive!{p16_exhaustive, crate::p16}
139  test_proptest!{p32_proptest, crate::p32}
140  test_proptest!{p64_proptest, crate::p64}
141  // test_proptest!{p128_proptest, crate::p128}
142
143  test_exhaustive!{posit_3_0_exhaustive, Posit::<3, 0, i8>}
144  test_exhaustive!{posit_4_0_exhaustive, Posit::<4, 0, i8>}
145  test_exhaustive!{posit_4_1_exhaustive, Posit::<4, 1, i8>}
146
147  test_exhaustive!{bposit_8_3_6_exhaustive, Posit::<8, 3, i8, 6>}
148  test_exhaustive!{bposit_16_5_6_exhaustive, Posit::<16, 5, i16, 6>}
149  test_proptest!{bposit_32_5_6_proptest, Posit::<32, 5, i32, 6>}
150  test_proptest!{bposit_64_5_6_proptest, Posit::<64, 5, i64, 6>}
151
152  test_exhaustive!{bposit_10_2_6_exhaustive, Posit::<10, 2, i16, 6>}
153  test_exhaustive!{bposit_10_2_7_exhaustive, Posit::<10, 2, i16, 7>}
154  test_exhaustive!{bposit_10_2_8_exhaustive, Posit::<10, 2, i16, 8>}
155  test_exhaustive!{bposit_10_2_9_exhaustive, Posit::<10, 2, i16, 9>}
156}