use core;
use core::mem;
use traits::checked_pow;
use traits::PrimInt;
use Integer;
pub trait Roots: Integer {
fn nth_root(&self, n: u32) -> Self;
#[inline]
fn sqrt(&self) -> Self {
self.nth_root(2)
}
#[inline]
fn cbrt(&self) -> Self {
self.nth_root(3)
}
}
#[inline]
pub fn sqrt<T: Roots>(x: T) -> T {
x.sqrt()
}
#[inline]
pub fn cbrt<T: Roots>(x: T) -> T {
x.cbrt()
}
#[inline]
pub fn nth_root<T: Roots>(x: T, n: u32) -> T {
x.nth_root(n)
}
macro_rules! signed_roots {
($T:ty, $U:ty) => {
impl Roots for $T {
#[inline]
fn nth_root(&self, n: u32) -> Self {
if *self >= 0 {
(*self as $U).nth_root(n) as Self
} else {
assert!(n.is_odd(), "even roots of a negative are imaginary");
-((self.wrapping_neg() as $U).nth_root(n) as Self)
}
}
#[inline]
fn sqrt(&self) -> Self {
assert!(*self >= 0, "the square root of a negative is imaginary");
(*self as $U).sqrt() as Self
}
#[inline]
fn cbrt(&self) -> Self {
if *self >= 0 {
(*self as $U).cbrt() as Self
} else {
-((self.wrapping_neg() as $U).cbrt() as Self)
}
}
}
};
}
signed_roots!(i8, u8);
signed_roots!(i16, u16);
signed_roots!(i32, u32);
signed_roots!(i64, u64);
#[cfg(has_i128)]
signed_roots!(i128, u128);
signed_roots!(isize, usize);
#[inline]
fn fixpoint<T, F>(mut x: T, f: F) -> T
where
T: Integer + Copy,
F: Fn(T) -> T,
{
let mut xn = f(x);
while x < xn {
x = xn;
xn = f(x);
}
while x > xn {
x = xn;
xn = f(x);
}
x
}
#[inline]
fn bits<T>() -> u32 {
8 * mem::size_of::<T>() as u32
}
#[inline]
fn log2<T: PrimInt>(x: T) -> u32 {
debug_assert!(x > T::zero());
bits::<T>() - 1 - x.leading_zeros()
}
macro_rules! unsigned_roots {
($T:ident) => {
impl Roots for $T {
fn nth_root(&self, n: u32) -> Self {
match n {
0 => panic!("can't find a root of degree 0!"),
1 => return *self,
2 => return self.sqrt(),
3 => return self.cbrt(),
_ => (),
}
if bits::<$T>() <= n || *self < (1 << n) {
return (*self > 0) as $T;
}
if bits::<$T>() > 64 {
return if *self <= core::u64::MAX as $T {
(*self as u64).nth_root(n) as $T
} else {
let lo = (self >> n).nth_root(n) << 1;
let hi = lo + 1;
if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() {
match checked_pow(hi, n as usize) {
Some(x) if x <= *self => hi,
_ => lo,
}
} else {
if hi.pow(n) <= *self {
hi
} else {
lo
}
}
};
}
#[cfg(feature = "std")]
#[inline]
fn guess(x: $T, n: u32) -> $T {
if bits::<$T>() <= 32 || x <= core::u32::MAX as $T {
1 << ((log2(x) + n - 1) / n)
} else {
((x as f64).ln() / f64::from(n)).exp() as $T
}
}
#[cfg(not(feature = "std"))]
#[inline]
fn guess(x: $T, n: u32) -> $T {
1 << ((log2(x) + n - 1) / n)
}
let n1 = n - 1;
let next = |x: $T| {
let y = match checked_pow(x, n1 as usize) {
Some(ax) => self / ax,
None => 0,
};
(y + x * n1 as $T) / n as $T
};
fixpoint(guess(*self, n), next)
}
fn sqrt(&self) -> Self {
if bits::<$T>() > 64 {
return if *self <= core::u64::MAX as $T {
(*self as u64).sqrt() as $T
} else {
let lo = (self >> 2u32).sqrt() << 1;
let hi = lo + 1;
if hi * hi <= *self {
hi
} else {
lo
}
};
}
if *self < 4 {
return (*self > 0) as Self;
}
#[cfg(feature = "std")]
#[inline]
fn guess(x: $T) -> $T {
(x as f64).sqrt() as $T
}
#[cfg(not(feature = "std"))]
#[inline]
fn guess(x: $T) -> $T {
1 << ((log2(x) + 1) / 2)
}
let next = |x: $T| (self / x + x) >> 1;
fixpoint(guess(*self), next)
}
fn cbrt(&self) -> Self {
if bits::<$T>() > 64 {
return if *self <= core::u64::MAX as $T {
(*self as u64).cbrt() as $T
} else {
let lo = (self >> 3u32).cbrt() << 1;
let hi = lo + 1;
if hi * hi * hi <= *self {
hi
} else {
lo
}
};
}
if bits::<$T>() <= 32 {
let mut x = *self;
let mut y2 = 0;
let mut y = 0;
let smax = bits::<$T>() / 3;
for s in (0..smax + 1).rev() {
let s = s * 3;
y2 *= 4;
y *= 2;
let b = 3 * (y2 + y) + 1;
if x >> s >= b {
x -= b << s;
y2 += 2 * y + 1;
y += 1;
}
}
return y;
}
if *self < 8 {
return (*self > 0) as Self;
}
if *self <= core::u32::MAX as $T {
return (*self as u32).cbrt() as $T;
}
#[cfg(feature = "std")]
#[inline]
fn guess(x: $T) -> $T {
(x as f64).cbrt() as $T
}
#[cfg(not(feature = "std"))]
#[inline]
fn guess(x: $T) -> $T {
1 << ((log2(x) + 2) / 3)
}
let next = |x: $T| (self / (x * x) + x * 2) / 3;
fixpoint(guess(*self), next)
}
}
};
}
unsigned_roots!(u8);
unsigned_roots!(u16);
unsigned_roots!(u32);
unsigned_roots!(u64);
#[cfg(has_i128)]
unsigned_roots!(u128);
unsigned_roots!(usize);