Skip to main content

oxinum_float/native/
float_sqrt.rs

1//! Square root for `BigFloat`.
2//!
3//! Strategy: integer floor square root of a *scaled* mantissa.
4//!
5//! For a value `v = m * 2^e` (sign positive — sqrt of a negative is
6//! returned as a `Domain` error since the native `BigFloat` does not model
7//! complex numbers or signed-NaN-like sentinels):
8//!
9//! ```text
10//! sqrt(v) = sqrt(m * 2^e) = sqrt(m) * 2^(e/2)
11//! ```
12//!
13//! The exponent split must be *integer*, so we first force `e` to be even:
14//! if `e` is odd, shift the mantissa left by one bit and decrement the
15//! exponent — a same-value transform.
16//!
17//! Then, to get *exactly* `target_precision` (or `target_precision + 1`)
18//! significant bits in the integer floor sqrt, we left-shift `m` by `k` bits
19//! before calling [`BigUint::sqrt`]. We choose `k` so the scaled mantissa
20//! has approximately `2 * target_precision` bits — its integer floor sqrt
21//! then has approximately `target_precision` bits.
22//!
23//! Specifically, with `b = m.bit_length()` and `P = target_precision`:
24//!
25//! - if `b` is even, choose `k = 2P - b` (even).
26//! - if `b` is odd,  choose `k = 2P - b + 1` (even).
27//!
28//! Both branches keep `k` even — required because the sqrt's exponent
29//! denominator is `(e - k) / 2`, which has to land in `i64` without losing
30//! a bit. The integer sqrt of the scaled mantissa then has either `P` or
31//! `P+1` bits; in either case, [`BigFloat::from_parts`] re-normalizes and
32//! rounds to the requested precision under the chosen mode.
33
34use oxinum_core::{OxiNumError, OxiNumResult, Sign};
35
36use super::float::{BigFloat, RoundingMode};
37
38impl BigFloat {
39    /// Return `sqrt(self)` at `prec` bits using the chosen rounding mode.
40    ///
41    /// # Errors
42    ///
43    /// - [`OxiNumError::Domain`] if `self < 0` (real-valued sqrt is
44    ///   undefined for negative inputs).
45    ///
46    /// # Examples
47    ///
48    /// ```
49    /// use oxinum_float::native::{BigFloat, RoundingMode};
50    /// let four = BigFloat::from_i64(4, 32, RoundingMode::HalfEven);
51    /// let two = four.sqrt(32, RoundingMode::HalfEven).expect("sqrt(4) is real");
52    /// assert_eq!(two.to_f64(), 2.0);
53    /// ```
54    pub fn sqrt(&self, prec: u32, mode: RoundingMode) -> OxiNumResult<Self> {
55        assert!(prec > 0, "BigFloat precision must be > 0");
56
57        // --- IEEE-754 non-finite guards ---
58        if self.is_nan() {
59            return Ok(BigFloat::nan(prec));
60        }
61        if self.is_infinite() {
62            return if self.sign == Sign::Negative {
63                Ok(BigFloat::nan(prec)) // sqrt(-Inf) = NaN
64            } else {
65                Ok(BigFloat::infinity(prec)) // sqrt(+Inf) = +Inf
66            };
67        }
68
69        if self.is_zero() {
70            return Ok(Self::zero(prec));
71        }
72        if self.sign == Sign::Negative {
73            return Err(OxiNumError::Domain(
74                "sqrt of negative is undefined for real BigFloat".into(),
75            ));
76        }
77
78        // --- Step 1: make the exponent even by shifting m left by one bit
79        // if needed. Adds one bit of precision to the mantissa, decrements
80        // the exponent — value preserved exactly.
81        let (mut even_exp, mut work_mantissa, mut work_bits) = {
82            let cur_e = self.exponent;
83            let cur_bits = self.mantissa.bit_length();
84            if cur_e.rem_euclid(2) == 0 {
85                (cur_e, self.mantissa.clone(), cur_bits)
86            } else {
87                // exponent is odd — left-shift mantissa by 1, decrement exp.
88                let shifted = self.mantissa.shl_bits(1);
89                (cur_e.saturating_sub(1), shifted, cur_bits + 1)
90            }
91        };
92        // After parity-fixup `even_exp` is even and `work_bits` reflects the
93        // current mantissa.
94        debug_assert_eq!(even_exp.rem_euclid(2), 0);
95
96        // --- Step 2: scale work_mantissa so its bit length is approximately
97        // 2*prec. The integer floor sqrt of a 2P-bit value lands in [2^(P-1),
98        // 2^P + epsilon], so the result mantissa has bit_length P or P+1.
99        let p = prec as u64;
100        let target_scaled_bits: u64 = p.saturating_mul(2);
101        let extra_shift: u64 = if target_scaled_bits > work_bits {
102            let raw = target_scaled_bits - work_bits;
103            // Round shift up to the next even number so the exponent
104            // halving lands cleanly in i64 — required because final
105            // exponent is (even_exp - extra_shift) / 2.
106            if raw.rem_euclid(2) == 0 {
107                raw
108            } else {
109                raw + 1
110            }
111        } else {
112            // Mantissa already has more bits than 2*prec — no extra shift
113            // needed. The integer sqrt will yield more than `prec` bits
114            // which from_parts then rounds down. The extra shift must
115            // still be even for clean exponent halving.
116            0
117        };
118        if extra_shift > 0 {
119            work_mantissa = work_mantissa.shl_bits(extra_shift);
120            work_bits = work_bits.saturating_add(extra_shift);
121            even_exp = even_exp.saturating_sub(extra_shift as i64);
122        }
123        debug_assert_eq!(extra_shift.rem_euclid(2), 0);
124        debug_assert_eq!(even_exp.rem_euclid(2), 0);
125        let _ = work_bits;
126
127        // --- Step 3: integer floor sqrt of the scaled mantissa.
128        let sqrt_mantissa = work_mantissa.sqrt();
129        debug_assert!(
130            !sqrt_mantissa.is_zero(),
131            "non-zero input must yield non-zero sqrt"
132        );
133
134        // --- Step 4: the exponent of the sqrt is even_exp / 2.
135        let new_exp = even_exp / 2;
136
137        // --- Step 5: land at canonical form at the requested precision.
138        Ok(BigFloat::from_parts(
139            Sign::Positive,
140            sqrt_mantissa,
141            new_exp,
142            prec,
143            mode,
144        ))
145    }
146}