dashu_int/
root_ops.rs

1use dashu_base::{CubicRoot, CubicRootRem, Sign, SquareRoot, SquareRootRem};
2
3use crate::{
4    error::{panic_root_negative, panic_root_zeroth},
5    ibig::IBig,
6    ubig::UBig,
7};
8
9impl UBig {
10    /// Calculate the nth-root of the integer rounding towards zero
11    ///
12    /// # Examples
13    ///
14    /// ```
15    /// # use dashu_int::UBig;
16    /// assert_eq!(UBig::from(4u8).nth_root(2), UBig::from(2u8));
17    /// assert_eq!(UBig::from(4u8).nth_root(3), UBig::from(1u8));
18    /// assert_eq!(UBig::from(1024u16).nth_root(5), UBig::from(4u8));
19    /// ```
20    ///
21    /// # Panics
22    ///
23    /// If `n` is zero
24    #[inline]
25    pub fn nth_root(&self, n: usize) -> UBig {
26        UBig(self.repr().nth_root(n))
27    }
28}
29
30impl SquareRoot for UBig {
31    type Output = UBig;
32    #[inline]
33    fn sqrt(&self) -> Self::Output {
34        UBig(self.repr().sqrt())
35    }
36}
37
38impl SquareRootRem for UBig {
39    type Output = UBig;
40    #[inline]
41    fn sqrt_rem(&self) -> (Self, Self) {
42        let (s, r) = self.repr().sqrt_rem();
43        (UBig(s), UBig(r))
44    }
45}
46
47impl CubicRoot for UBig {
48    type Output = UBig;
49    #[inline]
50    fn cbrt(&self) -> Self::Output {
51        self.nth_root(3)
52    }
53}
54
55impl CubicRootRem for UBig {
56    type Output = UBig;
57    #[inline]
58    fn cbrt_rem(&self) -> (Self::Output, Self) {
59        let c = self.nth_root(3);
60        let r = self - c.pow(3);
61        (c, r)
62    }
63}
64
65impl IBig {
66    /// Calculate the nth-root of the integer rounding towards zero
67    ///
68    /// # Examples
69    ///
70    /// ```
71    /// # use dashu_int::IBig;
72    /// assert_eq!(IBig::from(4).nth_root(2), IBig::from(2));
73    /// assert_eq!(IBig::from(-4).nth_root(3), IBig::from(-1));
74    /// assert_eq!(IBig::from(-1024).nth_root(5), IBig::from(-4));
75    /// ```
76    ///
77    /// # Panics
78    ///
79    /// If `n` is zero, or if `n` is even when the integer is negative.
80    #[inline]
81    pub fn nth_root(&self, n: usize) -> IBig {
82        if n == 0 {
83            panic_root_zeroth()
84        }
85
86        let (sign, mag) = self.as_sign_repr();
87        if sign == Sign::Negative && n % 2 == 0 {
88            panic_root_negative()
89        }
90
91        IBig(mag.nth_root(n).with_sign(sign))
92    }
93}
94
95impl SquareRoot for IBig {
96    type Output = UBig;
97    #[inline]
98    fn sqrt(&self) -> UBig {
99        let (sign, mag) = self.as_sign_repr();
100        if sign == Sign::Negative {
101            panic_root_negative()
102        }
103        UBig(mag.sqrt())
104    }
105}
106
107impl CubicRoot for IBig {
108    type Output = IBig;
109    #[inline]
110    fn cbrt(&self) -> IBig {
111        let (sign, mag) = self.as_sign_repr();
112        if sign == Sign::Negative {
113            panic_root_negative()
114        }
115        IBig(mag.nth_root(3).with_sign(sign))
116    }
117}
118
119mod repr {
120    use super::*;
121    use crate::{
122        add,
123        arch::word::Word,
124        buffer::Buffer,
125        memory::MemoryAllocation,
126        mul,
127        primitive::{extend_word, shrink_dword, WORD_BITS, WORD_BITS_USIZE},
128        repr::{
129            Repr,
130            TypedReprRef::{self, *},
131        },
132        root, shift, shift_ops,
133    };
134    use dashu_base::{SquareRoot, SquareRootRem};
135
136    impl<'a> TypedReprRef<'a> {
137        #[inline]
138        pub fn sqrt(self) -> Repr {
139            match self {
140                RefSmall(dw) => {
141                    if let Some(w) = shrink_dword(dw) {
142                        Repr::from_word(w.sqrt() as Word)
143                    } else {
144                        Repr::from_word(dw.sqrt())
145                    }
146                }
147                RefLarge(words) => sqrt_rem_large(words, true).0,
148            }
149        }
150
151        #[inline]
152        pub fn sqrt_rem(self) -> (Repr, Repr) {
153            match self {
154                RefSmall(dw) => {
155                    if let Some(w) = shrink_dword(dw) {
156                        let (s, r) = w.sqrt_rem();
157                        (Repr::from_word(s as Word), Repr::from_word(r))
158                    } else {
159                        let (s, r) = dw.sqrt_rem();
160                        (Repr::from_word(s), Repr::from_dword(r))
161                    }
162                }
163                RefLarge(words) => sqrt_rem_large(words, false),
164            }
165        }
166    }
167
168    fn sqrt_rem_large(words: &[Word], root_only: bool) -> (Repr, Repr) {
169        // first shift the words so that there are even words and
170        // the top word is normalized. Note: shift <= 2 * WORD_BITS - 2
171        let shift = WORD_BITS_USIZE * (words.len() & 1)
172            + (words.last().unwrap().leading_zeros() & !1) as usize;
173        let n = (words.len() + 1) / 2;
174        let mut buffer = shift_ops::repr::shl_large_ref(words, shift).into_buffer();
175        let mut out = Buffer::allocate(n);
176        out.push_zeros(n);
177
178        let mut allocation = MemoryAllocation::new(root::memory_requirement_sqrt_rem(n));
179        let r_top = root::sqrt_rem(&mut out, &mut buffer, &mut allocation.memory());
180
181        // afterwards, s = out[..], r = buffer[..n] + r_top << n*WORD_BITS
182        // then recover the result if shift != 0
183        if shift != 0 {
184            // to get the final result, let s0 = s mod 2^(shift/2), then
185            // 2^shift*n = (s-s0)^2 + 2s*s0 - s0^2 + r, so final r = (r + 2s*s0 - s0^2) / 2^shift
186            if !root_only {
187                let s0 = out[0] & ((1 << (shift / 2)) - 1);
188                let c1 = mul::add_mul_word_in_place(&mut buffer[..n], 2 * s0, &out);
189                let c2 =
190                    add::sub_dword_in_place(&mut buffer[..n], extend_word(s0) * extend_word(s0));
191                buffer[n] = r_top as Word + c1 - c2 as Word;
192            }
193
194            // s >>= shift/2, r >>= shift
195            let _ = shift::shr_in_place(&mut out, shift as u32 / 2);
196            if !root_only {
197                if shift > WORD_BITS_USIZE {
198                    shift::shr_in_place_one_word(&mut buffer);
199                    buffer.truncate(n);
200                } else {
201                    buffer.truncate(n + 1);
202                }
203                let _ = shift::shr_in_place(&mut buffer, shift as u32 % WORD_BITS);
204            }
205        } else if !root_only {
206            buffer[n] = r_top as Word;
207            buffer.truncate(n + 1);
208        }
209
210        (Repr::from_buffer(out), Repr::from_buffer(buffer))
211    }
212
213    impl<'a> TypedReprRef<'a> {
214        pub fn nth_root(self, n: usize) -> Repr {
215            match n {
216                0 => panic_root_zeroth(),
217                1 => return Repr::from_ref(self),
218                2 => return self.sqrt(),
219                _ => {}
220            }
221
222            // shortcut
223            let bits = self.bit_len();
224            if bits <= n {
225                // the result must be 1
226                return Repr::one();
227            }
228
229            // then use newton's method
230            let nm1 = n - 1;
231            let mut guess = UBig::ONE << (self.bit_len() / n); // underestimate
232            let next = |x: &UBig| {
233                let y = UBig(self / x.pow(nm1).into_repr());
234                (y + x * nm1) / n
235            };
236
237            let mut fixpoint = next(&guess);
238            // first go up then go down, to ensure an underestimate
239            while fixpoint > guess {
240                guess = fixpoint;
241                fixpoint = next(&guess);
242            }
243            while fixpoint < guess {
244                guess = fixpoint;
245                fixpoint = next(&guess);
246            }
247            guess.0
248        }
249    }
250}