use scirs2_core::numeric::Float;
use std::f64::consts::PI;
use crate::error::{LinalgError, LinalgResult};
type ZolotarevFilterResult<F> = LinalgResult<(Vec<(F, F)>, Vec<F>)>;
#[derive(Clone, Debug)]
pub struct ZolotarevApproximation<F: Float> {
pub poles: Vec<(F, F)>,
pub residues: Vec<(F, F)>,
pub numerator_sq: Vec<F>,
pub denominator_sq: Vec<F>,
pub scale: F,
pub degree: usize,
pub delta: F,
pub max_error: F,
}
fn complete_elliptic_k(k: f64) -> LinalgResult<f64> {
if !(0.0..1.0).contains(&k) {
return Err(LinalgError::DomainError(format!(
"Elliptic integral modulus k must be in [0, 1), got {}",
k
)));
}
if k == 0.0 {
return Ok(PI / 2.0);
}
let kp = (1.0 - k * k).sqrt(); let mut a = 1.0;
let mut b = kp;
for _ in 0..100 {
let a_new = (a + b) / 2.0;
let b_new = (a * b).sqrt();
if (a_new - b_new).abs() < 1e-15 * a_new {
return Ok(PI / (2.0 * a_new));
}
a = a_new;
b = b_new;
}
Ok(PI / (2.0 * ((a + b) / 2.0)))
}
fn jacobi_elliptic(u: f64, k: f64) -> LinalgResult<(f64, f64, f64)> {
if k.abs() < 1e-15 {
return Ok((u.sin(), u.cos(), 1.0));
}
if (k - 1.0).abs() < 1e-15 {
let sn = u.tanh();
let cn = 1.0 / u.cosh();
return Ok((sn, cn, cn));
}
let mut a = vec![1.0];
let mut b = vec![(1.0 - k * k).sqrt()]; let mut c = vec![k];
for _ in 0..50 {
let a_prev = *a.last().unwrap_or(&1.0);
let b_prev = *b.last().unwrap_or(&1.0);
let a_new = (a_prev + b_prev) / 2.0;
let b_new = (a_prev * b_prev).sqrt();
let c_new = (a_prev - b_prev) / 2.0;
a.push(a_new);
b.push(b_new);
c.push(c_new);
if c_new.abs() < 1e-15 {
break;
}
}
let n = a.len() - 1;
let mut phi = (1u64 << n.min(62)) as f64 * a[n] * u;
for j in (0..n).rev() {
phi = (phi + (c[j + 1] / a[j + 1] * phi.sin()).asin()) / 2.0;
}
let sn = phi.sin();
let cn = phi.cos();
let dn = (1.0 - k * k * sn * sn).max(0.0).sqrt();
Ok((sn, cn, dn))
}
fn jacobi_sn(u: f64, k: f64) -> LinalgResult<f64> {
let (sn, _, _) = jacobi_elliptic(u, k)?;
Ok(sn)
}
fn jacobi_cn(u: f64, k: f64) -> LinalgResult<f64> {
let (_, cn, _) = jacobi_elliptic(u, k)?;
Ok(cn)
}
fn jacobi_dn(u: f64, k: f64) -> LinalgResult<f64> {
let (_, _, dn) = jacobi_elliptic(u, k)?;
Ok(dn)
}
pub fn zolotarev_sign<F>(degree: usize, delta: f64) -> LinalgResult<ZolotarevApproximation<F>>
where
F: Float,
{
if degree == 0 {
return Err(LinalgError::DomainError(
"Zolotarev degree must be >= 1".to_string(),
));
}
if delta <= 0.0 || delta >= 1.0 {
return Err(LinalgError::DomainError(format!(
"Delta must be in (0, 1), got {}",
delta
)));
}
let n = degree;
let kp = delta; let k = (1.0 - kp * kp).sqrt();
let kk = complete_elliptic_k(k)?;
let kkp = complete_elliptic_k(kp)?;
let ratio = kkp / kk;
let max_error = 4.0 * (-(n as f64) * PI * ratio).exp();
let mut c_vals = Vec::with_capacity(2 * n);
for m in 1..=(2 * n - 1) {
let u = m as f64 * kkp / (2.0 * n as f64);
let (sn_val, cn_val, dn_val) = jacobi_elliptic(u, kp)?;
if sn_val.abs() < 1e-30 {
c_vals.push(1e30);
} else {
c_vals.push(cn_val * dn_val / sn_val);
}
}
let mut den_sq = Vec::with_capacity(n); let mut num_sq = Vec::with_capacity(n - 1);
for j in 0..n {
let m = 2 * j + 1; let c = c_vals[m - 1]; den_sq.push(c * c);
}
for j in 1..n {
let m = 2 * j; let c = c_vals[m - 1];
num_sq.push(c * c);
}
let mut num_at_1 = 1.0;
for &a2 in &num_sq {
num_at_1 *= 1.0 + a2;
}
let mut den_at_1 = 1.0;
for &b2 in &den_sq {
den_at_1 *= 1.0 + b2;
}
let scale = if num_at_1.abs() > 1e-30 {
den_at_1 / num_at_1
} else {
1.0
};
let mut poles = Vec::with_capacity(n);
for bsq in den_sq.iter().take(n) {
let c = bsq.sqrt();
poles.push((
F::from(0.0).unwrap_or(F::zero()),
F::from(c).unwrap_or(F::zero()),
));
}
let mut residues = Vec::with_capacity(n);
for j in 0..n {
let bj = den_sq[j];
let mut num_val = scale;
for &a2 in &num_sq {
num_val *= a2 - bj; }
let mut den_val = 1.0;
for (i, &bi) in den_sq.iter().enumerate() {
if i != j {
den_val *= bi - bj; }
}
let alpha_j = if den_val.abs() > 1e-30 {
num_val / den_val
} else {
0.0
};
residues.push((
F::from(alpha_j).unwrap_or(F::zero()),
F::from(0.0).unwrap_or(F::zero()),
));
}
Ok(ZolotarevApproximation {
poles,
residues,
numerator_sq: num_sq
.iter()
.map(|&v| F::from(v).unwrap_or(F::zero()))
.collect(),
denominator_sq: den_sq
.iter()
.map(|&v| F::from(v).unwrap_or(F::zero()))
.collect(),
scale: F::from(scale).unwrap_or(F::one()),
degree: n,
delta: F::from(delta).unwrap_or(F::zero()),
max_error: F::from(max_error).unwrap_or(F::zero()),
})
}
pub fn evaluate_rational<F>(x: F, approx: &ZolotarevApproximation<F>) -> F
where
F: Float,
{
let x2 = x * x;
let mut numerator = approx.scale * x;
for &a2 in &approx.numerator_sq {
numerator = numerator * (x2 + a2);
}
let mut denominator = F::one();
for &b2 in &approx.denominator_sq {
denominator = denominator * (x2 + b2);
}
if denominator.abs() > F::epsilon() {
numerator / denominator
} else {
let mut result = F::zero();
for j in 0..approx.residues.len() {
let (alpha_re, _) = approx.residues[j];
let b2 = approx.denominator_sq[j];
let denom = x2 + b2;
if denom.abs() > F::epsilon() {
result = result + alpha_re * x / denom;
}
}
result
}
}
pub fn evaluate_step<F>(x: F, approx: &ZolotarevApproximation<F>) -> F
where
F: Float,
{
let two = F::from(2.0).unwrap_or(F::one() + F::one());
(F::one() + evaluate_rational(x, approx)) / two
}
pub fn zolotarev_type1(x: f64, n: usize, k: f64) -> LinalgResult<f64> {
if n == 0 {
return Ok(1.0);
}
if k <= 0.0 || k >= 1.0 {
return Err(LinalgError::DomainError(format!(
"Modulus k must be in (0, 1), got {}",
k
)));
}
let kk = complete_elliptic_k(k)?;
let kp = (1.0 - k * k).sqrt();
let kkp = complete_elliptic_k(kp)?;
let q = (-PI * kkp / kk).exp();
let q_n = q.powi(n as i32);
let k_n = compute_modulus_from_nome(q_n)?;
let asin_x = x.asin();
let asin_k = k.asin();
if asin_k.abs() < 1e-15 {
return Err(LinalgError::DomainError(
"k too small for Type I evaluation".to_string(),
));
}
let kk_n = complete_elliptic_k(k_n)?;
let u = n as f64 * kk_n * asin_x / asin_k;
jacobi_sn(u, k_n)
}
pub fn zolotarev_type3(x: f64, n: usize, delta: f64) -> LinalgResult<f64> {
if delta <= 0.0 || delta >= 1.0 {
return Err(LinalgError::DomainError(format!(
"Delta must be in (0, 1), got {}",
delta
)));
}
let approx = zolotarev_sign::<f64>(n, delta)?;
Ok(evaluate_rational(x, &approx))
}
fn compute_modulus_from_nome(q: f64) -> LinalgResult<f64> {
if !(0.0..1.0).contains(&q) {
return Err(LinalgError::DomainError(format!(
"Nome q must be in [0, 1), got {}",
q
)));
}
if q < 1e-15 {
return Ok(4.0 * q.sqrt());
}
let mut theta2 = 0.0;
for nn in 0..100 {
let exponent = (nn as f64 + 0.5) * (nn as f64 + 0.5);
let term = q.powf(exponent);
if term < 1e-16 {
break;
}
theta2 += term;
}
theta2 *= 2.0;
let mut theta3 = 1.0;
for nn in 1..100 {
let exponent = (nn * nn) as f64;
let term = q.powf(exponent);
if term < 1e-16 {
break;
}
theta3 += 2.0 * term;
}
if theta3.abs() < 1e-15 {
return Err(LinalgError::ComputationError(
"theta3 is too small".to_string(),
));
}
let ratio = theta2 / theta3;
Ok(ratio * ratio)
}
pub fn zolotarev_filter<F>(
degree: usize,
interval_lower: f64,
interval_upper: f64,
spectrum_lower: f64,
spectrum_upper: f64,
) -> ZolotarevFilterResult<F>
where
F: Float,
{
if interval_lower >= interval_upper {
return Err(LinalgError::DomainError(
"Interval lower bound must be less than upper bound".to_string(),
));
}
if spectrum_lower >= spectrum_upper {
return Err(LinalgError::DomainError(
"Spectrum lower bound must be less than upper bound".to_string(),
));
}
let center = (interval_lower + interval_upper) / 2.0;
let half_width = (interval_upper - interval_lower) / 2.0;
let range = (spectrum_upper - center)
.abs()
.max((center - spectrum_lower).abs());
if range < 1e-15 {
return Err(LinalgError::DomainError(
"Spectrum range is too small".to_string(),
));
}
let delta = (half_width / range).min(0.999);
let delta = delta.max(0.001);
let approx = zolotarev_sign::<f64>(degree, delta)?;
let mut shifts = Vec::with_capacity(degree);
let mut weights = Vec::with_capacity(degree);
for j in 0..approx.degree {
let (_pole_re, pole_im) = approx.poles[j];
let shift_re = F::from(center).unwrap_or(F::zero());
let shift_im = F::from(range * pole_im).unwrap_or(F::zero());
shifts.push((shift_re, shift_im));
let (alpha_re, _) = approx.residues[j];
weights.push(F::from(alpha_re / range).unwrap_or(F::zero()));
}
Ok((shifts, weights))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_zolotarev_sign_accuracy_vs_degree() {
let delta = 0.1;
let approx_4 = zolotarev_sign::<f64>(4, delta).expect("degree 4 failed");
let approx_8 = zolotarev_sign::<f64>(8, delta).expect("degree 8 failed");
let approx_16 = zolotarev_sign::<f64>(16, delta).expect("degree 16 failed");
assert!(
approx_8.max_error < approx_4.max_error,
"degree 8 error ({}) should be less than degree 4 error ({})",
approx_8.max_error,
approx_4.max_error
);
assert!(
approx_16.max_error < approx_8.max_error,
"degree 16 error ({}) should be less than degree 8 error ({})",
approx_16.max_error,
approx_8.max_error
);
}
#[test]
fn test_zolotarev_poles_structure() {
let approx = zolotarev_sign::<f64>(6, 0.2).expect("Zolotarev failed");
assert_eq!(approx.poles.len(), 6);
assert_eq!(approx.residues.len(), 6);
for (j, &(re, im)) in approx.poles.iter().enumerate() {
assert!(
re.abs() < 1e-14,
"Pole {} real part should be 0, got {}",
j,
re
);
assert!(
im > 0.0,
"Pole {} imaginary part should be positive, got {}",
j,
im
);
}
for i in 0..approx.poles.len() {
for j in (i + 1)..approx.poles.len() {
let diff = (approx.poles[i].1 - approx.poles[j].1).abs();
assert!(
diff > 1e-10,
"Poles {} and {} should be distinct, diff = {}",
i,
j,
diff
);
}
}
}
#[test]
fn test_evaluate_rational_on_test_points() {
let approx = zolotarev_sign::<f64>(8, 0.1).expect("Zolotarev failed");
let test_points = [0.2, 0.3, 0.5, 0.7, 0.9, 1.0];
for &x in &test_points {
let val = evaluate_rational(x, &approx);
assert!(
(val - 1.0).abs() < 0.1, "r({}) = {} should be close to 1.0",
x,
val
);
}
for &x in &test_points {
let val = evaluate_rational(-x, &approx);
assert!(
(val + 1.0).abs() < 0.1,
"r(-{}) = {} should be close to -1.0",
x,
val
);
}
for &x in &test_points {
let val_pos = evaluate_rational(x, &approx);
let val_neg = evaluate_rational(-x, &approx);
assert_relative_eq!(val_neg, -val_pos, epsilon = 1e-12);
}
}
#[test]
fn test_evaluate_step_function() {
let approx = zolotarev_sign::<f64>(8, 0.1).expect("Zolotarev failed");
for &x in &[0.2, 0.5, 1.0] {
let val = evaluate_step(x, &approx);
assert!(
(val - 1.0).abs() < 0.1,
"step({}) = {} should be close to 1.0",
x,
val
);
}
for &x in &[0.2, 0.5, 1.0] {
let val = evaluate_step(-x, &approx);
assert!(
val.abs() < 0.1,
"step(-{}) = {} should be close to 0.0",
x,
val
);
}
}
#[test]
fn test_zolotarev_sign_invalid_inputs() {
assert!(zolotarev_sign::<f64>(0, 0.5).is_err());
assert!(zolotarev_sign::<f64>(4, 0.0).is_err());
assert!(zolotarev_sign::<f64>(4, 1.0).is_err());
assert!(zolotarev_sign::<f64>(4, -0.1).is_err());
assert!(zolotarev_sign::<f64>(4, 1.5).is_err());
}
#[test]
fn test_complete_elliptic_k() {
let k0 = complete_elliptic_k(0.0).expect("K(0) failed");
assert_relative_eq!(k0, PI / 2.0, epsilon = 1e-12);
let k_half = complete_elliptic_k(1.0 / 2.0_f64.sqrt()).expect("K(1/sqrt(2)) failed");
assert_relative_eq!(k_half, 1.8540746773013719, epsilon = 1e-8);
}
#[test]
fn test_jacobi_sn() {
let val = jacobi_sn(0.0, 0.5).expect("sn(0, 0.5) failed");
assert!(val.abs() < 1e-14, "sn(0, 0.5) = {} should be 0", val);
let kk = complete_elliptic_k(0.5).expect("K(0.5) failed");
let val_k = jacobi_sn(kk, 0.5).expect("sn(K, 0.5) failed");
assert_relative_eq!(val_k, 1.0, epsilon = 1e-8);
let val_sin = jacobi_sn(1.0, 1e-10).expect("sn(1, ~0) failed");
assert_relative_eq!(val_sin, 1.0_f64.sin(), epsilon = 1e-6);
}
#[test]
fn test_zolotarev_type3() {
let delta = 0.2;
let n = 6;
let val_pos = zolotarev_type3(0.5, n, delta).expect("type3 failed at 0.5");
assert!(
(val_pos - 1.0).abs() < 0.1,
"Z^III(0.5) = {} should be ~1.0",
val_pos
);
let val_neg = zolotarev_type3(-0.5, n, delta).expect("type3 failed at -0.5");
assert!(
(val_neg + 1.0).abs() < 0.1,
"Z^III(-0.5) = {} should be ~-1.0",
val_neg
);
assert_relative_eq!(val_neg, -val_pos, epsilon = 1e-10);
}
#[test]
fn test_zolotarev_filter() {
let (shifts, weights) =
zolotarev_filter::<f64>(4, 2.0, 4.0, 0.0, 10.0).expect("filter failed");
assert_eq!(shifts.len(), 4);
assert_eq!(weights.len(), 4);
let center = 3.0;
for &(re, _im) in &shifts {
assert_relative_eq!(re, center, epsilon = 1e-10);
}
}
#[test]
fn test_zolotarev_residues_real() {
let approx = zolotarev_sign::<f64>(6, 0.3).expect("Zolotarev failed");
for (j, &(_re, im)) in approx.residues.iter().enumerate() {
assert!(
im.abs() < 1e-14,
"Residue {} imaginary part should be 0, got {}",
j,
im
);
}
}
}