aberth/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(not(any(feature = "std", test, doctest)), no_std)]
3
4pub mod internal;
5#[cfg(any(test, doctest))]
6mod tests;
7
8pub use ::num_complex::Complex;
9
10use {
11  crate::internal::*,
12  ::arrayvec::ArrayVec,
13  ::core::{fmt::Debug, ops::Deref},
14  ::num_traits::{
15    float::{Float, FloatConst},
16    identities::Zero,
17    MulAdd,
18  },
19};
20
21/// Find all of the roots of a polynomial using Aberth's method
22///
23/// Polynomial of the form `f(x) = a + b*x + c*x^2 + d*x^3 + ...`
24///
25/// `polynomial` is a slice containing the coefficients `[a, b, c, d, ...]`
26///
27/// When two successive iterations produce roots with less than `epsilon`
28/// delta, the roots are returned.
29pub fn aberth<
30  const TERMS: usize,
31  F: Float + FloatConst + MulAdd<Output = F> + Debug,
32  C: ComplexCoefficient<F> + Into<Complex<F>>,
33>(
34  polynomial: &[C; TERMS],
35  max_iterations: u32,
36  epsilon: F,
37) -> Roots<ArrayVec<Complex<F>, TERMS>> {
38  let degree = TERMS - 1;
39
40  let polynomial: &[Complex<F>; TERMS] = &polynomial.map(|v| v.into());
41
42  let mut dydx: ArrayVec<_, TERMS> = ArrayVec::new_const();
43  // SAFETY: we immediately populate every entry in dydx.
44  unsafe {
45    dydx.set_len(degree);
46    derivative::<F>(polynomial, dydx.as_mut());
47  }
48
49  let mut guesses: ArrayVec<_, TERMS> = ArrayVec::new_const();
50  // SAFETY: we immediately populate every entry in guesses.
51  unsafe {
52    guesses.set_len(TERMS);
53    initial_guesses(polynomial, guesses.as_mut());
54    guesses.set_len(degree);
55  }
56
57  let mut output: ArrayVec<_, TERMS> = ArrayVec::new_const();
58  // SAFETY: we push 1 less elements than there are terms.
59  unsafe {
60    for _ in 0..degree {
61      output.push_unchecked(Complex::zero());
62    }
63  }
64
65  let stop_reason = aberth_raw(
66    polynomial,
67    dydx.as_ref(),
68    guesses.as_mut(),
69    output.as_mut(),
70    max_iterations,
71    epsilon,
72  );
73
74  Roots {
75    roots: output,
76    stop_reason,
77  }
78}
79
80/// The roots of a polynomial
81///
82/// Dereferences to an array-slice containing `roots`.
83///
84/// `stop_reason` contains information for how the solver terminated and how
85/// many iterations it took.
86#[derive(Clone, Debug, PartialEq)]
87pub struct Roots<Arr> {
88  pub roots: Arr,
89  pub stop_reason: StopReason,
90}
91
92impl<Arr> Deref for Roots<Arr> {
93  type Target = Arr;
94
95  fn deref(&self) -> &Arr {
96    &self.roots
97  }
98}
99
100/// The reason the solver terminated and the number of iterations it took.
101#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
102pub enum StopReason {
103  /// converged to within the required precision
104  Converged(/* iterations */ u32),
105  /// reached the iteration limit
106  MaxIteration(/* iterations */ u32),
107  /// detected a NaN or Inf while iterating
108  Failed(/* iterations */ u32),
109}
110
111#[cfg(feature = "std")]
112pub use feature_std::AberthSolver;
113
114#[cfg(feature = "std")]
115mod feature_std {
116  use {
117    crate::{internal::*, Roots},
118    ::core::fmt::Debug,
119    ::num_complex::Complex,
120    ::num_traits::{
121      cast,
122      float::{Float, FloatConst},
123      identities::Zero,
124      MulAdd,
125    },
126  };
127
128  impl<F: Clone> Roots<&[Complex<F>]> {
129    /// Create an owned duplicate of `Roots` by allocating a `Vec` to hold the
130    /// values
131    pub fn to_owned(&self) -> Roots<Vec<Complex<F>>> {
132      Roots {
133        roots: self.roots.to_vec(),
134        stop_reason: self.stop_reason,
135      }
136    }
137  }
138
139  /// A solver for polynomials with Float or ComplexFloat coefficients. Will
140  /// find all complex-roots, using the Aberth-Ehrlich method.
141  ///
142  /// The solver allocates some memory, and will reuse this allocation for
143  /// subsequent calls. This is good to use for polynomials of varying lengths,
144  /// polynomials with many terms, and for use in hot-loops where you want to
145  /// avoid repeated allocations.
146  ///
147  /// Note the returned solutions are not sorted in any particular order.
148  ///
149  /// usage example:
150  ///
151  /// ```rust
152  /// use aberth::AberthSolver;
153  ///
154  /// let mut solver = AberthSolver::new();
155  /// solver.epsilon = 0.001;
156  /// solver.max_iterations = 10;
157  ///
158  /// // 11x^4 + 4x^3 + 2x - 1 = 0
159  /// let polynomial_a = [-1., 2., 0., 4., 11.];
160  /// // x^4 -12x^3 + 39x^2 - 28 = 0
161  /// let polynomial_b = [-28., 0., 39., -12., 1.];
162  ///
163  /// for polynomial in [polynomial_a, polynomial_b] {
164  ///   let roots = solver.find_roots(&polynomial);
165  ///   // ...
166  /// }
167  /// ```
168  ///
169  /// If you want to hold onto the roots you previously found while reusing the
170  /// solver, then you can create an owned version:
171  /// ```rust
172  /// use aberth::AberthSolver;
173  ///
174  /// let mut solver = AberthSolver::new();
175  /// let roots_a = solver.find_roots(&[-1., 2., 0., 4., 11.]).to_owned();
176  /// let roots_b = solver.find_roots(&[-28., 39., -12., 1.]);
177  /// roots_a[0];
178  /// ```
179  /// or alternatively just copy the `.roots` field into a vec
180  /// ```rust
181  /// use aberth::{AberthSolver, Complex};
182  ///
183  /// let mut solver = AberthSolver::new();
184  /// let roots_a: Vec<Complex<f32>> =
185  ///   solver.find_roots(&[-1., 2., 0., 4., 11.]).to_vec();
186  /// let roots_b = solver.find_roots(&[-28., 39., -12., 1.]);
187  /// roots_a[0];
188  /// ```
189  #[derive(Debug, Clone)]
190  pub struct AberthSolver<F>
191  where
192    F: Float,
193  {
194    pub max_iterations: u32,
195    pub epsilon: F,
196    data: Vec<Complex<F>>,
197  }
198
199  impl<F: Float + FloatConst + MulAdd<Output = F> + Default + Debug> Default
200    for AberthSolver<F>
201  {
202    fn default() -> Self {
203      AberthSolver::new()
204    }
205  }
206
207  impl<F: Float + FloatConst + MulAdd<Output = F> + Default + Debug>
208    AberthSolver<F>
209  {
210    pub fn new() -> Self {
211      AberthSolver {
212        max_iterations: 100,
213        data: Vec::new(),
214        epsilon: cast(0.001).unwrap(),
215      }
216    }
217
218    /// Find all the complex roots of the polynomial
219    ///
220    /// Polynomial is given in the form `f(x) = a + b*x + c*x^2 + d*x^3 + ...`
221    ///
222    /// `polynomial` is a slice containing the coefficients `[a, b, c, d, ...]`
223    pub fn find_roots<C: ComplexCoefficient<F>>(
224      &mut self,
225      polynomial: &[C],
226    ) -> Roots<&[Complex<F>]> {
227      let len = polynomial.len();
228      let degree = len - 1;
229      // ensure we have enough space allocated
230      self
231        .data
232        .resize_with(len + degree + len + degree, Complex::zero);
233      // get mutable slices to our data
234      let (complex_poly, tail) = self.data.split_at_mut(len);
235      let (dydx, tail) = tail.split_at_mut(degree);
236      let (guesses, output) = tail.split_at_mut(len);
237
238      // convert the polynomial to a complex type
239      polynomial
240        .iter()
241        .enumerate()
242        .for_each(|(i, &coefficient)| complex_poly[i] = coefficient.into());
243
244      initial_guesses(complex_poly, guesses);
245      let guesses = &mut guesses[0..degree];
246      derivative(complex_poly, dydx);
247
248      let stop_reason = aberth_raw(
249        complex_poly,
250        dydx,
251        guesses,
252        output,
253        self.max_iterations,
254        self.epsilon,
255      );
256
257      Roots {
258        roots: output,
259        stop_reason,
260      }
261    }
262  }
263}