use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, One, Zero};
use scirs2_core::Complex;
use std::fmt::Debug;
use std::ops::{Add, Div, Mul, Sub};
use super::core::Polynomial;
pub fn roots<T>(p: &Polynomial<T>) -> Result<Array<Complex<T>>>
where
T: Clone
+ Zero
+ One
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ PartialEq
+ Debug
+ std::ops::Neg<Output = T>
+ Float,
{
let mut coeffs = p.coefficients().to_vec();
while coeffs.len() > 1 && coeffs[0] == T::zero() {
coeffs.remove(0);
}
let degree = coeffs.len() - 1;
if degree == 0 {
return Ok(Array::from_vec(vec![]));
}
if degree == 1 {
let root = -coeffs[1] / coeffs[0];
return Ok(Array::from_vec(vec![Complex::new(root, T::zero())]));
}
if degree == 2 {
let a = coeffs[0];
let b = coeffs[1];
let c = coeffs[2];
let discriminant = b * b - T::from(4.0).expect("4.0 should convert to float type") * a * c;
if discriminant >= T::zero() {
let sqrt_d = discriminant.sqrt();
let root1 =
(-b + sqrt_d) / (T::from(2.0).expect("2.0 should convert to float type") * a);
let root2 =
(-b - sqrt_d) / (T::from(2.0).expect("2.0 should convert to float type") * a);
return Ok(Array::from_vec(vec![
Complex::new(root1, T::zero()),
Complex::new(root2, T::zero()),
]));
} else {
let real_part = -b / (T::from(2.0).expect("2.0 should convert to float type") * a);
let imag_part = (-discriminant).sqrt()
/ (T::from(2.0).expect("2.0 should convert to float type") * a);
return Ok(Array::from_vec(vec![
Complex::new(real_part, imag_part),
Complex::new(real_part, -imag_part),
]));
}
}
let leading = coeffs[0];
for coeff in &mut coeffs {
*coeff = *coeff / leading;
}
Err(NumRs2Error::InvalidOperation(
"For polynomial root-finding for degree > 2, please use the eigenvalues module to compute the eigenvalues of the companion matrix".to_string()
))
}
pub fn poly<T>(roots: &Array<T>) -> Result<Polynomial<T>>
where
T: Clone
+ Zero
+ One
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ PartialEq
+ std::ops::Neg<Output = T>,
{
if roots.ndim() != 1 {
return Err(NumRs2Error::DimensionMismatch(
"poly requires 1D array of roots".to_string(),
));
}
let roots_vec = roots.to_vec();
if roots_vec.is_empty() {
return Ok(Polynomial::new(vec![T::one()]));
}
let mut result = Polynomial::new(vec![T::one(), -roots_vec[0].clone()]);
for i in 1..roots_vec.len() {
let factor = Polynomial::new(vec![T::one(), -roots_vec[i].clone()]);
result = result * factor;
}
Ok(result)
}
pub fn polyfromroots<T>(roots: &Array<T>) -> Result<Array<T>>
where
T: Clone
+ Zero
+ One
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ PartialEq
+ std::ops::Neg<Output = T>,
{
let result = poly(roots)?;
Ok(Array::from_vec(result.coefficients().to_vec()))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_roots_linear() {
let p = Polynomial::new(vec![2.0, -4.0]);
let r = roots(&p).expect("Linear polynomial root finding should succeed");
assert_eq!(r.len(), 1);
assert_relative_eq!(r.to_vec()[0].re, 2.0, epsilon = 1e-10);
assert_relative_eq!(r.to_vec()[0].im, 0.0, epsilon = 1e-10);
}
#[test]
fn test_roots_quadratic_real() {
let p = Polynomial::new(vec![1.0, -3.0, 2.0]);
let r = roots(&p).expect("Quadratic real roots finding should succeed");
assert_eq!(r.len(), 2);
let roots_vec = r.to_vec();
let (r1, r2) = if roots_vec[0].re < roots_vec[1].re {
(roots_vec[0], roots_vec[1])
} else {
(roots_vec[1], roots_vec[0])
};
assert_relative_eq!(r1.re, 1.0, epsilon = 1e-10);
assert_relative_eq!(r2.re, 2.0, epsilon = 1e-10);
}
#[test]
fn test_roots_quadratic_complex() {
let p = Polynomial::new(vec![1.0, 0.0, 1.0]);
let r = roots(&p).expect("Quadratic complex roots finding should succeed");
assert_eq!(r.len(), 2);
let roots_vec = r.to_vec();
assert_relative_eq!(roots_vec[0].re, 0.0, epsilon = 1e-10);
assert_relative_eq!(roots_vec[1].re, 0.0, epsilon = 1e-10);
assert_relative_eq!(roots_vec[0].im.abs(), 1.0, epsilon = 1e-10);
assert_relative_eq!(roots_vec[1].im.abs(), 1.0, epsilon = 1e-10);
}
#[test]
fn test_poly_from_roots() {
let roots = Array::from_vec(vec![1.0, 2.0, 3.0]);
let p = poly(&roots).expect("Polynomial from roots construction should succeed");
let coeffs = p.coefficients();
assert_eq!(coeffs.len(), 4);
assert_relative_eq!(coeffs[0], 1.0, epsilon = 1e-10);
assert_relative_eq!(coeffs[1], -6.0, epsilon = 1e-10);
assert_relative_eq!(coeffs[2], 11.0, epsilon = 1e-10);
assert_relative_eq!(coeffs[3], -6.0, epsilon = 1e-10);
}
#[test]
fn test_polyfromroots() {
let roots = Array::from_vec(vec![1.0, -1.0]);
let coeffs =
polyfromroots(&roots).expect("Polynomial coefficients from roots should succeed");
let coeffs_vec = coeffs.to_vec();
assert_eq!(coeffs_vec.len(), 3);
assert_relative_eq!(coeffs_vec[0], 1.0, epsilon = 1e-10);
assert_relative_eq!(coeffs_vec[1], 0.0, epsilon = 1e-10);
assert_relative_eq!(coeffs_vec[2], -1.0, epsilon = 1e-10);
}
}