use {
crate::StopReason,
::core::{fmt::Debug, iter::zip},
::num_complex::{Complex, ComplexFloat},
::num_traits::{
cast,
float::{Float, FloatConst},
identities::{One, Zero},
MulAdd,
},
};
pub fn aberth_raw<F: Float + MulAdd<Output = F>>(
polynomial: &[Complex<F>],
dydx: &[Complex<F>],
initial_guesses: &mut [Complex<F>],
out: &mut [Complex<F>],
max_iterations: u32,
epsilon: F,
) -> StopReason {
out.copy_from_slice(initial_guesses);
let mut zs = initial_guesses;
let mut new_zs = out;
for iteration in 1..=max_iterations {
let mut converged = true;
for i in 0..zs.len() {
let p_of_z = sample_polynomial(polynomial, zs[i]);
let dydx_of_z = sample_polynomial(dydx, zs[i]);
let sum = (0..zs.len())
.filter(|&k| k != i)
.fold(Complex::<F>::zero(), |acc, k| {
acc + Complex::<F>::one() / (zs[i] - zs[k])
});
let new_z = zs[i] + p_of_z / (p_of_z * sum - dydx_of_z);
new_zs[i] = new_z;
if new_z.re.is_nan()
|| new_z.im.is_nan()
|| new_z.re.is_infinite()
|| new_z.im.is_infinite()
{
return StopReason::Failed(iteration);
}
if !new_z.approx_eq(zs[i], epsilon) {
converged = false;
}
}
if converged {
return StopReason::Converged(iteration);
}
core::mem::swap(&mut zs, &mut new_zs);
}
StopReason::MaxIteration(max_iterations)
}
pub(crate) fn sample_polynomial<F: Float + MulAdd<Output = F>>(
coefficients: &[Complex<F>],
x: Complex<F>,
) -> Complex<F> {
#![allow(clippy::len_zero)]
debug_assert!(coefficients.len() != 0);
let mut r = Complex::zero();
for &c in coefficients.iter().rev() {
r = r.mul_add(x, c)
}
r
}
pub(crate) fn derivative<F: Float>(
polynomial: &[Complex<F>],
out: &mut [Complex<F>],
) {
polynomial
.iter()
.enumerate()
.skip(1)
.for_each(|(index, coefficient)| {
let p = unsafe { F::from(index).unwrap_unchecked() };
out[index - 1] = coefficient * p;
})
}
pub(crate) fn initial_guesses<
F: Float + FloatConst + MulAdd<Output = F> + Debug,
>(
polynomial: &[Complex<F>],
out: &mut [Complex<F>],
) {
let n = polynomial.len() - 1;
let n_f: F = unsafe { cast(n).unwrap_unchecked() };
let monic = out;
for (i, c) in polynomial.iter().enumerate() {
monic[i] = c / polynomial[n]; }
let a: Complex<F> = -monic[n - 1] / n_f;
let p_of_w = {
for coefficient_index in 0..=n {
let c = monic[coefficient_index];
monic[coefficient_index] = Complex::zero();
for ((index, power), pascal) in zip(
zip(0..=coefficient_index, (0..=coefficient_index).rev()),
PascalRowIter::new(coefficient_index as u32),
) {
let pascal: Complex<F> = unsafe { cast(pascal).unwrap_unchecked() };
monic[index] =
MulAdd::mul_add(c, pascal * a.powi(power as i32), monic[index]);
}
}
monic
};
let s_of_w = {
p_of_w.iter_mut().take(n).for_each(|coefficient| {
*coefficient = Complex::from(-coefficient.abs())
});
p_of_w
};
let mut int = F::one();
let r_0 = loop {
let s_at_r0 = sample_polynomial(s_of_w, int.into());
if s_at_r0.re > F::zero() {
break int;
}
int = int + F::one();
};
{
let guesses = s_of_w;
let frac_2pi_n = F::TAU() / n_f;
let frac_pi_2n = F::FRAC_PI_2() / n_f;
for (k, guess) in guesses.iter_mut().enumerate().take(n) {
let k_f = unsafe { cast(k).unwrap_unchecked() };
let theta = MulAdd::mul_add(frac_2pi_n, k_f, frac_pi_2n);
let real = r_0 * theta.cos();
let imaginary = r_0 * theta.sin();
let val = Complex::new(real, imaginary) + a;
*guess = val;
}
}
}
pub(crate) struct PascalRowIter {
n: u32,
k: u32,
previous: f64,
}
impl PascalRowIter {
pub fn new(n: u32) -> Self {
Self {
n,
k: 0,
previous: 1.0,
}
}
}
impl Iterator for PascalRowIter {
type Item = f64;
fn next(&mut self) -> Option<Self::Item> {
if self.k == 0 {
self.k = 1;
self.previous = 1.0;
return Some(1.0);
}
if self.k > self.n {
return None;
}
let new = self.previous * (self.n + 1 - self.k) as f64 / self.k as f64;
self.k += 1;
self.previous = new;
Some(new)
}
}
pub(crate) trait ComplexExt<F: Float> {
fn approx_eq(self, w: Self, epsilon: F) -> bool;
}
impl<F: Float> ComplexExt<F> for Complex<F> {
#[inline]
fn approx_eq(self, w: Complex<F>, epsilon: F) -> bool {
(self.re - w.re).abs() < epsilon && (self.im - w.im).abs() < epsilon
}
}
pub(crate) use private::ComplexCoefficient;
mod private {
use super::*;
pub trait ComplexCoefficient<F: Float>: Copy + Into<Complex<F>> {}
impl ComplexCoefficient<f32> for f32 {}
impl ComplexCoefficient<f64> for f64 {}
impl ComplexCoefficient<f32> for Complex<f32> {}
impl ComplexCoefficient<f64> for Complex<f64> {}
}