use crate::error::{SpecialError, SpecialResult};
use crate::gamma::gamma;
use std::f64::consts::PI;
pub fn riemann_zeta(s: f64) -> SpecialResult<f64> {
if s.is_nan() {
return Err(SpecialError::DomainError(
"riemann_zeta: NaN argument".to_string(),
));
}
if s == 1.0 {
return Ok(f64::INFINITY);
}
if s < 0.0 && s.fract() == 0.0 && ((-s) as u64).is_multiple_of(2) {
return Ok(0.0);
}
if s > 1.0 {
riemann_zeta_em(s)
} else {
let t = 1.0 - s;
let zeta_t = riemann_zeta_em(t)?;
let result = 2.0_f64.powf(s) * PI.powf(s - 1.0) * (PI * s / 2.0).sin() * gamma(t) * zeta_t;
Ok(result)
}
}
fn riemann_zeta_em(s: f64) -> SpecialResult<f64> {
let n: usize = if s > 30.0 {
5
} else if s > 5.0 {
30
} else {
80
};
let mut sum = 0.0_f64;
for k in 1..=n {
sum += (k as f64).powf(-s);
}
let nf = n as f64;
sum += nf.powf(1.0 - s) / (s - 1.0);
sum -= 0.5 * nf.powf(-s);
let bern: [(f64, f64); 6] = [
(1.0 / 6.0, 2.0),
(-1.0 / 30.0, 4.0),
(1.0 / 42.0, 6.0),
(-1.0 / 30.0, 8.0),
(5.0 / 66.0, 10.0),
(-691.0 / 2730.0, 12.0),
];
let mut rising = s; for (b2k, two_k) in &bern {
let fact_denom = factorial_f64(*two_k as usize);
let correction = b2k * rising * nf.powf(-s - two_k + 1.0) / fact_denom;
sum += correction;
let last = *two_k - 1.0; rising *= (s + last) * (s + last + 1.0);
}
Ok(sum)
}
fn factorial_f64(n: usize) -> f64 {
let mut f = 1.0_f64;
for i in 2..=n {
f *= i as f64;
}
f
}
pub fn riemann_zeta_complex(s: (f64, f64)) -> SpecialResult<(f64, f64)> {
let (sigma, t) = s;
if !sigma.is_finite() || !t.is_finite() {
return Err(SpecialError::DomainError(
"riemann_zeta_complex: non-finite argument".to_string(),
));
}
if t == 0.0 {
let re = riemann_zeta(sigma)?;
return Ok((re, 0.0));
}
if sigma > 0.0 {
zeta_complex_em(sigma, t)
} else {
let (re1, im1) = zeta_complex_em(1.0 - sigma, -t)?;
let prefactor = zeta_functional_prefactor(sigma, t, re1, im1)?;
Ok(prefactor)
}
}
fn zeta_complex_em(sigma: f64, t: f64) -> SpecialResult<(f64, f64)> {
let n: usize = if sigma > 10.0 { 10 } else { 60 };
let mut sum_re = 0.0_f64;
let mut sum_im = 0.0_f64;
for k in 1..=n {
let kf = k as f64;
let ln_k = kf.ln();
let mag = (-sigma * ln_k).exp();
let phase = t * ln_k;
sum_re += mag * phase.cos();
sum_im -= mag * phase.sin();
}
let nf = n as f64;
let ln_n = nf.ln();
let n_pow_re = ((1.0 - sigma) * ln_n).exp();
let n_pow_phase = -t * ln_n;
let n1ms_re = n_pow_re * n_pow_phase.cos();
let n1ms_im = n_pow_re * n_pow_phase.sin();
let denom = (sigma - 1.0) * (sigma - 1.0) + t * t;
if denom < 1e-300 {
return Ok((f64::INFINITY, 0.0));
}
let inv_re = (sigma - 1.0) / denom;
let inv_im = -t / denom;
let tail_re = n1ms_re * inv_re - n1ms_im * inv_im;
let tail_im = n1ms_re * inv_im + n1ms_im * inv_re;
sum_re += tail_re;
sum_im += tail_im;
let n_ms_re = ((-sigma) * ln_n).exp() * (t * ln_n).cos();
let n_ms_im = -((-sigma) * ln_n).exp() * (t * ln_n).sin();
sum_re -= 0.5 * n_ms_re;
sum_im -= 0.5 * n_ms_im;
Ok((sum_re, sum_im))
}
fn zeta_functional_prefactor(
sigma: f64,
t: f64,
zeta_1ms_re: f64,
zeta_1ms_im: f64,
) -> SpecialResult<(f64, f64)> {
let ln2 = std::f64::consts::LN_2;
let two_s_mag = (sigma * ln2).exp();
let two_s_re = two_s_mag * (t * ln2).cos();
let two_s_im = two_s_mag * (t * ln2).sin();
let ln_pi = PI.ln();
let pi_s1_mag = ((sigma - 1.0) * ln_pi).exp();
let pi_s1_re = pi_s1_mag * (t * ln_pi).cos();
let pi_s1_im = pi_s1_mag * (t * ln_pi).sin();
let a_re = two_s_re * pi_s1_re - two_s_im * pi_s1_im;
let a_im = two_s_re * pi_s1_im + two_s_im * pi_s1_re;
let hs = PI * sigma / 2.0;
let ht = PI * t / 2.0;
let sin_re = hs.sin() * ht.cosh();
let sin_im = hs.cos() * ht.sinh();
let b_re = a_re * sin_re - a_im * sin_im;
let b_im = a_re * sin_im + a_im * sin_re;
let gamma_val = complex_gamma(1.0 - sigma, -t);
let c_re = b_re * gamma_val.0 - b_im * gamma_val.1;
let c_im = b_re * gamma_val.1 + b_im * gamma_val.0;
let result_re = c_re * zeta_1ms_re - c_im * zeta_1ms_im;
let result_im = c_re * zeta_1ms_im + c_im * zeta_1ms_re;
Ok((result_re, result_im))
}
fn complex_gamma(x: f64, y: f64) -> (f64, f64) {
const G: f64 = 7.0;
const C: [f64; 9] = [
0.999_999_999_999_809_93,
676.520_368_121_885_10,
-1_259.139_216_722_402_9,
771.323_428_777_653_07,
-176.615_029_162_140_60,
12.507_343_278_686_905,
-0.138_571_095_265_720_12,
9.984_369_578_019_572e-6,
1.505_632_735_149_312_5e-7,
];
if x < 0.5 {
let (g1_re, g1_im) = complex_gamma(1.0 - x, -y);
let sin_re = (PI * x).sin() * (PI * y).cosh();
let sin_im = (PI * x).cos() * (PI * y).sinh();
let denom = sin_re * sin_re + sin_im * sin_im;
if denom < 1e-300 {
return (f64::INFINITY, 0.0);
}
let inv_re = PI * sin_re / denom;
let inv_im = -PI * sin_im / denom;
let d2 = g1_re * g1_re + g1_im * g1_im;
if d2 < 1e-300 {
return (f64::INFINITY, 0.0);
}
let inv_g_re = g1_re / d2;
let inv_g_im = -g1_im / d2;
let re = inv_re * inv_g_re - inv_im * inv_g_im;
let im = inv_re * inv_g_im + inv_im * inv_g_re;
return (re, im);
}
let zr = x - 1.0;
let zi = y;
let mut sum_re = C[0];
let mut sum_im = 0.0_f64;
for k in 1..9 {
let d = (zr + k as f64) * (zr + k as f64) + zi * zi;
if d < 1e-300 {
return (f64::INFINITY, 0.0);
}
sum_re += C[k] * (zr + k as f64) / d;
sum_im -= C[k] * zi / d;
}
let tr = zr + G + 0.5;
let ti = zi;
let ln_t_mag = (tr * tr + ti * ti).sqrt().ln();
let ln_t_arg = ti.atan2(tr);
let pw_exp = (zr + 0.5) * ln_t_mag - zi * ln_t_arg;
let pw_arg = (zr + 0.5) * ln_t_arg + zi * ln_t_mag;
let t_pow_re = pw_exp.exp() * pw_arg.cos();
let t_pow_im = pw_exp.exp() * pw_arg.sin();
let exp_t = (-tr).exp();
let exp_t_re = exp_t * (-ti).cos();
let exp_t_im = exp_t * (-ti).sin();
let sqrt2pi = (2.0 * PI).sqrt();
let te_re = t_pow_re * exp_t_re - t_pow_im * exp_t_im;
let te_im = t_pow_re * exp_t_im + t_pow_im * exp_t_re;
let result_re = sqrt2pi * (sum_re * te_re - sum_im * te_im);
let result_im = sqrt2pi * (sum_re * te_im + sum_im * te_re);
(result_re, result_im)
}
pub fn dirichlet_eta(s: f64) -> SpecialResult<f64> {
if s.is_nan() {
return Err(SpecialError::DomainError(
"dirichlet_eta: NaN argument".to_string(),
));
}
if (s - 1.0).abs() < 1e-15 {
return Ok(std::f64::consts::LN_2);
}
if s == 0.0 {
return Ok(0.5);
}
let factor = 1.0 - 2.0_f64.powf(1.0 - s);
let z = riemann_zeta(s)?;
Ok(factor * z)
}
#[allow(dead_code)]
fn eta_alternating(s: f64) -> SpecialResult<f64> {
lerch_phi_alternating(s, 1.0)
}
pub fn lerch_phi(z: f64, s: f64, a: f64) -> SpecialResult<f64> {
if z.is_nan() || s.is_nan() || a.is_nan() {
return Err(SpecialError::DomainError(
"lerch_phi: NaN argument".to_string(),
));
}
if a <= 0.0 {
return Err(SpecialError::DomainError(
"lerch_phi: a must be positive".to_string(),
));
}
if z == 1.0 {
if (a - 1.0).abs() < 1e-15 {
return riemann_zeta(s);
}
return crate::zeta::hurwitz_zeta(s, a);
}
if z.abs() > 1.0 {
return Err(SpecialError::DomainError(
"lerch_phi: |z| must be <= 1".to_string(),
));
}
if (z - (-1.0)).abs() < 1e-15 {
return lerch_phi_alternating(s, a);
}
let max_terms: usize = if z.abs() > 0.99 { 10000 } else { 500 };
let tol = 1e-15;
let mut sum = 0.0_f64;
let mut z_pow = 1.0_f64;
for n in 0..max_terms {
let term = z_pow / (n as f64 + a).powf(s);
sum += term;
z_pow *= z;
if term.abs() < tol * sum.abs() && n > 10 {
break;
}
}
Ok(sum)
}
fn lerch_phi_alternating(s: f64, a: f64) -> SpecialResult<f64> {
let n: usize = 50;
let mut d_vals = Vec::with_capacity(n + 1);
let mut binom = 1.0_f64; d_vals.push(binom); for k in 0..n {
binom *= (n - k) as f64 / (k + 1) as f64;
let d_prev = d_vals[k];
d_vals.push(d_prev + binom);
}
let d_n = d_vals[n];
let mut sum = 0.0_f64;
for k in 0..=n {
let sign = if k % 2 == 0 { 1.0 } else { -1.0 };
let weight = d_vals[k] - d_n; let b_k = 1.0 / (k as f64 + a).powf(s);
sum += sign * weight * b_k;
}
Ok(-sum / d_n)
}
#[allow(dead_code)]
fn polygamma_local(n: u32, x: f64) -> f64 {
crate::gamma::polygamma(n, x)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use std::f64::consts::{LN_2, PI};
#[test]
fn test_riemann_zeta_known_values() {
let z2 = riemann_zeta(2.0).expect("zeta(2)");
assert_relative_eq!(z2, PI * PI / 6.0, epsilon = 1e-8);
let z4 = riemann_zeta(4.0).expect("zeta(4)");
assert_relative_eq!(z4, PI.powi(4) / 90.0, epsilon = 1e-8);
let z3 = riemann_zeta(3.0).expect("zeta(3)");
assert_relative_eq!(z3, 1.2020569031595942_f64, epsilon = 1e-8);
}
#[test]
fn test_riemann_zeta_negative() {
let z_neg1 = riemann_zeta(-1.0).expect("zeta(-1)");
assert_relative_eq!(z_neg1, -1.0 / 12.0, epsilon = 1e-7);
let z_neg3 = riemann_zeta(-3.0).expect("zeta(-3)");
assert_relative_eq!(z_neg3, 1.0 / 120.0, epsilon = 1e-7);
let z_neg2 = riemann_zeta(-2.0).expect("zeta(-2)");
assert_relative_eq!(z_neg2, 0.0, epsilon = 1e-15);
let z_neg4 = riemann_zeta(-4.0).expect("zeta(-4)");
assert_relative_eq!(z_neg4, 0.0, epsilon = 1e-15);
}
#[test]
fn test_riemann_zeta_pole() {
let z1 = riemann_zeta(1.0).expect("zeta(1) pole");
assert!(z1.is_infinite());
}
#[test]
fn test_riemann_zeta_large_s() {
let z50 = riemann_zeta(50.0).expect("zeta(50)");
assert!((z50 - 1.0).abs() < 1e-14);
}
#[test]
fn test_riemann_zeta_complex_real_axis() {
let (re, im) = riemann_zeta_complex((2.0, 0.0)).expect("zeta_complex(2)");
assert_relative_eq!(re, PI * PI / 6.0, epsilon = 1e-6);
assert!(im.abs() < 1e-10);
}
#[test]
fn test_riemann_zeta_complex_nontrivial() {
let (re, im) = riemann_zeta_complex((2.0, 3.0)).expect("zeta_complex(2+3i)");
assert!(re.is_finite());
assert!(im.is_finite());
}
#[test]
fn test_dirichlet_eta_special_values() {
let eta1 = dirichlet_eta(1.0).expect("eta(1)");
assert_relative_eq!(eta1, LN_2, epsilon = 1e-10);
let eta0 = dirichlet_eta(0.0).expect("eta(0)");
assert_relative_eq!(eta0, 0.5, epsilon = 1e-10);
let eta2 = dirichlet_eta(2.0).expect("eta(2)");
assert_relative_eq!(eta2, PI * PI / 12.0, epsilon = 1e-8);
}
#[test]
fn test_dirichlet_eta_relation_to_zeta() {
for s in [2.0_f64, 3.0, 4.0, 5.0] {
let eta_val = dirichlet_eta(s).expect("eta");
let zeta_val = riemann_zeta(s).expect("zeta");
let expected = (1.0 - 2.0_f64.powf(1.0 - s)) * zeta_val;
let diff = (eta_val - expected).abs();
assert!(
diff < 1e-8,
"eta({s}) = {eta_val}, expected {expected}, diff={diff}"
);
}
}
#[test]
fn test_lerch_phi_reduces_to_zeta() {
let phi = lerch_phi(1.0, 2.0, 1.0).expect("Phi(1,2,1)");
assert_relative_eq!(phi, PI * PI / 6.0, epsilon = 1e-6);
}
#[test]
fn test_lerch_phi_reduces_to_eta() {
let phi = lerch_phi(-1.0, 1.0, 1.0).expect("Phi(-1,1,1)");
assert_relative_eq!(phi, LN_2, epsilon = 1e-8);
}
#[test]
fn test_lerch_phi_geometric() {
let z = 0.5_f64;
let phi = lerch_phi(z, 0.0, 1.0).expect("Phi(0.5,0,1)");
assert_relative_eq!(phi, 1.0 / (1.0 - z), epsilon = 1e-6);
}
#[test]
fn test_polygamma_wrapper() {
let euler_gamma = 0.577_215_664_901_532_9_f64;
let psi0 = polygamma_local(0, 1.0_f64);
assert_relative_eq!(psi0, -euler_gamma, epsilon = 1e-8);
let psi1 = polygamma_local(1, 1.0_f64);
assert_relative_eq!(psi1, PI * PI / 6.0, epsilon = 1e-8);
}
}