use crate::error::{LinalgError, LinalgResult};
use scirs2_core::num_complex::Complex;
use scirs2_core::ndarray::{Array2, ScalarOperand};
pub fn poly_roots(coeffs: &[f64]) -> LinalgResult<Vec<Complex<f64>>> {
if coeffs.len() < 2 {
return Err(LinalgError::InvalidInputError(
"Polynomial must have at least degree 1 (≥ 2 coefficients)".to_string(),
));
}
let start = coeffs
.iter()
.position(|&c| c.abs() > f64::EPSILON)
.ok_or_else(|| LinalgError::InvalidInputError("All coefficients are zero".to_string()))?;
let coeffs = &coeffs[start..];
if coeffs.len() < 2 {
return Err(LinalgError::InvalidInputError(
"Polynomial (after stripping leading zeros) has degree 0; no roots".to_string(),
));
}
if coeffs.len() == 2 {
let root = -(coeffs[1] / coeffs[0]);
return Ok(vec![Complex::new(root, 0.0)]);
}
if coeffs.len() == 3 {
return quadratic_roots_f64(coeffs[0], coeffs[1], coeffs[2]);
}
let comp = companion_matrix(coeffs)?;
companion_eigenvalues_f64(&comp)
}
pub fn companion_matrix(coeffs: &[f64]) -> LinalgResult<Array2<f64>> {
if coeffs.len() < 2 {
return Err(LinalgError::InvalidInputError(
"Polynomial must have degree >= 1".to_string(),
));
}
let lead = coeffs[0];
if lead.abs() < f64::EPSILON {
return Err(LinalgError::InvalidInputError(
"Leading coefficient must be non-zero".to_string(),
));
}
let n = coeffs.len() - 1; let mut c = Array2::<f64>::zeros((n, n));
for i in 1..n {
c[[i, i - 1]] = 1.0;
}
for i in 0..n {
c[[i, n - 1]] = -(coeffs[n - i] / lead);
}
Ok(c)
}
pub fn poly_eval_complex(coeffs: &[f64], z: Complex<f64>) -> Complex<f64> {
if coeffs.is_empty() {
return Complex::new(0.0, 0.0);
}
let mut acc = Complex::new(coeffs[0], 0.0);
for &c in &coeffs[1..] {
acc = acc * z + Complex::new(c, 0.0);
}
acc
}
pub fn poly_mul(a: &[f64], b: &[f64]) -> Vec<f64> {
if a.is_empty() || b.is_empty() {
return vec![];
}
let out_len = a.len() + b.len() - 1;
let mut result = vec![0.0f64; out_len];
for (i, &ai) in a.iter().enumerate() {
for (j, &bj) in b.iter().enumerate() {
result[i + j] += ai * bj;
}
}
result
}
pub fn char_poly_from_roots(roots: &[Complex<f64>]) -> Vec<f64> {
let mut p_complex: Vec<Complex<f64>> = vec![Complex::new(1.0, 0.0)];
for &lambda in roots {
let mut new_p = vec![Complex::new(0.0, 0.0); p_complex.len() + 1];
for (i, &c) in p_complex.iter().enumerate() {
new_p[i] = new_p[i] + c;
new_p[i + 1] = new_p[i + 1] - c * lambda;
}
p_complex = new_p;
}
p_complex.iter().map(|c| c.re).collect()
}
pub fn refine_roots_laguerre(
coeffs: &[f64],
initial_roots: &[Complex<f64>],
max_iter: usize,
tol: f64,
) -> LinalgResult<Vec<Complex<f64>>> {
if coeffs.len() < 2 {
return Err(LinalgError::InvalidInputError(
"Polynomial must have at least degree 1".to_string(),
));
}
let degree = coeffs.len() - 1;
let n_complex = Complex::new(degree as f64, 0.0);
let coeffs_c: Vec<Complex<f64>> = coeffs.iter().map(|&c| Complex::new(c, 0.0)).collect();
let deriv1: Vec<Complex<f64>> = coeffs_c[..degree]
.iter()
.enumerate()
.map(|(i, &c)| c * Complex::new((degree - i) as f64, 0.0))
.collect();
let deriv2: Vec<Complex<f64>> = if degree >= 2 {
deriv1[..degree - 1]
.iter()
.enumerate()
.map(|(i, &c)| c * Complex::new((degree - 1 - i) as f64, 0.0))
.collect()
} else {
vec![]
};
let mut refined = Vec::with_capacity(initial_roots.len());
for &z0 in initial_roots {
let mut z = z0;
let mut converged = false;
for _ in 0..max_iter {
let pz = poly_eval_complex_c(&coeffs_c, z);
let dpz = poly_eval_complex_c(&deriv1, z);
let d2pz = poly_eval_complex_c(&deriv2, z);
if pz.norm() < tol * 1e-2 {
converged = true;
break;
}
let h = dpz / pz;
let g2 = h * h - d2pz / pz;
let inner = (n_complex - Complex::new(1.0, 0.0))
* (n_complex * g2 - h * h);
let sqrt_inner = complex_sqrt(inner);
let denom_plus = h + sqrt_inner;
let denom_minus = h - sqrt_inner;
let denom = if denom_plus.norm() >= denom_minus.norm() {
denom_plus
} else {
denom_minus
};
if denom.norm() < f64::EPSILON {
break;
}
let dz = n_complex / denom;
z = z - dz;
if dz.norm() < tol {
converged = true;
break;
}
}
let _ = converged; refined.push(z);
}
Ok(refined)
}
fn quadratic_roots_f64(a: f64, b: f64, c: f64) -> LinalgResult<Vec<Complex<f64>>> {
let disc = b * b - 4.0 * a * c;
if disc >= 0.0 {
let sqrt_disc = disc.sqrt();
let sign = if b >= 0.0 { 1.0 } else { -1.0 };
let q = -(b + sign * sqrt_disc) / 2.0;
let r1 = q / a;
let r2 = c / q;
Ok(vec![Complex::new(r1, 0.0), Complex::new(r2, 0.0)])
} else {
let sqrt_disc = (-disc).sqrt();
let re = -b / (2.0 * a);
let im = sqrt_disc / (2.0 * a);
Ok(vec![Complex::new(re, im), Complex::new(re, -im)])
}
}
fn companion_eigenvalues_f64(comp: &Array2<f64>) -> LinalgResult<Vec<Complex<f64>>>
where
f64: ScalarOperand,
{
use crate::eigen::eig;
let (eigenvalues, _) = eig(&comp.view(), None)?;
Ok(eigenvalues.to_vec())
}
fn poly_eval_complex_c(coeffs: &[Complex<f64>], z: Complex<f64>) -> Complex<f64> {
if coeffs.is_empty() {
return Complex::new(0.0, 0.0);
}
let mut acc = coeffs[0];
for &c in &coeffs[1..] {
acc = acc * z + c;
}
acc
}
fn complex_sqrt(z: Complex<f64>) -> Complex<f64> {
let r = z.norm().sqrt();
let theta = z.arg() / 2.0;
Complex::new(r * theta.cos(), r * theta.sin())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_poly_roots_linear() {
let roots = poly_roots(&[2.0, -6.0]).expect("linear root");
assert_eq!(roots.len(), 1);
assert_relative_eq!(roots[0].re, 3.0, epsilon = 1e-12);
assert_relative_eq!(roots[0].im, 0.0, epsilon = 1e-12);
}
#[test]
fn test_poly_roots_quadratic_real() {
let mut roots = poly_roots(&[1.0, 0.0, -1.0]).expect("quadratic roots");
roots.sort_by(|a, b| a.re.partial_cmp(&b.re).expect("cmp"));
assert_eq!(roots.len(), 2);
assert_relative_eq!(roots[0].re, -1.0, epsilon = 1e-12);
assert_relative_eq!(roots[1].re, 1.0, epsilon = 1e-12);
for r in &roots {
assert_relative_eq!(r.im, 0.0, epsilon = 1e-12);
}
}
#[test]
fn test_poly_roots_cubic_unity() {
let roots = poly_roots(&[1.0, 0.0, 0.0, -1.0]).expect("cubic roots");
assert_eq!(roots.len(), 3);
for r in &roots {
let pval = poly_eval_complex(&[1.0, 0.0, 0.0, -1.0], *r);
assert!(
pval.norm() < 1e-8,
"root {} does not satisfy p(r)=0: p(r) = {}",
r,
pval
);
}
}
#[test]
fn test_poly_roots_quadratic_complex() {
let roots = poly_roots(&[1.0, 0.0, 1.0]).expect("complex roots");
assert_eq!(roots.len(), 2);
for r in &roots {
assert_relative_eq!(r.re, 0.0, epsilon = 1e-12);
assert_relative_eq!(r.im.abs(), 1.0, epsilon = 1e-12);
}
}
#[test]
fn test_poly_roots_cubic_real() {
let mut roots = poly_roots(&[1.0, -6.0, 11.0, -6.0]).expect("cubic roots");
roots.sort_by(|a, b| a.re.partial_cmp(&b.re).expect("cmp"));
assert_eq!(roots.len(), 3);
for r in &roots {
assert_relative_eq!(r.im.abs(), 0.0, epsilon = 1e-6);
}
assert_relative_eq!(roots[0].re, 1.0, epsilon = 1e-6);
assert_relative_eq!(roots[1].re, 2.0, epsilon = 1e-6);
assert_relative_eq!(roots[2].re, 3.0, epsilon = 1e-6);
}
#[test]
fn test_poly_roots_all_zero_error() {
assert!(poly_roots(&[0.0, 0.0, 0.0]).is_err());
}
#[test]
fn test_poly_roots_too_few_coeffs_error() {
assert!(poly_roots(&[1.0]).is_err());
}
#[test]
fn test_companion_matrix_quadratic() {
let c = companion_matrix(&[1.0, -5.0, 6.0]).expect("companion");
assert_eq!(c.shape(), &[2, 2]);
assert_relative_eq!(c[[0, 0]], 0.0, epsilon = 1e-14);
assert_relative_eq!(c[[1, 0]], 1.0, epsilon = 1e-14);
assert_relative_eq!(c[[0, 1]], -6.0, epsilon = 1e-14);
assert_relative_eq!(c[[1, 1]], 5.0, epsilon = 1e-14);
}
#[test]
fn test_companion_matrix_non_monic() {
let c = companion_matrix(&[2.0, -10.0, 12.0]).expect("companion non-monic");
assert_eq!(c.shape(), &[2, 2]);
assert_relative_eq!(c[[0, 1]], -6.0, epsilon = 1e-14);
assert_relative_eq!(c[[1, 1]], 5.0, epsilon = 1e-14);
}
#[test]
fn test_poly_eval_at_root() {
let val = poly_eval_complex(&[1.0, 0.0, -1.0], Complex::new(1.0, 0.0));
assert_relative_eq!(val.re, 0.0, epsilon = 1e-14);
assert_relative_eq!(val.im, 0.0, epsilon = 1e-14);
}
#[test]
fn test_poly_eval_at_complex_root() {
let val = poly_eval_complex(&[1.0, 0.0, 1.0], Complex::new(0.0, 1.0));
assert!(val.norm() < 1e-14, "p(i) should be 0, got {}", val);
}
#[test]
fn test_poly_eval_constant() {
let val = poly_eval_complex(&[5.0], Complex::new(3.0, 2.0));
assert_relative_eq!(val.re, 5.0, epsilon = 1e-14);
assert_relative_eq!(val.im, 0.0, epsilon = 1e-14);
}
#[test]
fn test_poly_mul_linear_factors() {
let p = poly_mul(&[1.0, -1.0], &[1.0, -2.0]);
assert_eq!(p.len(), 3);
assert_relative_eq!(p[0], 1.0, epsilon = 1e-14);
assert_relative_eq!(p[1], -3.0, epsilon = 1e-14);
assert_relative_eq!(p[2], 2.0, epsilon = 1e-14);
}
#[test]
fn test_poly_mul_by_constant() {
let p = poly_mul(&[3.0], &[1.0, 1.0]);
assert_eq!(p.len(), 2);
assert_relative_eq!(p[0], 3.0, epsilon = 1e-14);
assert_relative_eq!(p[1], 3.0, epsilon = 1e-14);
}
#[test]
fn test_poly_mul_empty() {
let p = poly_mul(&[], &[1.0, 2.0]);
assert!(p.is_empty());
}
#[test]
fn test_char_poly_two_real_roots() {
let roots = vec![Complex::new(1.0, 0.0), Complex::new(2.0, 0.0)];
let p = char_poly_from_roots(&roots);
assert_eq!(p.len(), 3);
assert_relative_eq!(p[0], 1.0, epsilon = 1e-12);
assert_relative_eq!(p[1], -3.0, epsilon = 1e-12);
assert_relative_eq!(p[2], 2.0, epsilon = 1e-12);
}
#[test]
fn test_char_poly_complex_conjugate_pair() {
let roots = vec![Complex::new(0.0, 1.0), Complex::new(0.0, -1.0)];
let p = char_poly_from_roots(&roots);
assert_eq!(p.len(), 3);
assert_relative_eq!(p[0], 1.0, epsilon = 1e-12);
assert_relative_eq!(p[1], 0.0, epsilon = 1e-12);
assert_relative_eq!(p[2], 1.0, epsilon = 1e-12);
}
#[test]
fn test_char_poly_single_root() {
let roots = vec![Complex::new(5.0, 0.0)];
let p = char_poly_from_roots(&roots);
assert_eq!(p.len(), 2);
assert_relative_eq!(p[0], 1.0, epsilon = 1e-12);
assert_relative_eq!(p[1], -5.0, epsilon = 1e-12);
}
#[test]
fn test_laguerre_refine_quadratic() {
let coeffs = [1.0_f64, 0.0, -1.0];
let initial = poly_roots(&coeffs).expect("initial");
let refined = refine_roots_laguerre(&coeffs, &initial, 50, 1e-14).expect("refined");
for r in &refined {
let pval = poly_eval_complex(&coeffs, *r);
assert!(
pval.norm() < 1e-12,
"Laguerre-refined root does not satisfy p(r)=0: p({}) = {}",
r,
pval
);
}
}
#[test]
fn test_laguerre_refine_cubic() {
let coeffs = [1.0_f64, -6.0, 11.0, -6.0];
let initial = poly_roots(&coeffs).expect("initial");
let refined = refine_roots_laguerre(&coeffs, &initial, 100, 1e-14).expect("refined");
assert_eq!(refined.len(), 3);
for r in &refined {
let pval = poly_eval_complex(&coeffs, *r);
assert!(
pval.norm() < 1e-10,
"Refined root does not satisfy p(r)=0: p({}) = {}",
r,
pval
);
}
}
#[test]
fn test_laguerre_refine_complex_roots() {
let coeffs = [1.0_f64, 0.0, 4.0];
let initial = poly_roots(&coeffs).expect("initial");
let refined = refine_roots_laguerre(&coeffs, &initial, 50, 1e-14).expect("refined");
for r in &refined {
let pval = poly_eval_complex(&coeffs, *r);
assert!(
pval.norm() < 1e-12,
"Laguerre-refined root does not satisfy p(r)=0: p({}) = {}",
r,
pval
);
}
}
#[test]
fn test_roundtrip_char_poly_and_roots() {
let known_roots = vec![
Complex::new(1.0, 0.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 0.0),
];
let coeffs = char_poly_from_roots(&known_roots);
let mut found_roots = poly_roots(&coeffs).expect("roots from char poly");
found_roots.sort_by(|a, b| a.re.partial_cmp(&b.re).expect("cmp"));
assert_eq!(found_roots.len(), 4);
let expected = [1.0, 2.0, 3.0, 4.0];
for (r, &e) in found_roots.iter().zip(expected.iter()) {
assert_relative_eq!(r.re, e, epsilon = 1e-6);
assert_relative_eq!(r.im.abs(), 0.0, epsilon = 1e-6);
}
}
}