use crate::error::{SpecialError, SpecialResult};
use crate::mathieu::advanced::{tridiag_eigenvalues, tridiag_eigenvector};
use crate::spheroidal::wave_functions::associated_legendre;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SpheroidalKind {
Prolate,
Oblate,
}
#[derive(Debug, Clone)]
pub struct SpheroidalEigenvalue {
pub m: usize,
pub n: usize,
pub c: f64,
pub kind: SpheroidalKind,
pub lambda: f64,
}
fn n_max(n: usize) -> usize {
n.saturating_add(20).max(40)
}
fn build_swf_tridiag(m: usize, n: usize, c: f64, sign_c2: f64) -> (Vec<f64>, Vec<f64>) {
let size = n_max(n);
let parity = (n - m) % 2;
let c2 = c * c * sign_c2;
let mut diag = vec![0.0_f64; size];
let mut off = vec![0.0_f64; size.saturating_sub(1)];
for p in 0..size {
let k = 2 * p + parity; let ell = (m + k) as f64;
let b_down = if k >= 2 {
let denom = (2.0 * ell - 1.0) * (2.0 * ell + 1.0);
if denom.abs() < 1e-14 {
0.0
} else {
-(ell + m as f64) * (ell + m as f64 - 1.0) / denom
}
} else {
0.0
};
let denom_up = (2.0 * ell + 1.0) * (2.0 * ell + 3.0);
let b_up = if denom_up.abs() < 1e-14 {
0.0
} else {
-(ell - m as f64 + 1.0) * (ell - m as f64 + 2.0) / denom_up
};
diag[p] = ell * (ell + 1.0) + c2 * (b_down + b_up);
if p + 1 < size {
off[p] = c2 * b_up;
}
}
(diag, off)
}
pub fn spheroidal_eigenvalue_mn(
m: usize,
n: usize,
c: f64,
kind: SpheroidalKind,
) -> SpecialResult<SpheroidalEigenvalue> {
if n < m {
return Err(SpecialError::DomainError(format!(
"spheroidal_eigenvalue_mn: degree n={n} must be >= order m={m}"
)));
}
if c < 0.0 {
return Err(SpecialError::DomainError(format!(
"spheroidal_eigenvalue_mn: parameter c={c} must be >= 0"
)));
}
let sign_c2 = match kind {
SpheroidalKind::Prolate => 1.0,
SpheroidalKind::Oblate => -1.0,
};
let (diag, off) = build_swf_tridiag(m, n, c, sign_c2);
let eigenvalues = tridiag_eigenvalues(&diag, &off);
let target = (n - m) / 2;
let lambda = eigenvalues.get(target).copied().ok_or_else(|| {
SpecialError::ComputationError(format!(
"spheroidal_eigenvalue_mn: failed to locate eigenvalue for m={m}, n={n}, c={c}"
))
})?;
Ok(SpheroidalEigenvalue {
m,
n,
c,
kind,
lambda,
})
}
pub fn spheroidal_ps(
m: usize,
n: usize,
c: f64,
x: f64,
kind: SpheroidalKind,
) -> SpecialResult<f64> {
if n < m {
return Err(SpecialError::DomainError(format!(
"spheroidal_ps: degree n={n} must be >= order m={m}"
)));
}
if x.abs() > 1.0 + 1e-10 {
return Err(SpecialError::DomainError(format!(
"spheroidal_ps: argument x={x} must satisfy |x| <= 1"
)));
}
if c < 0.0 {
return Err(SpecialError::DomainError(format!(
"spheroidal_ps: parameter c={c} must be >= 0"
)));
}
let x_clamped = x.clamp(-1.0, 1.0);
let sign_c2 = match kind {
SpheroidalKind::Prolate => 1.0,
SpheroidalKind::Oblate => -1.0,
};
let (diag, off) = build_swf_tridiag(m, n, c, sign_c2);
let eigenvalues = tridiag_eigenvalues(&diag, &off);
let target = (n - m) / 2;
let eigenval = eigenvalues.get(target).copied().ok_or_else(|| {
SpecialError::ComputationError(format!(
"spheroidal_ps: no eigenvalue for m={m}, n={n}, c={c}"
))
})?;
let coeffs = tridiag_eigenvector(&diag, &off, eigenval);
let max_abs_idx = coeffs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.abs()
.partial_cmp(&b.abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
let sign = if coeffs.get(max_abs_idx).copied().unwrap_or(1.0) >= 0.0 {
1.0_f64
} else {
-1.0_f64
};
let parity = (n - m) % 2;
let mut result = 0.0_f64;
for (p, &d) in coeffs.iter().enumerate() {
let k = 2 * p + parity;
let l = m + k;
let p_val = associated_legendre(l, m, x_clamped);
result += d * p_val;
}
Ok(sign * result)
}
pub fn spheroidal_wronskian(m: usize, n: usize, c: f64, x: f64) -> SpecialResult<f64> {
if x.abs() >= 1.0 - 1e-8 {
return Err(SpecialError::DomainError(format!(
"spheroidal_wronskian: x={x} must satisfy |x| < 1 (strict); too close to boundary"
)));
}
if n < m {
return Err(SpecialError::DomainError(format!(
"spheroidal_wronskian: n={n} must be >= m={m}"
)));
}
let h = 1e-5_f64.max(1e-5 * (1.0 - x.abs()));
let x_p = (x + h).clamp(-1.0 + 1e-12, 1.0 - 1e-12);
let x_m = (x - h).clamp(-1.0 + 1e-12, 1.0 - 1e-12);
let s_p = spheroidal_ps(m, n, c, x_p, SpheroidalKind::Prolate)?;
let s_m = spheroidal_ps(m, n, c, x_m, SpheroidalKind::Prolate)?;
let s_x = spheroidal_ps(m, n, c, x, SpheroidalKind::Prolate)?;
let ds_dx = (s_p - s_m) / (2.0 * h);
let one_minus_x2 = 1.0 - x * x;
if s_x.abs() < 1e-15 {
Ok(one_minus_x2 * ds_dx)
} else {
Ok(one_minus_x2 * ds_dx / s_x)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_eigenvalue_c0_n0m0_prolate() {
let ev = spheroidal_eigenvalue_mn(0, 0, 0.0, SpheroidalKind::Prolate)
.expect("m=0,n=0,c=0 eigenvalue should not fail");
assert!(
ev.lambda.abs() < 0.5,
"λ_{{00}}(0) should be 0, got {}",
ev.lambda
);
}
#[test]
fn test_eigenvalue_c0_n1m1_prolate() {
let ev = spheroidal_eigenvalue_mn(1, 1, 0.0, SpheroidalKind::Prolate)
.expect("m=1,n=1,c=0 eigenvalue should not fail");
assert!(
(ev.lambda - 2.0).abs() < 0.5,
"λ_{{11}}(0) should be ≈ 2, got {}",
ev.lambda
);
}
#[test]
fn test_eigenvalue_c0_matches_nn1() {
let cases = [
(0usize, 2usize, 6.0),
(0, 3, 12.0),
(1, 2, 6.0),
(2, 3, 12.0),
];
for (m, n, expected) in cases {
let ev = spheroidal_eigenvalue_mn(m, n, 0.0, SpheroidalKind::Prolate)
.unwrap_or_else(|_| panic!("eigenvalue for m={m},n={n} should not fail"));
assert!(
(ev.lambda - expected).abs() < 1.0,
"λ_{{m={m},n={n}}}(0) should be ≈ {expected}, got {}",
ev.lambda
);
}
}
#[test]
fn test_eigenvalue_increases_with_n() {
let c = 1.0;
let mut prev = f64::NEG_INFINITY;
for n in [0usize, 2, 4] {
let ev = spheroidal_eigenvalue_mn(0, n, c, SpheroidalKind::Prolate)
.unwrap_or_else(|_| panic!("eigenvalue for n={n} should not fail"));
assert!(
ev.lambda > prev - 1e-8,
"eigenvalue for n={n} ({}) should exceed previous ({})",
ev.lambda,
prev
);
prev = ev.lambda;
}
}
#[test]
fn test_eigenvalue_oblate_c0_equals_nn1() {
let ev = spheroidal_eigenvalue_mn(0, 3, 0.0, SpheroidalKind::Oblate)
.expect("oblate c=0 eigenvalue should not fail");
assert!(
(ev.lambda - 12.0).abs() < 0.5,
"oblate λ_{{03}}(0) should be ≈ 12, got {}",
ev.lambda
);
}
#[test]
fn test_eigenvalue_error_n_less_than_m() {
let result = spheroidal_eigenvalue_mn(3, 1, 1.0, SpheroidalKind::Prolate);
assert!(result.is_err(), "n < m should return an error");
}
#[test]
fn test_spheroidal_ps_c0_finite() {
let val = spheroidal_ps(0, 0, 0.0, 0.5, SpheroidalKind::Prolate)
.expect("spheroidal_ps m=0,n=0,c=0 should not fail");
assert!(
val.is_finite(),
"S_{{00}}(0, 0.5) should be finite, got {val}"
);
}
#[test]
fn test_spheroidal_ps_c0_legendre_proportional() {
let x = 0.7_f64;
let val = spheroidal_ps(0, 2, 0.0, x, SpheroidalKind::Prolate)
.expect("spheroidal_ps should not fail for c=0");
let p2 = (3.0 * x * x - 1.0) / 2.0; if p2.abs() > 1e-10 && val.abs() > 1e-10 {
let ratio = val / p2;
assert!(
ratio > 0.0,
"S_{{02}}(0, {x}) and P_2({x}) should have the same sign; ratio = {ratio}"
);
}
}
#[test]
fn test_spheroidal_ps_interior_finite() {
for x in [-0.8, -0.3, 0.0, 0.3, 0.8] {
let val = spheroidal_ps(1, 2, 2.0, x, SpheroidalKind::Prolate)
.unwrap_or_else(|_| panic!("spheroidal_ps at x={x} should not fail"));
assert!(
val.is_finite(),
"S_{{12}}(2, {x}) should be finite, got {val}"
);
}
}
#[test]
fn test_spheroidal_ps_domain_error() {
let result = spheroidal_ps(0, 2, 1.0, 1.5, SpheroidalKind::Prolate);
assert!(result.is_err(), "|x| > 1 should return a domain error");
}
#[test]
fn test_wronskian_finite_interior() {
for x in [-0.7, -0.2, 0.0, 0.4, 0.6] {
let w = spheroidal_wronskian(0, 2, 1.0, x)
.unwrap_or_else(|_| panic!("wronskian at x={x} should not fail"));
assert!(
w.is_finite(),
"wronskian at x={x} should be finite, got {w}"
);
}
}
#[test]
fn test_wronskian_boundary_error() {
let result = spheroidal_wronskian(0, 2, 1.0, 1.0);
assert!(result.is_err(), "wronskian at boundary x=1 should error");
}
#[test]
fn test_spheroidal_eigenvalue_metadata() {
let ev = spheroidal_eigenvalue_mn(1, 3, 2.0, SpheroidalKind::Prolate)
.expect("eigenvalue should not fail");
assert_eq!(ev.m, 1);
assert_eq!(ev.n, 3);
assert!((ev.c - 2.0).abs() < 1e-14);
assert_eq!(ev.kind, SpheroidalKind::Prolate);
assert!(ev.lambda.is_finite());
}
}