1use std::ops;
2
3use roots::{find_roots_cubic, Roots};
4#[cfg(feature = "serialization")]
5use serde_derive::{Deserialize, Serialize};
6
7#[derive(Clone, Copy, Debug, PartialEq)]
8#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
9pub struct CubicPoly<T> {
10 a: T,
11 b: T,
12 c: T,
13 d: T,
14}
15
16#[derive(Clone, Copy, Debug, PartialEq)]
17#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
18pub enum Factors {
19 ThreeLinear { a: f64, x1: f64, x2: f64, x3: f64 },
21 LinearAndQuadratic { a: f64, x1: f64, b: f64, c: f64 },
23}
24
25impl<T> CubicPoly<T>
26where
27 T: ops::Add<T, Output = T>
28 + ops::AddAssign<T>
29 + ops::Sub<T, Output = T>
30 + ops::SubAssign<T>
31 + ops::Mul<f64, Output = T>
32 + Copy,
33{
34 pub fn new(a: T, b: T, c: T, d: T) -> Self {
35 Self { a, b, c, d }
36 }
37
38 pub fn shifted(self, x0: f64) -> Self {
40 let a = self.a;
41 let b = self.b - self.a * 3.0 * x0;
42 let c = self.c + self.a * 3.0 * x0 * x0 - self.b * 2.0 * x0;
43 let d = self.d - self.a * x0 * x0 * x0 + self.b * x0 * x0 - self.c * x0;
44 Self { a, b, c, d }
45 }
46
47 pub fn eval(&self, x: f64) -> T {
48 self.a * x * x * x + self.b * x * x + self.c * x + self.d
49 }
50
51 pub fn derivative(&self, x: f64) -> T {
52 self.a * 3.0 * x * x + self.b * 2.0 * x + self.c
53 }
54}
55
56impl CubicPoly<f64> {
57 pub fn factors(&self) -> Factors {
58 let roots = find_roots_cubic(self.a, self.b, self.c, self.d);
59 match roots {
60 Roots::One([x1]) | Roots::Two([x1, _]) => {
61 let b = self.b / self.a + x1;
62 let c = self.c / self.a + b * x1;
63 let delta = b * b - 4.0 * c;
65 if delta >= 0.0 {
66 let x2 = 0.5 * (-b - delta.sqrt());
67 let x3 = 0.5 * (-b + delta.sqrt());
68 let (x1, x2) = if x1 < x2 { (x1, x2) } else { (x2, x1) };
70 let (x1, x3) = if x1 < x3 { (x1, x3) } else { (x3, x1) };
71 let (x2, x3) = if x2 < x3 { (x2, x3) } else { (x3, x2) };
72 Factors::ThreeLinear {
73 a: self.a,
74 x1,
75 x2,
76 x3,
77 }
78 } else {
79 Factors::LinearAndQuadratic {
80 a: self.a,
81 x1,
82 b,
83 c,
84 }
85 }
86 }
87 Roots::Three([x1, x2, x3]) => Factors::ThreeLinear {
88 a: self.a,
89 x1,
90 x2,
91 x3,
92 },
93 _ => panic!("should have either one or three roots! {:?}", roots),
94 }
95 }
96}
97
98impl<T> ops::AddAssign<CubicPoly<T>> for CubicPoly<T>
99where
100 T: ops::AddAssign<T>,
101{
102 fn add_assign(&mut self, other: CubicPoly<T>) {
103 self.a += other.a;
104 self.b += other.b;
105 self.c += other.c;
106 self.d += other.d;
107 }
108}
109
110impl<T> ops::SubAssign<CubicPoly<T>> for CubicPoly<T>
111where
112 T: ops::SubAssign<T>,
113{
114 fn sub_assign(&mut self, other: CubicPoly<T>) {
115 self.a -= other.a;
116 self.b -= other.b;
117 self.c -= other.c;
118 self.d -= other.d;
119 }
120}
121
122impl<T> ops::Add<CubicPoly<T>> for CubicPoly<T>
123where
124 T: ops::AddAssign<T>,
125{
126 type Output = CubicPoly<T>;
127
128 fn add(mut self, other: CubicPoly<T>) -> CubicPoly<T> {
129 self += other;
130 self
131 }
132}
133
134impl<T> ops::Sub<CubicPoly<T>> for CubicPoly<T>
135where
136 T: ops::SubAssign<T>,
137{
138 type Output = CubicPoly<T>;
139
140 fn sub(mut self, other: CubicPoly<T>) -> CubicPoly<T> {
141 self -= other;
142 self
143 }
144}
145
146impl<T> ops::MulAssign<f64> for CubicPoly<T>
147where
148 T: ops::MulAssign<f64>,
149{
150 fn mul_assign(&mut self, other: f64) {
151 self.a *= other;
152 self.b *= other;
153 self.c *= other;
154 self.d *= other;
155 }
156}
157
158impl<T> ops::Mul<f64> for CubicPoly<T>
159where
160 T: ops::MulAssign<f64>,
161{
162 type Output = CubicPoly<T>;
163
164 fn mul(mut self, other: f64) -> CubicPoly<T> {
165 self *= other;
166 self
167 }
168}
169
170impl<T> ops::DivAssign<f64> for CubicPoly<T>
171where
172 T: ops::DivAssign<f64>,
173{
174 fn div_assign(&mut self, other: f64) {
175 self.a /= other;
176 self.b /= other;
177 self.c /= other;
178 self.d /= other;
179 }
180}
181
182impl<T> ops::Div<f64> for CubicPoly<T>
183where
184 T: ops::DivAssign<f64>,
185{
186 type Output = CubicPoly<T>;
187
188 fn div(mut self, other: f64) -> CubicPoly<T> {
189 self /= other;
190 self
191 }
192}
193
194#[cfg(test)]
195#[allow(clippy::float_cmp)]
196mod tests {
197 use super::{CubicPoly, Factors};
198
199 #[test]
200 fn test_poly_shift() {
201 let poly = CubicPoly::new(1.0, -1.0, 1.0, -1.0);
202 assert_eq!(poly.eval(0.0), -1.0);
203 assert_eq!(poly.eval(1.0), 0.0);
204 assert_eq!(poly.eval(2.0), 5.0);
205 let poly2 = poly.shifted(1.0); assert_eq!(poly2.eval(1.0), -1.0);
207 assert_eq!(poly2.eval(2.0), 0.0);
208 assert_eq!(poly2.eval(3.0), 5.0);
209 }
210
211 #[test]
212 fn test_triple_root() {
213 let poly = CubicPoly::new(2.0, -6.0, 6.0, -2.0);
214 assert_eq!(
215 poly.factors(),
216 Factors::ThreeLinear {
217 a: 2.0,
218 x1: 1.0,
219 x2: 1.0,
220 x3: 1.0,
221 }
222 );
223 }
224
225 #[test]
226 fn test_double_root() {
227 let poly = CubicPoly::new(1.0, 1.0, -1.0, -1.0);
228 assert_eq!(
229 poly.factors(),
230 Factors::ThreeLinear {
231 a: 1.0,
232 x1: -1.0,
233 x2: -1.0,
234 x3: 1.0,
235 }
236 );
237 }
238
239 #[test]
240 fn test_single_root() {
241 let poly = CubicPoly::new(1.0, -1.0, 1.0, -1.0);
242 assert_eq!(
243 poly.factors(),
244 Factors::LinearAndQuadratic {
245 a: 1.0,
246 x1: 1.0,
247 b: 0.0,
248 c: 1.0,
249 }
250 );
251 }
252}