use super::{CubicRootRem, SquareRootRem};
use crate::DivRem;
pub(crate) trait NormalizedRootRem: Sized {
type OutputRoot;
fn normalized_sqrt_rem(self) -> (Self::OutputRoot, Self);
fn normalized_cbrt_rem(self) -> (Self::OutputRoot, Self);
}
const RSQRT_TAB: [u8; 96] = [
0xfc, 0xf4, 0xed, 0xe6, 0xdf, 0xd9, 0xd3, 0xcd, 0xc7, 0xc2, 0xbc, 0xb7, 0xb2, 0xad, 0xa9, 0xa4,
0xa0, 0x9c, 0x98, 0x94, 0x90, 0x8c, 0x88, 0x85, 0x81, 0x7e, 0x7b, 0x77, 0x74, 0x71, 0x6e, 0x6b,
0x69, 0x66, 0x63, 0x61, 0x5e, 0x5b, 0x59, 0x57, 0x54, 0x52, 0x50, 0x4d, 0x4b, 0x49, 0x47, 0x45,
0x43, 0x41, 0x3f, 0x3d, 0x3b, 0x39, 0x37, 0x36, 0x34, 0x32, 0x30, 0x2f, 0x2d, 0x2c, 0x2a, 0x28,
0x27, 0x25, 0x24, 0x22, 0x21, 0x1f, 0x1e, 0x1d, 0x1b, 0x1a, 0x19, 0x17, 0x16, 0x15, 0x14, 0x12,
0x11, 0x10, 0x0f, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01,
];
const RCBRT_TAB: [u8; 56] = [
0xf6, 0xe4, 0xd4, 0xc6, 0xb9, 0xae, 0xa4, 0x9b, 0x92, 0x8a, 0x83, 0x7c, 0x76, 0x70, 0x6b, 0x66,
0x61, 0x5c, 0x57, 0x53, 0x4f, 0x4b, 0x48, 0x44, 0x41, 0x3e, 0x3b, 0x38, 0x35, 0x32, 0x2f, 0x2d,
0x2a, 0x28, 0x25, 0x23, 0x21, 0x1f, 0x1d, 0x1b, 0x19, 0x17, 0x15, 0x13, 0x11, 0x10, 0x0e, 0x0c,
0x0b, 0x09, 0x08, 0x06, 0x05, 0x03, 0x02, 0x01,
];
macro_rules! fix_sqrt_error {
($t:ty, $n:ident, $s:ident) => {{
let mut e = $n - ($s as $t).pow(2);
let mut elim = 2 * $s as $t + 1;
while e >= elim {
$s += 1;
e -= elim;
elim += 2;
}
e
}};
}
macro_rules! fix_cbrt_error {
($t:ty, $n:ident, $c:ident) => {{
let cc = ($c as $t).pow(2);
let mut e = $n - cc * ($c as $t);
let mut elim = 3 * (cc + $c as $t) + 1;
while e >= elim {
$c += 1;
e -= elim;
elim += 6 * ($c as $t);
}
e
}};
}
impl NormalizedRootRem for u16 {
type OutputRoot = u8;
fn normalized_sqrt_rem(self) -> (u8, u16) {
debug_assert!(self.leading_zeros() <= 1);
let r = 0x100 | RSQRT_TAB[(self >> 9) as usize - 32] as u32; let s = (r * self as u32) >> 16;
let mut s = (s - 1) as u8;
let e = fix_sqrt_error!(u16, self, s);
(s, e)
}
fn normalized_cbrt_rem(self) -> (u8, u16) {
debug_assert!(self.leading_zeros() <= 2);
let adjust = self.leading_zeros() == 0;
let r = 0x100 | RCBRT_TAB[(self >> (9 + (3 * adjust as u8))) as usize - 8] as u32; let r2 = (r * r) >> (2 + 2 * adjust as u8);
let c = (r2 * self as u32) >> 24;
let mut c = (c - 1) as u8;
let e = fix_cbrt_error!(u16, self, c);
(c, e)
}
}
#[inline]
fn wmul16_hi(a: u16, b: u16) -> u16 {
(((a as u32) * (b as u32)) >> 16) as u16
}
impl NormalizedRootRem for u32 {
type OutputRoot = u16;
fn normalized_sqrt_rem(self) -> (u16, u32) {
debug_assert!(self.leading_zeros() <= 1);
let n16 = (self >> 16) as u16;
let r = 0x100 | RSQRT_TAB[(n16 >> 9) as usize - 32] as u32;
let r = ((3 * r as u16) << 5) - (wmul32_hi(self, r * r * r) >> 11) as u16;
let r = r << 1; let mut s = wmul16_hi(r, n16).saturating_mul(2); s -= 4;
let e = self - (s as u32) * (s as u32);
s += wmul16_hi((e >> 16) as u16, r);
let e = fix_sqrt_error!(u32, self, s);
(s, e)
}
fn normalized_cbrt_rem(self) -> (u16, u32) {
debug_assert!(self.leading_zeros() <= 2);
let adjust = self.leading_zeros() < 2;
let n16 = (self >> (16 + 3 * adjust as u8)) as u16;
let r = 0x100 | RCBRT_TAB[(n16 >> 8) as usize - 8] as u32;
let r3 = (r * r * r) >> 11;
let t = (4 << 11) - wmul16_hi(n16, r3 as u16); let mut r = ((r * t as u32 / 3) >> 4) as u16; r >>= adjust as u8;
let r = r - 10; let mut c = wmul16_hi(r, wmul16_hi(r, (self >> 16) as u16)) >> 2;
let e = fix_cbrt_error!(u32, self, c);
(c, e)
}
}
#[inline]
fn wmul32_hi(a: u32, b: u32) -> u32 {
(((a as u64) * (b as u64)) >> 32) as u32
}
impl NormalizedRootRem for u64 {
type OutputRoot = u32;
fn normalized_sqrt_rem(self) -> (u32, u64) {
debug_assert!(self.leading_zeros() <= 1);
let n32 = (self >> 32) as u32;
let r = 0x100 | RSQRT_TAB[(n32 >> 25) as usize - 32] as u32;
let r = ((3 * r) << 21) - wmul32_hi(n32, (r * r * r) << 5);
let t = (3 << 28) - wmul32_hi(r, wmul32_hi(r, n32)); let r = wmul32_hi(r, t);
let r = r << 4; let mut s = wmul32_hi(r, n32) << 1;
s -= 10;
let e = self - (s as u64) * (s as u64);
s += wmul32_hi((e >> 32) as u32, r);
let e = fix_sqrt_error!(u64, self, s);
(s, e)
}
fn normalized_cbrt_rem(self) -> (u32, u64) {
debug_assert!(self.leading_zeros() <= 2);
let adjust = self.leading_zeros() == 0;
let n32 = (self >> (32 + 3 * adjust as u8)) as u32;
let r = 0x100 | RCBRT_TAB[(n32 >> 25) as usize - 8] as u32;
let t = (4 << 23) - wmul32_hi(n32, r * r * r);
let r = r * (t / 3);
let t = (4 << 28) - wmul32_hi(r, wmul32_hi(r, wmul32_hi(r, n32)));
let mut r = wmul32_hi(r, t) / 3; r >>= adjust as u8;
let r = r - 1; let mut c = wmul32_hi(r, wmul32_hi(r, (self >> 32) as u32));
let e = fix_cbrt_error!(u64, self, c);
(c, e)
}
}
impl NormalizedRootRem for u128 {
type OutputRoot = u64;
fn normalized_sqrt_rem(self) -> (u64, u128) {
debug_assert!(self.leading_zeros() <= 1);
let (a, b) = (self >> u64::BITS, self & u64::MAX as u128);
let (a, b) = (a as u64, b as u64);
let (s1, r1) = a.normalized_sqrt_rem();
const KBITS: u32 = u64::BITS / 2;
let r0 = r1 << (KBITS - 1) | b >> (KBITS + 1);
let (mut q, mut u) = r0.div_rem(s1 as u64);
if q >> KBITS > 0 {
q -= 1;
u += s1 as u64;
}
let mut s = (s1 as u64) << KBITS | q;
let r = (u << (KBITS + 1)) | (b & ((1 << (KBITS + 1)) - 1));
let q2 = q * q;
let mut c = (u >> (KBITS - 1)) as i8 - (r < q2) as i8;
let mut r = r.wrapping_sub(q2);
if c < 0 {
let (new_r, c1) = r.overflowing_add(s);
s -= 1;
let (new_r, c2) = new_r.overflowing_add(s);
r = new_r;
c += c1 as i8 + c2 as i8;
}
(s, (c as u128) << u64::BITS | r as u128)
}
fn normalized_cbrt_rem(self) -> (u64, u128) {
debug_assert!(self.leading_zeros() <= 2);
let (c1, r1) = if self.leading_zeros() > 0 {
let a = (self >> 63) as u64;
let (mut c, _) = a.normalized_cbrt_rem();
c >>= 1;
(c, (a >> 3) - (c as u64).pow(3))
} else {
let a = (self >> 66) as u64;
a.normalized_cbrt_rem()
};
const KBITS: u32 = 22;
let r0 = ((r1 as u128) << KBITS) | (self >> (2 * KBITS) & ((1 << KBITS) - 1));
let (q, u) = r0.div_rem(3 * (c1 as u128).pow(2));
let mut c = ((c1 as u64) << KBITS) + (q as u64);
let t1 = (u << (2 * KBITS)) | (self & ((1 << (2 * KBITS)) - 1));
let t2 = (((3 * (c1 as u128)) << KBITS) + q) * q.pow(2);
let mut r = t1 as i128 - t2 as i128;
while r < 0 {
r += 3 * (c as i128 - 1) * c as i128 + 1;
c -= 1;
}
(c, r as u128)
}
}
impl SquareRootRem for u8 {
type Output = u8;
#[inline]
fn sqrt_rem(&self) -> (u8, u8) {
let mut s = 0;
let e = fix_sqrt_error!(u8, self, s);
(s, e)
}
}
impl CubicRootRem for u8 {
type Output = u8;
#[inline]
fn cbrt_rem(&self) -> (u8, u8) {
let mut c = 0;
let e = fix_cbrt_error!(u8, self, c);
(c, e)
}
}
macro_rules! impl_rootrem_using_normalized {
($t:ty, $half:ty) => {
impl SquareRootRem for $t {
type Output = $half;
fn sqrt_rem(&self) -> ($half, $t) {
if *self == 0 {
return (0, 0);
}
let shift = self.leading_zeros() & !1; let (mut root, mut rem) = (self << shift).normalized_sqrt_rem();
if shift != 0 {
root >>= shift / 2;
rem = self - (root as $t).pow(2);
}
(root, rem)
}
}
impl CubicRootRem for $t {
type Output = $half;
fn cbrt_rem(&self) -> ($half, $t) {
if *self == 0 {
return (0, 0);
}
let mut shift = self.leading_zeros();
shift -= shift % 3; let (mut root, mut rem) = (self << shift).normalized_cbrt_rem();
if shift != 0 {
root >>= shift / 3;
rem = self - (root as $t).pow(3);
}
(root, rem)
}
}
};
}
impl_rootrem_using_normalized!(u16, u8);
impl_rootrem_using_normalized!(u32, u16);
impl_rootrem_using_normalized!(u64, u32);
impl_rootrem_using_normalized!(u128, u64);
#[cfg(test)]
mod tests {
use super::*;
use crate::math::{CubicRoot, SquareRoot};
use rand::random;
#[test]
fn test_sqrt() {
assert_eq!(2u8.sqrt_rem(), (1, 1));
assert_eq!(2u16.sqrt_rem(), (1, 1));
assert_eq!(2u32.sqrt_rem(), (1, 1));
assert_eq!(2u64.sqrt_rem(), (1, 1));
assert_eq!(2u128.sqrt_rem(), (1, 1));
assert_eq!(u8::MAX.sqrt_rem(), (15, 30));
assert_eq!(u16::MAX.sqrt_rem(), (u8::MAX, (u8::MAX as u16) * 2));
assert_eq!(u32::MAX.sqrt_rem(), (u16::MAX, (u16::MAX as u32) * 2));
assert_eq!(u64::MAX.sqrt_rem(), (u32::MAX, (u32::MAX as u64) * 2));
assert_eq!(u128::MAX.sqrt_rem(), (u64::MAX, (u64::MAX as u128) * 2));
assert_eq!((u8::MAX / 2).sqrt_rem(), (11, 6));
assert_eq!((u16::MAX / 2).sqrt_rem(), (181, 6));
assert_eq!((u32::MAX / 2).sqrt_rem(), (46340, 88047));
assert_eq!((u64::MAX / 2).sqrt_rem(), (3037000499, 5928526806));
assert_eq!((u128::MAX / 2).sqrt_rem(), (13043817825332782212, 9119501915260492783));
assert_eq!(65533u32.sqrt_rem(), (255, 508));
macro_rules! random_case {
($T:ty) => {
let n: $T = random();
let (root, rem) = n.sqrt_rem();
assert_eq!(root, n.sqrt());
assert!(rem <= (root as $T) * 2, "sqrt({}) remainder too large", n);
assert_eq!(n, (root as $T).pow(2) + rem, "sqrt({}) != {}, {}", n, root, rem);
};
}
const N: u32 = 10000;
for _ in 0..N {
random_case!(u8);
random_case!(u16);
random_case!(u32);
random_case!(u64);
random_case!(u128);
}
}
#[test]
fn test_cbrt() {
assert_eq!(2u8.cbrt_rem(), (1, 1));
assert_eq!(2u16.cbrt_rem(), (1, 1));
assert_eq!(2u32.cbrt_rem(), (1, 1));
assert_eq!(2u64.cbrt_rem(), (1, 1));
assert_eq!(2u128.cbrt_rem(), (1, 1));
assert_eq!((u8::MAX / 2).cbrt_rem(), (5, 2));
assert_eq!((u16::MAX / 2).cbrt_rem(), (31, 2976));
assert_eq!((u32::MAX / 2).cbrt_rem(), (1290, 794647));
assert_eq!((u64::MAX / 2).cbrt_rem(), (2097151, 13194133241856));
assert_eq!((u128::MAX / 2).cbrt_rem(), (5541191377756, 58550521324026917344808511));
assert_eq!((u8::MAX / 4).cbrt_rem(), (3, 36));
assert_eq!((u16::MAX / 4).cbrt_rem(), (25, 758));
assert_eq!((u32::MAX / 4).cbrt_rem(), (1023, 3142656));
assert_eq!((u64::MAX / 4).cbrt_rem(), (1664510, 5364995536903));
assert_eq!((u128::MAX / 4).cbrt_rem(), (4398046511103, 58028439341489006246363136));
macro_rules! random_case {
($T:ty) => {
let n: $T = random();
let (root, rem) = n.cbrt_rem();
assert_eq!(root, n.cbrt());
let root = root as $T;
assert!(rem <= 3 * (root * root + root), "cbrt({}) remainder too large", n);
assert_eq!(n, root.pow(3) + rem, "cbrt({}) != {}, {}", n, root, rem);
};
}
const N: u32 = 10000;
for _ in 0..N {
random_case!(u16);
random_case!(u32);
random_case!(u64);
random_case!(u128);
}
}
}