use crate::error::{SpecialError, SpecialResult};
use std::f64::consts::PI;
pub fn exp1(x: f64) -> SpecialResult<f64> {
if x.is_nan() {
return Err(SpecialError::DomainError("NaN input to exp1".to_string()));
}
if x <= 0.0 {
return Err(SpecialError::DomainError("exp1 requires x > 0".to_string()));
}
if !x.is_finite() {
return Ok(0.0);
}
Ok(exp1_impl(x))
}
fn exp1_impl(x: f64) -> f64 {
const EULER_MASCHERONI: f64 = 0.5772156649015329;
if x <= 1.0 {
let mut sum = -EULER_MASCHERONI - x.ln();
let mut term = 1.0;
for k in 1..=100 {
let k_f = k as f64;
term *= -x / k_f;
let contribution = term / k_f;
sum -= contribution;
if contribution.abs() < 1e-16 * sum.abs() {
break;
}
}
sum
} else {
let tiny = 1.0e-30;
let mut f = 1.0 / x;
let mut c = x;
let mut d = 1.0 / x;
for i in 1..=200 {
let (a_i, b_i) = if i % 2 == 1 {
(((i + 1) / 2) as f64, 1.0)
} else {
((i / 2) as f64, x)
};
d = b_i + a_i * d;
if d.abs() < tiny {
d = tiny;
}
c = b_i + a_i / c;
if c.abs() < tiny {
c = tiny;
}
d = 1.0 / d;
let delta = c * d;
f *= delta;
if (delta - 1.0).abs() < 1e-15 {
break;
}
}
(-x).exp() * f
}
}
pub fn expn(n: i32, x: f64) -> SpecialResult<f64> {
if x.is_nan() {
return Err(SpecialError::DomainError("NaN input to expn".to_string()));
}
if n < 0 {
return Err(SpecialError::DomainError(
"expn requires n >= 0".to_string(),
));
}
if x <= 0.0 && n <= 1 {
return Err(SpecialError::DomainError(
"expn requires x > 0 for n <= 1".to_string(),
));
}
if n == 0 {
if x <= 0.0 {
return Err(SpecialError::DomainError(
"expn(0, x) requires x > 0".to_string(),
));
}
return Ok((-x).exp() / x);
}
if n == 1 {
return exp1(x);
}
let mut en = exp1_impl(x);
let exp_neg_x = (-x).exp();
for k in 1..n {
en = (exp_neg_x - x * en) / k as f64;
}
Ok(en)
}
pub fn loggamma_sign(x: f64) -> SpecialResult<(f64, f64)> {
if x.is_nan() {
return Err(SpecialError::DomainError(
"NaN input to loggamma_sign".to_string(),
));
}
if !x.is_finite() {
return Err(SpecialError::DomainError(
"Infinite input to loggamma_sign".to_string(),
));
}
if x <= 0.0 && x == x.floor() {
return Err(SpecialError::DomainError(
"loggamma_sign undefined at non-positive integers".to_string(),
));
}
if x > 0.0 {
let lg = lanczos_lgamma(x);
return Ok((lg, 1.0));
}
let one_minus_x = 1.0 - x;
let lg_one_minus_x = lanczos_lgamma(one_minus_x);
let sin_pi_x = (PI * x).sin();
if sin_pi_x.abs() < 1e-300 {
return Err(SpecialError::DomainError(
"loggamma_sign: near a pole".to_string(),
));
}
let lg = PI.ln() - sin_pi_x.abs().ln() - lg_one_minus_x;
let sign = if sin_pi_x > 0.0 { 1.0 } else { -1.0 };
Ok((lg, sign))
}
fn lanczos_lgamma(x: f64) -> f64 {
const LANCZOS_G: f64 = 7.0;
const LANCZOS_COEFFS: [f64; 9] = [
0.99999999999980993,
676.5203681218851,
-1259.1392167224028,
771.32342877765313,
-176.61502916214059,
12.507343278686905,
-0.13857109526572012,
9.9843695780195716e-6,
1.5056327351493116e-7,
];
if x < 0.5 {
let sin_pi_x = (PI * x).sin();
if sin_pi_x.abs() < 1e-300 {
return f64::INFINITY;
}
return PI.ln() - sin_pi_x.abs().ln() - lanczos_lgamma(1.0 - x);
}
let z = x - 1.0;
let mut ag = LANCZOS_COEFFS[0];
for i in 1..9 {
ag += LANCZOS_COEFFS[i] / (z + i as f64);
}
let tmp = z + LANCZOS_G + 0.5;
0.5 * (2.0 * PI).ln() + (z + 0.5) * tmp.ln() - tmp + ag.ln()
}
pub fn asindg(x: f64) -> SpecialResult<f64> {
if x.is_nan() {
return Err(SpecialError::DomainError("NaN input to asindg".to_string()));
}
if !(-1.0..=1.0).contains(&x) {
return Err(SpecialError::DomainError(
"asindg requires x in [-1, 1]".to_string(),
));
}
Ok(x.asin() * 180.0 / PI)
}
pub fn acosdg(x: f64) -> SpecialResult<f64> {
if x.is_nan() {
return Err(SpecialError::DomainError("NaN input to acosdg".to_string()));
}
if !(-1.0..=1.0).contains(&x) {
return Err(SpecialError::DomainError(
"acosdg requires x in [-1, 1]".to_string(),
));
}
Ok(x.acos() * 180.0 / PI)
}
pub fn atandg(x: f64) -> SpecialResult<f64> {
if x.is_nan() {
return Err(SpecialError::DomainError("NaN input to atandg".to_string()));
}
Ok(x.atan() * 180.0 / PI)
}
pub fn multinomial(n: u32, ks: &[u32]) -> SpecialResult<f64> {
let sum_ks: u32 = ks.iter().sum();
if sum_ks != n {
return Err(SpecialError::ValueError(format!(
"Sum of group sizes ({}) must equal n ({})",
sum_ks, n
)));
}
if n == 0 {
return Ok(1.0);
}
if ks.len() == 1 {
return Ok(1.0);
}
let mut result = 1.0;
let mut remaining = n;
for &k in ks {
if k > remaining {
return Ok(0.0);
}
let binom = crate::binomial(remaining, k)?;
result *= binom;
remaining -= k;
}
Ok(result)
}
pub fn bernoulli_poly(n: u32, x: f64) -> SpecialResult<f64> {
if n > 50 {
return Err(SpecialError::ValueError(
"bernoulli_poly: n too large (max 50)".to_string(),
));
}
let mut result = 0.0;
for k in 0..=n {
let bk = crate::bernoulli_number(k)?;
let binom = crate::binomial(n, k)?;
let x_power = x.powi((n - k) as i32);
result += binom * bk * x_power;
}
Ok(result)
}
pub fn euler_poly(n: u32, x: f64) -> SpecialResult<f64> {
if n > 50 {
return Err(SpecialError::ValueError(
"euler_poly: n too large (max 50)".to_string(),
));
}
let n_plus_1 = n + 1;
let bn1_x = bernoulli_poly(n_plus_1, x)?;
let bn1_x_half = bernoulli_poly(n_plus_1, x / 2.0)?;
let two_pow = 2.0_f64.powi(n_plus_1 as i32);
Ok(2.0 / (n_plus_1 as f64) * (bn1_x - two_pow * bn1_x_half))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_exp1_basic() {
let val = exp1(1.0).expect("exp1 failed");
assert_relative_eq!(val, 0.21938393439552, epsilon = 1e-6);
}
#[test]
fn test_expn_basic() {
let val = expn(1, 1.0).expect("expn failed");
let val2 = exp1(1.0).expect("exp1 failed");
assert_relative_eq!(val, val2, epsilon = 1e-12);
}
#[test]
fn test_loggamma_sign_positive() {
let (lg, sign) = loggamma_sign(5.0).expect("failed");
assert_relative_eq!(lg, 24.0_f64.ln(), epsilon = 1e-10);
assert_eq!(sign, 1.0);
}
#[test]
fn test_loggamma_sign_half() {
let (lg, sign) = loggamma_sign(0.5).expect("failed");
assert_relative_eq!(lg, (PI.sqrt()).ln(), epsilon = 1e-10);
assert_eq!(sign, 1.0);
}
#[test]
fn test_loggamma_sign_negative() {
let (lg, sign) = loggamma_sign(-0.5).expect("failed");
assert_eq!(sign, -1.0);
let expected_lg = (2.0 * PI.sqrt()).ln();
assert_relative_eq!(lg, expected_lg, epsilon = 1e-8);
}
#[test]
fn test_loggamma_sign_pole() {
assert!(loggamma_sign(0.0).is_err());
assert!(loggamma_sign(-1.0).is_err());
assert!(loggamma_sign(-2.0).is_err());
}
#[test]
fn test_asindg() {
assert_relative_eq!(asindg(0.0).expect("failed"), 0.0, epsilon = 1e-10);
assert_relative_eq!(asindg(0.5).expect("failed"), 30.0, epsilon = 1e-10);
assert_relative_eq!(asindg(1.0).expect("failed"), 90.0, epsilon = 1e-10);
assert_relative_eq!(asindg(-1.0).expect("failed"), -90.0, epsilon = 1e-10);
}
#[test]
fn test_asindg_domain_error() {
assert!(asindg(1.5).is_err());
assert!(asindg(-1.5).is_err());
}
#[test]
fn test_acosdg() {
assert_relative_eq!(acosdg(1.0).expect("failed"), 0.0, epsilon = 1e-10);
assert_relative_eq!(acosdg(0.5).expect("failed"), 60.0, epsilon = 1e-10);
assert_relative_eq!(acosdg(0.0).expect("failed"), 90.0, epsilon = 1e-10);
assert_relative_eq!(acosdg(-1.0).expect("failed"), 180.0, epsilon = 1e-10);
}
#[test]
fn test_acosdg_domain_error() {
assert!(acosdg(1.5).is_err());
}
#[test]
fn test_atandg() {
assert_relative_eq!(atandg(0.0).expect("failed"), 0.0, epsilon = 1e-10);
assert_relative_eq!(atandg(1.0).expect("failed"), 45.0, epsilon = 1e-10);
assert_relative_eq!(atandg(-1.0).expect("failed"), -45.0, epsilon = 1e-10);
}
#[test]
fn test_atandg_nan() {
assert!(atandg(f64::NAN).is_err());
}
#[test]
fn test_multinomial_basic() {
let val = multinomial(4, &[2, 2]).expect("failed");
assert_relative_eq!(val, 6.0, epsilon = 1e-10);
}
#[test]
fn test_multinomial_triple() {
let val = multinomial(6, &[2, 2, 2]).expect("failed");
assert_relative_eq!(val, 90.0, epsilon = 1e-10);
}
#[test]
fn test_multinomial_trivial() {
let val = multinomial(5, &[5]).expect("failed");
assert_relative_eq!(val, 1.0, epsilon = 1e-10);
}
#[test]
fn test_multinomial_zero() {
let val = multinomial(0, &[]).expect("failed");
assert_relative_eq!(val, 1.0, epsilon = 1e-10);
}
#[test]
fn test_multinomial_invalid_sum() {
assert!(multinomial(5, &[2, 2]).is_err());
}
#[test]
fn test_multinomial_with_ones() {
let val = multinomial(4, &[1, 1, 1, 1]).expect("failed");
assert_relative_eq!(val, 24.0, epsilon = 1e-10);
}
#[test]
fn test_bernoulli_poly_b0() {
assert_relative_eq!(
bernoulli_poly(0, 0.0).expect("failed"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
bernoulli_poly(0, 0.5).expect("failed"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
bernoulli_poly(0, 1.0).expect("failed"),
1.0,
epsilon = 1e-10
);
}
#[test]
fn test_bernoulli_poly_b1() {
assert_relative_eq!(
bernoulli_poly(1, 0.0).expect("failed"),
-0.5,
epsilon = 1e-10
);
assert_relative_eq!(
bernoulli_poly(1, 0.5).expect("failed"),
0.0,
epsilon = 1e-10
);
assert_relative_eq!(
bernoulli_poly(1, 1.0).expect("failed"),
0.5,
epsilon = 1e-10
);
}
#[test]
fn test_bernoulli_poly_b2() {
let b2_0 = bernoulli_poly(2, 0.0).expect("failed");
assert_relative_eq!(b2_0, 1.0 / 6.0, epsilon = 1e-10);
let b2_1 = bernoulli_poly(2, 1.0).expect("failed");
assert_relative_eq!(b2_1, 1.0 / 6.0, epsilon = 1e-10);
let b2_half = bernoulli_poly(2, 0.5).expect("failed");
assert_relative_eq!(b2_half, -1.0 / 12.0, epsilon = 1e-10);
}
#[test]
fn test_bernoulli_poly_at_zero() {
for n in 0..=8 {
let bp = bernoulli_poly(n, 0.0).expect("failed");
let bn = crate::bernoulli_number(n).expect("failed");
assert_relative_eq!(bp, bn, epsilon = 1e-8,);
}
}
#[test]
fn test_euler_poly_e0() {
assert_relative_eq!(euler_poly(0, 0.5).expect("failed"), 1.0, epsilon = 1e-10);
}
#[test]
fn test_euler_poly_e1() {
assert_relative_eq!(euler_poly(1, 0.0).expect("failed"), -0.5, epsilon = 1e-10);
assert_relative_eq!(euler_poly(1, 0.5).expect("failed"), 0.0, epsilon = 1e-10);
assert_relative_eq!(euler_poly(1, 1.0).expect("failed"), 0.5, epsilon = 1e-10);
}
#[test]
fn test_euler_poly_e2() {
let e2_0 = euler_poly(2, 0.0).expect("failed");
assert_relative_eq!(e2_0, 0.0, epsilon = 1e-10);
let e2_1 = euler_poly(2, 1.0).expect("failed");
assert_relative_eq!(e2_1, 0.0, epsilon = 1e-10);
let e2_half = euler_poly(2, 0.5).expect("failed");
assert_relative_eq!(e2_half, -0.25, epsilon = 1e-10);
}
}