cubic_splines/
cubic_poly.rs

1use std::ops;
2
3use roots::{find_roots_cubic, Roots};
4#[cfg(feature = "serialization")]
5use serde_derive::{Deserialize, Serialize};
6
7#[derive(Clone, Copy, Debug, PartialEq)]
8#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
9pub struct CubicPoly<T> {
10    a: T,
11    b: T,
12    c: T,
13    d: T,
14}
15
16#[derive(Clone, Copy, Debug, PartialEq)]
17#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
18pub enum Factors {
19    /// f(x) = a(x-x1)(x-x2)(x-x3)
20    ThreeLinear { a: f64, x1: f64, x2: f64, x3: f64 },
21    /// f(x) = a(x-x1)(x²+bx+c)
22    LinearAndQuadratic { a: f64, x1: f64, b: f64, c: f64 },
23}
24
25impl<T> CubicPoly<T>
26where
27    T: ops::Add<T, Output = T>
28        + ops::AddAssign<T>
29        + ops::Sub<T, Output = T>
30        + ops::SubAssign<T>
31        + ops::Mul<f64, Output = T>
32        + Copy,
33{
34    pub fn new(a: T, b: T, c: T, d: T) -> Self {
35        Self { a, b, c, d }
36    }
37
38    /// Creates a polynomial g(x) = f(x-x0), where f(x) = self
39    pub fn shifted(self, x0: f64) -> Self {
40        let a = self.a;
41        let b = self.b - self.a * 3.0 * x0;
42        let c = self.c + self.a * 3.0 * x0 * x0 - self.b * 2.0 * x0;
43        let d = self.d - self.a * x0 * x0 * x0 + self.b * x0 * x0 - self.c * x0;
44        Self { a, b, c, d }
45    }
46
47    pub fn eval(&self, x: f64) -> T {
48        self.a * x * x * x + self.b * x * x + self.c * x + self.d
49    }
50
51    pub fn derivative(&self, x: f64) -> T {
52        self.a * 3.0 * x * x + self.b * 2.0 * x + self.c
53    }
54}
55
56impl CubicPoly<f64> {
57    pub fn factors(&self) -> Factors {
58        let roots = find_roots_cubic(self.a, self.b, self.c, self.d);
59        match roots {
60            Roots::One([x1]) | Roots::Two([x1, _]) => {
61                let b = self.b / self.a + x1;
62                let c = self.c / self.a + b * x1;
63                // make sure that we haven't missed any real roots
64                let delta = b * b - 4.0 * c;
65                if delta >= 0.0 {
66                    let x2 = 0.5 * (-b - delta.sqrt());
67                    let x3 = 0.5 * (-b + delta.sqrt());
68                    // sort the roots
69                    let (x1, x2) = if x1 < x2 { (x1, x2) } else { (x2, x1) };
70                    let (x1, x3) = if x1 < x3 { (x1, x3) } else { (x3, x1) };
71                    let (x2, x3) = if x2 < x3 { (x2, x3) } else { (x3, x2) };
72                    Factors::ThreeLinear {
73                        a: self.a,
74                        x1,
75                        x2,
76                        x3,
77                    }
78                } else {
79                    Factors::LinearAndQuadratic {
80                        a: self.a,
81                        x1,
82                        b,
83                        c,
84                    }
85                }
86            }
87            Roots::Three([x1, x2, x3]) => Factors::ThreeLinear {
88                a: self.a,
89                x1,
90                x2,
91                x3,
92            },
93            _ => panic!("should have either one or three roots! {:?}", roots),
94        }
95    }
96}
97
98impl<T> ops::AddAssign<CubicPoly<T>> for CubicPoly<T>
99where
100    T: ops::AddAssign<T>,
101{
102    fn add_assign(&mut self, other: CubicPoly<T>) {
103        self.a += other.a;
104        self.b += other.b;
105        self.c += other.c;
106        self.d += other.d;
107    }
108}
109
110impl<T> ops::SubAssign<CubicPoly<T>> for CubicPoly<T>
111where
112    T: ops::SubAssign<T>,
113{
114    fn sub_assign(&mut self, other: CubicPoly<T>) {
115        self.a -= other.a;
116        self.b -= other.b;
117        self.c -= other.c;
118        self.d -= other.d;
119    }
120}
121
122impl<T> ops::Add<CubicPoly<T>> for CubicPoly<T>
123where
124    T: ops::AddAssign<T>,
125{
126    type Output = CubicPoly<T>;
127
128    fn add(mut self, other: CubicPoly<T>) -> CubicPoly<T> {
129        self += other;
130        self
131    }
132}
133
134impl<T> ops::Sub<CubicPoly<T>> for CubicPoly<T>
135where
136    T: ops::SubAssign<T>,
137{
138    type Output = CubicPoly<T>;
139
140    fn sub(mut self, other: CubicPoly<T>) -> CubicPoly<T> {
141        self -= other;
142        self
143    }
144}
145
146impl<T> ops::MulAssign<f64> for CubicPoly<T>
147where
148    T: ops::MulAssign<f64>,
149{
150    fn mul_assign(&mut self, other: f64) {
151        self.a *= other;
152        self.b *= other;
153        self.c *= other;
154        self.d *= other;
155    }
156}
157
158impl<T> ops::Mul<f64> for CubicPoly<T>
159where
160    T: ops::MulAssign<f64>,
161{
162    type Output = CubicPoly<T>;
163
164    fn mul(mut self, other: f64) -> CubicPoly<T> {
165        self *= other;
166        self
167    }
168}
169
170impl<T> ops::DivAssign<f64> for CubicPoly<T>
171where
172    T: ops::DivAssign<f64>,
173{
174    fn div_assign(&mut self, other: f64) {
175        self.a /= other;
176        self.b /= other;
177        self.c /= other;
178        self.d /= other;
179    }
180}
181
182impl<T> ops::Div<f64> for CubicPoly<T>
183where
184    T: ops::DivAssign<f64>,
185{
186    type Output = CubicPoly<T>;
187
188    fn div(mut self, other: f64) -> CubicPoly<T> {
189        self /= other;
190        self
191    }
192}
193
194#[cfg(test)]
195#[allow(clippy::float_cmp)]
196mod tests {
197    use super::{CubicPoly, Factors};
198
199    #[test]
200    fn test_poly_shift() {
201        let poly = CubicPoly::new(1.0, -1.0, 1.0, -1.0);
202        assert_eq!(poly.eval(0.0), -1.0);
203        assert_eq!(poly.eval(1.0), 0.0);
204        assert_eq!(poly.eval(2.0), 5.0);
205        let poly2 = poly.shifted(1.0); // poly2(x) = poly(x - 1)
206        assert_eq!(poly2.eval(1.0), -1.0);
207        assert_eq!(poly2.eval(2.0), 0.0);
208        assert_eq!(poly2.eval(3.0), 5.0);
209    }
210
211    #[test]
212    fn test_triple_root() {
213        let poly = CubicPoly::new(2.0, -6.0, 6.0, -2.0);
214        assert_eq!(
215            poly.factors(),
216            Factors::ThreeLinear {
217                a: 2.0,
218                x1: 1.0,
219                x2: 1.0,
220                x3: 1.0,
221            }
222        );
223    }
224
225    #[test]
226    fn test_double_root() {
227        let poly = CubicPoly::new(1.0, 1.0, -1.0, -1.0);
228        assert_eq!(
229            poly.factors(),
230            Factors::ThreeLinear {
231                a: 1.0,
232                x1: -1.0,
233                x2: -1.0,
234                x3: 1.0,
235            }
236        );
237    }
238
239    #[test]
240    fn test_single_root() {
241        let poly = CubicPoly::new(1.0, -1.0, 1.0, -1.0);
242        assert_eq!(
243            poly.factors(),
244            Factors::LinearAndQuadratic {
245                a: 1.0,
246                x1: 1.0,
247                b: 0.0,
248                c: 1.0,
249            }
250        );
251    }
252}