use std::ops;
use roots::{find_roots_cubic, Roots};
#[cfg(feature = "serialization")]
use serde_derive::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
pub struct CubicPoly<T> {
a: T,
b: T,
c: T,
d: T,
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
pub enum Factors {
ThreeLinear { a: f64, x1: f64, x2: f64, x3: f64 },
LinearAndQuadratic { a: f64, x1: f64, b: f64, c: f64 },
}
impl<T> CubicPoly<T>
where
T: ops::Add<T, Output = T>
+ ops::AddAssign<T>
+ ops::Sub<T, Output = T>
+ ops::SubAssign<T>
+ ops::Mul<f64, Output = T>
+ Copy,
{
pub fn new(a: T, b: T, c: T, d: T) -> Self {
Self { a, b, c, d }
}
pub fn shifted(self, x0: f64) -> Self {
let a = self.a;
let b = self.b - self.a * 3.0 * x0;
let c = self.c + self.a * 3.0 * x0 * x0 - self.b * 2.0 * x0;
let d = self.d - self.a * x0 * x0 * x0 + self.b * x0 * x0 - self.c * x0;
Self { a, b, c, d }
}
pub fn eval(&self, x: f64) -> T {
self.a * x * x * x + self.b * x * x + self.c * x + self.d
}
pub fn derivative(&self, x: f64) -> T {
self.a * 3.0 * x * x + self.b * 2.0 * x + self.c
}
}
impl CubicPoly<f64> {
pub fn factors(&self) -> Factors {
let roots = find_roots_cubic(self.a, self.b, self.c, self.d);
match roots {
Roots::One([x1]) | Roots::Two([x1, _]) => {
let b = self.b / self.a + x1;
let c = self.c / self.a + b * x1;
let delta = b * b - 4.0 * c;
if delta >= 0.0 {
let x2 = 0.5 * (-b - delta.sqrt());
let x3 = 0.5 * (-b + delta.sqrt());
let (x1, x2) = if x1 < x2 { (x1, x2) } else { (x2, x1) };
let (x1, x3) = if x1 < x3 { (x1, x3) } else { (x3, x1) };
let (x2, x3) = if x2 < x3 { (x2, x3) } else { (x3, x2) };
Factors::ThreeLinear {
a: self.a,
x1,
x2,
x3,
}
} else {
Factors::LinearAndQuadratic {
a: self.a,
x1,
b,
c,
}
}
}
Roots::Three([x1, x2, x3]) => Factors::ThreeLinear {
a: self.a,
x1,
x2,
x3,
},
_ => panic!("should have either one or three roots! {:?}", roots),
}
}
}
impl<T> ops::AddAssign<CubicPoly<T>> for CubicPoly<T>
where
T: ops::AddAssign<T>,
{
fn add_assign(&mut self, other: CubicPoly<T>) {
self.a += other.a;
self.b += other.b;
self.c += other.c;
self.d += other.d;
}
}
impl<T> ops::SubAssign<CubicPoly<T>> for CubicPoly<T>
where
T: ops::SubAssign<T>,
{
fn sub_assign(&mut self, other: CubicPoly<T>) {
self.a -= other.a;
self.b -= other.b;
self.c -= other.c;
self.d -= other.d;
}
}
impl<T> ops::Add<CubicPoly<T>> for CubicPoly<T>
where
T: ops::AddAssign<T>,
{
type Output = CubicPoly<T>;
fn add(mut self, other: CubicPoly<T>) -> CubicPoly<T> {
self += other;
self
}
}
impl<T> ops::Sub<CubicPoly<T>> for CubicPoly<T>
where
T: ops::SubAssign<T>,
{
type Output = CubicPoly<T>;
fn sub(mut self, other: CubicPoly<T>) -> CubicPoly<T> {
self -= other;
self
}
}
impl<T> ops::MulAssign<f64> for CubicPoly<T>
where
T: ops::MulAssign<f64>,
{
fn mul_assign(&mut self, other: f64) {
self.a *= other;
self.b *= other;
self.c *= other;
self.d *= other;
}
}
impl<T> ops::Mul<f64> for CubicPoly<T>
where
T: ops::MulAssign<f64>,
{
type Output = CubicPoly<T>;
fn mul(mut self, other: f64) -> CubicPoly<T> {
self *= other;
self
}
}
impl<T> ops::DivAssign<f64> for CubicPoly<T>
where
T: ops::DivAssign<f64>,
{
fn div_assign(&mut self, other: f64) {
self.a /= other;
self.b /= other;
self.c /= other;
self.d /= other;
}
}
impl<T> ops::Div<f64> for CubicPoly<T>
where
T: ops::DivAssign<f64>,
{
type Output = CubicPoly<T>;
fn div(mut self, other: f64) -> CubicPoly<T> {
self /= other;
self
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::{CubicPoly, Factors};
#[test]
fn test_poly_shift() {
let poly = CubicPoly::new(1.0, -1.0, 1.0, -1.0);
assert_eq!(poly.eval(0.0), -1.0);
assert_eq!(poly.eval(1.0), 0.0);
assert_eq!(poly.eval(2.0), 5.0);
let poly2 = poly.shifted(1.0); assert_eq!(poly2.eval(1.0), -1.0);
assert_eq!(poly2.eval(2.0), 0.0);
assert_eq!(poly2.eval(3.0), 5.0);
}
#[test]
fn test_triple_root() {
let poly = CubicPoly::new(2.0, -6.0, 6.0, -2.0);
assert_eq!(
poly.factors(),
Factors::ThreeLinear {
a: 2.0,
x1: 1.0,
x2: 1.0,
x3: 1.0,
}
);
}
#[test]
fn test_double_root() {
let poly = CubicPoly::new(1.0, 1.0, -1.0, -1.0);
assert_eq!(
poly.factors(),
Factors::ThreeLinear {
a: 1.0,
x1: -1.0,
x2: -1.0,
x3: 1.0,
}
);
}
#[test]
fn test_single_root() {
let poly = CubicPoly::new(1.0, -1.0, 1.0, -1.0);
assert_eq!(
poly.factors(),
Factors::LinearAndQuadratic {
a: 1.0,
x1: 1.0,
b: 0.0,
c: 1.0,
}
);
}
}