aberth/lib.rs
1#![doc = include_str!("../README.md")]
2#![cfg_attr(not(any(feature = "std", test, doctest)), no_std)]
3
4pub mod internal;
5#[cfg(any(test, doctest))]
6mod tests;
7
8pub use ::num_complex::Complex;
9
10use {
11 crate::internal::*,
12 ::arrayvec::ArrayVec,
13 ::core::{fmt::Debug, ops::Deref},
14 ::num_traits::{
15 float::{Float, FloatConst},
16 identities::Zero,
17 MulAdd,
18 },
19};
20
21/// Find all of the roots of a polynomial using Aberth's method
22///
23/// Polynomial of the form `f(x) = a + b*x + c*x^2 + d*x^3 + ...`
24///
25/// `polynomial` is a slice containing the coefficients `[a, b, c, d, ...]`
26///
27/// When two successive iterations produce roots with less than `epsilon`
28/// delta, the roots are returned.
29pub fn aberth<
30 const TERMS: usize,
31 F: Float + FloatConst + MulAdd<Output = F> + Debug,
32 C: ComplexCoefficient<F> + Into<Complex<F>>,
33>(
34 polynomial: &[C; TERMS],
35 max_iterations: u32,
36 epsilon: F,
37) -> Roots<ArrayVec<Complex<F>, TERMS>> {
38 let degree = TERMS - 1;
39
40 let polynomial: &[Complex<F>; TERMS] = &polynomial.map(|v| v.into());
41
42 let mut dydx: ArrayVec<_, TERMS> = ArrayVec::new_const();
43 // SAFETY: we immediately populate every entry in dydx.
44 unsafe {
45 dydx.set_len(degree);
46 derivative::<F>(polynomial, dydx.as_mut());
47 }
48
49 let mut guesses: ArrayVec<_, TERMS> = ArrayVec::new_const();
50 // SAFETY: we immediately populate every entry in guesses.
51 unsafe {
52 guesses.set_len(TERMS);
53 initial_guesses(polynomial, guesses.as_mut());
54 guesses.set_len(degree);
55 }
56
57 let mut output: ArrayVec<_, TERMS> = ArrayVec::new_const();
58 // SAFETY: we push 1 less elements than there are terms.
59 unsafe {
60 for _ in 0..degree {
61 output.push_unchecked(Complex::zero());
62 }
63 }
64
65 let stop_reason = aberth_raw(
66 polynomial,
67 dydx.as_ref(),
68 guesses.as_mut(),
69 output.as_mut(),
70 max_iterations,
71 epsilon,
72 );
73
74 Roots {
75 roots: output,
76 stop_reason,
77 }
78}
79
80/// The roots of a polynomial
81///
82/// Dereferences to an array-slice containing `roots`.
83///
84/// `stop_reason` contains information for how the solver terminated and how
85/// many iterations it took.
86#[derive(Clone, Debug, PartialEq)]
87pub struct Roots<Arr> {
88 pub roots: Arr,
89 pub stop_reason: StopReason,
90}
91
92impl<Arr> Deref for Roots<Arr> {
93 type Target = Arr;
94
95 fn deref(&self) -> &Arr {
96 &self.roots
97 }
98}
99
100/// The reason the solver terminated and the number of iterations it took.
101#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
102pub enum StopReason {
103 /// converged to within the required precision
104 Converged(/* iterations */ u32),
105 /// reached the iteration limit
106 MaxIteration(/* iterations */ u32),
107 /// detected a NaN or Inf while iterating
108 Failed(/* iterations */ u32),
109}
110
111#[cfg(feature = "std")]
112pub use feature_std::AberthSolver;
113
114#[cfg(feature = "std")]
115mod feature_std {
116 use {
117 crate::{internal::*, Roots},
118 ::core::fmt::Debug,
119 ::num_complex::Complex,
120 ::num_traits::{
121 cast,
122 float::{Float, FloatConst},
123 identities::Zero,
124 MulAdd,
125 },
126 };
127
128 impl<F: Clone> Roots<&[Complex<F>]> {
129 /// Create an owned duplicate of `Roots` by allocating a `Vec` to hold the
130 /// values
131 pub fn to_owned(&self) -> Roots<Vec<Complex<F>>> {
132 Roots {
133 roots: self.roots.to_vec(),
134 stop_reason: self.stop_reason,
135 }
136 }
137 }
138
139 /// A solver for polynomials with Float or ComplexFloat coefficients. Will
140 /// find all complex-roots, using the Aberth-Ehrlich method.
141 ///
142 /// The solver allocates some memory, and will reuse this allocation for
143 /// subsequent calls. This is good to use for polynomials of varying lengths,
144 /// polynomials with many terms, and for use in hot-loops where you want to
145 /// avoid repeated allocations.
146 ///
147 /// Note the returned solutions are not sorted in any particular order.
148 ///
149 /// usage example:
150 ///
151 /// ```rust
152 /// use aberth::AberthSolver;
153 ///
154 /// let mut solver = AberthSolver::new();
155 /// solver.epsilon = 0.001;
156 /// solver.max_iterations = 10;
157 ///
158 /// // 11x^4 + 4x^3 + 2x - 1 = 0
159 /// let polynomial_a = [-1., 2., 0., 4., 11.];
160 /// // x^4 -12x^3 + 39x^2 - 28 = 0
161 /// let polynomial_b = [-28., 0., 39., -12., 1.];
162 ///
163 /// for polynomial in [polynomial_a, polynomial_b] {
164 /// let roots = solver.find_roots(&polynomial);
165 /// // ...
166 /// }
167 /// ```
168 ///
169 /// If you want to hold onto the roots you previously found while reusing the
170 /// solver, then you can create an owned version:
171 /// ```rust
172 /// use aberth::AberthSolver;
173 ///
174 /// let mut solver = AberthSolver::new();
175 /// let roots_a = solver.find_roots(&[-1., 2., 0., 4., 11.]).to_owned();
176 /// let roots_b = solver.find_roots(&[-28., 39., -12., 1.]);
177 /// roots_a[0];
178 /// ```
179 /// or alternatively just copy the `.roots` field into a vec
180 /// ```rust
181 /// use aberth::{AberthSolver, Complex};
182 ///
183 /// let mut solver = AberthSolver::new();
184 /// let roots_a: Vec<Complex<f32>> =
185 /// solver.find_roots(&[-1., 2., 0., 4., 11.]).to_vec();
186 /// let roots_b = solver.find_roots(&[-28., 39., -12., 1.]);
187 /// roots_a[0];
188 /// ```
189 #[derive(Debug, Clone)]
190 pub struct AberthSolver<F>
191 where
192 F: Float,
193 {
194 pub max_iterations: u32,
195 pub epsilon: F,
196 data: Vec<Complex<F>>,
197 }
198
199 impl<F: Float + FloatConst + MulAdd<Output = F> + Default + Debug> Default
200 for AberthSolver<F>
201 {
202 fn default() -> Self {
203 AberthSolver::new()
204 }
205 }
206
207 impl<F: Float + FloatConst + MulAdd<Output = F> + Default + Debug>
208 AberthSolver<F>
209 {
210 pub fn new() -> Self {
211 AberthSolver {
212 max_iterations: 100,
213 data: Vec::new(),
214 epsilon: cast(0.001).unwrap(),
215 }
216 }
217
218 /// Find all the complex roots of the polynomial
219 ///
220 /// Polynomial is given in the form `f(x) = a + b*x + c*x^2 + d*x^3 + ...`
221 ///
222 /// `polynomial` is a slice containing the coefficients `[a, b, c, d, ...]`
223 pub fn find_roots<C: ComplexCoefficient<F>>(
224 &mut self,
225 polynomial: &[C],
226 ) -> Roots<&[Complex<F>]> {
227 let len = polynomial.len();
228 let degree = len - 1;
229 // ensure we have enough space allocated
230 self
231 .data
232 .resize_with(len + degree + len + degree, Complex::zero);
233 // get mutable slices to our data
234 let (complex_poly, tail) = self.data.split_at_mut(len);
235 let (dydx, tail) = tail.split_at_mut(degree);
236 let (guesses, output) = tail.split_at_mut(len);
237
238 // convert the polynomial to a complex type
239 polynomial
240 .iter()
241 .enumerate()
242 .for_each(|(i, &coefficient)| complex_poly[i] = coefficient.into());
243
244 initial_guesses(complex_poly, guesses);
245 let guesses = &mut guesses[0..degree];
246 derivative(complex_poly, dydx);
247
248 let stop_reason = aberth_raw(
249 complex_poly,
250 dydx,
251 guesses,
252 output,
253 self.max_iterations,
254 self.epsilon,
255 );
256
257 Roots {
258 roots: output,
259 stop_reason,
260 }
261 }
262 }
263}