polycool/
poly.rs

1// Copyright 2025 the Kurbo Authors
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4use arrayvec::ArrayVec;
5
6/// A polynomial whose degree is known at compile-time.
7///
8/// Although this supports polynomials of arbitrary degree, it is intended
9/// for low-degree polynomials. For example, the coefficients are stored
10/// in an array, and so they will be stack-allocated (unless you `Box`
11/// the `Poly`, of course) tend to be copied around.
12///
13/// Polynomial multiplication is not yet implemented, because doing it "nicely"
14/// would require const generic expressions: ideally we'd do something like
15///
16/// ```ignore
17/// impl<N, M> Mul<Poly<M>> for Poly<N> {
18///     type Output = Poly<{M + N - 1}>;
19/// }
20/// ```
21///
22/// It's possible to work around this with macros, but there are lots of
23/// possibilities and I didn't feel like it was worth the trouble (and the hit
24/// to compilation time).
25#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
26pub struct Poly<const N: usize> {
27    pub(crate) coeffs: [f64; N],
28}
29
30/// A polynomial of degree 2.
31pub type Quadratic = Poly<3>;
32
33/// A polynomial of degree 3.
34pub type Cubic = Poly<4>;
35
36/// A polynomial of degree 4.
37pub type Quartic = Poly<5>;
38
39/// A polynomial of degree 5.
40pub type Quintic = Poly<6>;
41
42impl<const N: usize> Poly<N> {
43    /// Creates a new polynomial with the provided coefficients.
44    ///
45    /// The constant coefficient comes first, then the linear coefficient, and
46    /// so on. So if you pass `[c, b, a]` you'll get the polynomial
47    /// `a x^2 + b x + c`.
48    pub const fn new(coeffs: [f64; N]) -> Poly<N> {
49        Poly { coeffs }
50    }
51
52    /// The coefficients of this polynomial.
53    ///
54    /// In the returned array, the coefficient of `x^i` is at index `i`.
55    pub fn coeffs(&self) -> &[f64; N] {
56        &self.coeffs
57    }
58
59    /// Evaluates this polynomial at a point.
60    pub fn eval(&self, x: f64) -> f64 {
61        let mut acc = 0.0;
62        for c in self.coeffs.iter().rev() {
63            // It would be nice to use `f64::mul_add` here, but it's slow on
64            // architectures that don't have a dedicated instruction.
65            acc = acc * x + c;
66        }
67        acc
68    }
69
70    /// Returns the largest absolute value of any coefficient.
71    ///
72    /// Always returns a non-negative number, or NaN if some coefficient is NaN.
73    pub fn max_abs_coefficient(&self) -> f64 {
74        let mut max = 0.0f64;
75        for c in &self.coeffs {
76            max = max.max(c.abs());
77        }
78        max
79    }
80
81    /// Are all the coefficients finite?
82    pub fn is_finite(&self) -> bool {
83        self.coeffs.iter().all(|c| c.is_finite())
84    }
85}
86
87macro_rules! impl_deriv_and_deflate {
88    ($N:literal, $N_MINUS_ONE:literal) => {
89        impl Poly<$N> {
90            /// Compute the derivative of this polynomial, as a polynomial with
91            /// one less coefficient.
92            pub fn deriv(&self) -> Poly<$N_MINUS_ONE> {
93                let mut coeffs = [0.0; $N_MINUS_ONE];
94                for (i, (d, c)) in coeffs.iter_mut().zip(&self.coeffs[1..]).enumerate() {
95                    *d = (i + 1) as f64 * c;
96                }
97                Poly::new(coeffs)
98            }
99
100            /// Divide this polynomial by the polynomial `x - root`, returning the
101            /// quotient (as a polynomial with one less coefficient) and ignoring
102            /// the remainder.
103            ///
104            /// If `root` is actually a root of `self` (as the name suggests
105            /// it should be, but this is not actually required), the
106            /// remainder will be zero. In general, the remainder will be
107            /// `self.eval(root)`.
108            pub fn deflate(&self, root: f64) -> Poly<$N_MINUS_ONE> {
109                let mut acc = 0.0;
110                let mut coeffs = [0.0; $N_MINUS_ONE];
111                for (d, c) in coeffs.iter_mut().zip(&self.coeffs[1..]).rev() {
112                    acc = acc * root + c;
113                    *d = acc;
114                }
115                Poly::new(coeffs)
116            }
117        }
118    };
119}
120
121macro_rules! impl_roots_between_recursive {
122    ($N:literal, $N_MINUS_ONE:literal) => {
123        impl Poly<$N> {
124            /// Computes all roots between `lower` and `upper`, to the desired accuracy.
125            ///
126            /// We make no guarantees about multiplicity. For example, if there's a
127            /// double-root that isn't a triple-root (and therefore has no sign change
128            /// nearby) then there's a good chance we miss it altogether. This is
129            /// fine if you're using this root-finding to find critical points for
130            /// optimizing a polynomial, because roots that don't come with a sign
131            /// change aren't local extrema.
132            pub fn roots_between(
133                self,
134                lower: f64,
135                upper: f64,
136                x_error: f64,
137            ) -> ArrayVec<f64, $N_MINUS_ONE> {
138                let mut ret = ArrayVec::new();
139                let mut scratch = ArrayVec::new();
140                self.roots_between_with_buffer(lower, upper, x_error, &mut ret, &mut scratch);
141                ret
142            }
143
144            // This would ideally have a `where M >= N - 1` bound on it,
145            // but it's private so it shouldn't matter too much.
146            // We assume that `scratch` and `out` are both empty.
147            fn roots_between_with_buffer<const M: usize>(
148                self,
149                lower: f64,
150                upper: f64,
151                x_error: f64,
152                out: &mut ArrayVec<f64, M>,
153                scratch: &mut ArrayVec<f64, M>,
154            ) {
155                let deriv = self.deriv();
156                if !deriv.is_finite() {
157                    return;
158                }
159                deriv.roots_between_with_buffer(lower, upper, x_error, scratch, out);
160                scratch.push(upper);
161                out.clear();
162                let mut last = lower;
163                let mut last_val = self.eval(last);
164
165                // `endpoint` now contains all the critical points (in increasing order)
166                // and the upper endpoint of the interval. These are the endpoints
167                // of the potential bracketing intervals of our polynomial.
168                for &mut x in scratch {
169                    let val = self.eval(x);
170                    if $crate::different_signs(last_val, val) {
171                        out.push($crate::yuksel::find_root(
172                            |x| self.eval(x),
173                            |x| deriv.eval(x),
174                            last,
175                            x,
176                            last_val,
177                            val,
178                            x_error,
179                        ));
180                    }
181
182                    last = x;
183                    last_val = val;
184                }
185            }
186        }
187    };
188}
189
190impl_deriv_and_deflate!(3, 2);
191impl_deriv_and_deflate!(4, 3);
192impl_deriv_and_deflate!(5, 4);
193impl_deriv_and_deflate!(6, 5);
194impl_deriv_and_deflate!(7, 6);
195impl_deriv_and_deflate!(8, 7);
196impl_deriv_and_deflate!(9, 8);
197impl_deriv_and_deflate!(10, 9);
198
199impl_roots_between_recursive!(5, 4);
200impl_roots_between_recursive!(6, 5);
201impl_roots_between_recursive!(7, 6);
202impl_roots_between_recursive!(8, 7);
203impl_roots_between_recursive!(9, 8);
204impl_roots_between_recursive!(10, 9);
205
206impl<const N: usize> core::ops::Mul<f64> for Poly<N> {
207    type Output = Poly<N>;
208
209    fn mul(mut self, scale: f64) -> Poly<N> {
210        self *= scale;
211        self
212    }
213}
214
215impl<const N: usize> core::ops::MulAssign<f64> for Poly<N> {
216    fn mul_assign(&mut self, scale: f64) {
217        for c in &mut self.coeffs {
218            *c *= scale;
219        }
220    }
221}
222
223impl<const N: usize> core::ops::Mul<f64> for &Poly<N> {
224    type Output = Poly<N>;
225
226    fn mul(self, scale: f64) -> Poly<N> {
227        (*self) * scale
228    }
229}
230
231impl<const N: usize> core::ops::Div<f64> for Poly<N> {
232    type Output = Poly<N>;
233
234    fn div(mut self, scale: f64) -> Poly<N> {
235        self /= scale;
236        self
237    }
238}
239
240impl<const N: usize> core::ops::DivAssign<f64> for Poly<N> {
241    fn div_assign(&mut self, scale: f64) {
242        for c in &mut self.coeffs {
243            *c /= scale;
244        }
245    }
246}
247
248impl<const N: usize> core::ops::Div<f64> for &Poly<N> {
249    type Output = Poly<N>;
250
251    fn div(self, scale: f64) -> Poly<N> {
252        (*self) / scale
253    }
254}
255
256impl<const N: usize> core::ops::AddAssign<&Poly<N>> for Poly<N> {
257    fn add_assign(&mut self, rhs: &Poly<N>) {
258        for (c, d) in self.coeffs.iter_mut().zip(rhs.coeffs) {
259            *c += d;
260        }
261    }
262}
263
264impl<const N: usize> core::ops::AddAssign<Poly<N>> for Poly<N> {
265    fn add_assign(&mut self, rhs: Poly<N>) {
266        *self += &rhs;
267    }
268}
269
270impl<const N: usize> core::ops::Add<Poly<N>> for Poly<N> {
271    type Output = Poly<N>;
272
273    fn add(mut self, rhs: Poly<N>) -> Poly<N> {
274        self += rhs;
275        self
276    }
277}
278
279impl<const N: usize> core::ops::Add<&Poly<N>> for Poly<N> {
280    type Output = Poly<N>;
281
282    fn add(mut self, rhs: &Poly<N>) -> Poly<N> {
283        self += rhs;
284        self
285    }
286}
287
288impl<const N: usize> core::ops::Add<Poly<N>> for &Poly<N> {
289    type Output = Poly<N>;
290
291    fn add(self, mut rhs: Poly<N>) -> Poly<N> {
292        rhs += self;
293        rhs
294    }
295}
296
297impl<const N: usize> core::ops::SubAssign<&Poly<N>> for Poly<N> {
298    fn sub_assign(&mut self, rhs: &Poly<N>) {
299        for (c, d) in self.coeffs.iter_mut().zip(rhs.coeffs) {
300            *c -= d;
301        }
302    }
303}
304
305impl<const N: usize> core::ops::SubAssign<Poly<N>> for Poly<N> {
306    fn sub_assign(&mut self, rhs: Poly<N>) {
307        *self -= &rhs;
308    }
309}
310
311impl<const N: usize> core::ops::Sub<Poly<N>> for Poly<N> {
312    type Output = Poly<N>;
313
314    fn sub(mut self, rhs: Poly<N>) -> Poly<N> {
315        self -= rhs;
316        self
317    }
318}
319
320impl<const N: usize> core::ops::Sub<&Poly<N>> for Poly<N> {
321    type Output = Poly<N>;
322
323    fn sub(mut self, rhs: &Poly<N>) -> Poly<N> {
324        self -= rhs;
325        self
326    }
327}
328
329impl<const N: usize> core::ops::Sub<Poly<N>> for &Poly<N> {
330    type Output = Poly<N>;
331
332    fn sub(self, mut rhs: Poly<N>) -> Poly<N> {
333        rhs -= self;
334        rhs
335    }
336}
337
338// We do property-testing with two strategies:
339//
340// - for the "value-testing" strategy, we test that the polynomial
341//   approximately evaluates to zero on all the claimed roots.
342// - for the "planted root" strategy, we generate a polynomial with
343//   a known root and check that we find it
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn smoke() {
350        let p = Poly::new([-6.0, 11.0, -6.0, 1.0]);
351
352        let roots = p.roots_between(0.0, 5.0, 1e-6);
353        assert_eq!(roots.len(), 3);
354        assert!((roots[0] - 1.0).abs() <= 1e-6);
355        assert!((roots[1] - 2.0).abs() <= 1e-6);
356        assert!((roots[2] - 3.0).abs() <= 1e-6);
357
358        let p = Poly::new([24.0, -50.0, 35.0, -10.0, 1.0]);
359
360        let roots = p.roots_between(0.0, 5.0, 1e-6);
361        assert_eq!(roots.len(), 4);
362        assert!((roots[0] - 1.0).abs() <= 1e-6);
363        assert!((roots[1] - 2.0).abs() <= 1e-6);
364        assert!((roots[2] - 3.0).abs() <= 1e-6);
365        assert!((roots[3] - 4.0).abs() <= 1e-6);
366    }
367
368    // Asserts that the supplied "roots" are close to being roots of the
369    // cubic, in the sense that the cubic evaluates to approximately zero
370    // on each of the roots.
371    fn check_root_values<const N: usize>(p: &Poly<N>, roots: &[f64]) {
372        // Arbitrary cubics can have coefficients with wild magnitudes,
373        // so we need to adjust our error expectations accordingly.
374        let magnitude = p.max_abs_coefficient().max(1.0);
375        let accuracy = magnitude * 1e-12;
376
377        for r in roots {
378            // We can't expect great accuracy for very large roots,
379            // because the polynomial evaluation will involve very
380            // large terms.
381            let accuracy = accuracy * r.abs().powi(N as i32 - 1).max(1.0);
382            let y = p.eval(*r);
383            assert!(
384                y.abs() <= accuracy,
385                "poly {p:?} had root {r} evaluate to {y:?}, but expected {accuracy:?}"
386            );
387        }
388    }
389
390    #[test]
391    fn root_evaluation_deg3() {
392        arbtest::arbtest(|u| {
393            let poly: Poly<4> = crate::arbitrary::poly(u)?;
394            // Ignore very large polynomials, because they'll just overflow everything.
395            if (poly.max_abs_coefficient() * 10.0f64.powi(4)).is_infinite() {
396                return Err(arbitrary::Error::IncorrectFormat);
397            }
398            let roots = poly.roots_between(-10.0, 10.0, 1e-13);
399
400            check_root_values(&poly, &roots);
401            Ok(())
402        })
403        .budget_ms(5_000);
404    }
405
406    #[test]
407    fn root_evaluation_deg4() {
408        arbtest::arbtest(|u| {
409            let poly: Poly<5> = crate::arbitrary::poly(u)?;
410            if (poly.max_abs_coefficient() * 10.0f64.powi(5)).is_infinite() {
411                return Err(arbitrary::Error::IncorrectFormat);
412            }
413            let roots = poly.roots_between(-10.0, 10.0, 1e-13);
414            check_root_values(&poly, &roots);
415            Ok(())
416        })
417        .budget_ms(5_000);
418    }
419
420    #[test]
421    fn root_evaluation_deg9() {
422        arbtest::arbtest(|u| {
423            let poly: Poly<10> = crate::arbitrary::poly(u)?;
424            if (poly.max_abs_coefficient() * 10.0f64.powi(11)).is_infinite() {
425                return Err(arbitrary::Error::IncorrectFormat);
426            }
427            let roots = poly.roots_between(-10.0, 10.0, 1e-13);
428            check_root_values(&poly, &roots);
429            Ok(())
430        })
431        .budget_ms(5_000);
432    }
433
434    #[test]
435    fn planted_root_deg5() {
436        arbtest::arbtest(|u| {
437            let planted_root = crate::arbitrary::float_in_unit_interval(u)?;
438            let poly: Poly<6> = crate::arbitrary::poly_with_planted_root(u, planted_root, 1e-6)?;
439
440            // Bear in mind that Yuksel's algorithm needs iterated derivatives to be
441            // finite (and that we aren't doing any preconditioning or normalization yet),
442            // ensure that the polynomial isn't too big.
443            if (poly.max_abs_coefficient() * 1024.0).is_infinite() {
444                return Err(arbitrary::Error::IncorrectFormat);
445            }
446            let roots = poly.roots_between(-2.0, 2.0, 1e-13);
447
448            assert!(roots.iter().all(|r| r.is_finite()));
449
450            // Check that the roots are sorted.
451            assert!(roots.is_sorted());
452            assert!(roots.iter().all(|r| (-2.0..=2.0).contains(r)));
453
454            // We can't expect great accuracy for huge coefficients, because the
455            // evaluations during Newton iteration are subject to error.
456            let error = poly.max_abs_coefficient().max(1.0) * 1e-12;
457            assert!(roots.iter().any(|r| (r - planted_root).abs() <= error));
458            Ok(())
459        })
460        .budget_ms(5_000);
461    }
462}