use numra_core::Scalar;
pub fn zeta<S: Scalar>(s: S) -> S {
let sf = s.to_f64();
S::from_f64(zeta_f64(sf))
}
pub fn hurwitz_zeta<S: Scalar>(s: S, a: S) -> S {
let sf = s.to_f64();
let af = a.to_f64();
if af <= 0.0 || sf <= 1.0 {
return S::NAN;
}
S::from_f64(euler_maclaurin_zeta(sf, af))
}
fn zeta_f64(s: f64) -> f64 {
if s == 1.0 {
return f64::INFINITY;
}
if s < 0.0 && s == s.floor() && (s as i64) % 2 == 0 {
return 0.0;
}
if s < 0.5 {
let pi = core::f64::consts::PI;
let one_minus_s = 1.0 - s;
let zeta_1ms = zeta_f64(one_minus_s);
let prefactor =
2.0_f64.powf(s) * pi.powf(s - 1.0) * (pi * s / 2.0).sin() * libm::tgamma(one_minus_s);
return prefactor * zeta_1ms;
}
euler_maclaurin_zeta(s, 1.0)
}
fn euler_maclaurin_zeta(s: f64, a: f64) -> f64 {
let n = 20;
let mut sum = 0.0;
for k in 0..n {
sum += 1.0 / (k as f64 + a).powf(s);
}
let na = n as f64 + a;
let integral = na.powf(1.0 - s) / (s - 1.0);
let boundary = 0.5 * na.powf(-s);
let bernoulli_2j = [
1.0 / 6.0, -1.0 / 30.0, 1.0 / 42.0, -1.0 / 30.0, 5.0 / 66.0, -691.0 / 2730.0, 7.0 / 6.0, ];
let mut corr = 0.0;
let mut rising = s;
let mut inv_na_pow = na.powf(-s - 1.0); corr += bernoulli_2j[0] / 2.0 * rising * inv_na_pow;
for j in 2..=7_usize {
let jf = j as f64;
rising *= (s + 2.0 * jf - 3.0) * (s + 2.0 * jf - 2.0);
inv_na_pow /= na * na;
corr += bernoulli_2j[j - 1] / factorial_2j(j) * rising * inv_na_pow;
}
sum + integral + boundary + corr
}
fn factorial_2j(j: usize) -> f64 {
match j {
1 => 2.0,
2 => 24.0,
3 => 720.0,
4 => 40320.0,
5 => 3628800.0,
6 => 479001600.0,
7 => 87178291200.0,
_ => {
let n = 2 * j;
let mut f = 1.0;
for i in 2..=n {
f *= i as f64;
}
f
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_zeta_2() {
let expected = core::f64::consts::PI * core::f64::consts::PI / 6.0;
assert_relative_eq!(zeta(2.0_f64), expected, epsilon = 1e-8);
}
#[test]
fn test_zeta_4() {
let pi = core::f64::consts::PI;
let expected = pi.powi(4) / 90.0;
assert_relative_eq!(zeta(4.0_f64), expected, epsilon = 1e-8);
}
#[test]
fn test_zeta_3() {
assert_relative_eq!(zeta(3.0_f64), 1.2020569031595942, epsilon = 1e-8);
}
#[test]
fn test_zeta_negative_even() {
assert!((zeta(-2.0_f64).to_f64()).abs() < 1e-10);
assert!((zeta(-4.0_f64).to_f64()).abs() < 1e-10);
}
#[test]
fn test_zeta_negative_one() {
assert_relative_eq!(zeta(-1.0_f64), -1.0 / 12.0, epsilon = 1e-6);
}
#[test]
fn test_hurwitz_zeta_riemann() {
assert_relative_eq!(
hurwitz_zeta(2.0_f64, 1.0_f64),
zeta(2.0_f64),
epsilon = 1e-4
);
}
#[test]
fn test_zeta_f32() {
let pi = core::f64::consts::PI;
let expected = pi * pi / 6.0;
assert!((zeta(2.0_f32).to_f64() - expected).abs() < 0.001);
}
}