use oxinum_core::{OxiNumError, OxiNumResult, Sign};
use super::float::{BigFloat, RoundingMode};
impl BigFloat {
pub fn sqrt(&self, prec: u32, mode: RoundingMode) -> OxiNumResult<Self> {
assert!(prec > 0, "BigFloat precision must be > 0");
if self.is_nan() {
return Ok(BigFloat::nan(prec));
}
if self.is_infinite() {
return if self.sign == Sign::Negative {
Ok(BigFloat::nan(prec)) } else {
Ok(BigFloat::infinity(prec)) };
}
if self.is_zero() {
return Ok(Self::zero(prec));
}
if self.sign == Sign::Negative {
return Err(OxiNumError::Domain(
"sqrt of negative is undefined for real BigFloat".into(),
));
}
let (mut even_exp, mut work_mantissa, mut work_bits) = {
let cur_e = self.exponent;
let cur_bits = self.mantissa.bit_length();
if cur_e.rem_euclid(2) == 0 {
(cur_e, self.mantissa.clone(), cur_bits)
} else {
let shifted = self.mantissa.shl_bits(1);
(cur_e.saturating_sub(1), shifted, cur_bits + 1)
}
};
debug_assert_eq!(even_exp.rem_euclid(2), 0);
let p = prec as u64;
let target_scaled_bits: u64 = p.saturating_mul(2);
let extra_shift: u64 = if target_scaled_bits > work_bits {
let raw = target_scaled_bits - work_bits;
if raw.rem_euclid(2) == 0 {
raw
} else {
raw + 1
}
} else {
0
};
if extra_shift > 0 {
work_mantissa = work_mantissa.shl_bits(extra_shift);
work_bits = work_bits.saturating_add(extra_shift);
even_exp = even_exp.saturating_sub(extra_shift as i64);
}
debug_assert_eq!(extra_shift.rem_euclid(2), 0);
debug_assert_eq!(even_exp.rem_euclid(2), 0);
let _ = work_bits;
let sqrt_mantissa = work_mantissa.sqrt();
debug_assert!(
!sqrt_mantissa.is_zero(),
"non-zero input must yield non-zero sqrt"
);
let new_exp = even_exp / 2;
Ok(BigFloat::from_parts(
Sign::Positive,
sqrt_mantissa,
new_exp,
prec,
mode,
))
}
}