use crate::error::{SpecialError, SpecialResult};
use crate::mathieu::advanced::{tridiag_eigenvalues, tridiag_eigenvector};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SpheroidalParity {
Prolate,
Oblate,
}
impl SpheroidalParity {
#[inline]
pub fn sign(self) -> f64 {
match self {
SpheroidalParity::Prolate => 1.0,
SpheroidalParity::Oblate => -1.0,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct LentzResult {
pub value: f64,
pub iterations: usize,
}
const LENTZ_TINY: f64 = 1.0e-30;
const LENTZ_TOLERANCE: f64 = 1.0e-14;
const LENTZ_MAX_ITER: usize = 1000;
const SCALE_HI: f64 = 1.0e150;
const SCALE_LO: f64 = 1.0e-150;
const SCALE_FACTOR: f64 = 1.0e150;
pub fn cf_modified_lentz<A, B>(a_fn: A, b_fn: B) -> SpecialResult<LentzResult>
where
A: Fn(usize) -> f64,
B: Fn(usize) -> f64,
{
let mut f = b_fn(0);
if f.abs() < LENTZ_TINY {
f = LENTZ_TINY;
}
let mut c_prev = f;
let mut d_prev = 0.0_f64;
for n in 1..=LENTZ_MAX_ITER {
let a_n = a_fn(n);
let b_n = b_fn(n);
let mut denom_d = b_n + a_n * d_prev;
if denom_d.abs() < LENTZ_TINY {
denom_d = LENTZ_TINY;
}
let d_n = 1.0 / denom_d;
let mut denom_c = c_prev;
if denom_c.abs() < LENTZ_TINY {
denom_c = LENTZ_TINY;
}
let c_n = b_n + a_n / denom_c;
let delta = c_n * d_n;
f *= delta;
c_prev = c_n;
d_prev = d_n;
if (delta - 1.0).abs() < LENTZ_TOLERANCE {
return Ok(LentzResult {
value: f,
iterations: n,
});
}
}
Err(SpecialError::ConvergenceError(format!(
"Modified-Lentz CF failed to converge within {LENTZ_MAX_ITER} iterations"
)))
}
pub fn scaled_recurrence_step(values: &mut [f64], log_scale: &mut f64) -> bool {
let max_abs = values.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
if max_abs > SCALE_HI {
for v in values.iter_mut() {
*v /= SCALE_FACTOR;
}
*log_scale += SCALE_FACTOR.ln();
true
} else if max_abs > 0.0 && max_abs < SCALE_LO {
for v in values.iter_mut() {
*v *= SCALE_FACTOR;
}
*log_scale -= SCALE_FACTOR.ln();
true
} else {
false
}
}
pub const DEFAULT_D_LEN: usize = 80;
#[inline]
fn flammer_beta(m: f64, r: f64, c2_signed: f64) -> f64 {
let ell = m + r;
let denom = (2.0 * m + 2.0 * r - 1.0) * (2.0 * m + 2.0 * r + 3.0);
if denom == 0.0 {
return ell * (ell + 1.0);
}
ell * (ell + 1.0) + c2_signed * (2.0 * ell * (ell + 1.0) - 2.0 * m * m - 1.0) / denom
}
#[inline]
fn flammer_alpha_up(m: f64, r: f64, c2_signed: f64) -> f64 {
let denom = (2.0 * m + 2.0 * r + 3.0) * (2.0 * m + 2.0 * r + 5.0);
if denom == 0.0 {
return 0.0;
}
c2_signed * (2.0 * m + r + 1.0) * (2.0 * m + r + 2.0) / denom
}
#[inline]
fn flammer_gamma_down(m: f64, r: f64, c2_signed: f64) -> f64 {
let denom = (2.0 * m + 2.0 * r - 3.0) * (2.0 * m + 2.0 * r - 1.0);
if denom == 0.0 {
return 0.0;
}
c2_signed * (r - 1.0) * r / denom
}
fn build_flammer_tridiag(
parity_kind: SpheroidalParity,
m: i32,
n: i32,
c: f64,
len: usize,
) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
let m_f = m as f64;
let parity = ((n - m) % 2 + 2) % 2;
let parity_f = parity as f64;
let s = parity_kind.sign();
let c2_signed = c * c * s;
let mut diag = vec![0.0_f64; len];
let mut off_sym = vec![0.0_f64; len.saturating_sub(1)];
let mut scale_factors = vec![1.0_f64; len];
for p in 0..len {
let r = 2.0 * p as f64 + parity_f;
diag[p] = flammer_beta(m_f, r, c2_signed);
if p + 1 < len {
let alpha_up = flammer_alpha_up(m_f, r, c2_signed);
let r_next = 2.0 * (p + 1) as f64 + parity_f;
let gamma_down_next = flammer_gamma_down(m_f, r_next, c2_signed);
let prod = alpha_up * gamma_down_next;
off_sym[p] = if prod >= 0.0 {
let mag = prod.sqrt();
if alpha_up >= 0.0 {
mag
} else {
-mag
}
} else {
let mag = (-prod).sqrt();
if alpha_up >= 0.0 {
mag
} else {
-mag
}
};
if alpha_up != 0.0 && gamma_down_next != 0.0 {
let ratio = gamma_down_next / alpha_up;
let mag = ratio.abs().sqrt();
scale_factors[p + 1] = scale_factors[p] * mag;
} else {
scale_factors[p + 1] = scale_factors[p];
}
}
}
(diag, off_sym, scale_factors)
}
pub fn d_coefficients(
parity_kind: SpheroidalParity,
m: i32,
n: i32,
c: f64,
lambda: f64,
) -> SpecialResult<Vec<f64>> {
d_coefficients_with_len(parity_kind, m, n, c, lambda, DEFAULT_D_LEN)
}
pub fn d_coefficients_with_len(
parity_kind: SpheroidalParity,
m: i32,
n: i32,
c: f64,
lambda: f64,
len: usize,
) -> SpecialResult<Vec<f64>> {
if m < 0 || n < m {
return Err(SpecialError::DomainError(format!(
"d_coefficients require 0 ≤ m ≤ n, got m={m}, n={n}"
)));
}
if len < 4 {
return Err(SpecialError::DomainError(
"d_coefficients length must be ≥ 4".to_string(),
));
}
let parity = ((n - m) % 2 + 2) % 2;
let k_target = ((n - m - parity) / 2) as usize;
if k_target >= len {
return Err(SpecialError::DomainError(format!(
"d_coefficients truncation len={len} too small for n={n}, m={m} (need k_target={k_target})"
)));
}
if c == 0.0 {
let mut d = vec![0.0_f64; len];
d[k_target] = 1.0;
return Ok(d);
}
let (diag, off_sym, scale_factors) = build_flammer_tridiag(parity_kind, m, n, c, len);
let eigs = tridiag_eigenvalues(&diag, &off_sym);
if eigs.len() <= k_target {
return Err(SpecialError::ComputationError(format!(
"d_coefficients: only {} eigenvalues found, need k_target={}",
eigs.len(),
k_target
)));
}
let lam_recovered = eigs[k_target];
if (lam_recovered - lambda).abs() > 1.0e-6 * (lambda.abs() + 1.0) {
}
let u = tridiag_eigenvector(&diag, &off_sym, lam_recovered);
let mut d: Vec<f64> = u
.iter()
.zip(scale_factors.iter())
.map(|(&ui, &sf)| ui * sf)
.collect();
let main = d[k_target];
if main.abs() < 1.0e-30 {
return Err(SpecialError::ComputationError(format!(
"d_coefficients: principal coefficient d[k_target={k_target}] is too small ({main:.3e}) to normalise"
)));
}
for di in d.iter_mut() {
*di /= main;
}
Ok(d)
}
fn legendre_assoc_cs(l: i32, m: i32, x: f64) -> f64 {
if m < 0 || l < m {
return 0.0;
}
if l == 0 && m == 0 {
return 1.0;
}
let oneminus_x2 = (1.0 - x * x).max(0.0);
let mut p_mm = 1.0_f64;
if m > 0 {
let mut dfact = 1.0_f64;
for k in 1..=m {
dfact *= (2 * k - 1) as f64;
}
let pow = oneminus_x2.powf(0.5 * m as f64);
let sign = if m % 2 == 0 { 1.0 } else { -1.0 };
p_mm = sign * dfact * pow;
}
if l == m {
return p_mm;
}
let p_mp1_m = x * (2 * m + 1) as f64 * p_mm;
if l == m + 1 {
return p_mp1_m;
}
let mut p_prev = p_mm;
let mut p_curr = p_mp1_m;
for k in (m + 2)..=l {
let k_f = k as f64;
let m_f = m as f64;
let p_next = (x * (2.0 * k_f - 1.0) * p_curr - (k_f + m_f - 1.0) * p_prev) / (k_f - m_f);
p_prev = p_curr;
p_curr = p_next;
}
p_curr
}
fn legendre_assoc_cs_prime(l: i32, m: i32, x: f64) -> f64 {
if m < 0 || l < m {
return 0.0;
}
if l == 0 {
return 0.0;
}
let l_f = l as f64;
let m_f = m as f64;
let oneminus_x2 = 1.0 - x * x;
if oneminus_x2 > 1.0e-8 {
let p_l = legendre_assoc_cs(l, m, x);
let p_lm1 = legendre_assoc_cs(l - 1, m, x);
return ((l_f + m_f) * p_lm1 - l_f * x * p_l) / oneminus_x2;
}
let h = 1.0e-4_f64;
let x_use = x.clamp(-1.0 + 8.0 * h, 1.0 - 8.0 * h);
let f_p2 = legendre_assoc_cs(l, m, x_use + 2.0 * h);
let f_p1 = legendre_assoc_cs(l, m, x_use + h);
let f_m1 = legendre_assoc_cs(l, m, x_use - h);
let f_m2 = legendre_assoc_cs(l, m, x_use - 2.0 * h);
(-f_p2 + 8.0 * f_p1 - 8.0 * f_m1 + f_m2) / (12.0 * h)
}
pub fn angular_function(
parity_kind: SpheroidalParity,
m: i32,
n: i32,
c: f64,
eta: f64,
) -> SpecialResult<(f64, f64)> {
if m < 0 || n < m {
return Err(SpecialError::DomainError(format!(
"angular_function requires 0 ≤ m ≤ n, got m={m}, n={n}"
)));
}
if !(-1.0..=1.0).contains(&eta) {
return Err(SpecialError::DomainError(format!(
"angular_function requires |η| ≤ 1, got η={eta}"
)));
}
let parity = ((n - m) % 2 + 2) % 2;
let parity_us = parity as usize;
if c == 0.0 {
let p_val = legendre_assoc_cs(n, m, eta);
let p_der = legendre_assoc_cs_prime(n, m, eta);
let sign_cs = if m % 2 == 0 { 1.0 } else { -1.0 };
return Ok((sign_cs * p_val, sign_cs * p_der));
}
let lambda = flammer_eigenvalue(parity_kind, m, n, c, DEFAULT_D_LEN)?;
let d = d_coefficients(parity_kind, m, n, c, lambda)?;
let len = d.len();
let mut s_raw = 0.0_f64;
let mut s_raw_prime = 0.0_f64;
for (p, &dp) in d.iter().enumerate().take(len) {
if dp.abs() < 1.0e-30 {
continue;
}
let r = 2 * p + parity_us;
let l = m + r as i32;
let p_val = legendre_assoc_cs(l, m, eta);
let p_der = legendre_assoc_cs_prime(l, m, eta);
s_raw += dp * p_val;
s_raw_prime += dp * p_der;
}
let k_factor = if parity == 0 {
let mut s_at_zero = 0.0_f64;
for (p, &dp) in d.iter().enumerate().take(len) {
if dp.abs() < 1.0e-30 {
continue;
}
let r = 2 * p + parity_us;
let l = m + r as i32;
s_at_zero += dp * legendre_assoc_cs(l, m, 0.0_f64);
}
if s_at_zero.abs() < 1.0e-30 {
return Err(SpecialError::ComputationError(format!(
"angular_function: Meixner–Schäfke (even) anchor S(0)={s_at_zero:.3e} too small for m={m}, n={n}, c={c}"
)));
}
let target = legendre_assoc_cs(n, m, 0.0_f64);
target / s_at_zero
} else {
let mut sp_at_zero = 0.0_f64;
for (p, &dp) in d.iter().enumerate().take(len) {
if dp.abs() < 1.0e-30 {
continue;
}
let r = 2 * p + parity_us;
let l = m + r as i32;
sp_at_zero += dp * legendre_assoc_cs_prime(l, m, 0.0_f64);
}
if sp_at_zero.abs() < 1.0e-30 {
return Err(SpecialError::ComputationError(format!(
"angular_function: Meixner–Schäfke (odd) anchor S'(0)={sp_at_zero:.3e} too small for m={m}, n={n}, c={c}"
)));
}
let target = legendre_assoc_cs_prime(n, m, 0.0_f64);
target / sp_at_zero
};
let sign_cs = if m % 2 == 0 { 1.0 } else { -1.0 };
let scale = sign_cs * k_factor;
Ok((scale * s_raw, scale * s_raw_prime))
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SphericalBesselKind {
First,
Second,
}
fn spherical_bessel_pair(kind: SphericalBesselKind, l: i32, x: f64) -> (f64, f64) {
let l_us = l;
let z_l = match kind {
SphericalBesselKind::First => crate::spherical_jn::<f64>(l_us, x),
SphericalBesselKind::Second => crate::spherical_yn::<f64>(l_us, x),
};
if l == 0 {
let z_1 = match kind {
SphericalBesselKind::First => crate::spherical_jn::<f64>(1, x),
SphericalBesselKind::Second => crate::spherical_yn::<f64>(1, x),
};
return (z_l, -z_1);
}
if x == 0.0 {
return (z_l, 0.0);
}
let z_lm1 = match kind {
SphericalBesselKind::First => crate::spherical_jn::<f64>(l_us - 1, x),
SphericalBesselKind::Second => crate::spherical_yn::<f64>(l_us - 1, x),
};
let der = z_lm1 - ((l + 1) as f64) / x * z_l;
(z_l, der)
}
fn kappa_r(m: i32, r: i32) -> f64 {
let mut v = 1.0_f64;
for k in (r + 1)..=(2 * m + r) {
v *= k as f64;
}
v
}
pub fn radial_function(
parity_kind: SpheroidalParity,
bessel_kind: SphericalBesselKind,
m: i32,
n: i32,
c: f64,
xi: f64,
) -> SpecialResult<(f64, f64)> {
if m < 0 || n < m {
return Err(SpecialError::DomainError(format!(
"radial_function requires 0 ≤ m ≤ n, got m={m}, n={n}"
)));
}
if c < 0.0 {
return Err(SpecialError::DomainError(format!(
"radial_function requires c ≥ 0, got c={c}"
)));
}
match parity_kind {
SpheroidalParity::Prolate => {
if xi < 1.0 {
return Err(SpecialError::DomainError(format!(
"prolate radial_function requires ξ ≥ 1, got ξ={xi}"
)));
}
}
SpheroidalParity::Oblate => {
if xi < 0.0 {
return Err(SpecialError::DomainError(format!(
"oblate radial_function requires ξ ≥ 0, got ξ={xi}"
)));
}
}
}
let parity = ((n - m) % 2 + 2) % 2;
let parity_us = parity as usize;
let lambda = flammer_eigenvalue(parity_kind, m, n, c, DEFAULT_D_LEN)?;
let d = d_coefficients(parity_kind, m, n, c, lambda)?;
let len = d.len();
let target_us = ((n - m - parity) / 2) as usize;
let s_pref = parity_kind.sign();
let prefactor = if m == 0 {
1.0_f64
} else {
if xi.abs() < 1.0e-30 {
return Err(SpecialError::DomainError(
"radial_function: prefactor singular at ξ=0 for m > 0".to_string(),
));
}
let xi2 = xi * xi;
let arg = (xi2 - s_pref) / xi2;
if arg < 0.0 {
return Err(SpecialError::DomainError(format!(
"radial_function: prefactor (ξ²-s)/ξ²={arg} < 0 for ξ={xi}"
)));
}
arg.powf(0.5 * m as f64)
};
let z_arg = c * xi;
let mut kdr = vec![0.0_f64; len];
for (p, &dp) in d.iter().enumerate().take(len) {
let r = (2 * p + parity_us) as i32;
kdr[p] = kappa_r(m, r) * dp;
}
let h_norm: f64 = kdr.iter().sum();
if h_norm.abs() < 1.0e-30 {
return Err(SpecialError::ComputationError(format!(
"radial_function: normalisation H_n^m(c)={h_norm:.3e} too small for m={m}, n={n}, c={c}"
)));
}
let max_iter = match bessel_kind {
SphericalBesselKind::First => len,
SphericalBesselKind::Second => (2 * n as usize + 8).min(len),
};
let mut numerator = 0.0_f64;
let mut numerator_prime = 0.0_f64;
let term_floor: f64 = 1.0e-15;
let mut max_seen: f64 = 0.0;
let mut consecutive_below = 0;
for (p, &kd) in kdr.iter().enumerate().take(max_iter) {
if kd == 0.0 {
continue;
}
let r = (2 * p + parity_us) as i32;
let l = m + r;
let phase = if (p as i32 - target_us as i32).rem_euclid(2) == 0 {
1.0
} else {
-1.0
};
let coef = phase * kd;
let (z_val, z_der) = spherical_bessel_pair(bessel_kind, l, z_arg);
if !z_val.is_finite() || !z_der.is_finite() {
break;
}
let term = coef * z_val;
let term_p = coef * c * z_der;
let abs_term = term.abs();
if abs_term > max_seen {
max_seen = abs_term;
}
numerator += term;
numerator_prime += term_p;
if max_seen > 0.0 && abs_term < term_floor * max_seen && p >= target_us + 6 {
consecutive_below += 1;
if consecutive_below >= 4 {
break;
}
} else {
consecutive_below = 0;
}
}
let r_value = prefactor * numerator / h_norm;
let prefactor_prime = if m == 0 {
0.0
} else {
let xi2 = xi * xi;
let base = 1.0 - s_pref / xi2;
if base.abs() < 1.0e-30 {
0.0
} else {
0.5 * m as f64 * base.powf(0.5 * m as f64 - 1.0) * (2.0 * s_pref / (xi2 * xi))
}
};
let r_derivative = prefactor_prime * numerator / h_norm + prefactor * numerator_prime / h_norm;
Ok((r_value, r_derivative))
}
pub fn flammer_eigenvalue(
parity_kind: SpheroidalParity,
m: i32,
n: i32,
c: f64,
len: usize,
) -> SpecialResult<f64> {
if m < 0 || n < m {
return Err(SpecialError::DomainError(format!(
"flammer_eigenvalue requires 0 ≤ m ≤ n, got m={m}, n={n}"
)));
}
if len < 4 {
return Err(SpecialError::DomainError(
"flammer_eigenvalue length must be ≥ 4".to_string(),
));
}
if c == 0.0 {
return Ok(n as f64 * (n as f64 + 1.0));
}
let parity = ((n - m) % 2 + 2) % 2;
let k_target = ((n - m - parity) / 2) as usize;
if k_target >= len {
return Err(SpecialError::DomainError(format!(
"flammer_eigenvalue truncation len={len} too small for n={n}, m={m}"
)));
}
let (diag, off_sym, _scale) = build_flammer_tridiag(parity_kind, m, n, c, len);
let eigs = tridiag_eigenvalues(&diag, &off_sym);
eigs.get(k_target).copied().ok_or_else(|| {
SpecialError::ComputationError(format!(
"flammer_eigenvalue: target eigenvalue {k_target} not found"
))
})
}
pub fn tail_ratio_lentz(
parity_kind: SpheroidalParity,
m: i32,
n: i32,
c: f64,
lambda: f64,
start_k: usize,
) -> SpecialResult<f64> {
let m_f = m as f64;
let parity = ((n - m) % 2 + 2) % 2;
let parity_f = parity as f64;
let s = parity_kind.sign();
let c2_signed = c * c * s;
if c2_signed == 0.0 {
return Ok(0.0);
}
let r_at = |k: usize| -> f64 { 2.0 * k as f64 + parity_f };
let beta_at = |k: usize| -> f64 { flammer_beta(m_f, r_at(k), c2_signed) - lambda };
let alpha_at = |k: usize| -> f64 { flammer_alpha_up(m_f, r_at(k), c2_signed) };
let gamma_at = |k: usize| -> f64 { flammer_gamma_down(m_f, r_at(k), c2_signed) };
let result = cf_modified_lentz(
|idx: usize| {
if idx == 0 {
0.0
} else {
let k = start_k + idx;
let level = start_k + idx - 1;
-gamma_at(level + 1) * alpha_at(level)
}
},
|idx: usize| {
if idx == 0 {
0.0
} else {
let level = start_k + idx - 1;
beta_at(level)
}
},
)?;
let alpha_start = alpha_at(start_k);
if alpha_start.abs() < f64::MIN_POSITIVE * 1.0e6 {
return Err(SpecialError::ConvergenceError(
"tail_ratio_lentz: α_up(start_k) too small".to_string(),
));
}
Ok(-result.value / alpha_start)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn cf_lentz_evaluates_tan_via_cf() {
let x: f64 = 0.5;
let x2 = x * x;
let cf = cf_modified_lentz(
|n| if n == 0 { 0.0 } else { -x2 },
|n| if n == 0 { 1.0 } else { (2 * n + 1) as f64 },
)
.expect("Lentz CF failed for tan");
let tan_via_cf = x / cf.value;
assert_abs_diff_eq!(tan_via_cf, x.tan(), epsilon = 1e-12);
}
#[test]
fn d_coefficients_c_zero_matches_legendre_limit() {
for n in 0..6 {
let lam = n as f64 * (n as f64 + 1.0);
let d = d_coefficients(SpheroidalParity::Prolate, 0, n, 0.0, lam).expect("d coef c=0");
let parity = (n % 2) as usize;
let target = ((n as usize) - parity) / 2;
for (k, &dk) in d.iter().enumerate() {
if k == target {
assert!((dk - 1.0).abs() < 1e-12, "d[target]=1 expected, got {dk}");
} else {
assert!(dk.abs() < 1e-10, "d[{k}]={dk} should be 0 at c=0");
}
}
}
}
#[test]
fn tail_ratio_lentz_c_zero_is_zero() {
let r = tail_ratio_lentz(SpheroidalParity::Prolate, 0, 1, 0.0, 2.0, 5)
.expect("tail ratio should succeed at c=0");
assert_abs_diff_eq!(r, 0.0, epsilon = 1e-15);
}
#[test]
fn scaled_recurrence_step_handles_overflow() {
let mut values = [1.0e160_f64, 5.0e159, -2.0e160];
let mut log_scale = 0.0_f64;
let did = scaled_recurrence_step(&mut values, &mut log_scale);
assert!(did);
assert!(values.iter().all(|v| v.abs() < 1.0e150));
assert!(log_scale > 0.0);
}
#[test]
fn scaled_recurrence_step_handles_underflow() {
let mut values = [1.0e-160_f64, -3.0e-160, 2.0e-161];
let mut log_scale = 0.0_f64;
let did = scaled_recurrence_step(&mut values, &mut log_scale);
assert!(did);
assert!(values.iter().all(|v| v.abs() > 1.0e-160));
assert!(log_scale < 0.0);
}
#[test]
fn flammer_eigenvalue_matches_scipy_reference() {
let lam = flammer_eigenvalue(SpheroidalParity::Prolate, 0, 2, 1.0, 60)
.expect("flammer_eigenvalue prolate");
assert_abs_diff_eq!(lam, 6.5334718005, epsilon = 1.0e-7);
let lam = flammer_eigenvalue(SpheroidalParity::Oblate, 0, 2, 1.0, 60)
.expect("flammer_eigenvalue oblate");
assert_abs_diff_eq!(lam, 5.4868000164, epsilon = 1.0e-7);
}
}