1use crate::{
2 util::{
3 complex::{c_from_f128, c_neg, c_sqrt, c_to_f128},
4 vec::slice_mean,
5 },
6 Poly, RealScalar,
7};
8use anyhow::anyhow;
9use f128::f128;
10use itertools::Itertools;
11use num::{Complex, FromPrimitive, One, Zero};
12
13mod single_root;
14pub use single_root::{halley, naive, newton};
15mod all_roots;
16pub use all_roots::{aberth_ehrlich, deflate, halley_deflate, naive_deflate, newton_deflate};
17mod many_roots;
18pub use many_roots::{halley_parallel, naive_parallel, newton_parallel, parallel};
19mod initial_guess;
20pub use initial_guess::{initial_guess_smallest, initial_guesses_circle};
21
22#[derive(thiserror::Error, Debug)]
23#[non_exhaustive]
24pub enum Error<T> {
25 #[error("root finder did not converge within the given constraints")]
26 NoConverge(T),
27
28 #[error("unexpected error while running root finder")]
29 Other(#[from] anyhow::Error),
30}
31
32pub type Result<T> = std::result::Result<Vec<Complex<T>>, Error<Vec<Complex<T>>>>;
35
36pub type Roots<T> = Vec<Complex<T>>;
38
39pub enum PolishingMode<T> {
40 None,
41 StandardPrecision {
42 epsilon: T,
43 min_iter: usize,
44 max_iter: usize,
45 },
46 #[cfg(target_arch = "x86_64")]
47 HighPrecision {
48 epsilon: T,
49 min_iter: usize,
50 max_iter: usize,
51 },
52}
53
54pub enum MultiplesHandlingMode<T> {
55 None,
56 BroadcastBest { detection_epsilon: T },
57 BroadcastAverage { detection_epsilon: T },
58 KeepBest { detection_epsilon: T },
59 KeepAverage { detection_epsilon: T },
60}
61
62pub enum InitialGuessMode<T> {
63 GuessPoolOnly,
64 RandomAnnulus { bias: T, perturbation: T, seed: u64 },
65 }
68
69impl<T: RealScalar> Poly<T> {
70 pub fn roots(&self, epsilon: T, max_iter: usize) -> Result<T> {
80 self.roots_expert(
81 epsilon.clone(),
82 max_iter,
83 0,
84 PolishingMode::StandardPrecision {
85 epsilon: epsilon.clone(),
86 min_iter: 0,
87 max_iter,
88 },
89 MultiplesHandlingMode::BroadcastBest {
90 detection_epsilon: epsilon * T::from_f64(1.5).expect("overflow"),
92 },
93 &[],
94 InitialGuessMode::RandomAnnulus {
95 bias: T::from_f64(0.5).expect("overflow"),
96 perturbation: T::from_f64(0.5).expect("overflow"),
97 seed: 1,
98 },
99 )
100 }
101
102 pub fn roots_expert(
114 &self,
115 epsilon: T,
116 max_iter: usize,
117 _min_iter: usize,
118 polishing_mode: PolishingMode<T>,
119 multiples_handling_mode: MultiplesHandlingMode<T>,
120 initial_guess_pool: &[Complex<T>],
121 initial_guess_mode: InitialGuessMode<T>,
122 ) -> Result<T> {
123 debug_assert!(self.is_normalized());
124
125 let mut this = self.clone();
126
127 let mut roots: Vec<Complex<T>> = this.zero_roots(epsilon.clone());
128
129 match this.degree_raw() {
130 1 => {
131 roots.extend(this.linear_roots());
132 return Ok(roots);
133 }
134 2 => {
135 roots.extend(this.quadratic_roots());
136 return Ok(roots);
137 }
138 _ => {}
139 }
140
141 this.make_monic();
142
143 debug_assert!(this.is_normalized());
144 let mut initial_guesses = Vec::with_capacity(this.degree_raw());
145 for guess in initial_guess_pool.iter().cloned() {
146 initial_guesses.push(guess);
147 }
148
149 let delta = this.degree_raw() - initial_guesses.len();
152 for _ in 0..delta {
153 initial_guesses.push(Complex::<T>::zero());
154 }
155 let remaining_guesses_view =
156 &mut initial_guesses[initial_guess_pool.len()..this.degree_raw()];
157
158 match initial_guess_mode {
159 InitialGuessMode::GuessPoolOnly => {
160 if initial_guess_pool.len() < this.degree_raw() {
161 return Err(Error::Other(anyhow!("not enough initial guesses, you must provide one guess per root when using GuessPoolOnly")));
162 }
163 }
164 InitialGuessMode::RandomAnnulus {
165 bias,
166 perturbation,
167 seed,
168 } => {
169 initial_guesses_circle(&this, bias, seed, perturbation, remaining_guesses_view);
170 } }
173
174 log::trace!("{initial_guesses:?}");
175
176 roots.extend(aberth_ehrlich(
177 &mut this,
178 Some(epsilon.clone()),
179 Some(max_iter),
180 &initial_guesses,
181 )?);
182
183 let roots: Roots<T> = match polishing_mode {
185 PolishingMode::None => Ok(roots),
186 PolishingMode::StandardPrecision {
187 epsilon,
188 min_iter,
189 max_iter,
190 } => newton_parallel(&mut this, Some(epsilon), Some(max_iter), &roots),
191
192 #[cfg(target_arch = "x86_64")]
193 PolishingMode::HighPrecision {
194 epsilon,
195 min_iter,
196 max_iter,
197 } => {
198 let mut this = this.clone().cast_to_f128();
199 let roots = roots.iter().cloned().map(|z| c_to_f128(z)).collect_vec();
200 newton_parallel(
201 &mut this,
202 Some(f128::from(epsilon.to_f64().expect("overflow"))),
203 Some(max_iter),
204 &roots,
205 )
206 .map(|v| v.into_iter().map(|z| c_from_f128::<T>(z)).collect_vec())
207 .map_err(|e| match e {
208 Error::NoConverge(v) => {
209 Error::NoConverge(v.into_iter().map(|z| c_from_f128::<T>(z)).collect_vec())
210 }
211 Error::Other(o) => Error::Other(o),
212 })
213 }
214 }?;
215
216 match multiples_handling_mode {
217 MultiplesHandlingMode::None => Ok(roots),
218 MultiplesHandlingMode::BroadcastBest { detection_epsilon } => Ok(best_multiples(
219 &this,
220 group_multiples(roots, detection_epsilon),
221 true,
222 )),
223 MultiplesHandlingMode::BroadcastAverage { detection_epsilon } => Ok(average_multiples(
224 &this,
225 group_multiples(roots, detection_epsilon),
226 true,
227 )),
228 MultiplesHandlingMode::KeepBest { detection_epsilon } => Ok(best_multiples(
229 &this,
230 group_multiples(roots, detection_epsilon),
231 false,
232 )),
233 MultiplesHandlingMode::KeepAverage { detection_epsilon } => Ok(average_multiples(
234 &this,
235 group_multiples(roots, detection_epsilon),
236 false,
237 )),
238 }
239 }
240}
241
242impl<T: RealScalar> Poly<T> {
244 fn zero_roots(&mut self, epsilon: T) -> Vec<Complex<T>> {
245 debug_assert!(self.is_normalized());
246
247 let mut roots = vec![];
248 for _ in 0..self.degree_raw() {
249 if self.eval(Complex::zero()).norm_sqr() < epsilon {
250 roots.push(Complex::zero());
251 *self = self.shift_down(1);
253 } else {
254 break;
255 }
256 }
257
258 roots
259 }
260
261 fn linear_roots(&mut self) -> Vec<Complex<T>> {
262 debug_assert!(self.is_normalized());
263 debug_assert_eq!(self.degree_raw(), 1);
264
265 self.trim();
266 if self.degree_raw() < 1 {
267 return vec![];
268 }
269
270 let a = self.0[1].clone();
271 let b = self.0[0].clone();
272
273 *self = Self::one();
275
276 vec![-b / a]
277 }
278
279 fn quadratic_roots(&mut self) -> Vec<Complex<T>> {
281 debug_assert!(self.is_normalized());
282 debug_assert_eq!(self.degree_raw(), 2);
283
284 self.trim();
286 if self.degree_raw() == 1 {
287 return self.linear_roots();
288 }
289 if self.degree_raw() == 0 {
290 return vec![];
291 }
292
293 let a = self.0[2].clone();
294 let b = self.0[1].clone();
295 let c = self.0[0].clone();
296 let four = Complex::<T>::from_u8(4).expect("overflow");
297 let two = Complex::<T>::from_u8(2).expect("overflow");
298
299 let plus_minus_term = c_sqrt(b.clone() * b.clone() - four * a.clone() * c);
302 let x1 = (plus_minus_term.clone() - b.clone()) / (two.clone() * a.clone());
303 let x2 = (c_neg(b.clone()) - plus_minus_term) / (two * a);
304
305 *self = Self::one();
307
308 vec![x1, x2]
309 }
310}
311
312fn group_multiples<T: RealScalar>(roots: Roots<T>, epsilon: T) -> Vec<Roots<T>> {
314 let mut groups: Vec<(Roots<T>, Complex<T>)> = vec![];
316
317 let mut roots = roots;
318
319 while !roots.is_empty() {
320 'roots_loop: for root in roots.drain(..) {
324 for group in &mut groups {
325 if (group.1.clone() - root.clone()).norm_sqr() <= epsilon {
326 group.0.push(root.clone());
327 group.1 = slice_mean(&group.0);
328 continue 'roots_loop;
329 }
330 }
331 groups.push((vec![root.clone()], root));
332 }
333
334 for group in &mut groups {
337 group.0.retain(|r| {
340 if (r.clone() - group.1.clone()).norm_sqr() <= epsilon {
341 true
342 } else {
343 roots.push(r.clone());
344 false
345 }
346 });
347 }
348
349 groups.retain(|g| !g.0.is_empty());
351 }
352
353 groups.into_iter().map(|(r, _)| r).collect_vec()
354}
355
356fn best_multiples<T: RealScalar>(
357 poly: &Poly<T>,
358 groups: Vec<Roots<T>>,
359 do_broadcast: bool,
360) -> Roots<T> {
361 groups
363 .into_iter()
364 .flat_map(|group| {
365 let len = group.len();
366 let best = group
367 .into_iter()
368 .map(|root| (root.clone(), poly.eval(root).norm_sqr()))
369 .reduce(|(a_root, a_eval), (b_root, b_eval)| {
370 if a_eval < b_eval {
371 (a_root, a_eval)
372 } else {
373 (b_root, b_eval)
374 }
375 })
376 .expect("empty groups not allowed")
377 .0;
378 if do_broadcast {
379 vec![best; len]
380 } else {
381 vec![best]
382 }
383 })
384 .collect_vec()
385}
386
387fn average_multiples<T: RealScalar>(
388 poly: &Poly<T>,
389 groups: Vec<Roots<T>>,
390 do_broadcast: bool,
391) -> Roots<T> {
392 groups
393 .into_iter()
394 .flat_map(|group| {
395 let len_usize = group.len();
396 debug_assert!(len_usize > 0);
397 let len = T::from_usize(len_usize).expect("infallible");
398 let sum: Complex<T> = group.into_iter().sum();
399 let avg = sum / len;
400 if do_broadcast {
401 vec![avg; len_usize]
402 } else {
403 vec![avg]
404 }
405 })
406 .collect_vec()
407}
408
409#[cfg(test)]
410mod test {
411 use num::complex::ComplexFloat;
412
413 use crate::Poly64;
414
415 #[test]
417 fn roots_of_reverse_bessel() {
418 let poly = Poly64::reverse_bessel(2).unwrap();
419 let roots = poly.roots(1E-10, 1000).unwrap();
420 assert!((roots[0].re() - -1.5).abs() < 0.01);
421 assert!((roots[0].im().abs() - 0.866).abs() < 0.01);
422 assert!((roots[1].re() - -1.5).abs() < 0.01);
423 assert!((roots[1].im().abs() - 0.866).abs() < 0.01);
424 }
425}