#[cfg(feature = "num-bigint")]
use crate::tables::{CUBIC_MODULI, CUBIC_RESIDUAL, QUAD_MODULI, QUAD_RESIDUAL};
use crate::traits::{BitTest, ExactRoots};
use num_integer::Roots;
#[cfg(feature = "num-bigint")]
use num_bigint::{BigInt, BigUint, ToBigInt};
#[cfg(feature = "num-bigint")]
use num_traits::{One, Signed, ToPrimitive, Zero};
macro_rules! impl_bittest_prim {
($($T:ty)*) => {$(
impl BitTest for $T {
#[inline]
fn bits(&self) -> usize {
(<$T>::BITS - self.leading_zeros()) as usize
}
#[inline]
fn bit(&self, position: usize) -> bool {
self & (1 << position) > 0
}
#[inline]
fn trailing_zeros(&self) -> usize {
<$T>::trailing_zeros(*self) as usize
}
}
)*}
}
impl_bittest_prim!(u8 u16 u32 u64 u128 usize);
#[cfg(feature = "num-bigint")]
impl BitTest for BigUint {
fn bit(&self, position: usize) -> bool {
self.bit(position as u64)
}
fn bits(&self) -> usize {
BigUint::bits(self) as usize
}
#[inline]
fn trailing_zeros(&self) -> usize {
match BigUint::trailing_zeros(self) {
Some(a) => a as usize,
None => 0,
}
}
}
macro_rules! impl_exactroot_prim {
($($T:ty)*) => {$(
impl ExactRoots for $T {
fn nth_root_exact(&self, n: u32) -> Option<Self> {
if self < &0 && n % 2 == 0 {
return None;
}
let r = self.nth_root(n);
if &r.clone().pow(n) == self {
Some(r)
} else {
None
}
}
fn sqrt_exact(&self) -> Option<Self> {
if self < &0 { return None; }
let shift = self.trailing_zeros();
if shift & 1 == 1 { return None; }
if (self >> shift) & 7 != 1 { return None; }
self.nth_root_exact(2)
}
}
)*};
}
impl_exactroot_prim!(u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize);
#[cfg(feature = "num-bigint")]
impl ExactRoots for BigUint {
fn sqrt_exact(&self) -> Option<Self> {
if self.is_zero() {
return Some(BigUint::zero());
}
if let Some(v) = self.to_u64() {
return v.sqrt_exact().map(BigUint::from);
}
let shift = self.trailing_zeros().unwrap();
if shift & 1 == 1 {
return None;
}
if !((self >> shift) & BigUint::from(7u8)).is_one() {
return None;
}
#[cfg(not(feature = "big-table"))]
for (m, res) in QUAD_MODULI.iter().zip(QUAD_RESIDUAL) {
if (res >> ((self % m).to_u8().unwrap() & 63)) & 1 == 0 {
return None;
}
}
#[cfg(feature = "big-table")]
for (m, res) in QUAD_MODULI.iter().zip(QUAD_RESIDUAL) {
let rem = (self % m).to_u16().unwrap();
if (res[(rem / 64) as usize] >> (rem % 64)) & 1 == 0 {
return None;
}
}
self.nth_root_exact(2)
}
fn cbrt_exact(&self) -> Option<Self> {
if self.is_zero() {
return Some(BigUint::zero());
}
if let Some(v) = self.to_u64() {
return v.cbrt_exact().map(BigUint::from);
}
let shift = self.trailing_zeros().unwrap();
if shift % 3 != 0 {
return None;
}
#[cfg(not(feature = "big-table"))]
for (m, res) in CUBIC_MODULI.iter().zip(CUBIC_RESIDUAL) {
if (res >> (self % m).to_u8().unwrap()) & 1 == 0 {
return None;
}
}
#[cfg(feature = "big-table")]
for (m, res) in CUBIC_MODULI.iter().zip(CUBIC_RESIDUAL) {
let rem = (self % m).to_u16().unwrap();
if (res[(rem / 64) as usize] >> (rem % 64)) & 1 == 0 {
return None;
}
}
self.nth_root_exact(3)
}
}
#[cfg(feature = "num-bigint")]
impl ExactRoots for BigInt {
fn nth_root_exact(&self, n: u32) -> Option<Self> {
if self.is_negative() && n % 2 == 0 {
return None;
}
if self.is_negative() {
self.magnitude()
.nth_root_exact(n)
.and_then(|u| u.to_bigint())
.map(|v| -v)
} else {
self.magnitude()
.nth_root_exact(n)
.and_then(|u| u.to_bigint())
}
}
fn sqrt_exact(&self) -> Option<Self> {
self.to_biguint()
.and_then(|u| u.sqrt_exact())
.and_then(|u| u.to_bigint())
}
fn cbrt_exact(&self) -> Option<Self> {
self.magnitude()
.cbrt_exact()
.and_then(|u| u.to_bigint())
.map(|v| if self.is_negative() { -v } else { v })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exact_root_test() {
assert!(ExactRoots::sqrt_exact(&3u8).is_none());
assert!(matches!(ExactRoots::sqrt_exact(&4u8), Some(2)));
assert!(matches!(ExactRoots::sqrt_exact(&9u8), Some(3)));
assert!(ExactRoots::sqrt_exact(&18u8).is_none());
assert!(ExactRoots::sqrt_exact(&3i8).is_none());
assert!(matches!(ExactRoots::sqrt_exact(&4i8), Some(2)));
assert!(matches!(ExactRoots::sqrt_exact(&9i8), Some(3)));
assert!(ExactRoots::sqrt_exact(&18i8).is_none());
for _ in 0..100 {
let x = rand::random::<u32>();
assert_eq!(
ExactRoots::sqrt_exact(&x),
ExactRoots::nth_root_exact(&x, 2)
);
assert_eq!(
ExactRoots::cbrt_exact(&x),
ExactRoots::nth_root_exact(&x, 3)
);
let x = rand::random::<i32>();
assert_eq!(
ExactRoots::cbrt_exact(&x),
ExactRoots::nth_root_exact(&x, 3)
);
}
for _ in 0..100 {
let x = u64::from(rand::random::<u32>());
assert!(matches!(ExactRoots::sqrt_exact(&(x * x)), Some(v) if v == x));
let x = i64::from(rand::random::<i16>());
assert!(matches!(ExactRoots::cbrt_exact(&(x * x * x)), Some(v) if v == x));
}
for _ in 0..100 {
let x = u64::from(rand::random::<u32>());
let y = u64::from(rand::random::<u32>());
if x == y {
continue;
}
assert!(ExactRoots::sqrt_exact(&(x * y)).is_none());
}
#[cfg(feature = "num-bigint")]
{
use num_bigint::RandBigInt;
let mut rng = rand::thread_rng();
for _ in 0..10 {
let x = rng.gen_biguint(150);
assert_eq!(
ExactRoots::sqrt_exact(&x),
ExactRoots::nth_root_exact(&x, 2)
);
assert_eq!(
ExactRoots::cbrt_exact(&x),
ExactRoots::nth_root_exact(&x, 3)
);
let x = rng.gen_bigint(150);
assert_eq!(
ExactRoots::cbrt_exact(&x),
ExactRoots::nth_root_exact(&x, 3)
);
}
for _ in 0..10 {
let x = rng.gen_biguint(150);
assert!(matches!(ExactRoots::sqrt_exact(&(&x * &x)), Some(v) if v == x));
let x = rng.gen_biguint(150);
assert!(
matches!(ExactRoots::cbrt_exact(&(&x * &x * &x)), Some(v) if v == x),
"failed at {}",
x
);
}
for _ in 0..10 {
let x = rng.gen_biguint(150);
let y = rng.gen_biguint(150);
if x == y {
continue;
}
assert!(ExactRoots::sqrt_exact(&(x * y)).is_none());
}
}
}
#[test]
fn test_nth_root_exact_negative_even_root() {
let result = (-1i32).nth_root_exact(2);
assert!(
result.is_none(),
"nth_root_exact(2) should return None for negative numbers"
);
let result = (-4i32).nth_root_exact(2);
assert!(
result.is_none(),
"nth_root_exact(2) should return None for negative numbers"
);
let result = (-8i32).nth_root_exact(4);
assert!(
result.is_none(),
"nth_root_exact(4) should return None for negative numbers"
);
let result = (-8i32).nth_root_exact(3);
assert_eq!(
result,
Some(-2),
"nth_root_exact(3) should work for negative numbers"
);
let result = (-27i32).nth_root_exact(3);
assert_eq!(
result,
Some(-3),
"nth_root_exact(3) should work for negative numbers"
);
}
#[test]
fn test_nth_root_exact_all_signed_types() {
assert_eq!((-1i8).nth_root_exact(2), None);
assert_eq!((-1i16).nth_root_exact(2), None);
assert_eq!((-1i32).nth_root_exact(2), None);
assert_eq!((-1i64).nth_root_exact(2), None);
assert_eq!((-1i128).nth_root_exact(2), None);
assert_eq!((-1isize).nth_root_exact(2), None);
assert_eq!((-8i32).nth_root_exact(3), Some(-2));
assert_eq!((-32i32).nth_root_exact(5), Some(-2));
assert_eq!(16i32.nth_root_exact(4), Some(2));
assert_eq!(32i32.nth_root_exact(5), Some(2));
}
#[test]
#[cfg(feature = "num-bigint")]
fn test_nth_root_exact_bigint_negative() {
use num_bigint::BigInt;
assert_eq!(BigInt::from(-1).nth_root_exact(2), None);
assert_eq!(BigInt::from(-16).nth_root_exact(4), None);
assert_eq!(BigInt::from(-8).nth_root_exact(3), Some(BigInt::from(-2)));
assert_eq!(
BigInt::from(-1000000000i64).nth_root_exact(3),
Some(BigInt::from(-1000i32))
);
}
}