use core::mem;
use crate::ArithmeticError;
pub(crate) trait Sqrt: Sized {
type Error;
fn sqrt(self) -> Result<Self, Self::Error>;
}
macro_rules! impl_sqrt {
($( $int:ty ),+ $(,)?) => {
$( impl_sqrt!(@single $int); )*
};
(@single $int:ty) => {
impl Sqrt for $int {
type Error = ArithmeticError;
#[inline]
fn sqrt(self) -> Result<Self, Self::Error> {
#[inline]
const fn bits<T>() -> u32 {
(mem::size_of::<T>() * 8) as _
}
#[cfg(feature = "std")]
#[inline]
fn guess(x: $int) -> $int {
(x as f64).sqrt() as $int
}
#[cfg(not(feature = "std"))]
#[inline]
fn guess(x: $int) -> $int {
#[inline]
fn log2_estimate(x: $int) -> u32 {
debug_assert!(x > 0);
bits::<$int>() - 1 - x.leading_zeros()
}
1 << ((log2_estimate(x) + 1) / 2)
}
#[inline]
fn fixpoint(mut x: $int, f: impl Fn($int) -> $int) -> $int {
let mut xn = f(x);
while x < xn {
x = xn;
xn = f(x);
}
while x > xn {
x = xn;
xn = f(x);
}
x
}
#[allow(unused_comparisons)]
{ debug_assert!(self >= 0); }
if bits::<$int>() > 64 {
let result = match u64::try_from(self) {
Ok(x) => x.sqrt()? as _,
Err(_) => {
let lo = (self >> 2u32).sqrt()? << 1;
let hi = lo + 1;
if hi * hi <= self { hi } else { lo }
}
};
return Ok(result);
}
if self < 4 {
return Ok((self > 0).into());
}
let next = |x: $int| (self / x + x) >> 1;
Ok(fixpoint(guess(self), next))
}
}
}
}
impl_sqrt!(i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);