dashu_base/ring/
root.rs

1use super::{CubicRootRem, SquareRootRem};
2use crate::DivRem;
3
4pub(crate) trait NormalizedRootRem: Sized {
5    type OutputRoot;
6
7    /// Square root with the normalized input such that highest or second
8    /// highest bit are set. For internal use only.
9    fn normalized_sqrt_rem(self) -> (Self::OutputRoot, Self);
10
11    /// Cubic root with the normalized input such that at least one of the
12    /// highest three bits are set. For internal use only.
13    fn normalized_cbrt_rem(self) -> (Self::OutputRoot, Self);
14}
15
16// Estimations of normalized 1/sqrt(x) with 9 bits precision. Specifically
17// (rsqrt_tab[i] + 0x100) / 0x200 ≈ (sqrt(32) / sqrt(32 + i))
18const RSQRT_TAB: [u8; 96] = [
19    0xfc, 0xf4, 0xed, 0xe6, 0xdf, 0xd9, 0xd3, 0xcd, 0xc7, 0xc2, 0xbc, 0xb7, 0xb2, 0xad, 0xa9, 0xa4,
20    0xa0, 0x9c, 0x98, 0x94, 0x90, 0x8c, 0x88, 0x85, 0x81, 0x7e, 0x7b, 0x77, 0x74, 0x71, 0x6e, 0x6b,
21    0x69, 0x66, 0x63, 0x61, 0x5e, 0x5b, 0x59, 0x57, 0x54, 0x52, 0x50, 0x4d, 0x4b, 0x49, 0x47, 0x45,
22    0x43, 0x41, 0x3f, 0x3d, 0x3b, 0x39, 0x37, 0x36, 0x34, 0x32, 0x30, 0x2f, 0x2d, 0x2c, 0x2a, 0x28,
23    0x27, 0x25, 0x24, 0x22, 0x21, 0x1f, 0x1e, 0x1d, 0x1b, 0x1a, 0x19, 0x17, 0x16, 0x15, 0x14, 0x12,
24    0x11, 0x10, 0x0f, 0x0d, 0x0c, 0x0b, 0x0a, 0x09, 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01,
25];
26
27// Estimations of normalized 1/cbrt(x) with 9 bits precision. Specifically
28// (rcbrt_tab[i] + 0x100) / 0x200 ≈ (cbrt(8) / cbrt(8 + i))
29const RCBRT_TAB: [u8; 56] = [
30    0xf6, 0xe4, 0xd4, 0xc6, 0xb9, 0xae, 0xa4, 0x9b, 0x92, 0x8a, 0x83, 0x7c, 0x76, 0x70, 0x6b, 0x66,
31    0x61, 0x5c, 0x57, 0x53, 0x4f, 0x4b, 0x48, 0x44, 0x41, 0x3e, 0x3b, 0x38, 0x35, 0x32, 0x2f, 0x2d,
32    0x2a, 0x28, 0x25, 0x23, 0x21, 0x1f, 0x1d, 0x1b, 0x19, 0x17, 0x15, 0x13, 0x11, 0x10, 0x0e, 0x0c,
33    0x0b, 0x09, 0x08, 0x06, 0x05, 0x03, 0x02, 0x01,
34];
35
36/// Fix the estimation error of `sqrt(n)`, `s` is the (mutable) estimation variable,
37/// This procedure requires s <= `sqrt(n)`, returns the error `n - s^2`.
38macro_rules! fix_sqrt_error {
39    ($t:ty, $n:ident, $s:ident) => {{
40        let mut e = $n - ($s as $t).pow(2);
41        let mut elim = 2 * $s as $t + 1;
42        while e >= elim {
43            $s += 1;
44            e -= elim;
45            elim += 2;
46        }
47        e
48    }};
49}
50
51/// Fix the estimation error of `cbrt(n)`, `c` is the (mutable) estimation variable,
52/// This procedure requires c <= `cbrt(n)`, returns the error `n - c^3`.
53macro_rules! fix_cbrt_error {
54    ($t:ty, $n:ident, $c:ident) => {{
55        let cc = ($c as $t).pow(2);
56        let mut e = $n - cc * ($c as $t);
57        let mut elim = 3 * (cc + $c as $t) + 1;
58        while e >= elim {
59            $c += 1;
60            e -= elim;
61            elim += 6 * ($c as $t);
62        }
63        e
64    }};
65}
66
67impl NormalizedRootRem for u16 {
68    type OutputRoot = u8;
69
70    fn normalized_sqrt_rem(self) -> (u8, u16) {
71        debug_assert!(self.leading_zeros() <= 1);
72
73        // retrieved r ≈ √32 / √(n >> 9) * 0x200 = 1 / √(n >> 14) * 2^9 = 2^16 / √n.
74        let r = 0x100 | RSQRT_TAB[(self >> 9) as usize - 32] as u32; // 9 bits
75        let s = (r * self as u32) >> 16;
76        let mut s = (s - 1) as u8; // to make sure s is an underestimate
77
78        // then fix the estimation error
79        let e = fix_sqrt_error!(u16, self, s);
80        (s, e)
81    }
82
83    fn normalized_cbrt_rem(self) -> (u8, u16) {
84        debug_assert!(self.leading_zeros() <= 2);
85
86        // retrieved r ≈ ∛8 / ∛(n >> 9) * 0x200 = 1 / ∛(n >> 12) * 2^9 = 2^13 / ∛n.
87        let adjust = self.leading_zeros() == 0;
88        let r = 0x100 | RCBRT_TAB[(self >> (9 + (3 * adjust as u8))) as usize - 8] as u32; // 9 bits
89        let r2 = (r * r) >> (2 + 2 * adjust as u8);
90        let c = (r2 * self as u32) >> 24;
91        let mut c = (c - 1) as u8; // to make sure c is an underestimate
92
93        // step6: fix the estimation error, at most 2 steps are needed
94        // if we use more bits to estimate the initial guess, less steps can be required
95        let e = fix_cbrt_error!(u16, self, c);
96        (c, e)
97    }
98}
99
100/// Get the high part of widening mul on two u16 integers
101#[inline]
102fn wmul16_hi(a: u16, b: u16) -> u16 {
103    (((a as u32) * (b as u32)) >> 16) as u16
104}
105
106impl NormalizedRootRem for u32 {
107    type OutputRoot = u16;
108
109    fn normalized_sqrt_rem(self) -> (u16, u32) {
110        // Use newton's method on 1/sqrt(n)
111        // x_{i+1} = x_i * (3 - n*x_i^2) / 2
112        debug_assert!(self.leading_zeros() <= 1);
113
114        // step1: lookup initial estimation of normalized 1/√n. The lookup table uses the highest 7 bits,
115        // since the input is normalized, the lookup index must be larger than 2**(7-2) = 32.
116        // then the retrieved r ≈ √32 / √(n >> 25) * 0x200 = 1 / √(n >> 30) / 2^9 = 2^24 / √n.
117        let n16 = (self >> 16) as u16;
118        let r = 0x100 | RSQRT_TAB[(n16 >> 9) as usize - 32] as u32; // 9 bits
119
120        // step2: first Newton iteration (without dividing by 2)
121        // r will be an estimation of 2^(24+6) / √n with 16 bits effective precision
122        let r = ((3 * r as u16) << 5) - (wmul32_hi(self, r * r * r) >> 11) as u16; // 15 bits
123
124        // step3: √n = x * 1/√n
125        let r = r << 1; // normalize to 16 bits, now r estimates 2^31 / √n
126        let mut s = wmul16_hi(r, n16).saturating_mul(2); // overflowing can happen
127        s -= 4; // to make sure s is an underestimate
128
129        // step4: second Newton iteration on √n
130        let e = self - (s as u32) * (s as u32);
131        s += wmul16_hi((e >> 16) as u16, r);
132
133        // step5: fix the estimation error, at most 2 steps are needed
134        // if we use more bits to estimate the initial guess, less steps can be required
135        let e = fix_sqrt_error!(u32, self, s);
136        (s, e)
137    }
138
139    fn normalized_cbrt_rem(self) -> (u16, u32) {
140        // Use newton's method on 1/cbrt(n)
141        // x_{i+1} = x_i * (4 - n*x_i^3) / 3
142        debug_assert!(self.leading_zeros() <= 2);
143
144        // step1: lookup initial estimation of 1/∛x. The lookup table uses the highest 6 bits up to 30rd.
145        // if the input is 32/31 bit, then shift it to 29/28 bit.
146        // retrieved r ≈ ∛8 / ∛(n >> 24) * 0x200 = 1 / ∛(n >> 27) * 2^9 = 2^18 / ∛n.
147        let adjust = self.leading_zeros() < 2;
148        let n16 = (self >> (16 + 3 * adjust as u8)) as u16;
149        let r = 0x100 | RCBRT_TAB[(n16 >> 8) as usize - 8] as u32; // 9 bits
150
151        // step2: first Newton iteration
152        // required shift = 18 * 3 - 11 - 16 * 2 - * 2 = 11
153        // afterwards, r ≈ 2^(18+11-4) / ∛n
154        let r3 = (r * r * r) >> 11;
155        let t = (4 << 11) - wmul16_hi(n16, r3 as u16); // 13 bits
156        let mut r = ((r * t as u32 / 3) >> 4) as u16; // 16 bits
157        r >>= adjust as u8; // recover the adjustment if needed
158
159        // step5: ∛x = x * (1/∛x)^2
160        let r = r - 10; // to make sure c is an underestimate
161        let mut c = wmul16_hi(r, wmul16_hi(r, (self >> 16) as u16)) >> 2;
162
163        // step6: fix the estimation error, at most 2 steps are needed
164        // if we use more bits to estimate the initial guess, less steps can be required
165        let e = fix_cbrt_error!(u32, self, c);
166        (c, e)
167    }
168}
169
170/// Get the high part of widening mul on two u32 integers
171#[inline]
172fn wmul32_hi(a: u32, b: u32) -> u32 {
173    (((a as u64) * (b as u64)) >> 32) as u32
174}
175
176impl NormalizedRootRem for u64 {
177    type OutputRoot = u32;
178
179    fn normalized_sqrt_rem(self) -> (u32, u64) {
180        // Use newton's method on 1/sqrt(n)
181        // x_{i+1} = x_i * (3 - n*x_i^2) / 2
182        debug_assert!(self.leading_zeros() <= 1);
183
184        // step1: lookup initial estimation of normalized 1/√n. The lookup table uses the highest 7 bits,
185        // since the input is normalized, the lookup index must be larger than 2**(7-2) = 32.
186        // then the retrieved r ≈ √32 / √(n >> 57) * 0x200 = 1 / √(n >> 62) / 2^9 = 2^40 / √n.
187        let n32 = (self >> 32) as u32;
188        let r = 0x100 | RSQRT_TAB[(n32 >> 25) as usize - 32] as u32; // 9 bits
189
190        // step2: first Newton iteration (without dividing by 2)
191        // afterwards, r ≈ 2^(40+22) / √n with 16 bits effective precision
192        let r = ((3 * r) << 21) - wmul32_hi(n32, (r * r * r) << 5); // 31 bits
193
194        // step3: second Newton iteration (without dividing by 2)
195        // afterwards, r ≈ 2^(40+19) / √n with 32 bits effective precision
196        let t = (3 << 28) - wmul32_hi(r, wmul32_hi(r, n32)); // 29 bits
197        let r = wmul32_hi(r, t); // 28 bits
198
199        // step4: √n = x * 1/√n
200        let r = r << 4; // normalize to 32 bits, now r estimates 2^63 / √n
201        let mut s = wmul32_hi(r, n32) << 1;
202        s -= 10; // to make sure s is an underestimate
203
204        // step5: third Newton iteration on √n
205        let e = self - (s as u64) * (s as u64);
206        s += wmul32_hi((e >> 32) as u32, r);
207
208        // step6: fix the estimation error, at most 2 steps are needed
209        // if we use more bits to estimate the initial guess, less steps can be required
210        let e = fix_sqrt_error!(u64, self, s);
211        (s, e)
212    }
213
214    fn normalized_cbrt_rem(self) -> (u32, u64) {
215        // Use newton's method on 1/cbrt(n)
216        // x_{i+1} = x_i * (4 - n*x_i^3) / 3
217        debug_assert!(self.leading_zeros() <= 2);
218
219        // step1: lookup initial estimation of 1/∛x. The lookup table uses the highest 6 bits up to 63rd.
220        // if the input has 64 bits, then shift it to 61 bits.
221        // retrieved r ≈ ∛8 / ∛(n >> 57) * 0x200 = 1 / ∛(n >> 60) * 2^9 = 2^29 / ∛n.
222        let adjust = self.leading_zeros() == 0;
223        let n32 = (self >> (32 + 3 * adjust as u8)) as u32;
224        let r = 0x100 | RCBRT_TAB[(n32 >> 25) as usize - 8] as u32; // 9 bits
225
226        // step2: first Newton iteration
227        // required shift = 29 * 3 - 32 * 2 = 23
228        // afterwards, r ≈ 2^(29+23) / ∛n = 2^52 / ∛n
229        let t = (4 << 23) - wmul32_hi(n32, r * r * r);
230        let r = r * (t / 3); // 32 bits
231
232        // step3: second Newton iteration
233        // required shift = 52 * 3 - 32 * 4 = 28
234        // afterwards, r ≈ 2^(52+28-32) / ∛n = 2^48 / ∛n
235        let t = (4 << 28) - wmul32_hi(r, wmul32_hi(r, wmul32_hi(r, n32)));
236        let mut r = wmul32_hi(r, t) / 3; // 28 bits
237        r >>= adjust as u8; // recover the adjustment if needed
238
239        // step4: ∛x = x * (1/∛x)^2 = x * (2^48/∛x)^2 / 2^(32*3)
240        let r = r - 1; // to make sure c is an underestimate
241        let mut c = wmul32_hi(r, wmul32_hi(r, (self >> 32) as u32));
242
243        // step5: fix the estimation error, at most 3 steps are needed
244        // if we use more bits to estimate the initial guess, less steps can be required
245        let e = fix_cbrt_error!(u64, self, c);
246        (c, e)
247    }
248}
249
250impl NormalizedRootRem for u128 {
251    type OutputRoot = u64;
252
253    fn normalized_sqrt_rem(self) -> (u64, u128) {
254        debug_assert!(self.leading_zeros() <= 1);
255
256        // use the "Karatsuba Square Root" algorithm
257        // (see the implementation in dashu_int, or https://hal.inria.fr/inria-00072854/en/)
258
259        // step1: calculate sqrt on high parts
260        let (a, b) = (self >> u64::BITS, self & u64::MAX as u128);
261        let (a, b) = (a as u64, b as u64);
262        let (s1, r1) = a.normalized_sqrt_rem();
263
264        // step2: estimate the result with low parts
265        // note that r1 <= 2*s1 < 2^(KBITS + 1)
266        // here r0 = (r1*B + b) / 2
267        const KBITS: u32 = u64::BITS / 2;
268        let r0 = r1 << (KBITS - 1) | b >> (KBITS + 1);
269        let (mut q, mut u) = r0.div_rem(s1 as u64);
270        if q >> KBITS > 0 {
271            // if q >= B (then q = B), reduce the overestimate
272            q -= 1;
273            u += s1 as u64;
274        }
275
276        let mut s = (s1 as u64) << KBITS | q;
277        let r = (u << (KBITS + 1)) | (b & ((1 << (KBITS + 1)) - 1));
278        let q2 = q * q;
279        let mut c = (u >> (KBITS - 1)) as i8 - (r < q2) as i8;
280        let mut r = r.wrapping_sub(q2);
281
282        // step3: fix the estimation error if necessary
283        if c < 0 {
284            let (new_r, c1) = r.overflowing_add(s);
285            s -= 1;
286            let (new_r, c2) = new_r.overflowing_add(s);
287            r = new_r;
288            c += c1 as i8 + c2 as i8;
289        }
290        (s, (c as u128) << u64::BITS | r as u128)
291    }
292
293    fn normalized_cbrt_rem(self) -> (u64, u128) {
294        debug_assert!(self.leading_zeros() <= 2);
295
296        /*
297         * the following algorithm is similar to the "Karatsuba Square Root" above:
298         * assume n = a*B^3 + b2*B^2 + b1*B + b0, B=2^k, a has roughly 3k bits
299         * 1. calculate cbrt on high part:
300         *     c1, r1 = cbrt_rem(a)
301         * 2. estimate the root with low part
302         *     q, u = div_rem(r1*B + b2, 3*c1^2)
303         *     c = c1*B + q
304         *     r = u*B^2 + b1*B + b0 - 3*c1*q^2*B - q^3
305         *
306         * 3. if a5 is normalized, then only few adjustments are needed
307         *     while r < 0 {
308         *         r += 3*c^2 - 3*c + 1
309         *         c -= 1
310         *     }
311         */
312
313        // step1: calculate cbrt on high 62 bits
314        let (c1, r1) = if self.leading_zeros() > 0 {
315            // actually on high 65 bits
316            let a = (self >> 63) as u64;
317            let (mut c, _) = a.normalized_cbrt_rem();
318            c >>= 1;
319            (c, (a >> 3) - (c as u64).pow(3))
320        } else {
321            let a = (self >> 66) as u64;
322            a.normalized_cbrt_rem()
323        };
324
325        // step2: estimate the root with low part
326        const KBITS: u32 = 22;
327        let r0 = ((r1 as u128) << KBITS) | (self >> (2 * KBITS) & ((1 << KBITS) - 1));
328        let (q, u) = r0.div_rem(3 * (c1 as u128).pow(2));
329        let mut c = ((c1 as u64) << KBITS) + (q as u64); // here q might be larger than B
330
331        // r = u*B^2 + b1*B + b0 - 3*c1*q^2*B - q^3
332        let t1 = (u << (2 * KBITS)) | (self & ((1 << (2 * KBITS)) - 1));
333        let t2 = (((3 * (c1 as u128)) << KBITS) + q) * q.pow(2);
334        let mut r = t1 as i128 - t2 as i128;
335
336        // step3: adjustment, finishes in at most 4 steps
337        while r < 0 {
338            r += 3 * (c as i128 - 1) * c as i128 + 1;
339            c -= 1;
340        }
341        (c, r as u128)
342    }
343}
344
345// The implementation for u8 is very naive, because it's rarely used
346impl SquareRootRem for u8 {
347    type Output = u8;
348
349    #[inline]
350    fn sqrt_rem(&self) -> (u8, u8) {
351        // brute-force search, because there are only 16 possibilites.
352        let mut s = 0;
353        let e = fix_sqrt_error!(u8, self, s);
354        (s, e)
355    }
356}
357
358impl CubicRootRem for u8 {
359    type Output = u8;
360
361    #[inline]
362    fn cbrt_rem(&self) -> (u8, u8) {
363        // brute-force search, because there are only 7 possibilites.
364        let mut c = 0;
365        let e = fix_cbrt_error!(u8, self, c);
366        (c, e)
367    }
368}
369
370macro_rules! impl_rootrem_using_normalized {
371    ($t:ty, $half:ty) => {
372        impl SquareRootRem for $t {
373            type Output = $half;
374
375            fn sqrt_rem(&self) -> ($half, $t) {
376                if *self == 0 {
377                    return (0, 0);
378                }
379
380                // normalize the input and call the normalized subroutine
381                let shift = self.leading_zeros() & !1; // make sure shift is divisible by 2
382                let (mut root, mut rem) = (self << shift).normalized_sqrt_rem();
383                if shift != 0 {
384                    root >>= shift / 2;
385                    rem = self - (root as $t).pow(2);
386                }
387                (root, rem)
388            }
389        }
390
391        impl CubicRootRem for $t {
392            type Output = $half;
393
394            fn cbrt_rem(&self) -> ($half, $t) {
395                if *self == 0 {
396                    return (0, 0);
397                }
398
399                // normalize the input and call the normalized subroutine
400                let mut shift = self.leading_zeros();
401                shift -= shift % 3; // make sure shift is divisible by 3
402                let (mut root, mut rem) = (self << shift).normalized_cbrt_rem();
403                if shift != 0 {
404                    root >>= shift / 3;
405                    rem = self - (root as $t).pow(3);
406                }
407                (root, rem)
408            }
409        }
410    };
411}
412impl_rootrem_using_normalized!(u16, u8);
413impl_rootrem_using_normalized!(u32, u16);
414impl_rootrem_using_normalized!(u64, u32);
415impl_rootrem_using_normalized!(u128, u64);
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use crate::math::{CubicRoot, SquareRoot};
421    use rand::random;
422
423    #[test]
424    fn test_sqrt() {
425        assert_eq!(2u8.sqrt_rem(), (1, 1));
426        assert_eq!(2u16.sqrt_rem(), (1, 1));
427        assert_eq!(2u32.sqrt_rem(), (1, 1));
428        assert_eq!(2u64.sqrt_rem(), (1, 1));
429        assert_eq!(2u128.sqrt_rem(), (1, 1));
430
431        assert_eq!(u8::MAX.sqrt_rem(), (15, 30));
432        assert_eq!(u16::MAX.sqrt_rem(), (u8::MAX, (u8::MAX as u16) * 2));
433        assert_eq!(u32::MAX.sqrt_rem(), (u16::MAX, (u16::MAX as u32) * 2));
434        assert_eq!(u64::MAX.sqrt_rem(), (u32::MAX, (u32::MAX as u64) * 2));
435        assert_eq!(u128::MAX.sqrt_rem(), (u64::MAX, (u64::MAX as u128) * 2));
436
437        assert_eq!((u8::MAX / 2).sqrt_rem(), (11, 6));
438        assert_eq!((u16::MAX / 2).sqrt_rem(), (181, 6));
439        assert_eq!((u32::MAX / 2).sqrt_rem(), (46340, 88047));
440        assert_eq!((u64::MAX / 2).sqrt_rem(), (3037000499, 5928526806));
441        assert_eq!((u128::MAX / 2).sqrt_rem(), (13043817825332782212, 9119501915260492783));
442
443        // some cases from previous bugs
444        assert_eq!(65533u32.sqrt_rem(), (255, 508));
445
446        macro_rules! random_case {
447            ($T:ty) => {
448                let n: $T = random();
449                let (root, rem) = n.sqrt_rem();
450                assert_eq!(root, n.sqrt());
451
452                assert!(rem <= (root as $T) * 2, "sqrt({}) remainder too large", n);
453                assert_eq!(n, (root as $T).pow(2) + rem, "sqrt({}) != {}, {}", n, root, rem);
454            };
455        }
456
457        const N: u32 = 10000;
458        for _ in 0..N {
459            random_case!(u8);
460            random_case!(u16);
461            random_case!(u32);
462            random_case!(u64);
463            random_case!(u128);
464        }
465    }
466
467    #[test]
468    fn test_cbrt() {
469        assert_eq!(2u8.cbrt_rem(), (1, 1));
470        assert_eq!(2u16.cbrt_rem(), (1, 1));
471        assert_eq!(2u32.cbrt_rem(), (1, 1));
472        assert_eq!(2u64.cbrt_rem(), (1, 1));
473        assert_eq!(2u128.cbrt_rem(), (1, 1));
474
475        assert_eq!((u8::MAX / 2).cbrt_rem(), (5, 2));
476        assert_eq!((u16::MAX / 2).cbrt_rem(), (31, 2976));
477        assert_eq!((u32::MAX / 2).cbrt_rem(), (1290, 794647));
478        assert_eq!((u64::MAX / 2).cbrt_rem(), (2097151, 13194133241856));
479        assert_eq!((u128::MAX / 2).cbrt_rem(), (5541191377756, 58550521324026917344808511));
480        assert_eq!((u8::MAX / 4).cbrt_rem(), (3, 36));
481        assert_eq!((u16::MAX / 4).cbrt_rem(), (25, 758));
482        assert_eq!((u32::MAX / 4).cbrt_rem(), (1023, 3142656));
483        assert_eq!((u64::MAX / 4).cbrt_rem(), (1664510, 5364995536903));
484        assert_eq!((u128::MAX / 4).cbrt_rem(), (4398046511103, 58028439341489006246363136));
485
486        macro_rules! random_case {
487            ($T:ty) => {
488                let n: $T = random();
489                let (root, rem) = n.cbrt_rem();
490                assert_eq!(root, n.cbrt());
491
492                let root = root as $T;
493                assert!(rem <= 3 * (root * root + root), "cbrt({}) remainder too large", n);
494                assert_eq!(n, root.pow(3) + rem, "cbrt({}) != {}, {}", n, root, rem);
495            };
496        }
497
498        const N: u32 = 10000;
499        for _ in 0..N {
500            random_case!(u16);
501            random_case!(u32);
502            random_case!(u64);
503            random_case!(u128);
504        }
505    }
506}