1use super::{CubicRootRem, SquareRootRem};
2use crate::DivRem;
3
4pub(crate) trait NormalizedRootRem: Sized {
5 type OutputRoot;
6
7 fn normalized_sqrt_rem(self) -> (Self::OutputRoot, Self);
10
11 fn normalized_cbrt_rem(self) -> (Self::OutputRoot, Self);
14}
15
16const RSQRT_TAB: [u8; 96] = [
19 0xfc, 0xf4, 0xed, 0xe6, 0xdf, 0xd9, 0xd3, 0xcd, 0xc7, 0xc2, 0xbc, 0xb7, 0xb2, 0xad, 0xa9, 0xa4,
20 0xa0, 0x9c, 0x98, 0x94, 0x90, 0x8c, 0x88, 0x85, 0x81, 0x7e, 0x7b, 0x77, 0x74, 0x71, 0x6e, 0x6b,
21 0x69, 0x66, 0x63, 0x61, 0x5e, 0x5b, 0x59, 0x57, 0x54, 0x52, 0x50, 0x4d, 0x4b, 0x49, 0x47, 0x45,
22 0x43, 0x41, 0x3f, 0x3d, 0x3b, 0x39, 0x37, 0x36, 0x34, 0x32, 0x30, 0x2f, 0x2d, 0x2c, 0x2a, 0x28,
23 0x27, 0x25, 0x24, 0x22, 0x21, 0x1f, 0x1e, 0x1d, 0x1b, 0x1a, 0x19, 0x17, 0x16, 0x15, 0x14, 0x12,
24 0x11, 0x10, 0x0f, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01,
25];
26
27const RCBRT_TAB: [u8; 56] = [
30 0xf6, 0xe4, 0xd4, 0xc6, 0xb9, 0xae, 0xa4, 0x9b, 0x92, 0x8a, 0x83, 0x7c, 0x76, 0x70, 0x6b, 0x66,
31 0x61, 0x5c, 0x57, 0x53, 0x4f, 0x4b, 0x48, 0x44, 0x41, 0x3e, 0x3b, 0x38, 0x35, 0x32, 0x2f, 0x2d,
32 0x2a, 0x28, 0x25, 0x23, 0x21, 0x1f, 0x1d, 0x1b, 0x19, 0x17, 0x15, 0x13, 0x11, 0x10, 0x0e, 0x0c,
33 0x0b, 0x09, 0x08, 0x06, 0x05, 0x03, 0x02, 0x01,
34];
35
36macro_rules! fix_sqrt_error {
39 ($t:ty, $n:ident, $s:ident) => {{
40 let mut e = $n - ($s as $t).pow(2);
41 let mut elim = 2 * $s as $t + 1;
42 while e >= elim {
43 $s += 1;
44 e -= elim;
45 elim += 2;
46 }
47 e
48 }};
49}
50
51macro_rules! fix_cbrt_error {
54 ($t:ty, $n:ident, $c:ident) => {{
55 let cc = ($c as $t).pow(2);
56 let mut e = $n - cc * ($c as $t);
57 let mut elim = 3 * (cc + $c as $t) + 1;
58 while e >= elim {
59 $c += 1;
60 e -= elim;
61 elim += 6 * ($c as $t);
62 }
63 e
64 }};
65}
66
67impl NormalizedRootRem for u16 {
68 type OutputRoot = u8;
69
70 fn normalized_sqrt_rem(self) -> (u8, u16) {
71 debug_assert!(self.leading_zeros() <= 1);
72
73 let r = 0x100 | RSQRT_TAB[(self >> 9) as usize - 32] as u32; let s = (r * self as u32) >> 16;
76 let mut s = (s - 1) as u8; let e = fix_sqrt_error!(u16, self, s);
80 (s, e)
81 }
82
83 fn normalized_cbrt_rem(self) -> (u8, u16) {
84 debug_assert!(self.leading_zeros() <= 2);
85
86 let adjust = self.leading_zeros() == 0;
88 let r = 0x100 | RCBRT_TAB[(self >> (9 + (3 * adjust as u8))) as usize - 8] as u32; let r2 = (r * r) >> (2 + 2 * adjust as u8);
90 let c = (r2 * self as u32) >> 24;
91 let mut c = (c - 1) as u8; let e = fix_cbrt_error!(u16, self, c);
96 (c, e)
97 }
98}
99
100#[inline]
102fn wmul16_hi(a: u16, b: u16) -> u16 {
103 (((a as u32) * (b as u32)) >> 16) as u16
104}
105
106impl NormalizedRootRem for u32 {
107 type OutputRoot = u16;
108
109 fn normalized_sqrt_rem(self) -> (u16, u32) {
110 debug_assert!(self.leading_zeros() <= 1);
113
114 let n16 = (self >> 16) as u16;
118 let r = 0x100 | RSQRT_TAB[(n16 >> 9) as usize - 32] as u32; let r = ((3 * r as u16) << 5) - (wmul32_hi(self, r * r * r) >> 11) as u16; let r = r << 1; let mut s = wmul16_hi(r, n16).saturating_mul(2); s -= 4; let e = self - (s as u32) * (s as u32);
131 s += wmul16_hi((e >> 16) as u16, r);
132
133 let e = fix_sqrt_error!(u32, self, s);
136 (s, e)
137 }
138
139 fn normalized_cbrt_rem(self) -> (u16, u32) {
140 debug_assert!(self.leading_zeros() <= 2);
143
144 let adjust = self.leading_zeros() < 2;
148 let n16 = (self >> (16 + 3 * adjust as u8)) as u16;
149 let r = 0x100 | RCBRT_TAB[(n16 >> 8) as usize - 8] as u32; let r3 = (r * r * r) >> 11;
155 let t = (4 << 11) - wmul16_hi(n16, r3 as u16); let mut r = ((r * t as u32 / 3) >> 4) as u16; r >>= adjust as u8; let r = r - 10; let mut c = wmul16_hi(r, wmul16_hi(r, (self >> 16) as u16)) >> 2;
162
163 let e = fix_cbrt_error!(u32, self, c);
166 (c, e)
167 }
168}
169
170#[inline]
172fn wmul32_hi(a: u32, b: u32) -> u32 {
173 (((a as u64) * (b as u64)) >> 32) as u32
174}
175
176impl NormalizedRootRem for u64 {
177 type OutputRoot = u32;
178
179 fn normalized_sqrt_rem(self) -> (u32, u64) {
180 debug_assert!(self.leading_zeros() <= 1);
183
184 let n32 = (self >> 32) as u32;
188 let r = 0x100 | RSQRT_TAB[(n32 >> 25) as usize - 32] as u32; let r = ((3 * r) << 21) - wmul32_hi(n32, (r * r * r) << 5); let t = (3 << 28) - wmul32_hi(r, wmul32_hi(r, n32)); let r = wmul32_hi(r, t); let r = r << 4; let mut s = wmul32_hi(r, n32) << 1;
202 s -= 10; let e = self - (s as u64) * (s as u64);
206 s += wmul32_hi((e >> 32) as u32, r);
207
208 let e = fix_sqrt_error!(u64, self, s);
211 (s, e)
212 }
213
214 fn normalized_cbrt_rem(self) -> (u32, u64) {
215 debug_assert!(self.leading_zeros() <= 2);
218
219 let adjust = self.leading_zeros() == 0;
223 let n32 = (self >> (32 + 3 * adjust as u8)) as u32;
224 let r = 0x100 | RCBRT_TAB[(n32 >> 25) as usize - 8] as u32; let t = (4 << 23) - wmul32_hi(n32, r * r * r);
230 let r = r * (t / 3); let t = (4 << 28) - wmul32_hi(r, wmul32_hi(r, wmul32_hi(r, n32)));
236 let mut r = wmul32_hi(r, t) / 3; r >>= adjust as u8; let r = r - 1; let mut c = wmul32_hi(r, wmul32_hi(r, (self >> 32) as u32));
242
243 let e = fix_cbrt_error!(u64, self, c);
246 (c, e)
247 }
248}
249
250impl NormalizedRootRem for u128 {
251 type OutputRoot = u64;
252
253 fn normalized_sqrt_rem(self) -> (u64, u128) {
254 debug_assert!(self.leading_zeros() <= 1);
255
256 let (a, b) = (self >> u64::BITS, self & u64::MAX as u128);
261 let (a, b) = (a as u64, b as u64);
262 let (s1, r1) = a.normalized_sqrt_rem();
263
264 const KBITS: u32 = u64::BITS / 2;
268 let r0 = r1 << (KBITS - 1) | b >> (KBITS + 1);
269 let (mut q, mut u) = r0.div_rem(s1 as u64);
270 if q >> KBITS > 0 {
271 q -= 1;
273 u += s1 as u64;
274 }
275
276 let mut s = (s1 as u64) << KBITS | q;
277 let r = (u << (KBITS + 1)) | (b & ((1 << (KBITS + 1)) - 1));
278 let q2 = q * q;
279 let mut c = (u >> (KBITS - 1)) as i8 - (r < q2) as i8;
280 let mut r = r.wrapping_sub(q2);
281
282 if c < 0 {
284 let (new_r, c1) = r.overflowing_add(s);
285 s -= 1;
286 let (new_r, c2) = new_r.overflowing_add(s);
287 r = new_r;
288 c += c1 as i8 + c2 as i8;
289 }
290 (s, (c as u128) << u64::BITS | r as u128)
291 }
292
293 fn normalized_cbrt_rem(self) -> (u64, u128) {
294 debug_assert!(self.leading_zeros() <= 2);
295
296 let (c1, r1) = if self.leading_zeros() > 0 {
315 let a = (self >> 63) as u64;
317 let (mut c, _) = a.normalized_cbrt_rem();
318 c >>= 1;
319 (c, (a >> 3) - (c as u64).pow(3))
320 } else {
321 let a = (self >> 66) as u64;
322 a.normalized_cbrt_rem()
323 };
324
325 const KBITS: u32 = 22;
327 let r0 = ((r1 as u128) << KBITS) | (self >> (2 * KBITS) & ((1 << KBITS) - 1));
328 let (q, u) = r0.div_rem(3 * (c1 as u128).pow(2));
329 let mut c = ((c1 as u64) << KBITS) + (q as u64); let t1 = (u << (2 * KBITS)) | (self & ((1 << (2 * KBITS)) - 1));
333 let t2 = (((3 * (c1 as u128)) << KBITS) + q) * q.pow(2);
334 let mut r = t1 as i128 - t2 as i128;
335
336 while r < 0 {
338 r += 3 * (c as i128 - 1) * c as i128 + 1;
339 c -= 1;
340 }
341 (c, r as u128)
342 }
343}
344
345impl SquareRootRem for u8 {
347 type Output = u8;
348
349 #[inline]
350 fn sqrt_rem(&self) -> (u8, u8) {
351 let mut s = 0;
353 let e = fix_sqrt_error!(u8, self, s);
354 (s, e)
355 }
356}
357
358impl CubicRootRem for u8 {
359 type Output = u8;
360
361 #[inline]
362 fn cbrt_rem(&self) -> (u8, u8) {
363 let mut c = 0;
365 let e = fix_cbrt_error!(u8, self, c);
366 (c, e)
367 }
368}
369
370macro_rules! impl_rootrem_using_normalized {
371 ($t:ty, $half:ty) => {
372 impl SquareRootRem for $t {
373 type Output = $half;
374
375 fn sqrt_rem(&self) -> ($half, $t) {
376 if *self == 0 {
377 return (0, 0);
378 }
379
380 let shift = self.leading_zeros() & !1; let (mut root, mut rem) = (self << shift).normalized_sqrt_rem();
383 if shift != 0 {
384 root >>= shift / 2;
385 rem = self - (root as $t).pow(2);
386 }
387 (root, rem)
388 }
389 }
390
391 impl CubicRootRem for $t {
392 type Output = $half;
393
394 fn cbrt_rem(&self) -> ($half, $t) {
395 if *self == 0 {
396 return (0, 0);
397 }
398
399 let mut shift = self.leading_zeros();
401 shift -= shift % 3; let (mut root, mut rem) = (self << shift).normalized_cbrt_rem();
403 if shift != 0 {
404 root >>= shift / 3;
405 rem = self - (root as $t).pow(3);
406 }
407 (root, rem)
408 }
409 }
410 };
411}
412impl_rootrem_using_normalized!(u16, u8);
413impl_rootrem_using_normalized!(u32, u16);
414impl_rootrem_using_normalized!(u64, u32);
415impl_rootrem_using_normalized!(u128, u64);
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use crate::math::{CubicRoot, SquareRoot};
421 use rand::random;
422
423 #[test]
424 fn test_sqrt() {
425 assert_eq!(2u8.sqrt_rem(), (1, 1));
426 assert_eq!(2u16.sqrt_rem(), (1, 1));
427 assert_eq!(2u32.sqrt_rem(), (1, 1));
428 assert_eq!(2u64.sqrt_rem(), (1, 1));
429 assert_eq!(2u128.sqrt_rem(), (1, 1));
430
431 assert_eq!(u8::MAX.sqrt_rem(), (15, 30));
432 assert_eq!(u16::MAX.sqrt_rem(), (u8::MAX, (u8::MAX as u16) * 2));
433 assert_eq!(u32::MAX.sqrt_rem(), (u16::MAX, (u16::MAX as u32) * 2));
434 assert_eq!(u64::MAX.sqrt_rem(), (u32::MAX, (u32::MAX as u64) * 2));
435 assert_eq!(u128::MAX.sqrt_rem(), (u64::MAX, (u64::MAX as u128) * 2));
436
437 assert_eq!((u8::MAX / 2).sqrt_rem(), (11, 6));
438 assert_eq!((u16::MAX / 2).sqrt_rem(), (181, 6));
439 assert_eq!((u32::MAX / 2).sqrt_rem(), (46340, 88047));
440 assert_eq!((u64::MAX / 2).sqrt_rem(), (3037000499, 5928526806));
441 assert_eq!((u128::MAX / 2).sqrt_rem(), (13043817825332782212, 9119501915260492783));
442
443 assert_eq!(65533u32.sqrt_rem(), (255, 508));
445
446 macro_rules! random_case {
447 ($T:ty) => {
448 let n: $T = random();
449 let (root, rem) = n.sqrt_rem();
450 assert_eq!(root, n.sqrt());
451
452 assert!(rem <= (root as $T) * 2, "sqrt({}) remainder too large", n);
453 assert_eq!(n, (root as $T).pow(2) + rem, "sqrt({}) != {}, {}", n, root, rem);
454 };
455 }
456
457 const N: u32 = 10000;
458 for _ in 0..N {
459 random_case!(u8);
460 random_case!(u16);
461 random_case!(u32);
462 random_case!(u64);
463 random_case!(u128);
464 }
465 }
466
467 #[test]
468 fn test_cbrt() {
469 assert_eq!(2u8.cbrt_rem(), (1, 1));
470 assert_eq!(2u16.cbrt_rem(), (1, 1));
471 assert_eq!(2u32.cbrt_rem(), (1, 1));
472 assert_eq!(2u64.cbrt_rem(), (1, 1));
473 assert_eq!(2u128.cbrt_rem(), (1, 1));
474
475 assert_eq!((u8::MAX / 2).cbrt_rem(), (5, 2));
476 assert_eq!((u16::MAX / 2).cbrt_rem(), (31, 2976));
477 assert_eq!((u32::MAX / 2).cbrt_rem(), (1290, 794647));
478 assert_eq!((u64::MAX / 2).cbrt_rem(), (2097151, 13194133241856));
479 assert_eq!((u128::MAX / 2).cbrt_rem(), (5541191377756, 58550521324026917344808511));
480 assert_eq!((u8::MAX / 4).cbrt_rem(), (3, 36));
481 assert_eq!((u16::MAX / 4).cbrt_rem(), (25, 758));
482 assert_eq!((u32::MAX / 4).cbrt_rem(), (1023, 3142656));
483 assert_eq!((u64::MAX / 4).cbrt_rem(), (1664510, 5364995536903));
484 assert_eq!((u128::MAX / 4).cbrt_rem(), (4398046511103, 58028439341489006246363136));
485
486 macro_rules! random_case {
487 ($T:ty) => {
488 let n: $T = random();
489 let (root, rem) = n.cbrt_rem();
490 assert_eq!(root, n.cbrt());
491
492 let root = root as $T;
493 assert!(rem <= 3 * (root * root + root), "cbrt({}) remainder too large", n);
494 assert_eq!(n, root.pow(3) + rem, "cbrt({}) != {}, {}", n, root, rem);
495 };
496 }
497
498 const N: u32 = 10000;
499 for _ in 0..N {
500 random_case!(u16);
501 random_case!(u32);
502 random_case!(u64);
503 random_case!(u128);
504 }
505 }
506}