1use dashu_base::{CubicRoot, CubicRootRem, Sign, SquareRoot, SquareRootRem};
2
3use crate::{
4 error::{panic_root_negative, panic_root_zeroth},
5 ibig::IBig,
6 ubig::UBig,
7};
8
9impl UBig {
10 #[inline]
25 pub fn nth_root(&self, n: usize) -> UBig {
26 UBig(self.repr().nth_root(n))
27 }
28}
29
30impl SquareRoot for UBig {
31 type Output = UBig;
32 #[inline]
33 fn sqrt(&self) -> Self::Output {
34 UBig(self.repr().sqrt())
35 }
36}
37
38impl SquareRootRem for UBig {
39 type Output = UBig;
40 #[inline]
41 fn sqrt_rem(&self) -> (Self, Self) {
42 let (s, r) = self.repr().sqrt_rem();
43 (UBig(s), UBig(r))
44 }
45}
46
47impl CubicRoot for UBig {
48 type Output = UBig;
49 #[inline]
50 fn cbrt(&self) -> Self::Output {
51 self.nth_root(3)
52 }
53}
54
55impl CubicRootRem for UBig {
56 type Output = UBig;
57 #[inline]
58 fn cbrt_rem(&self) -> (Self::Output, Self) {
59 let c = self.nth_root(3);
60 let r = self - c.pow(3);
61 (c, r)
62 }
63}
64
65impl IBig {
66 #[inline]
81 pub fn nth_root(&self, n: usize) -> IBig {
82 if n == 0 {
83 panic_root_zeroth()
84 }
85
86 let (sign, mag) = self.as_sign_repr();
87 if sign == Sign::Negative && n % 2 == 0 {
88 panic_root_negative()
89 }
90
91 IBig(mag.nth_root(n).with_sign(sign))
92 }
93}
94
95impl SquareRoot for IBig {
96 type Output = UBig;
97 #[inline]
98 fn sqrt(&self) -> UBig {
99 let (sign, mag) = self.as_sign_repr();
100 if sign == Sign::Negative {
101 panic_root_negative()
102 }
103 UBig(mag.sqrt())
104 }
105}
106
107impl CubicRoot for IBig {
108 type Output = IBig;
109 #[inline]
110 fn cbrt(&self) -> IBig {
111 let (sign, mag) = self.as_sign_repr();
112 if sign == Sign::Negative {
113 panic_root_negative()
114 }
115 IBig(mag.nth_root(3).with_sign(sign))
116 }
117}
118
119mod repr {
120 use super::*;
121 use crate::{
122 add,
123 arch::word::Word,
124 buffer::Buffer,
125 memory::MemoryAllocation,
126 mul,
127 primitive::{extend_word, shrink_dword, WORD_BITS, WORD_BITS_USIZE},
128 repr::{
129 Repr,
130 TypedReprRef::{self, *},
131 },
132 root, shift, shift_ops,
133 };
134 use dashu_base::{SquareRoot, SquareRootRem};
135
136 impl<'a> TypedReprRef<'a> {
137 #[inline]
138 pub fn sqrt(self) -> Repr {
139 match self {
140 RefSmall(dw) => {
141 if let Some(w) = shrink_dword(dw) {
142 Repr::from_word(w.sqrt() as Word)
143 } else {
144 Repr::from_word(dw.sqrt())
145 }
146 }
147 RefLarge(words) => sqrt_rem_large(words, true).0,
148 }
149 }
150
151 #[inline]
152 pub fn sqrt_rem(self) -> (Repr, Repr) {
153 match self {
154 RefSmall(dw) => {
155 if let Some(w) = shrink_dword(dw) {
156 let (s, r) = w.sqrt_rem();
157 (Repr::from_word(s as Word), Repr::from_word(r))
158 } else {
159 let (s, r) = dw.sqrt_rem();
160 (Repr::from_word(s), Repr::from_dword(r))
161 }
162 }
163 RefLarge(words) => sqrt_rem_large(words, false),
164 }
165 }
166 }
167
168 fn sqrt_rem_large(words: &[Word], root_only: bool) -> (Repr, Repr) {
169 let shift = WORD_BITS_USIZE * (words.len() & 1)
172 + (words.last().unwrap().leading_zeros() & !1) as usize;
173 let n = (words.len() + 1) / 2;
174 let mut buffer = shift_ops::repr::shl_large_ref(words, shift).into_buffer();
175 let mut out = Buffer::allocate(n);
176 out.push_zeros(n);
177
178 let mut allocation = MemoryAllocation::new(root::memory_requirement_sqrt_rem(n));
179 let r_top = root::sqrt_rem(&mut out, &mut buffer, &mut allocation.memory());
180
181 if shift != 0 {
184 if !root_only {
187 let s0 = out[0] & ((1 << (shift / 2)) - 1);
188 let c1 = mul::add_mul_word_in_place(&mut buffer[..n], 2 * s0, &out);
189 let c2 =
190 add::sub_dword_in_place(&mut buffer[..n], extend_word(s0) * extend_word(s0));
191 buffer[n] = r_top as Word + c1 - c2 as Word;
192 }
193
194 let _ = shift::shr_in_place(&mut out, shift as u32 / 2);
196 if !root_only {
197 if shift > WORD_BITS_USIZE {
198 shift::shr_in_place_one_word(&mut buffer);
199 buffer.truncate(n);
200 } else {
201 buffer.truncate(n + 1);
202 }
203 let _ = shift::shr_in_place(&mut buffer, shift as u32 % WORD_BITS);
204 }
205 } else if !root_only {
206 buffer[n] = r_top as Word;
207 buffer.truncate(n + 1);
208 }
209
210 (Repr::from_buffer(out), Repr::from_buffer(buffer))
211 }
212
213 impl<'a> TypedReprRef<'a> {
214 pub fn nth_root(self, n: usize) -> Repr {
215 match n {
216 0 => panic_root_zeroth(),
217 1 => return Repr::from_ref(self),
218 2 => return self.sqrt(),
219 _ => {}
220 }
221
222 let bits = self.bit_len();
224 if bits <= n {
225 return Repr::one();
227 }
228
229 let nm1 = n - 1;
231 let mut guess = UBig::ONE << (self.bit_len() / n); let next = |x: &UBig| {
233 let y = UBig(self / x.pow(nm1).into_repr());
234 (y + x * nm1) / n
235 };
236
237 let mut fixpoint = next(&guess);
238 while fixpoint > guess {
240 guess = fixpoint;
241 fixpoint = next(&guess);
242 }
243 while fixpoint < guess {
244 guess = fixpoint;
245 fixpoint = next(&guess);
246 }
247 guess.0
248 }
249 }
250}