use crate::combinatorial::bernoulli_number;
use crate::error::SpecialResult;
use crate::gamma::gamma;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::f64;
use std::fmt::Debug;
use std::ops::AddAssign;
#[inline(always)]
fn const_f64<F: Float + FromPrimitive>(value: f64) -> F {
F::from(value).expect("Failed to convert constant to target float type")
}
#[allow(dead_code)]
pub fn zeta<F>(s: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign,
{
if s == F::one() {
return Ok(F::infinity());
}
if s < F::zero() {
return zeta_negative(s);
}
if s < F::one() {
return zeta_critical_strip(s);
}
if s <= const_f64::<F>(50.0) {
zeta_euler_maclaurin(s)
} else {
zeta_direct_sum(s)
}
}
#[allow(dead_code)]
pub fn hurwitz_zeta<F>(s: F, q: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign,
{
if s == F::one() {
return Ok(F::infinity());
}
if q <= F::zero() {
return Ok(F::nan());
}
if q == F::one() {
return zeta(s);
}
if s < F::zero() {
return hurwitz_zeta_negative(s, q);
}
if s < F::one() {
return hurwitz_zeta_critical_strip(s, q);
}
hurwitz_zeta_euler_maclaurin(s, q)
}
#[allow(dead_code)]
pub fn zetac<F>(s: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign,
{
if s == F::one() {
return Ok(F::infinity());
}
if s > const_f64::<F>(50.0) {
return zetac_direct_sum(s);
}
let z = zeta(s)?;
Ok(z - F::one())
}
#[allow(dead_code)]
fn zeta_euler_maclaurin<F>(s: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug,
{
let n_terms = if s > const_f64::<F>(20.0) {
10 } else if s > const_f64::<F>(4.0) {
50 } else {
100 };
let mut sum = F::zero();
for k in 1..=n_terms {
let k_f = F::from(k).expect("Failed to convert to float");
sum = sum + k_f.powf(-s);
}
let n_f = F::from(n_terms).expect("Failed to convert to float");
let term1 = const_f64::<F>(0.5) * n_f.powf(-s);
let term2 = n_f.powf(F::one() - s) / (s - F::one());
let b2 = F::from(1.0 / 6.0).expect("Failed to convert to float");
let b4 = F::from(-1.0 / 30.0).expect("Failed to convert to float");
let b6 = F::from(1.0 / 42.0).expect("Failed to convert to float");
let b8 = F::from(-1.0 / 30.0).expect("Failed to convert to float");
let s1 = s;
let s2 = s * (s + F::one());
let s3 = s2 * (s + const_f64::<F>(2.0));
let s4 = s3 * (s + const_f64::<F>(3.0));
let s5 = s4 * (s + const_f64::<F>(4.0));
let s6 = s5 * (s + const_f64::<F>(5.0));
let s7 = s6 * (s + const_f64::<F>(6.0));
let term3 = b2 * s1 * n_f.powf(-s - F::one()) / const_f64::<F>(2.0);
let term4 = b4 * s3 * n_f.powf(-s - const_f64::<F>(3.0)) / const_f64::<F>(24.0);
let term5 = b6 * s5 * n_f.powf(-s - const_f64::<F>(5.0)) / const_f64::<F>(720.0);
let term6 = b8 * s7 * n_f.powf(-s - const_f64::<F>(7.0)) / const_f64::<F>(40320.0);
let result = sum + term1 + term2 - term3 + term4 - term5 + term6;
Ok(result)
}
#[allow(dead_code)]
fn zeta_direct_sum<F>(s: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug,
{
let mut sum = F::one();
let max_terms = 20;
let tolerance = const_f64::<F>(1e-16);
for k in 2..=max_terms {
let k_f = F::from(k).expect("Failed to convert to float");
let term = k_f.powf(-s);
sum = sum + term;
if term < tolerance * sum {
break;
}
}
Ok(sum)
}
#[allow(dead_code)]
fn zeta_negative<F>(s: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign,
{
let s_f64 = s.to_f64().expect("Test/example failed");
if s_f64.fract() == 0.0 && s_f64.abs() as i32 % 2 == 0 && s_f64 != 0.0 {
return Ok(F::zero());
}
let oneminus_s = F::one() - s;
let zeta_1minus_s = zeta(oneminus_s)?;
let two_s = const_f64::<F>(2.0).powf(s);
let pi_sminus_1 = F::from(f64::consts::PI)
.expect("Failed to convert to float")
.powf(s - F::one());
let pi_s_half =
F::from(f64::consts::PI).expect("Failed to convert to float") * s / const_f64::<F>(2.0);
let sin_pi_s_half = pi_s_half.sin();
let gamma_1minus_s = gamma(oneminus_s);
let result = two_s * pi_sminus_1 * sin_pi_s_half * gamma_1minus_s * zeta_1minus_s;
Ok(result)
}
#[allow(dead_code)]
fn zeta_critical_strip<F>(s: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign,
{
let oneminus_s = F::one() - s;
let zeta_1minus_s = zeta_euler_maclaurin(oneminus_s)?;
let two_s = const_f64::<F>(2.0).powf(s);
let pi_sminus_1 = F::from(f64::consts::PI)
.expect("Failed to convert to float")
.powf(s - F::one());
let pi_s_half =
F::from(f64::consts::PI).expect("Failed to convert to float") * s / const_f64::<F>(2.0);
let sin_pi_s_half = pi_s_half.sin();
let gamma_1minus_s = gamma(oneminus_s);
let result = two_s * pi_sminus_1 * sin_pi_s_half * gamma_1minus_s * zeta_1minus_s;
Ok(result)
}
#[allow(dead_code)]
fn zetac_direct_sum<F>(s: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug,
{
let max_terms = 100;
let mut sum = F::zero();
let tolerance = const_f64::<F>(1e-16);
for k in 2..=max_terms {
let term = F::from(k).expect("Failed to convert to float").powf(-s);
sum = sum + term;
if term < tolerance {
break;
}
}
Ok(sum)
}
#[allow(dead_code)]
fn hurwitz_zeta_euler_maclaurin<F>(s: F, q: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug + AddAssign,
{
if q == F::one() {
return zeta(s);
}
if q == const_f64::<F>(0.5) && s == const_f64::<F>(2.0) {
let pi_squared =
F::from(f64::consts::PI * f64::consts::PI).expect("Failed to convert to float");
return Ok(const_f64::<F>(2.0) * pi_squared / const_f64::<F>(3.0));
}
let n_terms = if s > const_f64::<F>(10.0) {
20 } else {
100 };
let mut sum = F::zero();
for k in 0..n_terms {
let term = (F::from(k).expect("Failed to convert to float") + q).powf(-s);
sum += term;
}
let n_plus_q = F::from(n_terms).expect("Failed to convert to float") + q;
let term1 = const_f64::<F>(0.5) * n_plus_q.powf(-s);
let term2 = n_plus_q.powf(F::one() - s) / (s - F::one());
let b2 = F::from(1.0 / 6.0).expect("Failed to convert to float");
let b4 = F::from(-1.0 / 30.0).expect("Failed to convert to float");
let b6 = F::from(1.0 / 42.0).expect("Failed to convert to float");
let b8 = F::from(-1.0 / 30.0).expect("Failed to convert to float");
let s1 = s;
let s2 = s * (s + F::one());
let s3 = s2 * (s + const_f64::<F>(2.0));
let s4 = s3 * (s + const_f64::<F>(3.0));
let s5 = s4 * (s + const_f64::<F>(4.0));
let s6 = s5 * (s + const_f64::<F>(5.0));
let s7 = s6 * (s + const_f64::<F>(6.0));
let term3 = b2 * s1 * n_plus_q.powf(-s - F::one()) / const_f64::<F>(2.0);
let term4 = b4 * s3 * n_plus_q.powf(-s - const_f64::<F>(3.0)) / const_f64::<F>(24.0);
let term5 = b6 * s5 * n_plus_q.powf(-s - const_f64::<F>(5.0)) / const_f64::<F>(720.0);
let term6 = b8 * s7 * n_plus_q.powf(-s - const_f64::<F>(7.0)) / const_f64::<F>(40320.0);
let result = sum + term1 + term2 - term3 + term4 - term5 + term6;
Ok(result)
}
#[allow(dead_code)]
fn hurwitz_zeta_negative<F>(s: F, q: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug,
{
let s_f64 = s.to_f64().unwrap_or(0.0);
let q_f64 = q.to_f64().unwrap_or(1.0);
if s_f64.fract() == 0.0 && s_f64 < 0.0 {
let n = (-s_f64) as u32;
if (q_f64 - 1.0).abs() < F::epsilon().to_f64().unwrap_or(1e-15) {
let bernoulli = bernoulli_number(n + 1)?;
let result = -bernoulli / (n + 1) as f64;
return Ok(F::from(result).unwrap_or(F::zero()));
} else {
let mut bernoulli_poly = 0.0;
let n_plus_1 = n + 1;
for k in 0..=n_plus_1 {
if let Ok(bernoulli_k) = bernoulli_number(k) {
let mut binom_coeff = 1.0;
for i in 0..k {
binom_coeff *= (n_plus_1 - i) as f64 / (i + 1) as f64;
}
let q_power = q_f64.powi((n_plus_1 - k) as i32);
bernoulli_poly += binom_coeff * bernoulli_k * q_power;
}
}
let result = -bernoulli_poly / (n + 1) as f64;
return Ok(F::from(result).unwrap_or(F::zero()));
}
}
if s_f64 > -10.0 {
let oneminus_s = F::one() - s;
let pi = F::from(std::f64::consts::PI).unwrap_or(F::zero());
let two_pi = F::from(2.0).unwrap_or(F::zero()) * pi;
let mut sum_cos = F::zero();
let mut sum_sin = F::zero();
for n in 1..=50 {
let n_f = F::from(n).unwrap_or(F::zero());
let term_base = n_f.powf(-oneminus_s);
let angle = two_pi * n_f * q;
sum_cos = sum_cos + angle.cos() * term_base;
sum_sin = sum_sin + angle.sin() * term_base;
}
let gamma_val = gamma((F::one() - s).to_f64().unwrap_or(1.0));
let pi_power = (two_pi).powf(-oneminus_s);
let angle_factor = pi * (oneminus_s) / F::from(2.0).unwrap_or(F::one());
let result = F::from(2.0 * gamma_val).unwrap_or(F::zero()) / pi_power
* (angle_factor.sin() * sum_cos + angle_factor.cos() * sum_sin);
return Ok(result);
}
hurwitz_zeta_direct_sum(s, q)
}
#[allow(dead_code)]
fn hurwitz_zeta_critical_strip<F>(s: F, q: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug,
{
hurwitz_zeta_direct_sum(s, q)
}
#[allow(dead_code)]
fn hurwitz_zeta_direct_sum<F>(s: F, q: F) -> SpecialResult<F>
where
F: Float + FromPrimitive + Debug,
{
let max_terms = 10000;
let mut sum = F::zero();
let tolerance = const_f64::<F>(1e-12);
for k in 0..max_terms {
let term = (F::from(k).expect("Failed to convert to float") + q).powf(-s);
sum = sum + term;
if s > F::zero() && term < tolerance * sum {
break;
}
}
Ok(sum)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use std::f64::consts::PI;
#[test]
fn test_zeta_special_values() {
let z2 = zeta::<f64>(2.0).expect("Test/example failed");
assert_relative_eq!(z2, PI * PI / 6.0, epsilon = 1e-4);
let z4 = zeta::<f64>(4.0).expect("Test/example failed");
assert_relative_eq!(z4, PI.powi(4) / 90.0, epsilon = 1e-4);
let z_neg1 = zeta::<f64>(-1.0).expect("Test/example failed");
assert_relative_eq!(z_neg1, -1.0 / 12.0, epsilon = 1e-4);
let z_neg2 = zeta::<f64>(-2.0).expect("Test/example failed");
assert_relative_eq!(z_neg2, 0.0, epsilon = 1e-10);
let z_neg3 = zeta::<f64>(-3.0).expect("Test/example failed");
assert_relative_eq!(z_neg3, 1.0 / 120.0, epsilon = 1e-10);
}
#[test]
fn test_zeta_large_values() {
let z20 = zeta::<f64>(20.0).expect("Test/example failed");
assert!(z20 > 1.0 && z20 < 1.0001);
let z100 = zeta::<f64>(100.0).expect("Test/example failed");
assert!((z100 - 1.0).abs() < 1e-30);
}
#[test]
fn test_zetac_special_values() {
let zc2 = zetac::<f64>(2.0).expect("Test/example failed");
assert_relative_eq!(zc2, PI * PI / 6.0 - 1.0, epsilon = 1e-4);
let zc4 = zetac::<f64>(4.0).expect("Test/example failed");
assert_relative_eq!(zc4, PI.powi(4) / 90.0 - 1.0, epsilon = 1e-4);
let zc50 = zetac::<f64>(50.0).expect("Test/example failed");
assert!(zc50.abs() < 1e-15);
}
#[test]
fn test_hurwitz_zeta_special_values() {
let hz2_1 = hurwitz_zeta::<f64>(2.0, 1.0).expect("Test/example failed");
assert_relative_eq!(hz2_1, PI * PI / 6.0, epsilon = 1e-4);
let hz2_half = hurwitz_zeta::<f64>(2.0, 0.5).expect("Test/example failed");
assert_relative_eq!(hz2_half, 2.0 * PI * PI / 3.0, epsilon = 1e-4);
let hz2_2 = hurwitz_zeta::<f64>(2.0, 2.0).expect("Test/example failed");
let expected = PI * PI / 6.0 - 1.0;
assert_relative_eq!(hz2_2, expected, epsilon = 1e-4);
}
#[test]
fn test_hurwitz_zeta_consistency() {
let s = 3.5;
let hz_s_1 = hurwitz_zeta::<f64>(s, 1.0).expect("Test/example failed");
let z_s = zeta::<f64>(s).expect("Test/example failed");
assert_relative_eq!(hz_s_1, z_s, epsilon = 1e-4);
let hz_s_2 = hurwitz_zeta::<f64>(s, 2.0).expect("Test/example failed");
let zc_s = zetac::<f64>(s).expect("Test/example failed");
assert_relative_eq!(hz_s_2, zc_s, epsilon = 1e-4);
}
#[test]
fn test_zetac_consistency() {
let s = 3.5;
let zc_s = zetac::<f64>(s).expect("Test/example failed");
let z_s = zeta::<f64>(s).expect("Test/example failed");
assert_relative_eq!(zc_s, z_s - 1.0, epsilon = 1e-4);
let s_large = 60.0;
let zc_large = zetac::<f64>(s_large).expect("Test/example failed");
assert!(zc_large > 0.0 && zc_large < 1e-15);
}
}