use std::{
fmt::Debug,
ops::{Add, Div, Mul, Neg, Sub},
};
use nalgebra::DMatrix;
use crate::{
complex::Complex,
convex_hull::{self, Point2D},
Abs, Const, Cos, Inv, Ln, NumCast, One, Poly, Pow, Sign, Sin, Sqrt, Zero,
};
const DEFAULT_ITERATIONS: u32 = 30;
#[derive(Debug)]
pub(super) struct RootsFinder<T> {
poly: Poly<T>,
derivative: Poly<T>,
solution: Vec<Complex<T>>,
iterations: u32,
}
impl<T> RootsFinder<T>
where
T: Abs
+ Add<Output = T>
+ Clone
+ Const
+ Cos
+ Div<Output = T>
+ Inv
+ Ln
+ Mul<Output = T>
+ Neg<Output = T>
+ NumCast
+ One
+ PartialEq
+ PartialOrd
+ Pow
+ Sin
+ Sub<Output = T>
+ Zero,
{
pub(super) fn new(poly: Poly<T>, iterations: u32) -> Self {
let derivative = poly.derive();
let initial_guess = init(&poly);
debug_assert!(poly.degree().unwrap_or(0) == initial_guess.len());
Self {
poly,
derivative,
solution: initial_guess,
iterations,
}
}
pub(super) fn roots_finder(mut self) -> Vec<Complex<T>> {
let n_roots = self.solution.len();
let mut done = vec![false; n_roots];
for _k in 0..self.iterations {
if done.iter().all(|&d| d) {
break;
}
for (i, d) in done.iter_mut().enumerate() {
let solution_i = self.solution[i].clone();
let derivative = self.derivative.eval(&solution_i);
let a_xki: Complex<T> = self
.solution
.iter()
.enumerate()
.filter_map(|(j, s)| {
if j == i {
None
} else {
let den = solution_i.clone() - s;
Some(den.inv())
}
})
.fold(Complex::zero(), |acc, c| acc + c);
let fraction = if derivative.is_zero() {
-a_xki.inv()
} else {
let n_xki = self.poly.eval(&solution_i) / derivative;
n_xki.clone() / (Complex::<T>::one() - n_xki * a_xki)
};
let new = solution_i.clone() - fraction;
*d = if solution_i == new {
true
} else {
self.solution[i] = new;
false
};
}
}
self.solution
}
}
#[derive(Clone, Debug)]
struct CoeffPoint<T>(usize, T, T);
impl<T: Clone> Point2D for CoeffPoint<T> {
type Output = T;
fn x(&self) -> Self::Output {
self.1.clone()
}
fn y(&self) -> Self::Output {
self.2.clone()
}
}
fn init<T>(poly: &Poly<T>) -> Vec<Complex<T>>
where
T: Abs
+ Add<Output = T>
+ Clone
+ Const
+ Cos
+ Div<Output = T>
+ Inv
+ Ln
+ Mul<Output = T>
+ Neg<Output = T>
+ NumCast
+ PartialOrd
+ Pow
+ Sin
+ Sub<Output = T>
+ Zero,
{
let set = poly
.coeffs
.iter()
.enumerate()
.map(|(k, c)| CoeffPoint(k, T::from(k).unwrap(), c.abs().ln()));
let hull = convex_hull::convex_hull_top(set);
let ch: Vec<_> = hull.iter().map(|CoeffPoint(a, b, _)| (a, b)).collect();
let r = ch.windows(2).map(|w| {
let tmp = (poly.coeffs[*w[0].0].clone() / poly.coeffs[*w[1].0].clone()).abs();
(
w[1].0 - w[0].0,
tmp.powf((w[1].1.clone() - w[0].1.clone()).inv()),
)
});
let initial: Vec<Complex<T>> = r
.flat_map(|(n_k, r)| {
let n_k_f = T::from(n_k).unwrap();
(0..n_k).map(move |i| {
let i_f = T::from(i).unwrap();
let theta = T::tau() * i_f / n_k_f.clone();
Complex::from_polar(r.clone(), theta)
})
})
.collect();
initial
}
macro_rules! eigen_roots {
($ty:ty) => {
impl Poly<$ty> {
fn companion(&self) -> Option<DMatrix<$ty>> {
match self.degree() {
Some(degree) if degree > 0 => {
let hi_coeff = self.coeffs[degree];
let comp = DMatrix::from_fn(degree, degree, |i, j| {
if j == degree - 1 {
-self.coeffs[i] / hi_coeff } else if i == j + 1 {
<$ty as One>::one()
} else {
<$ty as Zero>::zero()
}
});
debug_assert!(comp.is_square());
Some(comp)
}
_ => None,
}
}
}
impl Poly<$ty> {
#[doc = "Calculate the real roots of the polynomial"]
#[doc = " using companion matrix eigenvalues decomposition.\n"]
#[doc = "# Example\n```\nuse polynomen::Poly;"]
#[doc = concat!("let roots = &[1.0_", stringify!($ty), ", -1., 0.];")]
#[doc = "let p = Poly::new_from_roots(roots);"]
#[doc = "assert_eq!(roots, p.real_roots().unwrap().as_slice());\n```"]
#[must_use]
pub fn real_roots(&self) -> Option<Vec<$ty>> {
let (zeros, cropped) = self.find_zero_roots();
let roots = match cropped.degree() {
Some(0) | None => None,
Some(1) => Some(cropped.real_deg1_root()),
Some(2) => cropped.real_deg2_roots(),
_ => {
let comp = cropped.companion()?;
comp.eigenvalues().map(|e| e.as_slice().to_vec())
}
};
roots.map(|r| extend_roots(r, zeros))
}
#[doc = "Calculate the complex roots of the polynomial"]
#[doc = " using companion matrix eigenvalues decomposition.\n"]
#[doc = "# Example\n```\nuse polynomen::Poly;"]
#[doc = concat!("let p = Poly::new_from_coeffs(&[1.0_", stringify!($ty), ", 0., 1.]);")]
#[doc = "assert_eq!(vec![(0., -1.), (0., 1.)], p.complex_roots());\n```"]
#[must_use]
pub fn complex_roots(&self) -> Vec<($ty, $ty)> {
let (zeros, cropped) = self.find_zero_roots();
let roots = match cropped.degree() {
Some(0) | None => Vec::new(),
Some(1) => cropped.complex_deg1_root(),
Some(2) => cropped.complex_deg2_roots(),
_ => {
let comp = match cropped.companion() {
Some(comp) => comp,
None => return Vec::new(),
};
comp.complex_eigenvalues()
.as_slice()
.iter()
.map(|x| Complex::new(x.re, x.im))
.collect()
}
};
extend_roots(roots, zeros)
.iter()
.map(|r| (r.re.x, r.im.x))
.collect::<Vec<_>>()
}
}
};
}
eigen_roots!(f64);
eigen_roots!(f32);
impl<T> Poly<T>
where
T: Abs
+ Add<Output = T>
+ Clone
+ Const
+ Cos
+ Div<Output = T>
+ Inv
+ Ln
+ Mul<Output = T>
+ Neg<Output = T>
+ NumCast
+ One
+ PartialOrd
+ Pow
+ Sign
+ Sin
+ Sqrt
+ Sub<Output = T>
+ Zero,
{
#[must_use]
pub fn iterative_roots(&self) -> Vec<(T, T)> {
self.iterative_roots_with_max(DEFAULT_ITERATIONS)
}
#[must_use]
pub fn iterative_roots_with_max(&self, max_iter: u32) -> Vec<(T, T)> {
let (zeros, cropped) = self.find_zero_roots();
let roots = match cropped.degree() {
Some(0) | None => Vec::new(),
Some(1) => cropped.complex_deg1_root(),
Some(2) => cropped.complex_deg2_roots(),
_ => {
let rf = RootsFinder::new(cropped, max_iter);
rf.roots_finder()
}
};
extend_roots(roots, zeros)
.iter()
.map(|r| (r.re.x.clone(), r.im.x.clone()))
.collect::<Vec<_>>()
}
}
fn extend_roots<T: Clone + Zero>(mut roots: Vec<T>, zeros: usize) -> Vec<T> {
roots.extend(std::iter::repeat(T::zero()).take(zeros));
roots
}
impl<T: Clone + PartialEq + Zero> Poly<T> {
fn find_zero_roots(&self) -> (usize, Self) {
if self.is_zero() {
return (0, Poly::zero());
}
let zeros = self.zero_roots_count();
let p = Self {
coeffs: self.coeffs().split_off(zeros),
};
(zeros, p)
}
#[allow(dead_code)]
fn find_zero_roots_mut(&mut self) -> usize {
if self.is_zero() {
return 0;
}
let zeros = self.zero_roots_count();
self.coeffs.drain(..zeros);
zeros
}
fn zero_roots_count(&self) -> usize {
self.coeffs.iter().take_while(|c| c.is_zero()).count()
}
}
impl<T> Poly<T>
where
T: Abs
+ Add<Output = T>
+ Clone
+ Div<Output = T>
+ Inv
+ Mul<Output = T>
+ Neg<Output = T>
+ One
+ PartialOrd
+ Pow
+ Sign
+ Sqrt
+ Sub<Output = T>
+ Zero,
{
pub(super) fn complex_deg1_root(&self) -> Vec<Complex<T>> {
vec![Complex::new(-self[0].clone() / self[1].clone(), T::zero())]
}
pub(super) fn complex_deg2_roots(&self) -> Vec<Complex<T>> {
let b = self[1].clone() / self[2].clone();
let c = self[0].clone() / self[2].clone();
let (r1, r2) = complex_quadratic_roots_impl(b, c);
vec![r1, r2]
}
pub(super) fn real_deg1_root(&self) -> Vec<T> {
vec![-self[0].clone() / self[1].clone()]
}
pub(super) fn real_deg2_roots(&self) -> Option<Vec<T>> {
let b = self[1].clone() / self[2].clone();
let c = self[0].clone() / self[2].clone();
let (r1, r2) = real_quadratic_roots_impl(b, c)?;
Some(vec![r1, r2])
}
}
#[allow(clippy::many_single_char_names)]
pub(super) fn complex_quadratic_roots_impl<T>(b: T, c: T) -> (Complex<T>, Complex<T>)
where
T: Abs
+ Add<Output = T>
+ Clone
+ Div<Output = T>
+ Inv
+ Mul<Output = T>
+ Neg<Output = T>
+ One
+ PartialOrd
+ Pow
+ Sign
+ Sqrt
+ Sub<Output = T>
+ Zero,
{
let two = T::one() + T::one();
let b_ = b.clone() / two;
let d = b_.powi(2) - c.clone(); let (root1_r, root1_i, root2_r, root2_i) = if d.is_zero() {
(-b_.clone(), T::zero(), -b_, T::zero())
} else if d.is_sign_negative() {
let s = (-d).sqrt();
(-b_.clone(), -s.clone(), -b_, s)
} else {
let s = b.signum() * d.sqrt();
let h = -(b_ + s);
(c / h.clone(), T::zero(), h, T::zero())
};
(
Complex::new(root1_r, root1_i),
Complex::new(root2_r, root2_i),
)
}
#[allow(clippy::many_single_char_names)]
pub(super) fn real_quadratic_roots_impl<T>(b: T, c: T) -> Option<(T, T)>
where
T: Add<Output = T>
+ Clone
+ Div<Output = T>
+ Mul<Output = T>
+ Neg<Output = T>
+ One
+ Pow
+ Sign
+ Sqrt
+ Sub<Output = T>
+ Zero,
{
let two = T::one() + T::one();
let b_ = b.clone() / two;
let d = b_.powi(2) - c.clone(); if d.is_zero() {
Some((-b_.clone(), -b_))
} else if d.is_sign_negative() {
None
} else {
let s = b.signum() * d.sqrt();
let h = -(b_ + s);
Some((c / h.clone(), h))
}
}
#[cfg(test)]
mod tests {
use crate::poly;
use super::*;
#[test]
fn failing_companion() {
let p = Poly::<f32>::zero();
assert_eq!(None, p.companion());
}
#[test]
fn quadratic_roots_with_real_values() {
let root1 = -1.;
let root2 = -2.;
assert_eq!(Some((root1, root2)), real_quadratic_roots_impl(3., 2.));
let root3 = 1.;
let root4 = 2.;
assert_eq!(Some((root3, root4)), real_quadratic_roots_impl(-3., 2.));
assert_eq!(None, real_quadratic_roots_impl(-6., 10.));
let root5 = 3.;
assert_eq!(Some((root5, root5)), real_quadratic_roots_impl(-6., 9.));
}
#[test]
fn none_roots_eigen() {
let p: Poly<f32> = Poly::zero();
let res = p.real_roots();
assert!(res.is_none());
let p = poly!(5.3_f64);
let res = p.complex_roots();
assert_eq!(0, res.len());
assert!(res.is_empty());
}
#[test]
fn real_1_root_eigen() {
let p = poly!(10.0_f32, -2.);
let r = p.real_roots().unwrap();
assert_eq!(r.len(), 1);
assert_relative_eq!(5., r[0]);
}
#[test]
fn real_3_roots_eigen() {
let roots = &[-1.0_f64, 0., 1.];
let p = Poly::new_from_roots(roots);
let mut sorted_roots = p.real_roots().unwrap();
sorted_roots.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
for (r, rr) in roots.iter().zip(&sorted_roots) {
assert_relative_eq!(*r, *rr);
}
}
#[test]
fn complex_1_root_eigen() {
let p = poly!(10.0_f64, -2.);
let r = p.complex_roots();
assert_eq!(r.len(), 1);
assert_eq!((5., 0.), r[0]);
}
#[test]
fn complex_3_roots_eigen() {
let p = Poly::new_from_coeffs(&[1.0_f32, 0., 1.]) * poly!(2., 1.);
assert_eq!(p.complex_roots().len(), 3);
}
#[test]
fn complex_2_roots() {
let root1 = Complex::<f64>::new(-1., 0.);
let root2 = Complex::<f64>::new(-2., 0.);
assert_eq!((root1, root2), complex_quadratic_roots_impl(3., 2.));
let root1 = Complex::<f64>::new(1., 0.);
let root2 = Complex::<f64>::new(2., 0.);
assert_eq!((root1, root2), complex_quadratic_roots_impl(-3., 2.));
let root1 = Complex::<f64>::new(-0., -1.);
let root2 = Complex::<f64>::new(-0., 1.);
assert_eq!((root1, root2), complex_quadratic_roots_impl(0., 1.));
let root1 = Complex::<f64>::new(3., -1.);
let root2 = Complex::<f64>::new(3., 1.);
assert_eq!((root1, root2), complex_quadratic_roots_impl(-6., 10.));
let root1 = Complex::<f64>::new(3., 0.);
assert_eq!((root1, root1), complex_quadratic_roots_impl(-6., 9.));
}
#[test]
fn none_roots_iterative() {
let p: Poly<f32> = Poly::zero();
let res = p.iterative_roots();
assert_eq!(0, res.len());
assert!(res.is_empty());
let p = poly!(5.3);
let res = p.iterative_roots();
assert_eq!(0, res.len());
assert!(res.is_empty());
}
#[test]
fn complex_1_roots_iterative() {
let root = -12.4;
let p = poly!(3.0 * root, 3.0);
let res = p.iterative_roots();
assert_eq!(1, res.len());
let expected = (-root, 0.);
assert_eq!(expected, res[0]);
}
#[test]
fn complex_2_roots_iterative() {
let p = poly!(6., 5., 1.);
let res = p.iterative_roots();
assert_eq!(2, res.len());
let expected1 = (-3., 0.);
let expected2 = (-2., 0.);
assert_eq!(expected2, res[0]);
assert_eq!(expected1, res[1]);
}
#[test]
fn complex_3_roots_iterative() {
let p = Poly::new_from_coeffs(&[1.0_f32, 0., 1.]) * poly!(2., 1.);
assert_eq!(p.iterative_roots().len(), 3);
}
#[test]
fn complex_3_roots_with_zeros_iterative() {
let p = Poly::new_from_coeffs(&[0.0_f32, 0., 1.]) * poly!(2., 1.);
let mut roots = p.iterative_roots();
assert_eq!(roots.len(), 3);
assert_eq!(*roots.last().unwrap(), (0., 0.));
roots.pop();
assert_eq!(*roots.last().unwrap(), (0., 0.));
}
#[test]
fn none_roots_iterative_with_max() {
let p: Poly<f32> = Poly::zero();
let res = p.iterative_roots_with_max(5);
assert_eq!(0, res.len());
assert!(res.is_empty());
let p = poly!(5.3);
let res = p.iterative_roots_with_max(6);
assert_eq!(0, res.len());
assert!(res.is_empty());
}
#[test]
fn complex_1_roots_iterative_with_max() {
let root = -12.4;
let p = poly!(3.0 * root, 3.0);
let res = p.iterative_roots_with_max(5);
assert_eq!(1, res.len());
let expected = (-root, 0.);
assert_eq!(expected, res[0]);
}
#[test]
fn complex_2_roots_iterative_with_max() {
let p = poly!(6., 5., 1.);
let res = p.iterative_roots_with_max(6);
assert_eq!(2, res.len());
let expected1 = (-3., 0.);
let expected2 = (-2., 0.);
assert_eq!(expected2, res[0]);
assert_eq!(expected1, res[1]);
}
#[test]
fn complex_3_roots_iterative_with_max() {
let p = Poly::new_from_coeffs(&[1.0_f32, 0., 1.]) * poly!(2., 1.);
assert_eq!(p.iterative_roots_with_max(7).len(), 3);
}
#[test]
fn remove_zero_roots() {
let p = Poly::new_from_coeffs(&[0, 0, 1, 0, 2]);
let (z, p2) = p.find_zero_roots();
assert_eq!(2, z);
assert_eq!(Poly::new_from_coeffs(&[1, 0, 2]), p2);
}
#[test]
fn remove_zero_roots_mut() {
let mut p = Poly::new_from_coeffs(&[0, 0, 1, 0, 2]);
let z = p.find_zero_roots_mut();
assert_eq!(2, z);
assert_eq!(Poly::new_from_coeffs(&[1, 0, 2]), p);
assert_eq!(0, Poly::<i8>::zero().find_zero_roots_mut());
}
#[test]
fn iterative_roots_finder() {
let roots = &[10.0_f32, 10. / 323.4, 1., -2., 3.];
let poly = Poly::new_from_roots(roots);
let rf = RootsFinder::new(poly, DEFAULT_ITERATIONS);
let actual = rf.roots_finder();
assert_eq!(roots.len(), actual.len());
}
#[test]
fn roots_finder_debug_string() {
let poly = Poly::new_from_coeffs(&[1., 2.]);
let rf = RootsFinder::new(poly, DEFAULT_ITERATIONS);
let debug_str = format!("{:?}", &rf);
assert!(
!debug_str.is_empty(),
"RootsFinder<T> structure must be debuggable if T: Debug."
);
}
#[allow(clippy::float_cmp)]
#[test]
fn coeffpoint_implementation() {
let cp = &CoeffPoint(1, 2., -3.);
assert_eq!(2., cp.x());
assert_eq!(-3., cp.y());
}
}