#![allow(dead_code)]
use crate::error::{SpecialError, SpecialResult};
use rug::{float::Constant, ops::Pow, Complex, Float};
pub const DEFAULT_PRECISION: u32 = 256;
pub const MAX_PRECISION: u32 = 4096;
#[derive(Debug, Clone)]
pub struct PrecisionContext {
precision: u32,
rounding: rug::float::Round,
}
impl Default for PrecisionContext {
fn default() -> Self {
Self {
precision: DEFAULT_PRECISION,
rounding: rug::float::Round::Nearest,
}
}
}
impl PrecisionContext {
pub fn new(precision: u32) -> SpecialResult<Self> {
if precision == 0 || precision > MAX_PRECISION {
return Err(SpecialError::DomainError(format!(
"Precision must be between 1 and {} bits",
MAX_PRECISION
)));
}
Ok(Self {
precision,
rounding: rug::float::Round::Nearest,
})
}
pub fn with_rounding(mut self, rounding: rug::float::Round) -> Self {
self.rounding = rounding;
self
}
pub fn precision(&self) -> u32 {
self.precision
}
pub fn rounding(&self) -> rug::float::Round {
self.rounding
}
pub fn float(&self, value: f64) -> Float {
Float::with_val(self.precision, value)
}
pub fn complex(&self, real: f64, imag: f64) -> Complex {
Complex::with_val(self.precision, (real, imag))
}
pub fn pi(&self) -> Float {
Float::with_val(self.precision, Constant::Pi)
}
pub fn e(&self) -> Float {
Float::with_val(self.precision, 1).exp()
}
pub fn ln2(&self) -> Float {
Float::with_val(self.precision, Constant::Log2)
}
pub fn euler_gamma(&self) -> Float {
Float::with_val(self.precision, Constant::Euler)
}
pub fn catalan(&self) -> Float {
Float::with_val(self.precision, Constant::Catalan)
}
}
pub mod gamma {
use super::*;
pub fn gamma_ap(x: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let x_mp = ctx.float(x);
gamma_mp(&x_mp, ctx)
}
pub fn gamma_mp(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
if x.is_zero() || (x.is_finite() && *x < 0.0 && x.is_integer()) {
return Err(SpecialError::DomainError(
"Gamma function undefined at non-positive integers".to_string(),
));
}
if *x > 20.0 {
stirling_gamma(x, ctx)
} else if *x > 0.0 {
lanczos_gamma(x, ctx)
} else {
reflection_gamma(x, ctx)
}
}
fn stirling_gamma(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
let two_pi = ctx.pi() * Float::with_val(ctx.precision, 2.0);
let sqrt_2pi = two_pi.sqrt();
let e = ctx.e();
let term1 = sqrt_2pi / x.clone().sqrt();
let term2 = (x.clone() / e).pow(x);
let mut correction = ctx.float(1.0);
let x2 = Float::with_val(ctx.precision, x.clone() * x);
let x3 = Float::with_val(ctx.precision, &x2 * x);
let x4 = Float::with_val(ctx.precision, &x2 * &x2);
correction += ctx.float(1.0) / (ctx.float(12.0) * x);
correction += ctx.float(1.0) / (ctx.float(288.0) * &x2);
let denom1 = Float::with_val(ctx.precision, ctx.float(51840.0) * &x3);
correction -= ctx.float(139.0) / denom1;
let denom2 = Float::with_val(ctx.precision, ctx.float(2488320.0) * &x4);
correction -= ctx.float(571.0) / denom2;
Ok(term1 * term2 * correction)
}
fn lanczos_gamma(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
const LANCZOS_G: f64 = 7.0;
const LANCZOS_COEFFS: &[f64] = &[
0.99999999999980993,
676.5203681218851,
-1259.1392167224028,
771.32342877765313,
-176.61502916214059,
12.507343278686905,
-0.13857109526572012,
9.9843695780195716e-6,
1.5056327351493116e-7,
];
let g = ctx.float(LANCZOS_G);
let sqrt_2pi = (ctx.pi() * ctx.float(2.0)).sqrt();
let mut ag = ctx.float(LANCZOS_COEFFS[0]);
for i in 1..LANCZOS_COEFFS.len() {
ag += ctx.float(LANCZOS_COEFFS[i]) / (x.clone() + i as f64);
}
let tmp = x.clone() + &g + ctx.float(0.5);
let result = sqrt_2pi * ag * tmp.clone().pow(x.clone() + 0.5) * (-tmp).exp();
Ok(result / x)
}
fn reflection_gamma(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
let pi = ctx.pi();
let sin_pi_x = (pi.clone() * x).sin();
if sin_pi_x.is_zero() {
return Err(SpecialError::DomainError(
"Gamma function has poles at negative integers".to_string(),
));
}
let pos_gamma = gamma_mp(&(ctx.float(1.0) - x), ctx)?;
Ok(pi / (sin_pi_x * pos_gamma))
}
pub fn log_gamma_ap(x: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let x_mp = ctx.float(x);
log_gamma_mp(&x_mp, ctx)
}
pub fn log_gamma_mp(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
if x.is_zero() || (x.is_finite() && *x < 0.0) {
return Err(SpecialError::DomainError(
"log_gamma undefined for non-positive values".to_string(),
));
}
if *x > 10.0 {
stirling_log_gamma(x, ctx)
} else {
let gamma_x = gamma_mp(x, ctx)?;
Ok(gamma_x.ln())
}
}
fn stirling_log_gamma(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
let two_pi = ctx.pi() * Float::with_val(ctx.precision, 2.0);
let ln_2pi = two_pi.ln();
let mut result = (x.clone() - 0.5) * x.clone().ln() - x.clone() + ln_2pi / 2.0;
let x2 = Float::with_val(ctx.precision, x.clone() * x);
let x3 = Float::with_val(ctx.precision, &x2 * x);
let x5 = Float::with_val(ctx.precision, &x3 * &x2);
let x7 = Float::with_val(ctx.precision, &x5 * &x2);
result += ctx.float(1.0) / (ctx.float(12.0) * x);
let denom3 = Float::with_val(ctx.precision, ctx.float(360.0) * &x3);
result -= ctx.float(1.0) / denom3;
let denom4 = ctx.float(1260.0) * &x5;
result += ctx.float(1.0) / denom4;
let denom5 = ctx.float(1680.0) * &x7;
result -= ctx.float(1.0) / denom5;
Ok(result)
}
}
pub mod bessel {
use super::*;
pub fn bessel_j_ap(n: i32, x: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let x_mp = ctx.float(x);
bessel_j_mp(n, &x_mp, ctx)
}
pub fn bessel_j_mp(n: i32, x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
if x.is_zero() {
return Ok(if n == 0 {
ctx.float(1.0)
} else {
ctx.float(0.0)
});
}
if x.clone().abs() < 10.0 {
bessel_j_series(n, x, ctx)
} else {
bessel_j_asymptotic(n, x, ctx)
}
}
fn bessel_j_series(n: i32, x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
let mut sum = ctx.float(0.0);
let x_half = Float::with_val(ctx.precision(), x.clone() / 2.0);
let x2_quarter = Float::with_val(ctx.precision(), &x_half * &x_half);
let mut term = x_half.pow(n) / factorial_mp(n.abs() as u32, ctx);
let sign = if n < 0 && n % 2 != 0 { -1.0 } else { 1.0 };
term *= sign;
sum += &term;
for k in 1..200 {
let divisor = Float::with_val(ctx.precision(), k as f64 * (k as f64 + n.abs() as f64));
let neg_x2_quarter = Float::with_val(ctx.precision(), -&x2_quarter);
term *= neg_x2_quarter / divisor;
sum += &term;
if term.clone().abs()
< sum.clone().abs()
* Float::with_val(ctx.precision(), 10.0).pow(-(ctx.precision() as i32) / 10)
{
break;
}
}
Ok(sum)
}
fn bessel_j_asymptotic(n: i32, x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
let pi = ctx.pi();
let pi_x = Float::with_val(ctx.precision(), &pi * x);
let sqrt_2_pi_x = (ctx.float(2.0) / pi_x).sqrt();
let phase_coefficient = Float::with_val(ctx.precision(), n as f64 + 0.5);
let phase_pi_mult = Float::with_val(ctx.precision(), &phase_coefficient * &pi);
let phase_offset = Float::with_val(ctx.precision(), phase_pi_mult / 2.0);
let phase = x.clone() - phase_offset;
let cos_phase = phase.cos();
let mut correction = ctx.float(1.0);
let n2 = (n * n) as f64;
let x2 = x.clone() * x;
let x_mult = Float::with_val(ctx.precision(), 8.0 * x);
correction -= (4.0 * n2 - 1.0) / x_mult;
let x2_mult = Float::with_val(ctx.precision(), 128.0 * &x2);
correction += (4.0 * n2 - 1.0) * (4.0 * n2 - 9.0) / x2_mult;
Ok(sqrt_2_pi_x * cos_phase * correction)
}
pub fn bessel_y_ap(n: i32, x: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let x_mp = ctx.float(x);
bessel_y_mp(n, &x_mp, ctx)
}
pub fn bessel_y_mp(n: i32, x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
if *x <= 0.0 {
return Err(SpecialError::DomainError(
"Bessel Y function undefined for non-positive arguments".to_string(),
));
}
if *x > 10.0 {
bessel_y_asymptotic(n, x, ctx)
} else {
bessel_y_relation(n, x, ctx)
}
}
fn bessel_y_relation(n: i32, x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
let pi = ctx.pi();
if n >= 0 {
let jn = bessel_j_mp(n, x, ctx)?;
let jn_neg = bessel_j_mp(-n, x, ctx)?;
let cos_n_pi = if n % 2 == 0 { 1.0 } else { -1.0 };
let n_pi = Float::with_val(ctx.precision(), n as f64) * π
Ok((jn * cos_n_pi - jn_neg) / n_pi.sin())
} else {
let yn_pos = bessel_y_mp(-n, x, ctx)?;
Ok(if n % 2 == 0 { yn_pos } else { -yn_pos })
}
}
fn bessel_y_asymptotic(n: i32, x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
let pi = ctx.pi();
let pi_x = Float::with_val(ctx.precision(), &pi * x);
let sqrt_2_pi_x = (ctx.float(2.0) / pi_x).sqrt();
let phase_coefficient = Float::with_val(ctx.precision(), n as f64 + 0.5);
let phase_pi_mult = Float::with_val(ctx.precision(), &phase_coefficient * &pi);
let phase_offset = Float::with_val(ctx.precision(), phase_pi_mult / 2.0);
let phase = x.clone() - phase_offset;
let sin_phase = phase.sin();
let mut correction = ctx.float(1.0);
let n2 = (n * n) as f64;
let x2 = x.clone() * x;
let x_mult = Float::with_val(ctx.precision(), 8.0 * x);
correction -= (4.0 * n2 - 1.0) / x_mult;
let x2_mult = Float::with_val(ctx.precision(), 128.0 * &x2);
correction += (4.0 * n2 - 1.0) * (4.0 * n2 - 9.0) / x2_mult;
Ok(sqrt_2_pi_x * sin_phase * correction)
}
}
pub mod error_function {
use super::*;
pub fn erf_ap(x: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let x_mp = ctx.float(x);
erf_mp(&x_mp, ctx)
}
pub fn erf_mp(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
if x.is_zero() {
return Ok(ctx.float(0.0));
}
let abs_x = x.clone().abs();
if abs_x < 2.0 {
erf_series(x, ctx)
} else {
let erfc_val = erfc_asymptotic(&abs_x, ctx)?;
Ok(if *x > 0.0 {
ctx.float(1.0) - erfc_val
} else {
erfc_val - ctx.float(1.0)
})
}
}
fn erf_series(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
let sqrt_pi = ctx.pi().sqrt();
let x2 = x.clone() * x;
let mut sum = x.clone();
let mut term = x.clone();
for n in 1..200 {
let neg_x2 = Float::with_val(ctx.precision(), -&x2);
term *= neg_x2 / (n as f64);
let new_term = Float::with_val(ctx.precision(), &term / (2 * n + 1) as f64);
sum += &new_term;
if new_term.abs()
< sum.clone().abs()
* Float::with_val(ctx.precision(), 10.0).pow(-(ctx.precision() as i32) / 10)
{
break;
}
}
Ok(2.0 * sum / sqrt_pi)
}
fn erfc_asymptotic(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
let sqrt_pi = ctx.pi().sqrt();
let x2 = x.clone() * x;
let neg_x2 = Float::with_val(ctx.precision(), -&x2);
let exp_neg_x2 = neg_x2.exp();
let mut sum = ctx.float(1.0);
let mut term = ctx.float(1.0);
for n in 1..50 {
let x2_mult = Float::with_val(ctx.precision(), 2.0 * &x2);
term *= -(2 * n - 1) as f64 / x2_mult;
sum += &term;
if term.clone().abs()
< sum.clone().abs()
* Float::with_val(ctx.precision(), 10.0).pow(-(ctx.precision() as i32) / 10)
{
break;
}
}
Ok(exp_neg_x2 * sum / (x.clone() * sqrt_pi))
}
pub fn erfc_ap(x: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let x_mp = ctx.float(x);
erfc_mp(&x_mp, ctx)
}
pub fn erfc_mp(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
if x.is_zero() {
return Ok(ctx.float(1.0));
}
let abs_x = x.clone().abs();
if abs_x < 2.0 {
let erf_val = erf_mp(x, ctx)?;
Ok(ctx.float(1.0) - erf_val)
} else {
if *x > 0.0 {
erfc_asymptotic(x, ctx)
} else {
let erfc_pos = erfc_asymptotic(&abs_x, ctx)?;
Ok(ctx.float(2.0) - erfc_pos)
}
}
}
}
mod utils {
use super::*;
pub fn factorial_mp(n: u32, ctx: &PrecisionContext) -> Float {
if n == 0 || n == 1 {
return ctx.float(1.0);
}
let mut result = ctx.float(1.0);
for i in 2..=n {
result *= i as f64;
}
result
}
pub fn binomial_mp(n: u32, k: u32, ctx: &PrecisionContext) -> Float {
if k > n {
return ctx.float(0.0);
}
if k == 0 || k == n {
return ctx.float(1.0);
}
let k = k.min(n - k);
let mut result = ctx.float(1.0);
for i in 0..k {
result *= (n - i) as f64;
result /= (i + 1) as f64;
}
result
}
pub fn pochhammer_mp(x: &Float, n: u32, ctx: &PrecisionContext) -> SpecialResult<Float> {
if n == 0 {
return Ok(ctx.float(1.0));
}
let mut result = x.clone();
for i in 1..n {
result *= x.clone() + i as f64;
}
Ok(result)
}
}
pub use utils::*;
pub mod hypergeometric {
use super::*;
#[allow(dead_code)]
pub fn hyp2f1_ap(
a: f64,
b: f64,
c: f64,
z: f64,
ctx: &PrecisionContext,
) -> SpecialResult<Float> {
let a_mp = ctx.float(a);
let b_mp = ctx.float(b);
let c_mp = ctx.float(c);
let z_mp = ctx.float(z);
hyp2f1_mp(&a_mp, &b_mp, &c_mp, &z_mp, ctx)
}
#[allow(dead_code)]
pub fn hyp2f1_mp(
a: &Float,
b: &Float,
c: &Float,
z: &Float,
ctx: &PrecisionContext,
) -> SpecialResult<Float> {
if c.is_zero() || (c.is_finite() && *c < 0.0 && c.is_integer()) {
return Err(SpecialError::DomainError(
"c must not be 0 or a negative integer".to_string(),
));
}
if z.is_zero() {
return Ok(ctx.float(1.0));
}
let mut sum = ctx.float(1.0);
let mut term = ctx.float(1.0);
let tol = Float::with_val(ctx.precision(), 10.0).pow(-(ctx.precision() as i32) / 3);
for n in 1..500 {
let n_f = ctx.float(n as f64);
let n_minus_1 = ctx.float((n - 1) as f64);
let numerator = Float::with_val(ctx.precision(), a.clone() + &n_minus_1)
* Float::with_val(ctx.precision(), b.clone() + &n_minus_1);
let denominator = Float::with_val(ctx.precision(), c.clone() + &n_minus_1)
* Float::with_val(ctx.precision(), &n_f);
term *= Float::with_val(ctx.precision(), numerator / denominator);
term *= z;
sum += &term;
if term.clone().abs() < sum.clone().abs() * &tol {
break;
}
}
Ok(sum)
}
#[allow(dead_code)]
pub fn hyp1f1_ap(a: f64, b: f64, z: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let a_mp = ctx.float(a);
let b_mp = ctx.float(b);
let z_mp = ctx.float(z);
hyp1f1_mp(&a_mp, &b_mp, &z_mp, ctx)
}
#[allow(dead_code)]
pub fn hyp1f1_mp(
a: &Float,
b: &Float,
z: &Float,
ctx: &PrecisionContext,
) -> SpecialResult<Float> {
if b.is_zero() || (b.is_finite() && *b < 0.0 && b.is_integer()) {
return Err(SpecialError::DomainError(
"b must not be 0 or a negative integer".to_string(),
));
}
if z.is_zero() {
return Ok(ctx.float(1.0));
}
if *z < -20.0 {
let exp_z = z.clone().exp();
let b_minus_a = Float::with_val(ctx.precision(), b.clone() - a);
let neg_z = Float::with_val(ctx.precision(), -z.clone());
let transformed = hyp1f1_mp(&b_minus_a, b, &neg_z, ctx)?;
return Ok(exp_z * transformed);
}
let mut sum = ctx.float(1.0);
let mut term = ctx.float(1.0);
let tol = Float::with_val(ctx.precision(), 10.0).pow(-(ctx.precision() as i32) / 3);
for n in 1..500 {
let n_f = ctx.float(n as f64);
let n_minus_1 = ctx.float((n - 1) as f64);
let a_plus_n = Float::with_val(ctx.precision(), a.clone() + &n_minus_1);
let b_plus_n = Float::with_val(ctx.precision(), b.clone() + &n_minus_1);
let factor = Float::with_val(ctx.precision(), a_plus_n / (b_plus_n * &n_f));
term *= Float::with_val(ctx.precision(), factor * z);
sum += &term;
if term.clone().abs() < sum.clone().abs() * &tol {
break;
}
}
Ok(sum)
}
}
pub mod incomplete_gamma {
use super::*;
#[allow(dead_code)]
pub fn gammainc_lower_ap(a: f64, x: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let a_mp = ctx.float(a);
let x_mp = ctx.float(x);
gammainc_lower_mp(&a_mp, &x_mp, ctx)
}
#[allow(dead_code)]
pub fn gammainc_lower_mp(a: &Float, x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
if *x < 0.0 {
return Err(SpecialError::DomainError(
"x must be non-negative for lower incomplete gamma".to_string(),
));
}
if x.is_zero() {
return Ok(ctx.float(0.0));
}
let x_pow_a = x.clone().pow(a);
let neg_x = Float::with_val(ctx.precision(), -x.clone());
let exp_neg_x = neg_x.exp();
let mut sum = ctx.float(0.0);
let mut term = ctx.float(1.0) / a;
let tol = Float::with_val(ctx.precision(), 10.0).pow(-(ctx.precision() as i32) / 3);
sum += &term;
for n in 1..500 {
let n_f = ctx.float(n as f64);
let a_plus_n = Float::with_val(ctx.precision(), a.clone() + &n_f);
term *= Float::with_val(ctx.precision(), x.clone() / a_plus_n);
sum += &term;
if term.clone().abs() < sum.clone().abs() * &tol {
break;
}
}
Ok(x_pow_a * exp_neg_x * sum)
}
#[allow(dead_code)]
pub fn gammainc_upper_ap(a: f64, x: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let a_mp = ctx.float(a);
let x_mp = ctx.float(x);
gammainc_upper_mp(&a_mp, &x_mp, ctx)
}
#[allow(dead_code)]
pub fn gammainc_upper_mp(a: &Float, x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
if *x < 0.0 {
return Err(SpecialError::DomainError(
"x must be non-negative for upper incomplete gamma".to_string(),
));
}
let gamma_a = super::gamma::gamma_mp(a, ctx)?;
let lower = gammainc_lower_mp(a, x, ctx)?;
Ok(gamma_a - lower)
}
#[allow(dead_code)]
pub fn gammainc_regularized_ap(a: f64, x: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let a_mp = ctx.float(a);
let x_mp = ctx.float(x);
gammainc_regularized_mp(&a_mp, &x_mp, ctx)
}
#[allow(dead_code)]
pub fn gammainc_regularized_mp(
a: &Float,
x: &Float,
ctx: &PrecisionContext,
) -> SpecialResult<Float> {
let lower = gammainc_lower_mp(a, x, ctx)?;
let gamma_a = super::gamma::gamma_mp(a, ctx)?;
Ok(lower / gamma_a)
}
}
pub mod digamma {
use super::*;
#[allow(dead_code)]
pub fn digamma_ap(x: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let x_mp = ctx.float(x);
digamma_mp(&x_mp, ctx)
}
#[allow(dead_code)]
pub fn digamma_mp(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
if x.is_zero() || (x.is_finite() && *x < 0.0 && x.is_integer()) {
return Err(SpecialError::DomainError(
"Digamma undefined at non-positive integers".to_string(),
));
}
if *x < 1.0 {
let pi = ctx.pi();
let pi_x = Float::with_val(ctx.precision(), &pi * x);
let cot_pi_x = pi_x.clone().cos() / pi_x.sin();
let one_minus_x = ctx.float(1.0) - x;
let psi_oneminus_x = digamma_mp(&one_minus_x, ctx)?;
return Ok(psi_oneminus_x - pi * cot_pi_x);
}
let mut result = ctx.float(0.0);
let mut curr_x = x.clone();
while curr_x < 8.0 {
result -= Float::with_val(ctx.precision(), 1.0) / &curr_x;
curr_x += 1.0;
}
let ln_x = curr_x.clone().ln();
let x2 = Float::with_val(ctx.precision(), &curr_x * &curr_x);
let x4 = Float::with_val(ctx.precision(), &x2 * &x2);
let x6 = Float::with_val(ctx.precision(), &x4 * &x2);
let asymp = ln_x
- Float::with_val(ctx.precision(), 1.0)
/ (Float::with_val(ctx.precision(), 2.0) * &curr_x)
- Float::with_val(ctx.precision(), 1.0)
/ (Float::with_val(ctx.precision(), 12.0) * &x2)
+ Float::with_val(ctx.precision(), 1.0)
/ (Float::with_val(ctx.precision(), 120.0) * &x4)
- Float::with_val(ctx.precision(), 1.0)
/ (Float::with_val(ctx.precision(), 252.0) * &x6);
Ok(result + asymp)
}
#[allow(dead_code)]
pub fn trigamma_ap(x: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let x_mp = ctx.float(x);
trigamma_mp(&x_mp, ctx)
}
#[allow(dead_code)]
pub fn trigamma_mp(x: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
if x.is_zero() || (x.is_finite() && *x < 0.0 && x.is_integer()) {
return Err(SpecialError::DomainError(
"Trigamma undefined at non-positive integers".to_string(),
));
}
if *x < 1.0 {
let pi = ctx.pi();
let pi_x = Float::with_val(ctx.precision(), &pi * x);
let sin_pi_x = pi_x.sin();
let csc_sq = Float::with_val(ctx.precision(), 1.0)
/ Float::with_val(ctx.precision(), &sin_pi_x * &sin_pi_x);
let one_minus_x = ctx.float(1.0) - x;
let psi1_oneminus_x = trigamma_mp(&one_minus_x, ctx)?;
return Ok(Float::with_val(ctx.precision(), &pi * &pi) * csc_sq - psi1_oneminus_x);
}
let mut result = ctx.float(0.0);
let mut curr_x = x.clone();
while curr_x < 8.0 {
let one_over_x2 = Float::with_val(ctx.precision(), 1.0)
/ Float::with_val(ctx.precision(), &curr_x * &curr_x);
result += one_over_x2;
curr_x += 1.0;
}
let x2 = Float::with_val(ctx.precision(), &curr_x * &curr_x);
let x3 = Float::with_val(ctx.precision(), &x2 * &curr_x);
let x4 = Float::with_val(ctx.precision(), &x2 * &x2);
let x5 = Float::with_val(ctx.precision(), &x4 * &curr_x);
let asymp = Float::with_val(ctx.precision(), 1.0) / &curr_x
+ Float::with_val(ctx.precision(), 1.0) / (Float::with_val(ctx.precision(), 2.0) * &x2)
+ Float::with_val(ctx.precision(), 1.0) / (Float::with_val(ctx.precision(), 6.0) * &x3)
- Float::with_val(ctx.precision(), 1.0)
/ (Float::with_val(ctx.precision(), 30.0) * &x5);
Ok(result + asymp)
}
}
pub mod beta {
use super::*;
#[allow(dead_code)]
pub fn beta_ap(a: f64, b: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let a_mp = ctx.float(a);
let b_mp = ctx.float(b);
beta_mp(&a_mp, &b_mp, ctx)
}
#[allow(dead_code)]
pub fn beta_mp(a: &Float, b: &Float, ctx: &PrecisionContext) -> SpecialResult<Float> {
let gamma_a = super::gamma::gamma_mp(a, ctx)?;
let gamma_b = super::gamma::gamma_mp(b, ctx)?;
let a_plus_b = Float::with_val(ctx.precision(), a.clone() + b);
let gamma_aplusb = super::gamma::gamma_mp(&a_plus_b, ctx)?;
Ok(gamma_a * gamma_b / gamma_aplusb)
}
#[allow(dead_code)]
pub fn betainc_ap(x: f64, a: f64, b: f64, ctx: &PrecisionContext) -> SpecialResult<Float> {
let x_mp = ctx.float(x);
let a_mp = ctx.float(a);
let b_mp = ctx.float(b);
betainc_mp(&x_mp, &a_mp, &b_mp, ctx)
}
#[allow(dead_code)]
pub fn betainc_mp(
x: &Float,
a: &Float,
b: &Float,
ctx: &PrecisionContext,
) -> SpecialResult<Float> {
if *x < 0.0 || *x > 1.0 {
return Err(SpecialError::DomainError(
"x must be in [0, 1] for incomplete beta".to_string(),
));
}
if x.is_zero() {
return Ok(ctx.float(0.0));
}
if *x == 1.0 {
return beta_mp(a, b, ctx);
}
let x_pow_a = x.clone().pow(a);
let hyp =
super::hypergeometric::hyp2f1_mp(a, &(ctx.float(1.0) - b), &(a.clone() + 1.0), x, ctx)?;
Ok(x_pow_a * hyp / a)
}
#[allow(dead_code)]
pub fn betainc_regularized_ap(
x: f64,
a: f64,
b: f64,
ctx: &PrecisionContext,
) -> SpecialResult<Float> {
let x_mp = ctx.float(x);
let a_mp = ctx.float(a);
let b_mp = ctx.float(b);
betainc_regularized_mp(&x_mp, &a_mp, &b_mp, ctx)
}
#[allow(dead_code)]
pub fn betainc_regularized_mp(
x: &Float,
a: &Float,
b: &Float,
ctx: &PrecisionContext,
) -> SpecialResult<Float> {
let inc = betainc_mp(x, a, b, ctx)?;
let full = beta_mp(a, b, ctx)?;
Ok(inc / full)
}
}
#[allow(dead_code)]
pub fn to_f64(x: &Float) -> f64 {
x.to_f64()
}
#[allow(dead_code)]
pub fn to_complex64(z: &Complex) -> scirs2_core::numeric::Complex64 {
let (re, im) = z.clone().into_real_imag();
scirs2_core::numeric::Complex64::new(re.to_f64(), im.to_f64())
}
#[allow(dead_code)]
pub fn cleanup_cache() {
rug::float::free_cache(rug::float::FreeCache::All);
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_precision_context() {
let ctx = PrecisionContext::new(512).expect("Operation failed");
assert_eq!(ctx.precision(), 512);
let pi = ctx.pi();
assert!(pi.prec() >= 512);
let pi_str = pi.to_string();
assert!(pi_str.len() > 20); }
#[test]
fn test_gamma_ap() {
let ctx = PrecisionContext::default();
let gamma_1 = gamma::gamma_ap(1.0, &ctx).expect("Operation failed");
assert_relative_eq!(to_f64(&gamma_1), 1.0, epsilon = 1e-15);
let gamma_half = gamma::gamma_ap(0.5, &ctx).expect("Operation failed");
let sqrt_pi = std::f64::consts::PI.sqrt();
assert_relative_eq!(to_f64(&gamma_half), sqrt_pi, epsilon = 1e-15);
let gamma_5 = gamma::gamma_ap(5.0, &ctx).expect("Operation failed");
assert_relative_eq!(to_f64(&gamma_5), 24.0, epsilon = 1e-13);
}
#[test]
fn test_bessel_ap() {
let ctx = PrecisionContext::default();
let j0_0 = bessel::bessel_j_ap(0, 0.0, &ctx).expect("Operation failed");
assert_relative_eq!(to_f64(&j0_0), 1.0, epsilon = 1e-15);
let j1_0 = bessel::bessel_j_ap(1, 0.0, &ctx).expect("Operation failed");
assert_relative_eq!(to_f64(&j1_0), 0.0, epsilon = 1e-15);
}
#[test]
fn test_erf_ap() {
let ctx = PrecisionContext::default();
let erf_0 = error_function::erf_ap(0.0, &ctx).expect("Operation failed");
assert_relative_eq!(to_f64(&erf_0), 0.0, epsilon = 1e-15);
let erfc_0 = error_function::erfc_ap(0.0, &ctx).expect("Operation failed");
assert_relative_eq!(to_f64(&erfc_0), 1.0, epsilon = 1e-15);
let x = 1.5;
let erf_x = error_function::erf_ap(x, &ctx).expect("Operation failed");
let erfc_x = error_function::erfc_ap(x, &ctx).expect("Operation failed");
let sum = to_f64(&erf_x) + to_f64(&erfc_x);
assert_relative_eq!(sum, 1.0, epsilon = 1e-15);
}
#[test]
fn test_high_precision() {
let ctx = PrecisionContext::new(1024).expect("Operation failed");
let pi = ctx.pi();
let pi_str = format!("{:.100}", pi);
let expected_pi = "3.141592653589793238462643383279502884197169399375105820974944592307816406286208998628034825342117068";
assert!(pi_str.starts_with(expected_pi));
}
}