use std::any::TypeId;
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::ops::elementwise::{binary_map, unary_map};
use crate::tensor::Tensor;
#[inline]
fn nt_zero<T: num_traits::Zero>() -> T {
<T as num_traits::Zero>::zero()
}
#[inline]
fn nt_one<T: num_traits::One>() -> T {
<T as num_traits::One>::one()
}
const ERF_A1: f64 = 0.254829592;
const ERF_A2: f64 = -0.284496736;
const ERF_A3: f64 = 1.421413741;
const ERF_A4: f64 = -1.453152027;
const ERF_A5: f64 = 1.061405429;
const ERF_P: f64 = 0.3275911;
const LANCZOS_G: f64 = 7.0;
const LANCZOS_COEFFICIENTS: [f64; 9] = [
0.999_999_999_999_809_9,
676.5203681218851,
-1259.1392167224028,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507343278686905,
-0.13857109526572012,
9.984_369_578_019_572e-6,
1.5056327351493116e-7,
];
#[allow(clippy::excessive_precision)]
const ERF_EFX: f64 = 1.2837916709551257e-01;
#[allow(clippy::excessive_precision)]
const ERF_PP0: f64 = 1.28379167095512558561e-01;
#[allow(clippy::excessive_precision)]
const ERF_PP1: f64 = -3.25042107247001499370e-01;
#[allow(clippy::excessive_precision)]
const ERF_PP2: f64 = -2.84817495755985104766e-02;
#[allow(clippy::excessive_precision)]
const ERF_PP3: f64 = -5.77027029648944159157e-03;
#[allow(clippy::excessive_precision)]
const ERF_PP4: f64 = -2.37630166566501626084e-05;
#[allow(clippy::excessive_precision)]
const ERF_QQ1: f64 = 3.97917223959155352819e-01;
#[allow(clippy::excessive_precision)]
const ERF_QQ2: f64 = 6.50222499887672944485e-02;
#[allow(clippy::excessive_precision)]
const ERF_QQ3: f64 = 5.08130628187576562776e-03;
#[allow(clippy::excessive_precision)]
const ERF_QQ4: f64 = 1.32494738004321644526e-04;
#[allow(clippy::excessive_precision)]
const ERF_QQ5: f64 = -3.96022827877536812320e-06;
#[allow(clippy::excessive_precision)]
const ERF_ERX: f64 = 8.45062911510467529297e-01;
#[allow(clippy::excessive_precision)]
const ERF_PA0: f64 = -2.36211856075265944077e-03;
#[allow(clippy::excessive_precision)]
const ERF_PA1: f64 = 4.14856118683748331666e-01;
#[allow(clippy::excessive_precision)]
const ERF_PA2: f64 = -3.72207876035701323847e-01;
#[allow(clippy::excessive_precision)]
const ERF_PA3: f64 = 3.18346619901161753674e-01;
#[allow(clippy::excessive_precision)]
const ERF_PA4: f64 = -1.10894694282396677476e-01;
#[allow(clippy::excessive_precision)]
const ERF_PA5: f64 = 3.54783043256182359371e-02;
#[allow(clippy::excessive_precision)]
const ERF_PA6: f64 = -2.16637559486879084300e-03;
#[allow(clippy::excessive_precision)]
const ERF_QA1: f64 = 1.06420880400844228286e-01;
#[allow(clippy::excessive_precision)]
const ERF_QA2: f64 = 5.40397917702171048937e-01;
#[allow(clippy::excessive_precision)]
const ERF_QA3: f64 = 7.18286544141962662868e-02;
#[allow(clippy::excessive_precision)]
const ERF_QA4: f64 = 1.26171219808761642112e-01;
#[allow(clippy::excessive_precision)]
const ERF_QA5: f64 = 1.36370839120290507362e-02;
#[allow(clippy::excessive_precision)]
const ERF_QA6: f64 = 1.19844998467991074170e-02;
#[allow(clippy::excessive_precision)]
const ERF_RA0: f64 = -9.86494403484714822705e-03;
#[allow(clippy::excessive_precision)]
const ERF_RA1: f64 = -6.93858572707181764372e-01;
#[allow(clippy::excessive_precision)]
const ERF_RA2: f64 = -1.05586262253232909814e+01;
#[allow(clippy::excessive_precision)]
const ERF_RA3: f64 = -6.23753324503260060396e+01;
#[allow(clippy::excessive_precision)]
const ERF_RA4: f64 = -1.62396669462573470355e+02;
#[allow(clippy::excessive_precision)]
const ERF_RA5: f64 = -1.84605092906711035994e+02;
#[allow(clippy::excessive_precision)]
const ERF_RA6: f64 = -8.12874355063065934246e+01;
#[allow(clippy::excessive_precision)]
const ERF_RA7: f64 = -9.81432934416914548592e+00;
#[allow(clippy::excessive_precision)]
const ERF_SA1: f64 = 1.96512716674392571292e+01;
#[allow(clippy::excessive_precision)]
const ERF_SA2: f64 = 1.37657754143519042600e+02;
#[allow(clippy::excessive_precision)]
const ERF_SA3: f64 = 4.34565877475229228821e+02;
#[allow(clippy::excessive_precision)]
const ERF_SA4: f64 = 6.45387271733267880336e+02;
#[allow(clippy::excessive_precision)]
const ERF_SA5: f64 = 4.29008140027567833386e+02;
#[allow(clippy::excessive_precision)]
const ERF_SA6: f64 = 1.08635005541779435134e+02;
#[allow(clippy::excessive_precision)]
const ERF_SA7: f64 = 6.57024977031928170135e+00;
#[allow(clippy::excessive_precision)]
const ERF_SA8: f64 = -6.04244152148580987438e-02;
#[allow(clippy::excessive_precision)]
const ERF_RB0: f64 = -9.86494292470009928597e-03;
#[allow(clippy::excessive_precision)]
const ERF_RB1: f64 = -7.99283237680523006574e-01;
#[allow(clippy::excessive_precision)]
const ERF_RB2: f64 = -1.77579549177547519889e+01;
#[allow(clippy::excessive_precision)]
const ERF_RB3: f64 = -1.60636384855821916062e+02;
#[allow(clippy::excessive_precision)]
const ERF_RB4: f64 = -6.37566443368389627722e+02;
#[allow(clippy::excessive_precision)]
const ERF_RB5: f64 = -1.02509513161107724954e+03;
#[allow(clippy::excessive_precision)]
const ERF_RB6: f64 = -4.83519191608651397019e+02;
#[allow(clippy::excessive_precision)]
const ERF_SB1: f64 = 3.03380607434824582924e+01;
#[allow(clippy::excessive_precision)]
const ERF_SB2: f64 = 3.25792512996573918826e+02;
#[allow(clippy::excessive_precision)]
const ERF_SB3: f64 = 1.53672958608443695994e+03;
#[allow(clippy::excessive_precision)]
const ERF_SB4: f64 = 3.19985821950859553908e+03;
#[allow(clippy::excessive_precision)]
const ERF_SB5: f64 = 2.55305040643316442583e+03;
#[allow(clippy::excessive_precision)]
const ERF_SB6: f64 = 4.74528541206955367215e+02;
#[allow(clippy::excessive_precision)]
const ERF_SB7: f64 = -2.24409524465858183362e+01;
fn erf_f64_hi(x: f64) -> f64 {
if x.is_nan() {
return x;
}
if x == f64::INFINITY {
return 1.0;
}
if x == f64::NEG_INFINITY {
return -1.0;
}
let ax = x.abs();
if ax < 0.84375 {
if ax < f64::from_bits(0x3E300000_00000000) {
return x + ERF_EFX * x;
}
let z = x * x;
let r = ERF_PP0 + z * (ERF_PP1 + z * (ERF_PP2 + z * (ERF_PP3 + z * ERF_PP4)));
let s = 1.0 + z * (ERF_QQ1 + z * (ERF_QQ2 + z * (ERF_QQ3 + z * (ERF_QQ4 + z * ERF_QQ5))));
let y = r / s;
return x + x * y;
}
if ax < 1.25 {
let s = ax - 1.0;
let p = ERF_PA0
+ s * (ERF_PA1
+ s * (ERF_PA2 + s * (ERF_PA3 + s * (ERF_PA4 + s * (ERF_PA5 + s * ERF_PA6)))));
let q = 1.0
+ s * (ERF_QA1
+ s * (ERF_QA2 + s * (ERF_QA3 + s * (ERF_QA4 + s * (ERF_QA5 + s * ERF_QA6)))));
let y = ERF_ERX + p / q;
return if x >= 0.0 { y } else { -y };
}
if ax >= 6.0 {
return if x >= 0.0 { 1.0 } else { -1.0 };
}
let s = 1.0 / (ax * ax);
let (r, big_s) = if ax < 1.0 / 0.35 {
let r = ERF_RA0
+ s * (ERF_RA1
+ s * (ERF_RA2
+ s * (ERF_RA3 + s * (ERF_RA4 + s * (ERF_RA5 + s * (ERF_RA6 + s * ERF_RA7))))));
let big_s = 1.0
+ s * (ERF_SA1
+ s * (ERF_SA2
+ s * (ERF_SA3
+ s * (ERF_SA4
+ s * (ERF_SA5 + s * (ERF_SA6 + s * (ERF_SA7 + s * ERF_SA8)))))));
(r, big_s)
} else {
let r = ERF_RB0
+ s * (ERF_RB1
+ s * (ERF_RB2 + s * (ERF_RB3 + s * (ERF_RB4 + s * (ERF_RB5 + s * ERF_RB6)))));
let big_s = 1.0
+ s * (ERF_SB1
+ s * (ERF_SB2
+ s * (ERF_SB3 + s * (ERF_SB4 + s * (ERF_SB5 + s * (ERF_SB6 + s * ERF_SB7))))));
(r, big_s)
};
let bits = ax.to_bits() & 0xFFFFFFFF_00000000;
let z = f64::from_bits(bits);
let r_factor = (-z * z - 0.5625).exp() * (-(ax - z) * (ax + z) + r / big_s).exp() / ax;
if x >= 0.0 {
1.0 - r_factor
} else {
r_factor - 1.0
}
}
fn erfc_f64_hi(x: f64) -> f64 {
if x.is_nan() {
return x;
}
if x == f64::INFINITY {
return 0.0;
}
if x == f64::NEG_INFINITY {
return 2.0;
}
let ax = x.abs();
if ax < 0.84375 {
if ax < f64::from_bits(0x3C700000_00000000) {
return 1.0 - x;
}
let z = x * x;
let r = ERF_PP0 + z * (ERF_PP1 + z * (ERF_PP2 + z * (ERF_PP3 + z * ERF_PP4)));
let s = 1.0 + z * (ERF_QQ1 + z * (ERF_QQ2 + z * (ERF_QQ3 + z * (ERF_QQ4 + z * ERF_QQ5))));
let y = r / s;
if ax < 0.25 {
return 1.0 - (x + x * y);
}
let r2 = x * y;
let r3 = r2 + x;
return 0.5 - (r3 - 0.5);
}
if ax < 1.25 {
let s = ax - 1.0;
let p = ERF_PA0
+ s * (ERF_PA1
+ s * (ERF_PA2 + s * (ERF_PA3 + s * (ERF_PA4 + s * (ERF_PA5 + s * ERF_PA6)))));
let q = 1.0
+ s * (ERF_QA1
+ s * (ERF_QA2 + s * (ERF_QA3 + s * (ERF_QA4 + s * (ERF_QA5 + s * ERF_QA6)))));
if x >= 0.0 {
let z = 1.0 - ERF_ERX;
return z - p / q;
}
let z = ERF_ERX + p / q;
return 1.0 + z;
}
if ax < 28.0 {
let s = 1.0 / (ax * ax);
let (r, big_s) = if ax < 1.0 / 0.35 {
let r = ERF_RA0
+ s * (ERF_RA1
+ s * (ERF_RA2
+ s * (ERF_RA3
+ s * (ERF_RA4 + s * (ERF_RA5 + s * (ERF_RA6 + s * ERF_RA7))))));
let big_s = 1.0
+ s * (ERF_SA1
+ s * (ERF_SA2
+ s * (ERF_SA3
+ s * (ERF_SA4
+ s * (ERF_SA5 + s * (ERF_SA6 + s * (ERF_SA7 + s * ERF_SA8)))))));
(r, big_s)
} else {
let r = ERF_RB0
+ s * (ERF_RB1
+ s * (ERF_RB2 + s * (ERF_RB3 + s * (ERF_RB4 + s * (ERF_RB5 + s * ERF_RB6)))));
let big_s = 1.0
+ s * (ERF_SB1
+ s * (ERF_SB2
+ s * (ERF_SB3
+ s * (ERF_SB4 + s * (ERF_SB5 + s * (ERF_SB6 + s * ERF_SB7))))));
(r, big_s)
};
let bits = ax.to_bits() & 0xFFFFFFFF_00000000;
let z = f64::from_bits(bits);
let r_factor = (-z * z - 0.5625).exp() * (-(ax - z) * (ax + z) + r / big_s).exp() / ax;
if x >= 0.0 { r_factor } else { 2.0 - r_factor }
} else if x >= 0.0 {
0.0
} else {
2.0
}
}
pub(crate) fn erf_scalar<T: Float>(x: T) -> T {
if TypeId::of::<T>() == TypeId::of::<f64>() {
let xf = x.to_f64().unwrap();
let yf = erf_f64_hi(xf);
return T::from(yf).unwrap();
}
let zero = nt_zero::<T>();
let one = nt_one::<T>();
if x == zero {
return zero;
}
let sign = if x < zero { -one } else { one };
let ax = x.abs();
let p = T::from(ERF_P).unwrap();
let t = one / (one + p * ax);
let a1 = T::from(ERF_A1).unwrap();
let a2 = T::from(ERF_A2).unwrap();
let a3 = T::from(ERF_A3).unwrap();
let a4 = T::from(ERF_A4).unwrap();
let a5 = T::from(ERF_A5).unwrap();
let poly = a1 + t * (a2 + t * (a3 + t * (a4 + t * a5)));
sign * (one - poly * t * (-ax * ax).exp())
}
fn erfc_scalar<T: Float>(x: T) -> T {
if TypeId::of::<T>() == TypeId::of::<f64>() {
let xf = x.to_f64().unwrap();
let yf = erfc_f64_hi(xf);
return T::from(yf).unwrap();
}
nt_one::<T>() - erf_scalar(x)
}
fn erfinv_scalar<T: Float>(x: T) -> T {
let zero = nt_zero::<T>();
let one = nt_one::<T>();
if x == zero {
return zero;
}
if x >= one {
return T::infinity();
}
if x <= -one {
return T::neg_infinity();
}
let y = x.to_f64().unwrap();
let sign = if y < 0.0 { -1.0 } else { 1.0 };
let ay = y.abs();
let a = 0.147_f64;
let pi = std::f64::consts::PI;
let ln_term = (1.0 - ay * ay).ln();
let b = 2.0 / (pi * a) + ln_term / 2.0;
let c = ln_term / a;
let mut z = sign * (-b + (b * b - c).sqrt()).sqrt();
let half_sqrt_pi = 0.5 * pi.sqrt();
for _ in 0..3 {
let resid = erf_f64_hi(z) - y;
if resid.abs() < 4.0 * f64::EPSILON {
break;
}
z -= resid * half_sqrt_pi * (z * z).exp();
}
T::from(z).unwrap()
}
fn lgamma_scalar<T: Float>(x: T) -> T {
let one = nt_one::<T>();
let half = T::from(0.5).unwrap();
let half_ln_2pi = T::from(0.9189385332046727).unwrap(); let g = T::from(LANCZOS_G).unwrap();
if x < half {
let pi = T::from(std::f64::consts::PI).unwrap();
let sin_pi_x = (pi * x).sin();
if sin_pi_x == nt_zero::<T>() {
return T::infinity();
}
return (pi / sin_pi_x.abs()).ln() - lgamma_scalar(one - x);
}
let z = x - one;
let mut sum = T::from(LANCZOS_COEFFICIENTS[0]).unwrap();
for (i, &coeff) in LANCZOS_COEFFICIENTS.iter().enumerate().skip(1) {
sum += T::from(coeff).unwrap() / (z + T::from(i as f64).unwrap());
}
let t = z + g + half;
half_ln_2pi + (t).ln() * (z + half) - t + sum.ln()
}
fn digamma_f64_hi(x: f64) -> f64 {
if x < 0.5 {
let pi = std::f64::consts::PI;
let cot = (pi * x).cos() / (pi * x).sin();
return digamma_f64_hi(1.0 - x) - pi * cot;
}
let mut acc = 0.0_f64;
let mut z = x;
while z < 14.0 {
acc -= 1.0 / z;
z += 1.0;
}
let z2 = z * z;
let z4 = z2 * z2;
let z6 = z4 * z2;
let z8 = z4 * z4;
let z10 = z8 * z2;
let z12 = z8 * z4;
acc + z.ln() - 1.0 / (2.0 * z) - 1.0 / (12.0 * z2) + 1.0 / (120.0 * z4) - 1.0 / (252.0 * z6)
+ 1.0 / (240.0 * z8)
- 1.0 / (132.0 * z10)
+ 691.0 / (32_760.0 * z12)
}
fn digamma_scalar<T: Float>(x: T) -> T {
if TypeId::of::<T>() == TypeId::of::<f64>() {
let xf = x.to_f64().unwrap();
let yf = digamma_f64_hi(xf);
return T::from(yf).unwrap();
}
let zero = nt_zero::<T>();
let one = nt_one::<T>();
let half = T::from(0.5).unwrap();
if x < half {
let pi = T::from(std::f64::consts::PI).unwrap();
let cot = (pi * x).cos() / (pi * x).sin();
return digamma_scalar(one - x) - pi * cot;
}
let mut result = zero;
let mut z = x;
let six = T::from(6.0).unwrap();
while z < six {
#[allow(clippy::assign_op_pattern)]
{
result = result - one / z;
}
#[allow(clippy::assign_op_pattern)]
{
z = z + one;
}
}
let z2 = z * z;
let z4 = z2 * z2;
let z6 = z4 * z2;
result =
result + z.ln() - one / (T::from(2.0).unwrap() * z) - one / (T::from(12.0).unwrap() * z2)
+ one / (T::from(120.0).unwrap() * z4)
- one / (T::from(252.0).unwrap() * z6);
result
}
fn sinc_scalar<T: Float>(x: T) -> T {
let zero = nt_zero::<T>();
let one = nt_one::<T>();
if x == zero {
return one;
}
let pi = T::from(std::f64::consts::PI).unwrap();
let pi_x = pi * x;
pi_x.sin() / pi_x
}
fn xlogy_scalar<T: Float>(x: T, y: T) -> T {
if x == nt_zero::<T>() {
nt_zero::<T>()
} else {
x * y.ln()
}
}
fn entr_scalar<T: Float>(a: T) -> T {
if a.is_nan() {
return a;
}
let zero = nt_zero::<T>();
if a > zero {
return -a * a.ln();
}
if a == zero {
return zero;
}
T::neg_infinity()
}
fn ndtr_scalar<T: Float>(x: T) -> T {
let sqrt1_2 = T::from(std::f64::consts::FRAC_1_SQRT_2).unwrap();
let one = nt_one::<T>();
let half = T::from(0.5).unwrap();
(one + erf_scalar(x * sqrt1_2)) * half
}
fn polevl<T: Float>(x: T, coeffs: &[T]) -> T {
let mut result = nt_zero::<T>();
for &c in coeffs {
result = result * x + c;
}
result
}
fn ndtri_scalar<T: Float>(y0: T) -> T {
let yf = <T as num_traits::ToPrimitive>::to_f64(&y0).unwrap_or(f64::NAN);
T::from(ndtri_f64(yf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
#[allow(
clippy::float_cmp,
clippy::manual_range_contains,
clippy::excessive_precision,
reason = "verbatim Cephes ndtri port: exact-endpoint boundary tests and full-width coefficients mirror Math.cuh:48-173 for ULP parity + audit-friendly diff"
)]
fn ndtri_f64(y0: f64) -> f64 {
if y0 == 0.0 {
return f64::NEG_INFINITY;
}
if y0 == 1.0 {
return f64::INFINITY;
}
if y0 < 0.0 || y0 > 1.0 {
return f64::NAN;
}
if y0.is_nan() {
return f64::NAN;
}
const P0: [f64; 5] = [
-5.99633501014107895267E1,
9.80010754185999661536E1,
-5.66762857469070293439E1,
1.39312609387279679503E1,
-1.23916583867381258016E0,
];
const Q0: [f64; 9] = [
1.00000000000000000000E0,
1.95448858338141759834E0,
4.67627912898881538453E0,
8.63602421390890590575E1,
-2.25462687854119370527E2,
2.00260212380060660359E2,
-8.20372256168333339912E1,
1.59056225126211695515E1,
-1.18331621121330003142E0,
];
const P1: [f64; 9] = [
4.05544892305962419923E0,
3.15251094599893866154E1,
5.71628192246421288162E1,
4.40805073893200834700E1,
1.46849561928858024014E1,
2.18663306850790267539E0,
-1.40256079171354495875E-1,
-3.50424626827848203418E-2,
-8.57456785154685413611E-4,
];
const Q1: [f64; 9] = [
1.00000000000000000000E0,
1.57799883256466749731E1,
4.53907635128879210584E1,
4.13172038254672030440E1,
1.50425385692907503408E1,
2.50464946208309415979E0,
-1.42182922854787788574E-1,
-3.80806407691578277194E-2,
-9.33259480895457427372E-4,
];
const P2: [f64; 9] = [
3.23774891776946035970E0,
6.91522889068984211695E0,
3.93881025292474443415E0,
1.33303460815807542389E0,
2.01485389549179081538E-1,
1.23716634817820021358E-2,
3.01581553508235416007E-4,
2.65806974686737550832E-6,
6.23974539184983293730E-9,
];
const Q2: [f64; 9] = [
1.00000000000000000000E0,
6.02427039364742014255E0,
3.67983563856160859403E0,
1.37702099489081330271E0,
2.16236993594496635890E-1,
1.34204006088543189037E-2,
3.28014464682127739104E-4,
2.89247864745380683936E-6,
6.79019408009981274425E-9,
];
const S2PI: f64 = 2.50662827463100050242E0;
const EXP_M2: f64 = 0.13533528323661269189;
let mut code = true;
let mut y = y0;
if y > 1.0 - EXP_M2 {
y = 1.0 - y;
code = false;
}
if y > EXP_M2 {
y -= 0.5;
let y2 = y * y;
let x = y + y * (y2 * polevl(y2, &P0) / polevl(y2, &Q0));
return x * S2PI;
}
let mut x = (-2.0 * y.ln()).sqrt();
let x0 = x - (x.ln() / x);
let z = 1.0 / x;
let x1 = if x < 8.0 {
z * polevl(z, &P1) / polevl(z, &Q1)
} else {
z * polevl(z, &P2) / polevl(z, &Q2)
};
x = x0 - x1;
if code { -x } else { x }
}
fn chbevl<T: Float>(x: T, array: &[T]) -> T {
let mut b0 = array[0];
let mut b1 = nt_zero::<T>();
let mut b2 = nt_zero::<T>();
for &c in &array[1..] {
b2 = b1;
b1 = b0;
b0 = x * b1 - b2 + c;
}
T::from(0.5).unwrap() * (b0 - b2)
}
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes i0e_A coefficients (cuda/Math.cuh:512-527) reproduced to full width for an audit-friendly diff; trailing digits round to the same f64 bit pattern"
)]
const I0E_A: [f64; 30] = [
-4.41534164647933937950E-18,
3.33079451882223809783E-17,
-2.43127984654795469359E-16,
1.71539128555513303061E-15,
-1.16853328779934516808E-14,
7.67618549860493561688E-14,
-4.85644678311192946090E-13,
2.95505266312963983461E-12,
-1.72682629144155570723E-11,
9.67580903537323691224E-11,
-5.18979560163526290666E-10,
2.65982372468238665035E-9,
-1.30002500998624804212E-8,
6.04699502254191894932E-8,
-2.67079385394061173391E-7,
1.11738753912010371815E-6,
-4.41673835845875056359E-6,
1.64484480707288970893E-5,
-5.75419501008210370398E-5,
1.88502885095841655729E-4,
-5.76375574538582365885E-4,
1.63947561694133579842E-3,
-4.32430999505057594430E-3,
1.05464603945949983183E-2,
-2.37374148058994688156E-2,
4.93052842396707084878E-2,
-9.49010970480476444210E-2,
1.71620901522208775349E-1,
-3.04682672343198398683E-1,
6.76795274409476084995E-1,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes i0e_B coefficients (cuda/Math.cuh:539-552)"
)]
const I0E_B: [f64; 25] = [
-7.23318048787475395456E-18,
-4.83050448594418207126E-18,
4.46562142029675999901E-17,
3.46122286769746109310E-17,
-2.82762398051658348494E-16,
-3.42548561967721913462E-16,
1.77256013305652638360E-15,
3.81168066935262242075E-15,
-9.55484669882830764870E-15,
-4.15056934728722208663E-14,
1.54008621752140982691E-14,
3.85277838274214270114E-13,
7.18012445138366623367E-13,
-1.79417853150680611778E-12,
-1.32158118404477131188E-11,
-3.14991652796324136454E-11,
1.18891471078464383424E-11,
4.94060238822496958910E-10,
3.39623202570838634515E-9,
2.26666899049817806459E-8,
2.04891858946906374183E-7,
2.89137052083475648297E-6,
6.88975834691682398426E-5,
3.36911647825569408990E-3,
8.04490411014108831608E-1,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes i1e_A coefficients (cuda/Math.cuh:582-597)"
)]
const I1E_A: [f64; 29] = [
2.77791411276104639959E-18,
-2.11142121435816608115E-17,
1.55363195773620046921E-16,
-1.10559694773538630805E-15,
7.60068429473540693410E-15,
-5.04218550472791168711E-14,
3.22379336594557470981E-13,
-1.98397439776494371520E-12,
1.17361862988909016308E-11,
-6.66348972350202774223E-11,
3.62559028155211703701E-10,
-1.88724975172282928790E-9,
9.38153738649577178388E-9,
-4.44505912879632808065E-8,
2.00329475355213526229E-7,
-8.56872026469545474066E-7,
3.47025130813767847674E-6,
-1.32731636560394358279E-5,
4.78156510755005422638E-5,
-1.61760815825896745588E-4,
5.12285956168575772895E-4,
-1.51357245063125314899E-3,
4.15642294431288815669E-3,
-1.05640848946261981558E-2,
2.47264490306265168283E-2,
-5.29459812080949914269E-2,
1.02643658689847095384E-1,
-1.76416518357834055153E-1,
2.52587186443633654823E-1,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes i1e_B coefficients (cuda/Math.cuh:606-619)"
)]
const I1E_B: [f64; 25] = [
7.51729631084210481353E-18,
4.41434832307170791151E-18,
-4.65030536848935832153E-17,
-3.20952592199342395980E-17,
2.96262899764595013876E-16,
3.30820231092092828324E-16,
-1.88035477551078244854E-15,
-3.81440307243700780478E-15,
1.04202769841288027642E-14,
4.27244001671195135429E-14,
-2.10154184277266431302E-14,
-4.08355111109219731823E-13,
-7.19855177624590851209E-13,
2.03562854414708950722E-12,
1.41258074366137813316E-11,
3.25260358301548823856E-11,
-1.89749581235054123450E-11,
-5.58974346219658380687E-10,
-3.83538038596423702205E-9,
-2.63146884688951950684E-8,
-2.51223623787020892529E-7,
-3.88256480887769039346E-6,
-1.10588938762623716291E-4,
-9.76109749136146840777E-3,
7.78576235018280120474E-1,
];
fn i0_f64(x_in: f64) -> f64 {
let x = x_in.abs();
if x <= 8.0 {
let y = (x / 2.0) - 2.0;
x.exp() * chbevl(y, &I0E_A)
} else {
(x.exp() * chbevl(32.0 / x - 2.0, &I0E_B)) / x.sqrt()
}
}
fn i0e_f64(x_in: f64) -> f64 {
let x = x_in.abs();
if x <= 8.0 {
let y = (x / 2.0) - 2.0;
chbevl(y, &I0E_A)
} else {
chbevl(32.0 / x - 2.0, &I0E_B) / x.sqrt()
}
}
fn i1_f64(x_in: f64) -> f64 {
let x = x_in.abs();
let out = if x <= 8.0 {
let y = x / 2.0 - 2.0;
x.exp() * x * chbevl(y, &I1E_A)
} else {
(x.exp() * chbevl(32.0 / x - 2.0, &I1E_B)) / x.sqrt()
};
if x_in < 0.0 { -out } else { out }
}
fn i1e_f64(x_in: f64) -> f64 {
let x = x_in.abs();
let out = if x <= 8.0 {
let y = x / 2.0 - 2.0;
chbevl(y, &I1E_A) * x
} else {
chbevl(32.0 / x - 2.0, &I1E_B) / x.sqrt()
};
if x_in < 0.0 { -out } else { out }
}
fn i0_scalar<T: Float>(x: T) -> T {
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
T::from(i0_f64(xf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn i0e_scalar<T: Float>(x: T) -> T {
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
T::from(i0e_f64(xf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn i1_scalar<T: Float>(x: T) -> T {
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
T::from(i1_f64(xf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn i1e_scalar<T: Float>(x: T) -> T {
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
T::from(i1e_f64(xf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn spherical_bessel_j0_f64(x: f64) -> f64 {
if x.is_infinite() {
return 0.0;
}
if x.abs() < 0.5 {
let x2 = x * x;
return 1.0
+ x2 * (-1.0 / 6.0
+ x2 * (1.0 / 120.0
+ x2 * (-1.0 / 5040.0
+ x2 * (1.0 / 362880.0
+ x2 * (-1.0 / 39916800.0 + x2 * (1.0 / 6227020800.0))))));
}
x.sin() / x
}
fn spherical_bessel_j0_scalar<T: Float>(x: T) -> T {
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
T::from(spherical_bessel_j0_f64(xf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes K0 A-set (cuda/Math.cuh:2504-2515); full-width for audit-friendly diff"
)]
const K0_A: [f64; 10] = [
1.37446543561352307156e-16,
4.25981614279661018399e-14,
1.03496952576338420167e-11,
1.90451637722020886025e-09,
2.53479107902614945675e-07,
2.28621210311945178607e-05,
1.26461541144692592338e-03,
3.59799365153615016266e-02,
3.44289899924628486886e-01,
-5.35327393233902768720e-01,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes K0 B-set (cuda/Math.cuh:2517-2543)"
)]
const K0_B: [f64; 25] = [
5.30043377268626276149e-18,
-1.64758043015242134646e-17,
5.21039150503902756861e-17,
-1.67823109680541210385e-16,
5.51205597852431940784e-16,
-1.84859337734377901440e-15,
6.34007647740507060557e-15,
-2.22751332699166985548e-14,
8.03289077536357521100e-14,
-2.98009692317273043925e-13,
1.14034058820847496303e-12,
-4.51459788337394416547e-12,
1.85594911495471785253e-11,
-7.95748924447710747776e-11,
3.57739728140030116597e-10,
-1.69753450938905987466e-09,
8.57403401741422608519e-09,
-4.66048989768794782956e-08,
2.76681363944501510342e-07,
-1.83175552271911948767e-06,
1.39498137188764993662e-05,
-1.28495495816278026384e-04,
1.56988388573005337491e-03,
-3.14481013119645005427e-02,
2.44030308206595545468e+00,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes K1 A-set (cuda/Math.cuh:2662-2673)"
)]
const K1_A: [f64; 11] = [
-7.02386347938628759343e-18,
-2.42744985051936593393e-15,
-6.66690169419932900609e-13,
-1.41148839263352776110e-10,
-2.21338763073472585583e-08,
-2.43340614156596823496e-06,
-1.73028895751305206302e-04,
-6.97572385963986435018e-03,
-1.22611180822657148235e-01,
-3.53155960776544875667e-01,
1.52530022733894777053e+00,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes K1 B-set (cuda/Math.cuh:2676-2702)"
)]
const K1_B: [f64; 25] = [
-5.75674448366501715755e-18,
1.79405087314755922667e-17,
-5.68946255844285935196e-17,
1.83809354436663880070e-16,
-6.05704724837331885336e-16,
2.03870316562433424052e-15,
-7.01983709041831346144e-15,
2.47715442448130437068e-14,
-8.97670518232499435011e-14,
3.34841966607842919884e-13,
-1.28917396095102890680e-12,
5.13963967348173025100e-12,
-2.12996783842756842877e-11,
9.21831518760500529508e-11,
-4.19035475934189648750e-10,
2.01504975519703286596e-09,
-1.03457624656780970260e-08,
5.74108412545004946722e-08,
-3.50196060308781257119e-07,
2.40648494783721712015e-06,
-1.93619797416608296024e-05,
1.95215518471351631108e-04,
-2.85781685962277938680e-03,
1.03923736576817238437e-01,
2.72062619048444266945e+00,
];
fn modified_bessel_k0_f64(x: f64) -> f64 {
if x == 0.0 {
return f64::INFINITY;
}
if x < 0.0 {
return f64::NAN;
}
if x <= 2.0 {
chbevl(x * x - 2.0, &K0_A) - (0.5 * x).ln() * i0_f64(x)
} else {
(-x).exp() * chbevl(8.0 / x - 2.0, &K0_B) / x.sqrt()
}
}
fn scaled_modified_bessel_k0_f64(x: f64) -> f64 {
if x == 0.0 {
return f64::INFINITY;
}
if x < 0.0 {
return f64::NAN;
}
if x <= 2.0 {
(chbevl(x * x - 2.0, &K0_A) - (0.5 * x).ln() * i0_f64(x)) * x.exp()
} else {
chbevl(8.0 / x - 2.0, &K0_B) / x.sqrt()
}
}
fn modified_bessel_k1_f64(x: f64) -> f64 {
if x == 0.0 {
return f64::INFINITY;
}
if x < 0.0 {
return f64::NAN;
}
if x <= 2.0 {
(0.5 * x).ln() * i1_f64(x) + chbevl(x * x - 2.0, &K1_A) / x
} else {
(-x).exp() * chbevl(8.0 / x - 2.0, &K1_B) / x.sqrt()
}
}
fn scaled_modified_bessel_k1_f64(x: f64) -> f64 {
if x == 0.0 {
return f64::INFINITY;
}
if x < 0.0 {
return f64::NAN;
}
if x <= 2.0 {
((0.5 * x).ln() * i1_f64(x) + chbevl(x * x - 2.0, &K1_A) / x) * x.exp()
} else {
chbevl(8.0 / x - 2.0, &K1_B) / x.sqrt()
}
}
fn modified_bessel_k0_scalar<T: Float>(x: T) -> T {
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
T::from(modified_bessel_k0_f64(xf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn scaled_modified_bessel_k0_scalar<T: Float>(x: T) -> T {
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
T::from(scaled_modified_bessel_k0_f64(xf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn modified_bessel_k1_scalar<T: Float>(x: T) -> T {
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
T::from(modified_bessel_k1_f64(xf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn scaled_modified_bessel_k1_scalar<T: Float>(x: T) -> T {
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
T::from(scaled_modified_bessel_k1_f64(xf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes MACHEP (cuda/Math.cuh:302); full-width for audit-friendly diff"
)]
const CEPHES_MACHEP: f64 = 1.11022302462515654042E-16;
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes zeta A-set (cuda/Math.cuh:306-319); full-width for audit-friendly diff"
)]
const ZETA_A: [f64; 12] = [
12.0,
-720.0,
30240.0,
-1209600.0,
47900160.0,
-1.8924375803183791606e9,
7.47242496e10,
-2.950130727918164224e12,
1.1646782814350067249e14,
-4.5979787224074726105e15,
1.8152105401943546773e17,
-7.1661652561756670113e18,
];
#[allow(
clippy::float_cmp,
reason = "verbatim Cephes edge ladder: `x == 1` and the `q == floor(q)` / `x != floor(x)` integer tests are exact-equality branches in upstream (cuda/Math.cuh:325, 337, 340); R-DEV-1 byte-match"
)]
fn zeta_f64(x: f64, q: f64) -> f64 {
const ZERO: f64 = 0.0;
const HALF: f64 = 0.5;
const ONE: f64 = 1.0;
if x == ONE {
return f64::INFINITY;
}
if x < ONE {
return f64::NAN;
}
if q <= ZERO {
if q == q.floor() {
return f64::INFINITY;
}
if x != x.floor() {
return f64::NAN;
}
}
let mut s = q.powf(-x);
let mut a = q;
let mut i: i32 = 0;
let mut b = ZERO;
while (i < 9) || (a <= 9.0) {
i += 1;
a += ONE;
b = a.powf(-x);
s += b;
if (-CEPHES_MACHEP * s < b) && (b < CEPHES_MACHEP * s) {
return s;
}
}
let w = a;
s += b * w / (x - ONE);
s -= HALF * b;
a = ONE;
let mut k = ZERO;
for &coeff in &ZETA_A {
a *= x + k;
b /= w;
let mut t = a * b / coeff;
s += t;
t = (t / s).abs();
if t < CEPHES_MACHEP {
return s;
}
k += ONE;
a *= x + k;
b /= w;
k += ONE;
}
s
}
fn zeta_scalar<T: Float>(x: T, q: T) -> T {
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
let qf = <T as num_traits::ToPrimitive>::to_f64(&q).unwrap_or(f64::NAN);
T::from(zeta_f64(xf, qf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes airy AN-set (cuda/Math.cuh:1283-1292)"
)]
const AIRY_AN: [f64; 8] = [
3.46538101525629032477e-01,
1.20075952739645805542e+01,
7.62796053615234516538e+01,
1.68089224934630576269e+02,
1.59756391350164413639e+02,
7.05360906840444183113e+01,
1.40264691163389668864e+01,
9.99999999999999995305e-01,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes airy AD-set (cuda/Math.cuh:1294-1303)"
)]
const AIRY_AD: [f64; 8] = [
5.67594532638770212846e-01,
1.47562562584847203173e+01,
8.45138970141474626562e+01,
1.77318088145400459522e+02,
1.64234692871529701831e+02,
7.14778400825575695274e+01,
1.40959135607834029598e+01,
1.00000000000000000470e+00,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes airy AFN-set (cuda/Math.cuh:1305-1315)"
)]
const AIRY_AFN: [f64; 9] = [
-1.31696323418331795333e-01,
-6.26456544431912369773e-01,
-6.93158036036933542233e-01,
-2.79779981545119124951e-01,
-4.91900132609500318020e-02,
-4.06265923594885404393e-03,
-1.59276496239262096340e-04,
-2.77649108155232920844e-06,
-1.67787698489114633780e-08,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes airy AFD-set (cuda/Math.cuh:1317-1327)"
)]
const AIRY_AFD: [f64; 9] = [
1.33560420706553243746e+01,
3.26825032795224613948e+01,
2.67367040941499554804e+01,
9.18707402907259625840e+00,
1.47529146771666414581e+00,
1.15687173795188044134e-01,
4.40291641615211203805e-03,
7.54720348287414296618e-05,
4.51850092970580378464e-07,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes airy AGN-set (cuda/Math.cuh:1329-1341)"
)]
const AIRY_AGN: [f64; 11] = [
1.97339932091685679179e-02,
3.91103029615688277255e-01,
1.06579897599595591108e+00,
9.39169229816650230044e-01,
3.51465656105547619242e-01,
6.33888919628925490927e-02,
5.85804113048388458567e-03,
2.82851600836737019778e-04,
6.98793669997260967291e-06,
8.11789239554389293311e-08,
3.41551784765923618484e-10,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes airy AGD-set (cuda/Math.cuh:1343-1354)"
)]
const AIRY_AGD: [f64; 10] = [
9.30892908077441974853e+00,
1.98352928718312140417e+01,
1.55646628932864612953e+01,
5.47686069422975497931e+00,
9.54293611618961883998e-01,
8.64580826352392193095e-02,
4.12656523824222607191e-03,
1.01259085116509135510e-04,
1.17166733214413521882e-06,
4.91834570062930015649e-09,
];
#[allow(
clippy::excessive_precision,
reason = "verbatim Cephes airy magic constants 5.64189583547756286948e-01 (1/(2*sqrt(pi))), 0.355028053887817239260 (Ai(0)), 0.258819403792806798405 (-Ai'(0)) from cuda/Math.cuh:1399,1401,1421,1454; full-width for audit-friendly diff"
)]
fn airy_ai_f64(x: f64) -> f64 {
if x.is_infinite() {
return f64::NAN;
}
if x > 103.892 {
return 0.0;
}
let mut domain_flag: i32 = 0;
let mut ai = 0.0;
if x < -2.09 {
let z = 1.0 / (-2.0 * x * (-x).sqrt() / 3.0);
let z2 = z * z;
let mut afn = 0.0;
for &c in &AIRY_AFN {
afn = afn * z2 + c;
}
let mut afd = 0.0;
for &c in &AIRY_AFD {
afd = afd * z2 + c;
}
let mut agn = 0.0;
for &c in &AIRY_AGN {
agn = agn * z2 + c;
}
let mut agd = 0.0;
for &c in &AIRY_AGD {
agd = agd * z2 + c;
}
let t = -2.0 * x * (-x).sqrt() / 3.0 + 0.25 * std::f64::consts::PI;
return 5.64189583547756286948e-01 / (-x).sqrt().sqrt()
* (t.sin() * (1.0 + z2 * afn / afd) - t.cos() * (z * agn / agd));
}
if x >= 2.09 {
domain_flag = 5;
let zeta = 2.0 * x * x.sqrt() / 3.0;
let mut an = 0.0;
for &c in &AIRY_AN {
an = an * (1.0 / zeta) + c;
}
let mut ad = 0.0;
for &c in &AIRY_AD {
ad = ad * (1.0 / zeta) + c;
}
ai = 5.64189583547756286948e-01 * (an / ad) / (2.0 * x.sqrt().sqrt() * zeta.exp());
if x > 8.3203353 {
return ai;
}
}
let mut f = 1.0;
let mut g = x;
let mut k = 1.0;
let mut m = 1.0;
let mut n = x;
let mut t = 1.0;
let z = x * x * x;
while t > CEPHES_MACHEP {
m *= z;
k += 1.0;
m /= k;
n *= z;
k += 1.0;
n /= k;
m /= k;
f += m;
k += 1.0;
n /= k;
g += n;
t = (m / f).abs();
}
if (domain_flag & 1) == 0 {
return 0.355028053887817239260 * f - 0.258819403792806798405 * g;
}
ai
}
fn airy_ai_scalar<T: Float>(x: T) -> T {
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
T::from(airy_ai_f64(xf)).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn gammp_core_f64(a: f64, x: f64) -> f64 {
let gln = lgamma_scalar(a);
if x < a + 1.0 {
let mut ap = a;
let mut sum = 1.0 / a;
let mut del = sum;
for _ in 0..300 {
ap += 1.0;
del *= x / ap;
sum += del;
if del.abs() < sum.abs() * 1e-15 {
break;
}
}
sum * (-x + a * x.ln() - gln).exp()
} else {
1.0 - gammq_core_f64_cf(a, x, gln)
}
}
fn gammq_core_f64_cf(a: f64, x: f64, gln: f64) -> f64 {
let tiny = 1e-300;
let mut b = x + 1.0 - a;
let mut c = 1.0 / tiny;
let mut d = 1.0 / b;
let mut h = d;
for i in 1..300 {
let an = -(i as f64) * (i as f64 - a);
b += 2.0;
d = an * d + b;
if d.abs() < tiny {
d = tiny;
}
c = b + an / c;
if c.abs() < tiny {
c = tiny;
}
d = 1.0 / d;
let del = d * c;
h *= del;
if (del - 1.0).abs() < 1e-15 {
break;
}
}
(-x + a * x.ln() - gln).exp() * h
}
fn gammq_core_f64(a: f64, x: f64) -> f64 {
let gln = lgamma_scalar(a);
if x < a + 1.0 {
1.0 - gammp_core_f64(a, x)
} else {
gammq_core_f64_cf(a, x, gln)
}
}
fn gammainc_scalar<T: Float>(a: T, x: T) -> T {
let af = <T as num_traits::ToPrimitive>::to_f64(&a).unwrap_or(f64::NAN);
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
let result = calc_igamma_f64(af, xf);
T::from(result).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn calc_igamma_f64(a: f64, x: f64) -> f64 {
if a.is_nan() || x.is_nan() {
return f64::NAN;
}
if x < 0.0 || a < 0.0 {
return f64::NAN;
}
if a == 0.0 {
return if x > 0.0 { 1.0 } else { f64::NAN };
}
if x == 0.0 {
return 0.0;
}
if a.is_infinite() {
return if x.is_infinite() { f64::NAN } else { 0.0 };
}
if x.is_infinite() {
return 1.0;
}
gammp_core_f64(a, x)
}
fn gammaincc_scalar<T: Float>(a: T, x: T) -> T {
let af = <T as num_traits::ToPrimitive>::to_f64(&a).unwrap_or(f64::NAN);
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
let result = calc_igammac_f64(af, xf);
T::from(result).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn calc_igammac_f64(a: f64, x: f64) -> f64 {
if a.is_nan() || x.is_nan() {
return f64::NAN;
}
if x < 0.0 || a < 0.0 {
return f64::NAN;
}
if a == 0.0 {
return if x > 0.0 { 0.0 } else { f64::NAN };
}
if x == 0.0 {
return 1.0;
}
if a.is_infinite() {
return if x.is_infinite() { f64::NAN } else { 1.0 };
}
if x.is_infinite() {
return 0.0;
}
gammq_core_f64(a, x)
}
fn log_beta_scalar<T: Float>(a: T, b: T) -> T {
let af = <T as num_traits::ToPrimitive>::to_f64(&a).unwrap_or(f64::NAN);
let bf = <T as num_traits::ToPrimitive>::to_f64(&b).unwrap_or(f64::NAN);
let r = lgamma_scalar(af) + lgamma_scalar(bf) - lgamma_scalar(af + bf);
T::from(r).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn beta_scalar<T: Float>(a: T, b: T) -> T {
let lb = log_beta_scalar::<f64>(
<T as num_traits::ToPrimitive>::to_f64(&a).unwrap_or(f64::NAN),
<T as num_traits::ToPrimitive>::to_f64(&b).unwrap_or(f64::NAN),
);
T::from(lb.exp()).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn multigammaln_scalar<T: Float>(a: T, p: usize) -> T {
let af = <T as num_traits::ToPrimitive>::to_f64(&a).unwrap_or(f64::NAN);
let pf = p as f64;
let c = (pf * (pf - 1.0) / 4.0) * std::f64::consts::PI.ln();
let mut sum = 0.0_f64;
for i in 1..=p {
sum += lgamma_scalar(af + (1.0 - i as f64) / 2.0);
}
T::from(c + sum).unwrap_or_else(|| T::from(f64::NAN).unwrap())
}
fn gammaln_sign_scalar<T: Float>(x: T) -> T {
let one = nt_one::<T>();
let zero = nt_zero::<T>();
if x.is_nan() {
return x;
}
if x > zero {
return one;
}
let xf = <T as num_traits::ToPrimitive>::to_f64(&x).unwrap_or(f64::NAN);
if xf < 0.0 {
if xf.fract() == 0.0 {
return T::from(f64::NAN).unwrap();
}
let fi = xf.floor() as i64;
return if fi.rem_euclid(2) == 0 { one } else { -one };
}
if xf.is_sign_negative() { -one } else { one }
}
pub fn erf<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, erf_scalar)
}
pub fn erfc<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, erfc_scalar)
}
pub fn erfinv<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, erfinv_scalar)
}
pub fn lgamma<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, lgamma_scalar)
}
pub fn digamma<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, digamma_scalar)
}
pub fn log1p<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, |x| x.ln_1p())
}
pub fn expm1<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, |x| x.exp_m1())
}
pub fn sinc<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, sinc_scalar)
}
pub fn xlogy<T: Float>(x: &Tensor<T>, y: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
binary_map(x, y, xlogy_scalar)
}
pub fn entr<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) =
special_gpu_simple(input, "entr", |b, h| b.entr_f32(h), |b, h| b.entr_f64(h))?
{
return Ok(out);
}
unary_map(input, entr_scalar)
}
pub fn ndtr<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) =
special_gpu_simple(input, "ndtr", |b, h| b.ndtr_f32(h), |b, h| b.ndtr_f64(h))?
{
return Ok(out);
}
unary_map(input, ndtr_scalar)
}
pub fn ndtri<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) =
special_gpu_simple(input, "ndtri", |b, h| b.ndtri_f32(h), |b, h| b.ndtri_f64(h))?
{
return Ok(out);
}
unary_map(input, ndtri_scalar)
}
pub fn i0<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = special_gpu_simple(input, "i0", |b, h| b.i0_f32(h), |b, h| b.i0_f64(h))? {
return Ok(out);
}
unary_map(input, i0_scalar)
}
pub fn i0e<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = special_gpu_simple(input, "i0e", |b, h| b.i0e_f32(h), |b, h| b.i0e_f64(h))? {
return Ok(out);
}
unary_map(input, i0e_scalar)
}
pub fn i1<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = special_gpu_simple(input, "i1", |b, h| b.i1_f32(h), |b, h| b.i1_f64(h))? {
return Ok(out);
}
unary_map(input, i1_scalar)
}
pub fn i1e<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = special_gpu_simple(input, "i1e", |b, h| b.i1e_f32(h), |b, h| b.i1e_f64(h))? {
return Ok(out);
}
unary_map(input, i1e_scalar)
}
pub fn spherical_bessel_j0<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = special_gpu_simple(
input,
"spherical_bessel_j0",
|b, h| b.spherical_bessel_j0_f32(h),
|b, h| b.spherical_bessel_j0_f64(h),
)? {
return Ok(out);
}
unary_map(input, spherical_bessel_j0_scalar)
}
pub fn modified_bessel_k0<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = special_gpu_simple(
input,
"modified_bessel_k0",
|b, h| b.modified_bessel_k0_f32(h),
|b, h| b.modified_bessel_k0_f64(h),
)? {
return Ok(out);
}
unary_map(input, modified_bessel_k0_scalar)
}
pub fn scaled_modified_bessel_k0<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = special_gpu_simple(
input,
"scaled_modified_bessel_k0",
|b, h| b.scaled_modified_bessel_k0_f32(h),
|b, h| b.scaled_modified_bessel_k0_f64(h),
)? {
return Ok(out);
}
unary_map(input, scaled_modified_bessel_k0_scalar)
}
pub fn modified_bessel_k1<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = special_gpu_simple(
input,
"modified_bessel_k1",
|b, h| b.modified_bessel_k1_f32(h),
|b, h| b.modified_bessel_k1_f64(h),
)? {
return Ok(out);
}
unary_map(input, modified_bessel_k1_scalar)
}
pub fn scaled_modified_bessel_k1<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = special_gpu_simple(
input,
"scaled_modified_bessel_k1",
|b, h| b.scaled_modified_bessel_k1_f32(h),
|b, h| b.scaled_modified_bessel_k1_f64(h),
)? {
return Ok(out);
}
unary_map(input, scaled_modified_bessel_k1_scalar)
}
pub fn zeta<T: Float>(input: &Tensor<T>, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = special_gpu_binary(
input,
other,
"zeta",
|b, x, q| b.zeta_f32(x, q),
|b, x, q| b.zeta_f64(x, q),
)? {
return Ok(out);
}
binary_map(input, other, zeta_scalar)
}
pub fn airy_ai<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = special_gpu_simple(
input,
"airy_ai",
|b, h| b.airy_ai_f32(h),
|b, h| b.airy_ai_f64(h),
)? {
return Ok(out);
}
unary_map(input, airy_ai_scalar)
}
pub fn gammainc<T: Float>(input: &Tensor<T>, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
binary_map(input, other, gammainc_scalar)
}
pub fn gammaincc<T: Float>(input: &Tensor<T>, other: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
binary_map(input, other, gammaincc_scalar)
}
pub fn log_beta<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
binary_map(a, b, log_beta_scalar)
}
pub fn beta<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
binary_map(a, b, beta_scalar)
}
pub fn multigammaln<T: Float>(input: &Tensor<T>, p: usize) -> FerrotorchResult<Tensor<T>> {
if p == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "multigammaln: p has to be greater than or equal to 1".to_string(),
});
}
unary_map(input, move |x| multigammaln_scalar(x, p))
}
pub fn mvlgamma<T: Float>(input: &Tensor<T>, p: usize) -> FerrotorchResult<Tensor<T>> {
multigammaln(input, p)
}
pub fn gammaln_sign<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, gammaln_sign_scalar)
}
use crate::error::FerrotorchError;
#[inline]
fn poly_is_f32<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<f32>()
}
#[inline]
fn poly_is_f64<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<f64>()
}
#[inline]
fn hermitian_limit<T: Float>() -> usize {
if poly_is_f32::<T>() {
128
} else if poly_is_f64::<T>() {
512
} else {
1024
}
}
#[inline]
fn poly_gpu_output<T: Float>(
handle: crate::gpu_dispatch::GpuBufferHandle,
shape: Vec<usize>,
) -> FerrotorchResult<Tensor<T>> {
Tensor::from_storage(crate::storage::TensorStorage::gpu(handle), shape, false)
}
fn poly_gpu_simple<T: Float>(
input: &Tensor<T>,
n: usize,
op: &'static str,
f32_call: impl Fn(
&dyn crate::gpu_dispatch::GpuBackend,
&crate::gpu_dispatch::GpuBufferHandle,
) -> FerrotorchResult<crate::gpu_dispatch::GpuBufferHandle>,
f64_call: impl Fn(
&dyn crate::gpu_dispatch::GpuBackend,
&crate::gpu_dispatch::GpuBufferHandle,
) -> FerrotorchResult<crate::gpu_dispatch::GpuBufferHandle>,
) -> FerrotorchResult<Option<Tensor<T>>> {
let _ = n;
if !input.is_cuda() {
return Ok(None);
}
if !(poly_is_f32::<T>() || poly_is_f64::<T>()) {
return Err(FerrotorchError::NotImplementedOnCuda { op });
}
let backend = crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let handle = input.gpu_handle()?;
let out_handle = if poly_is_f32::<T>() {
f32_call(backend, handle)?
} else {
f64_call(backend, handle)?
};
Ok(Some(poly_gpu_output::<T>(
out_handle,
input.shape().to_vec(),
)?))
}
fn special_gpu_simple<T: Float>(
input: &Tensor<T>,
op: &'static str,
f32_call: impl Fn(
&dyn crate::gpu_dispatch::GpuBackend,
&crate::gpu_dispatch::GpuBufferHandle,
) -> FerrotorchResult<crate::gpu_dispatch::GpuBufferHandle>,
f64_call: impl Fn(
&dyn crate::gpu_dispatch::GpuBackend,
&crate::gpu_dispatch::GpuBufferHandle,
) -> FerrotorchResult<crate::gpu_dispatch::GpuBufferHandle>,
) -> FerrotorchResult<Option<Tensor<T>>> {
if !input.is_cuda() {
return Ok(None);
}
if !(poly_is_f32::<T>() || poly_is_f64::<T>()) {
return Err(FerrotorchError::NotImplementedOnCuda { op });
}
let backend = crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let handle = input.gpu_handle()?;
let out_handle = if poly_is_f32::<T>() {
f32_call(backend, handle)?
} else {
f64_call(backend, handle)?
};
Ok(Some(poly_gpu_output::<T>(
out_handle,
input.shape().to_vec(),
)?))
}
fn special_gpu_binary<T: Float>(
x: &Tensor<T>,
q: &Tensor<T>,
op: &'static str,
f32_call: impl Fn(
&dyn crate::gpu_dispatch::GpuBackend,
&crate::gpu_dispatch::GpuBufferHandle,
&crate::gpu_dispatch::GpuBufferHandle,
) -> FerrotorchResult<crate::gpu_dispatch::GpuBufferHandle>,
f64_call: impl Fn(
&dyn crate::gpu_dispatch::GpuBackend,
&crate::gpu_dispatch::GpuBufferHandle,
&crate::gpu_dispatch::GpuBufferHandle,
) -> FerrotorchResult<crate::gpu_dispatch::GpuBufferHandle>,
) -> FerrotorchResult<Option<Tensor<T>>> {
if !x.is_cuda() && !q.is_cuda() {
return Ok(None);
}
if !(poly_is_f32::<T>() || poly_is_f64::<T>()) {
return Err(FerrotorchError::NotImplementedOnCuda { op });
}
if x.is_cuda() != q.is_cuda() || x.shape() != q.shape() {
return Err(FerrotorchError::NotImplementedOnCuda { op });
}
let backend = crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let x = x.contiguous()?;
let q = q.contiguous()?;
let xh = x.gpu_handle()?;
let qh = q.gpu_handle()?;
let out_handle = if poly_is_f32::<T>() {
f32_call(backend, xh, qh)?
} else {
f64_call(backend, xh, qh)?
};
Ok(Some(poly_gpu_output::<T>(out_handle, x.shape().to_vec())?))
}
fn poly_gpu_chebyshev<T: Float>(
input: &Tensor<T>,
n: usize,
seed_a: f64,
seed_b: f64,
shift: bool,
op: &'static str,
) -> FerrotorchResult<Option<Tensor<T>>> {
if !input.is_cuda() {
return Ok(None);
}
if !(poly_is_f32::<T>() || poly_is_f64::<T>()) {
return Err(FerrotorchError::NotImplementedOnCuda { op });
}
let backend = crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let handle = input.gpu_handle()?;
let out_handle = if poly_is_f32::<T>() {
backend.chebyshev_poly_f32(handle, n, seed_a as f32, seed_b as f32, shift)?
} else {
backend.chebyshev_poly_f64(handle, n, seed_a, seed_b, shift)?
};
Ok(Some(poly_gpu_output::<T>(
out_handle,
input.shape().to_vec(),
)?))
}
fn elementwise_native<T: Float, F: Fn(T) -> T>(
input: &Tensor<T>,
_op: &'static str,
f: F,
) -> FerrotorchResult<Tensor<T>> {
let data = input.data_vec()?;
let out: Vec<T> = data.into_iter().map(f).collect();
crate::tensor::Tensor::from_storage(
crate::storage::TensorStorage::cpu(out),
input.shape().to_vec(),
false,
)
}
fn nan_like<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let nan = T::from(f64::NAN).unwrap_or_else(T::nan);
let out = vec![nan; input.numel()];
crate::tensor::Tensor::from_storage(
crate::storage::TensorStorage::cpu(out),
input.shape().to_vec(),
false,
)
}
pub fn chebyshev_polynomial_t<T: Float>(
input: &Tensor<T>,
n: usize,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = poly_gpu_chebyshev(input, n, 1.0, 0.0, false, "chebyshev_polynomial_t")? {
return Ok(out);
}
elementwise_native(input, "chebyshev_polynomial_t", move |x| chebyshev_t(n, x))
}
pub fn chebyshev_polynomial_u<T: Float>(
input: &Tensor<T>,
n: usize,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = poly_gpu_chebyshev(input, n, 2.0, 0.0, false, "chebyshev_polynomial_u")? {
return Ok(out);
}
elementwise_native(input, "chebyshev_polynomial_u", move |x| chebyshev_u(n, x))
}
pub fn chebyshev_polynomial_v<T: Float>(
input: &Tensor<T>,
n: usize,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = poly_gpu_chebyshev(input, n, 2.0, -1.0, false, "chebyshev_polynomial_v")? {
return Ok(out);
}
elementwise_native(input, "chebyshev_polynomial_v", move |x| chebyshev_v(n, x))
}
pub fn chebyshev_polynomial_w<T: Float>(
input: &Tensor<T>,
n: usize,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = poly_gpu_chebyshev(input, n, 2.0, 1.0, false, "chebyshev_polynomial_w")? {
return Ok(out);
}
elementwise_native(input, "chebyshev_polynomial_w", move |x| chebyshev_w(n, x))
}
pub fn hermite_polynomial_h<T: Float>(input: &Tensor<T>, n: usize) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = poly_gpu_simple(
input,
n,
"hermite_polynomial_h",
|b, h| b.hermite_h_poly_f32(h, n),
|b, h| b.hermite_h_poly_f64(h, n),
)? {
return Ok(out);
}
if n > hermitian_limit::<T>() {
return nan_like(input);
}
elementwise_native(input, "hermite_polynomial_h", move |x| hermite_h(n, x))
}
pub fn hermite_polynomial_he<T: Float>(input: &Tensor<T>, n: usize) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = poly_gpu_simple(
input,
n,
"hermite_polynomial_he",
|b, h| b.hermite_he_poly_f32(h, n),
|b, h| b.hermite_he_poly_f64(h, n),
)? {
return Ok(out);
}
if n > hermitian_limit::<T>() {
return nan_like(input);
}
elementwise_native(input, "hermite_polynomial_he", move |x| hermite_he(n, x))
}
pub fn laguerre_polynomial_l<T: Float>(input: &Tensor<T>, n: usize) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = poly_gpu_simple(
input,
n,
"laguerre_polynomial_l",
|b, h| b.laguerre_poly_f32(h, n),
|b, h| b.laguerre_poly_f64(h, n),
)? {
return Ok(out);
}
elementwise_native(input, "laguerre_polynomial_l", move |x| laguerre_l(n, x))
}
pub fn legendre_polynomial_p<T: Float>(input: &Tensor<T>, n: usize) -> FerrotorchResult<Tensor<T>> {
if let Some(out) = poly_gpu_simple(
input,
n,
"legendre_polynomial_p",
|b, h| b.legendre_poly_f32(h, n),
|b, h| b.legendre_poly_f64(h, n),
)? {
return Ok(out);
}
elementwise_native(input, "legendre_polynomial_p", move |x| legendre_p(n, x))
}
pub fn shifted_chebyshev_polynomial_t<T: Float>(
input: &Tensor<T>,
n: usize,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) =
poly_gpu_chebyshev(input, n, 1.0, 0.0, true, "shifted_chebyshev_polynomial_t")?
{
return Ok(out);
}
elementwise_native(input, "shifted_chebyshev_polynomial_t", move |x| {
let one = nt_one::<T>();
chebyshev_t(n, x + x - one)
})
}
pub fn shifted_chebyshev_polynomial_u<T: Float>(
input: &Tensor<T>,
n: usize,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) =
poly_gpu_chebyshev(input, n, 2.0, 0.0, true, "shifted_chebyshev_polynomial_u")?
{
return Ok(out);
}
elementwise_native(input, "shifted_chebyshev_polynomial_u", move |x| {
let one = nt_one::<T>();
chebyshev_u(n, x + x - one)
})
}
pub fn shifted_chebyshev_polynomial_v<T: Float>(
input: &Tensor<T>,
n: usize,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) =
poly_gpu_chebyshev(input, n, 2.0, -1.0, true, "shifted_chebyshev_polynomial_v")?
{
return Ok(out);
}
elementwise_native(input, "shifted_chebyshev_polynomial_v", move |x| {
let one = nt_one::<T>();
chebyshev_v(n, x + x - one)
})
}
pub fn shifted_chebyshev_polynomial_w<T: Float>(
input: &Tensor<T>,
n: usize,
) -> FerrotorchResult<Tensor<T>> {
if let Some(out) =
poly_gpu_chebyshev(input, n, 2.0, 1.0, true, "shifted_chebyshev_polynomial_w")?
{
return Ok(out);
}
elementwise_native(input, "shifted_chebyshev_polynomial_w", move |x| {
let one = nt_one::<T>();
chebyshev_w(n, x + x - one)
})
}
#[inline]
fn nt_from_usize<T: Float>(k: usize) -> T {
T::from(k).unwrap_or_else(T::nan)
}
fn hermite_h<T: Float>(n: usize, x: T) -> T {
let one = nt_one::<T>();
if n == 0 {
return one;
}
if n == 1 {
return x + x;
}
let mut prev2 = one;
let mut prev1 = x + x;
for k in 1..n {
let kf = nt_from_usize::<T>(k);
let next = (x + x) * prev1 - (kf + kf) * prev2;
prev2 = prev1;
prev1 = next;
}
prev1
}
fn hermite_he<T: Float>(n: usize, x: T) -> T {
let one = nt_one::<T>();
if n == 0 {
return one;
}
if n == 1 {
return x;
}
let mut prev2 = one;
let mut prev1 = x;
for k in 1..n {
let kf = nt_from_usize::<T>(k);
let next = x * prev1 - kf * prev2;
prev2 = prev1;
prev1 = next;
}
prev1
}
fn chebyshev_t<T: Float>(n: usize, x: T) -> T {
let one = nt_one::<T>();
if n == 0 {
return one;
}
if n == 1 {
return x;
}
let mut prev2 = one;
let mut prev1 = x;
for _ in 2..=n {
if prev1.is_nan() {
break;
}
let next = (x + x) * prev1 - prev2;
prev2 = prev1;
prev1 = next;
}
prev1
}
fn chebyshev_u<T: Float>(n: usize, x: T) -> T {
let one = nt_one::<T>();
if n == 0 {
return one;
}
if n == 1 {
return x + x;
}
let mut prev2 = one;
let mut prev1 = x + x;
for _ in 2..=n {
if prev1.is_nan() {
break;
}
let next = (x + x) * prev1 - prev2;
prev2 = prev1;
prev1 = next;
}
prev1
}
fn chebyshev_v<T: Float>(n: usize, x: T) -> T {
let one = nt_one::<T>();
if n == 0 {
return one;
}
if n == 1 {
return x + x - one;
}
let mut prev2 = one;
let mut prev1 = x + x - one;
for _ in 2..=n {
if prev1.is_nan() {
break;
}
let next = (x + x) * prev1 - prev2;
prev2 = prev1;
prev1 = next;
}
prev1
}
fn chebyshev_w<T: Float>(n: usize, x: T) -> T {
let one = nt_one::<T>();
if n == 0 {
return one;
}
if n == 1 {
return x + x + one;
}
let mut prev2 = one;
let mut prev1 = x + x + one;
for _ in 2..=n {
if prev1.is_nan() {
break;
}
let next = (x + x) * prev1 - prev2;
prev2 = prev1;
prev1 = next;
}
prev1
}
fn laguerre_l<T: Float>(n: usize, x: T) -> T {
let one = nt_one::<T>();
if n == 0 {
return one;
}
if n == 1 {
return one - x;
}
let mut prev2 = one;
let mut prev1 = one - x;
for k in 1..n {
if prev1.is_nan() {
break;
}
let kf = nt_from_usize::<T>(k);
let next = ((kf + kf + (one - x)) * prev1 - kf * prev2) / (kf + one);
prev2 = prev1;
prev1 = next;
}
prev1
}
fn legendre_p<T: Float>(n: usize, x: T) -> T {
let one = nt_one::<T>();
if n == 0 {
return one;
}
if n == 1 {
return x;
}
let mut prev2 = one;
let mut prev1 = x;
for k in 1..n {
if prev1.is_nan() {
break;
}
let kf = nt_from_usize::<T>(k);
let next = ((kf + kf + one) * x * prev1 - kf * prev2) / (kf + one);
prev2 = prev1;
prev1 = next;
}
prev1
}
#[cfg(test)]
#[allow(
clippy::excessive_precision,
clippy::inconsistent_digit_grouping,
clippy::unreadable_literal,
clippy::float_cmp,
clippy::type_complexity,
clippy::approx_constant,
reason = "oracle divergence tests: expected values are copied verbatim from live torch 2.11 / scipy / Cephes (full precision + grouping intentional); float comparisons are deliberately exact byte-for-byte parity checks; the (name, fn, [f64;3]) case tuples are a local test fixture, not a public type"
)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn t(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
#[test]
fn erf_zero() {
let input = t(&[0.0], &[1]);
let result = erf(&input).unwrap();
assert!((result.data().unwrap()[0]).abs() < 1e-10);
}
#[test]
fn erf_symmetry() {
let input = t(&[0.5, 1.0, 2.0], &[3]);
let neg_input = t(&[-0.5, -1.0, -2.0], &[3]);
let pos = erf(&input).unwrap();
let neg = erf(&neg_input).unwrap();
let pd = pos.data().unwrap();
let nd = neg.data().unwrap();
for i in 0..3 {
assert!(
(pd[i] + nd[i]).abs() < 1e-6,
"erf({}) + erf({}) = {} (expected 0)",
input.data().unwrap()[i],
neg_input.data().unwrap()[i],
pd[i] + nd[i],
);
}
}
#[test]
fn erf_large_value() {
let input = t(&[f64::INFINITY], &[1]);
let result = erf(&input).unwrap();
assert!((result.data().unwrap()[0] - 1.0).abs() < 1e-6);
}
#[test]
fn erf_known_values() {
let input = t(&[1.0], &[1]);
let result = erf(&input).unwrap();
assert!(
(result.data().unwrap()[0] - 0.8427007929).abs() < 2e-7,
"erf(1) = {}",
result.data().unwrap()[0]
);
}
#[test]
fn erfc_is_one_minus_erf() {
let input = t(&[0.0, 0.5, 1.0, -0.5, 2.0], &[5]);
let erf_result = erf(&input).unwrap();
let erfc_result = erfc(&input).unwrap();
let ed = erf_result.data().unwrap();
let cd = erfc_result.data().unwrap();
for i in 0..5 {
assert!(
(ed[i] + cd[i] - 1.0).abs() < 1e-10,
"erf({0}) + erfc({0}) = {1} (expected 1.0)",
input.data().unwrap()[i],
ed[i] + cd[i],
);
}
}
#[test]
fn erfinv_zero() {
let input = t(&[0.0], &[1]);
let result = erfinv(&input).unwrap();
assert!(result.data().unwrap()[0].abs() < 1e-10);
}
#[test]
fn erfinv_roundtrip() {
let xs = t(&[0.1, 0.5, 1.0, -0.3, -1.5], &[5]);
let erf_xs = erf(&xs).unwrap();
let roundtrip = erfinv(&erf_xs).unwrap();
let orig = xs.data().unwrap();
let rt = roundtrip.data().unwrap();
for i in 0..5 {
assert!(
(orig[i] - rt[i]).abs() < 0.01,
"erfinv(erf({})) = {} (expected {})",
orig[i],
rt[i],
orig[i],
);
}
}
#[test]
fn erfinv_boundary() {
let input = t(&[1.0, -1.0], &[2]);
let result = erfinv(&input).unwrap();
let d = result.data().unwrap();
assert!(d[0].is_infinite() && d[0] > 0.0, "erfinv(1) should be +inf");
assert!(
d[1].is_infinite() && d[1] < 0.0,
"erfinv(-1) should be -inf"
);
}
#[test]
fn lgamma_at_one_and_two() {
let input = t(&[1.0, 2.0], &[2]);
let result = lgamma(&input).unwrap();
let d = result.data().unwrap();
assert!(d[0].abs() < 1e-10, "lgamma(1) = {} (expected 0)", d[0]);
assert!(d[1].abs() < 1e-10, "lgamma(2) = {} (expected 0)", d[1]);
}
#[test]
fn lgamma_known_values() {
let input = t(&[0.5], &[1]);
let result = lgamma(&input).unwrap();
let expected = 0.5723649429247001;
assert!(
(result.data().unwrap()[0] - expected).abs() < 1e-8,
"lgamma(0.5) = {} (expected {})",
result.data().unwrap()[0],
expected,
);
}
#[test]
fn lgamma_factorial() {
let input = t(&[6.0], &[1]);
let result = lgamma(&input).unwrap();
let expected = (120.0f64).ln();
assert!(
(result.data().unwrap()[0] - expected).abs() < 1e-8,
"lgamma(6) = {} (expected {})",
result.data().unwrap()[0],
expected,
);
}
#[test]
fn digamma_known_values() {
let input = t(&[1.0], &[1]);
let result = digamma(&input).unwrap();
let expected = -0.5772156649015329;
assert!(
(result.data().unwrap()[0] - expected).abs() < 1e-6,
"digamma(1) = {} (expected {})",
result.data().unwrap()[0],
expected,
);
}
#[test]
fn digamma_recurrence() {
let x_val = 2.5;
let input_x = t(&[x_val], &[1]);
let input_x1 = t(&[x_val + 1.0], &[1]);
let psi_x = digamma(&input_x).unwrap().data().unwrap()[0];
let psi_x1 = digamma(&input_x1).unwrap().data().unwrap()[0];
assert!(
(psi_x1 - psi_x - 1.0 / x_val).abs() < 1e-8,
"psi({}) - psi({}) = {} (expected {})",
x_val + 1.0,
x_val,
psi_x1 - psi_x,
1.0 / x_val,
);
}
#[test]
fn log1p_zero() {
let input = t(&[0.0], &[1]);
let result = log1p(&input).unwrap();
assert!(result.data().unwrap()[0].abs() < 1e-15);
}
#[test]
fn log1p_small() {
let small = 1e-10;
let input = t(&[small], &[1]);
let result = log1p(&input).unwrap();
assert!(
(result.data().unwrap()[0] - small).abs() < 1e-15,
"log1p({small}) = {} (expected ~{small})",
result.data().unwrap()[0],
);
}
#[test]
fn log1p_known() {
let input = t(&[1.0], &[1]);
let result = log1p(&input).unwrap();
assert!((result.data().unwrap()[0] - std::f64::consts::LN_2).abs() < 1e-15,);
}
#[test]
fn expm1_zero() {
let input = t(&[0.0], &[1]);
let result = expm1(&input).unwrap();
assert!(result.data().unwrap()[0].abs() < 1e-15);
}
#[test]
fn expm1_small() {
let small = 1e-10;
let input = t(&[small], &[1]);
let result = expm1(&input).unwrap();
assert!(
(result.data().unwrap()[0] - small).abs() < 1e-15,
"expm1({small}) = {} (expected ~{small})",
result.data().unwrap()[0],
);
}
#[test]
fn expm1_known() {
let input = t(&[1.0], &[1]);
let result = expm1(&input).unwrap();
let expected = std::f64::consts::E - 1.0;
assert!((result.data().unwrap()[0] - expected).abs() < 1e-14,);
}
#[test]
fn sinc_zero() {
let input = t(&[0.0], &[1]);
let result = sinc(&input).unwrap();
assert!(
(result.data().unwrap()[0] - 1.0).abs() < 1e-15,
"sinc(0) = {} (expected 1)",
result.data().unwrap()[0],
);
}
#[test]
#[allow(clippy::needless_range_loop)]
fn sinc_integer() {
let input = t(&[1.0, 2.0, -1.0, -3.0], &[4]);
let result = sinc(&input).unwrap();
let d = result.data().unwrap();
for i in 0..4 {
assert!(
d[i].abs() < 1e-15,
"sinc({}) = {} (expected 0)",
input.data().unwrap()[i],
d[i],
);
}
}
#[test]
fn sinc_half() {
let input = t(&[0.5], &[1]);
let result = sinc(&input).unwrap();
let expected = 2.0 / std::f64::consts::PI;
assert!(
(result.data().unwrap()[0] - expected).abs() < 1e-15,
"sinc(0.5) = {} (expected {})",
result.data().unwrap()[0],
expected,
);
}
#[test]
#[allow(clippy::needless_range_loop)]
fn xlogy_zero_x() {
let x = t(&[0.0, 0.0, 0.0], &[3]);
let y = t(&[1.0, 0.0, f64::INFINITY], &[3]);
let result = xlogy(&x, &y).unwrap();
let d = result.data().unwrap();
for i in 0..3 {
assert!(
d[i] == 0.0,
"xlogy(0, {}) = {} (expected 0)",
y.data().unwrap()[i],
d[i],
);
}
}
#[test]
fn xlogy_normal() {
let x = t(&[2.0], &[1]);
let y = t(&[std::f64::consts::E], &[1]);
let result = xlogy(&x, &y).unwrap();
assert!(
(result.data().unwrap()[0] - 2.0).abs() < 1e-14,
"xlogy(2, e) = {} (expected 2)",
result.data().unwrap()[0],
);
}
#[test]
fn xlogy_broadcast() {
let x = t(&[2.0, 3.0], &[2]);
let y = t(&[std::f64::consts::E, std::f64::consts::E], &[2]);
let result = xlogy(&x, &y).unwrap();
let d = result.data().unwrap();
assert!((d[0] - 2.0).abs() < 1e-14);
assert!((d[1] - 3.0).abs() < 1e-14);
}
#[test]
fn erf_f32() {
let input =
Tensor::from_storage(TensorStorage::cpu(vec![0.0f32, 1.0, -1.0]), vec![3], false)
.unwrap();
let result = erf(&input).unwrap();
let d = result.data().unwrap();
assert!(d[0].abs() < 1e-6);
assert!((d[1] - 0.8427008).abs() < 1e-5);
assert!((d[2] + 0.8427008).abs() < 1e-5);
}
#[test]
fn erf_2d() {
let input = t(&[0.0, 0.5, 1.0, -0.5, -1.0, 2.0], &[2, 3]);
let result = erf(&input).unwrap();
assert_eq!(result.shape(), &[2, 3]);
let d = result.data().unwrap();
assert!(d[0].abs() < 1e-10); assert!(d[2] > 0.8); assert!(d[3] < 0.0); }
fn close(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
fn xs() -> Tensor<f64> {
t(&[0.0, 0.5, 1.0, -0.5, -1.0, 0.25], &[6])
}
#[test]
fn chebyshev_t_n0_is_one() {
let r = chebyshev_polynomial_t(&xs(), 0).unwrap();
for &v in r.data().unwrap() {
assert!(close(v, 1.0, 1e-12));
}
}
#[test]
fn chebyshev_t_n1_is_x() {
let x = xs();
let r = chebyshev_polynomial_t(&x, 1).unwrap();
for (a, b) in r.data().unwrap().iter().zip(x.data().unwrap().iter()) {
assert!(close(*a, *b, 1e-12));
}
}
#[test]
fn chebyshev_t_n2_is_2xx_minus_one() {
let x = xs();
let r = chebyshev_polynomial_t(&x, 2).unwrap();
for (a, &xv) in r.data().unwrap().iter().zip(x.data().unwrap().iter()) {
assert!(close(*a, 2.0 * xv * xv - 1.0, 1e-12));
}
}
#[test]
fn chebyshev_t_at_endpoints() {
let pts = t(&[1.0, -1.0], &[2]);
for n in 0..6 {
let r = chebyshev_polynomial_t(&pts, n).unwrap();
let d = r.data().unwrap();
assert!(close(d[0], 1.0, 1e-12), "T_{n}(1) = {}", d[0]);
let expected_neg = if n % 2 == 0 { 1.0 } else { -1.0 };
assert!(close(d[1], expected_neg, 1e-12), "T_{n}(-1) = {}", d[1]);
}
}
#[test]
fn chebyshev_u_n0_n1_n2() {
let x = t(&[0.5], &[1]);
let xv = 0.5;
assert!(close(
chebyshev_polynomial_u(&x, 0).unwrap().data().unwrap()[0],
1.0,
1e-12
));
assert!(close(
chebyshev_polynomial_u(&x, 1).unwrap().data().unwrap()[0],
2.0 * xv,
1e-12,
));
assert!(close(
chebyshev_polynomial_u(&x, 2).unwrap().data().unwrap()[0],
4.0 * xv * xv - 1.0,
1e-12,
));
}
#[test]
fn chebyshev_v_endpoints() {
let pts = t(&[1.0, 0.0], &[2]);
for n in 0..4 {
let r = chebyshev_polynomial_v(&pts, n).unwrap();
assert!(close(r.data().unwrap()[0], 1.0, 1e-12));
}
let r1 = chebyshev_polynomial_v(&pts, 1).unwrap();
assert!(close(r1.data().unwrap()[1], -1.0, 1e-12));
}
#[test]
fn chebyshev_w_endpoints() {
let zero = t(&[0.0], &[1]);
assert!(close(
chebyshev_polynomial_w(&zero, 1).unwrap().data().unwrap()[0],
1.0,
1e-12
));
}
#[test]
fn hermite_h_known_values() {
let x = t(&[0.5], &[1]);
let xv = 0.5;
assert!(close(
hermite_polynomial_h(&x, 0).unwrap().data().unwrap()[0],
1.0,
1e-12
));
assert!(close(
hermite_polynomial_h(&x, 1).unwrap().data().unwrap()[0],
2.0 * xv,
1e-12
));
assert!(close(
hermite_polynomial_h(&x, 2).unwrap().data().unwrap()[0],
4.0 * xv * xv - 2.0,
1e-12,
));
assert!(close(
hermite_polynomial_h(&x, 3).unwrap().data().unwrap()[0],
8.0 * xv * xv * xv - 12.0 * xv,
1e-12,
));
}
#[test]
fn hermite_he_known_values() {
let x = t(&[0.5], &[1]);
let xv = 0.5;
assert!(close(
hermite_polynomial_he(&x, 0).unwrap().data().unwrap()[0],
1.0,
1e-12
));
assert!(close(
hermite_polynomial_he(&x, 1).unwrap().data().unwrap()[0],
xv,
1e-12
));
assert!(close(
hermite_polynomial_he(&x, 2).unwrap().data().unwrap()[0],
xv * xv - 1.0,
1e-12,
));
assert!(close(
hermite_polynomial_he(&x, 3).unwrap().data().unwrap()[0],
xv * xv * xv - 3.0 * xv,
1e-12,
));
}
#[test]
fn laguerre_l_known_values() {
let x = t(&[0.5], &[1]);
let xv = 0.5;
assert!(close(
laguerre_polynomial_l(&x, 0).unwrap().data().unwrap()[0],
1.0,
1e-12
));
assert!(close(
laguerre_polynomial_l(&x, 1).unwrap().data().unwrap()[0],
1.0 - xv,
1e-12
));
assert!(close(
laguerre_polynomial_l(&x, 2).unwrap().data().unwrap()[0],
f64::midpoint(xv * xv - 4.0 * xv, 2.0),
1e-12,
));
}
#[test]
fn legendre_p_known_values() {
let x = t(&[0.5], &[1]);
let xv = 0.5;
assert!(close(
legendre_polynomial_p(&x, 0).unwrap().data().unwrap()[0],
1.0,
1e-12
));
assert!(close(
legendre_polynomial_p(&x, 1).unwrap().data().unwrap()[0],
xv,
1e-12
));
assert!(close(
legendre_polynomial_p(&x, 2).unwrap().data().unwrap()[0],
(3.0 * xv * xv - 1.0) / 2.0,
1e-12,
));
assert!(close(
legendre_polynomial_p(&x, 3).unwrap().data().unwrap()[0],
(5.0 * xv * xv * xv - 3.0 * xv) / 2.0,
1e-12,
));
}
#[test]
fn legendre_p_endpoints() {
let pts = t(&[1.0, -1.0], &[2]);
for n in 0..6 {
let r = legendre_polynomial_p(&pts, n).unwrap();
let d = r.data().unwrap();
assert!(close(d[0], 1.0, 1e-12), "P_{n}(1) = {}", d[0]);
let expected_neg = if n % 2 == 0 { 1.0 } else { -1.0 };
assert!(close(d[1], expected_neg, 1e-12), "P_{n}(-1) = {}", d[1]);
}
}
#[test]
fn shifted_chebyshev_t_matches_t_of_2x_minus_1() {
let x = xs();
for n in 0..5 {
let shifted = shifted_chebyshev_polynomial_t(&x, n).unwrap();
let xs_data = x.data().unwrap();
let mapped: Vec<f64> = xs_data.iter().map(|v| 2.0 * v - 1.0).collect();
let mapped_t = t(&mapped, &[mapped.len()]);
let direct = chebyshev_polynomial_t(&mapped_t, n).unwrap();
for (s, d) in shifted
.data()
.unwrap()
.iter()
.zip(direct.data().unwrap().iter())
{
assert!(close(*s, *d, 1e-12), "T*_{n} mismatch at n={n}: {s} vs {d}");
}
}
}
#[test]
fn polynomial_fns_reject_cuda_tensors_explicitly() {
let x = t(&[0.0], &[1]);
assert!(chebyshev_polynomial_t(&x, 3).is_ok());
}
#[test]
fn gammainc_series_region_matches_oracle() {
let a = t(&[2.0], &[1]);
let x = t(&[1.5], &[1]);
let r = gammainc(&a, &x).unwrap();
assert!(
(r.data().unwrap()[0] - 0.442_174_599_628_925_2).abs() < 1e-12,
"got {}",
r.data().unwrap()[0]
);
}
#[test]
fn gammainc_continued_fraction_region_matches_oracle() {
let a = t(&[7.5], &[1]);
let x = t(&[10.0], &[1]);
let r = gammainc(&a, &x).unwrap();
assert!(
(r.data().unwrap()[0] - 0.828_067_310_623_399_1).abs() < 1e-12,
"got {}",
r.data().unwrap()[0]
);
}
#[test]
fn gammaincc_matches_oracle() {
let a = t(&[2.0, 4.0], &[2]);
let x = t(&[1.5, 3.0], &[2]);
let r = gammaincc(&a, &x).unwrap();
let d = r.data().unwrap();
assert!(
(d[0] - 0.557_825_400_371_074_8).abs() < 1e-12,
"got {}",
d[0]
);
assert!(
(d[1] - 0.647_231_888_782_231_3).abs() < 1e-12,
"got {}",
d[1]
);
}
#[test]
fn gammainc_plus_gammaincc_is_one() {
let a = t(&[4.0, 4.0, 4.0], &[3]);
let x = t(&[3.0, 4.0, 5.0], &[3]);
let p = gammainc(&a, &x).unwrap();
let q = gammaincc(&a, &x).unwrap();
let pd = p.data().unwrap();
let qd = q.data().unwrap();
for i in 0..3 {
assert!(
(pd[i] + qd[i] - 1.0).abs() < 1e-12,
"P+Q at i={i} = {} (expected 1)",
pd[i] + qd[i]
);
}
}
#[test]
fn gammainc_subunit_concentration_matches_oracle() {
let a = t(&[0.5], &[1]);
let x = t(&[0.5], &[1]);
let r = gammainc(&a, &x).unwrap();
assert!(
(r.data().unwrap()[0] - 0.682_689_492_137_085_9).abs() < 1e-12,
"got {}",
r.data().unwrap()[0]
);
}
#[test]
#[allow(
clippy::float_cmp,
reason = "gammainc/gammaincc boundary values (a=0 -> 1.0, x=0 -> 0.0, x=inf -> 1.0/0.0) are EXACT mathematical limits torch returns, not floating approximations"
)]
fn gammainc_boundary_cases_match_torch() {
let gi = |av: f64, xv: f64| {
gammainc(&t(&[av], &[1]), &t(&[xv], &[1]))
.unwrap()
.data()
.unwrap()[0]
};
let gc = |av: f64, xv: f64| {
gammaincc(&t(&[av], &[1]), &t(&[xv], &[1]))
.unwrap()
.data()
.unwrap()[0]
};
assert_eq!(gi(0.0, 2.0), 1.0);
assert_eq!(gi(2.0, 0.0), 0.0);
assert!(gi(-1.0, 2.0).is_nan());
assert!(gi(2.0, -1.0).is_nan());
assert!(gi(0.0, 0.0).is_nan());
assert_eq!(gc(0.0, 2.0), 0.0);
assert_eq!(gc(2.0, 0.0), 1.0);
assert_eq!(gi(2.0, f64::INFINITY), 1.0);
assert_eq!(gc(2.0, f64::INFINITY), 0.0);
}
#[test]
fn gammainc_f32_matches_oracle_within_f32_tol() {
let a = Tensor::from_storage(TensorStorage::cpu(vec![2.0f32]), vec![1], false).unwrap();
let x = Tensor::from_storage(TensorStorage::cpu(vec![1.5f32]), vec![1], false).unwrap();
let r = gammainc(&a, &x).unwrap();
assert!(
(r.data().unwrap()[0] - 0.442_174_6).abs() < 1e-5,
"got {}",
r.data().unwrap()[0]
);
}
#[test]
fn log_beta_matches_scipy() {
let a = t(&[2.0], &[1]);
let b = t(&[3.0], &[1]);
let r = log_beta(&a, &b).unwrap();
assert!(
(r.data().unwrap()[0] - (-2.484_906_649_788_000_4)).abs() < 1e-12,
"got {}",
r.data().unwrap()[0]
);
}
#[test]
fn beta_matches_scipy() {
let a = t(&[2.0], &[1]);
let b = t(&[3.0], &[1]);
let r = beta(&a, &b).unwrap();
assert!(
(r.data().unwrap()[0] - 0.083_333_333_333_333_33).abs() < 1e-12,
"got {}",
r.data().unwrap()[0]
);
}
#[test]
fn log_beta_symmetric_and_broadcasts() {
let a = t(&[0.5, 3.0], &[2]);
let b = t(&[2.5, 2.0], &[2]);
let r = log_beta(&a, &b).unwrap();
let d = r.data().unwrap();
assert!(
(d[0] - 0.163_900_632_837_673_9).abs() < 1e-12,
"got {}",
d[0]
);
assert!((d[1] - (1.0f64 / 12.0).ln()).abs() < 1e-12, "got {}", d[1]);
}
#[test]
fn multigammaln_p2_matches_scipy() {
let a = t(&[3.0], &[1]);
let r = multigammaln(&a, 2).unwrap();
assert!(
(r.data().unwrap()[0] - 1.550_194_993_957_564_5).abs() < 1e-12,
"got {}",
r.data().unwrap()[0]
);
}
#[test]
fn multigammaln_p3_matches_scipy() {
let a = t(&[5.0], &[1]);
let r = multigammaln(&a, 3).unwrap();
assert!(
(r.data().unwrap()[0] - 9.140_644_699_192_542).abs() < 1e-11,
"got {}",
r.data().unwrap()[0]
);
}
#[test]
#[allow(
clippy::float_cmp,
reason = "mvlgamma is a literal alias of multigammaln (same code path); bit-exact equality asserts they are identical, not approximately equal"
)]
fn multigammaln_p1_is_lgamma() {
let a = t(&[2.5], &[1]);
let r = multigammaln(&a, 1).unwrap();
assert!(
(r.data().unwrap()[0] - 0.284_682_870_472_919_2).abs() < 1e-12,
"got {}",
r.data().unwrap()[0]
);
let r2 = mvlgamma(&a, 1).unwrap();
assert_eq!(r.data().unwrap()[0], r2.data().unwrap()[0]);
}
#[test]
fn multigammaln_out_of_domain_matches_torch() {
let finite = multigammaln(&t(&[0.3], &[1]), 3).unwrap();
assert!(
(finite.data().unwrap()[0] - 6.026_863_353_182_922).abs() < 1e-12,
"multigammaln(0.3, 3): torch returns finite 6.026863353182922, got {}",
finite.data().unwrap()[0]
);
let pole = multigammaln(&t(&[0.5], &[1]), 3).unwrap();
assert!(
pole.data().unwrap()[0] == f64::INFINITY,
"multigammaln(0.5, 3): torch returns +inf (lgamma pole), got {}",
pole.data().unwrap()[0]
);
}
#[test]
fn multigammaln_p_zero_errors() {
let a = t(&[3.0], &[1]);
assert!(multigammaln(&a, 0).is_err());
}
#[test]
#[allow(clippy::needless_range_loop)]
fn gammaln_sign_matches_scipy_gammasgn() {
let xs = [-2.5, -1.5, -0.5, 2.0, -6.5, 0.5, -3.7];
let expected = [-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0];
let input = t(&xs, &[xs.len()]);
let r = gammaln_sign(&input).unwrap();
let d = r.data().unwrap();
for i in 0..xs.len() {
assert!(
(d[i] - expected[i]).abs() < 1e-15,
"gammasgn({}) = {} (expected {})",
xs[i],
d[i],
expected[i]
);
}
}
#[test]
fn gammaln_sign_negative_integer_is_nan() {
let input = t(&[-2.0, -3.0, -10.0], &[3]);
let r = gammaln_sign(&input).unwrap();
for &v in r.data().unwrap() {
assert!(v.is_nan(), "expected NaN at pole, got {v}");
}
}
#[test]
#[allow(
clippy::float_cmp,
reason = "gammasgn returns the EXACT sign value +1.0 for positive/zero inputs; bit-exact equality is the contract, not an epsilon"
)]
fn gammaln_sign_positive_and_zero() {
let input = t(&[0.0, 1.0, 100.0], &[3]);
let r = gammaln_sign(&input).unwrap();
for &v in r.data().unwrap() {
assert_eq!(v, 1.0);
}
}
#[test]
fn gammaln_sign_recovers_gamma_via_lgamma() {
let input = t(&[-0.5], &[1]);
let sign = gammaln_sign(&input).unwrap().data().unwrap()[0];
let lg = lgamma(&input).unwrap().data().unwrap()[0];
let gamma = sign * lg.exp();
let expected = -2.0 * std::f64::consts::PI.sqrt();
assert!(
(gamma - expected).abs() < 1e-9,
"reconstructed Γ(-0.5) = {gamma} (expected {expected})"
);
}
#[test]
fn entr_known_values_vs_torch() {
let input = t(&[0.5, 2.0, 0.1], &[3]);
let r = entr(&input).unwrap();
let d = r.data().unwrap();
let want = [
0.346_573_590_279_972_64,
-1.386_294_361_119_890_6,
0.230_258_509_299_404_56,
];
for i in 0..3 {
assert!(
(d[i] - want[i]).abs() < 1e-12,
"entr idx {i}: got {} want {}",
d[i],
want[i]
);
}
}
#[test]
#[allow(
clippy::float_cmp,
reason = "entr(0) is the exact branch return 0.0 / entr(1) the exact -1*ln(1) = -0.0; bit-exact equality is the torch contract (Math.cuh:474-476)"
)]
fn entr_edges_vs_torch() {
let input = t(&[0.0, -1.0, f64::NAN, 1.0], &[4]);
let r = entr(&input).unwrap();
let d = r.data().unwrap();
assert_eq!(d[0], 0.0, "entr(0) == +0.0");
assert!(d[0].is_sign_positive(), "entr(0) sign is +0.0");
assert!(d[1].is_infinite() && d[1] < 0.0, "entr(-1) == -inf");
assert!(d[2].is_nan(), "entr(NaN) == NaN");
assert_eq!(d[3], 0.0, "entr(1) magnitude 0");
assert!(
d[3].is_sign_negative(),
"entr(1) sign is -0.0 (torch parity)"
);
}
#[test]
fn ndtr_known_values_vs_torch() {
let input = t(&[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], &[7]);
let r = ndtr(&input).unwrap();
let d = r.data().unwrap();
let want = [
0.001_349_898_031_630_103_5,
0.022_750_131_948_179_209,
0.158_655_253_931_457_02,
0.5,
0.841_344_746_068_543_04,
0.977_249_868_051_820_79,
0.998_650_101_968_369_9,
];
for i in 0..7 {
assert!(
(d[i] - want[i]).abs() < 1e-12,
"ndtr idx {i}: got {} want {}",
d[i],
want[i]
);
}
}
#[test]
fn ndtr_edges_vs_torch() {
let input = t(&[f64::NEG_INFINITY, f64::INFINITY, f64::NAN], &[3]);
let r = ndtr(&input).unwrap();
let d = r.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-15, "ndtr(-inf) == 0");
assert!((d[1] - 1.0).abs() < 1e-15, "ndtr(+inf) == 1");
assert!(d[2].is_nan(), "ndtr(NaN) == NaN");
}
#[test]
fn ndtri_known_values_vs_torch() {
let input = t(&[0.025, 0.25, 0.5, 0.75, 0.975], &[5]);
let r = ndtri(&input).unwrap();
let d = r.data().unwrap();
let want = [
-1.959_963_984_540_054_5,
-0.674_489_750_196_081_71,
0.0,
0.674_489_750_196_081_71,
1.959_963_984_540_054,
];
for i in 0..5 {
assert!(
(d[i] - want[i]).abs() < 1e-12,
"ndtri idx {i}: got {} want {}",
d[i],
want[i]
);
}
}
#[test]
fn ndtri_cephes_regions_vs_torch() {
let input = t(&[0.001, 1e-10, 0.9, 0.999], &[4]);
let r = ndtri(&input).unwrap();
let d = r.data().unwrap();
let want = [
-3.090_232_306_167_813_2,
-6.361_340_902_404_055_7,
1.281_551_565_544_600_4,
3.090_232_306_167_813_2,
];
for i in 0..4 {
assert!(
(d[i] - want[i]).abs() < 1e-11,
"ndtri region idx {i}: got {} want {}",
d[i],
want[i]
);
}
}
#[test]
fn ndtri_domain_edges_vs_torch() {
let input = t(&[0.0, 1.0, -0.1, 1.1], &[4]);
let r = ndtri(&input).unwrap();
let d = r.data().unwrap();
assert!(d[0].is_infinite() && d[0] < 0.0, "ndtri(0) == -inf");
assert!(d[1].is_infinite() && d[1] > 0.0, "ndtri(1) == +inf");
assert!(d[2].is_nan(), "ndtri(-0.1) == NaN");
assert!(d[3].is_nan(), "ndtri(1.1) == NaN");
}
#[test]
fn ndtr_ndtri_roundtrip() {
let ps = [0.05, 0.2, 0.5, 0.8, 0.95];
let input = t(&ps, &[5]);
let q = ndtri(&input).unwrap();
let back = ndtr(&q).unwrap();
let bd = back.data().unwrap();
for i in 0..5 {
assert!(
(bd[i] - ps[i]).abs() < 1e-12,
"ndtr(ndtri({})) = {} (expected {})",
ps[i],
bd[i],
ps[i]
);
}
}
#[test]
fn ndtri_f32_vs_torch() {
let input = Tensor::from_storage(
TensorStorage::cpu(vec![0.025f32, 0.25, 0.5, 0.75, 0.975]),
vec![5],
false,
)
.unwrap();
let r = ndtri(&input).unwrap();
let d = r.data().unwrap();
let want = [-1.959_963_8f32, -0.674_489_8, 0.0, 0.674_489_8, 1.959_964_4];
for i in 0..5 {
assert!(
(d[i] - want[i]).abs() < 1e-5,
"ndtri_f32 idx {i}: got {} want {}",
d[i],
want[i]
);
}
}
const I_GRID: [f64; 11] = [0.0, 0.5, 1.0, 2.0, 5.0, 8.0, 10.0, 20.0, -1.0, -2.0, -5.0];
#[test]
fn i0_known_values_vs_torch() {
let input = t(&I_GRID, &[11]);
let r = i0(&input).unwrap();
let d = r.data().unwrap();
let want = [
1.0,
1.063_483_370_741_323_6,
1.266_065_877_752_008_2,
2.279_585_302_336_067,
27.239_871_823_604_442,
427.564_115_721_804_74,
2815.716_628_466_254,
43_558_282.559_553_534,
1.266_065_877_752_008_2,
2.279_585_302_336_067,
27.239_871_823_604_442,
];
for i in 0..11 {
assert!(
(d[i] - want[i]).abs() <= 1e-9 * (1.0 + want[i].abs()),
"i0 idx {i} x={}: got {} want {}",
I_GRID[i],
d[i],
want[i]
);
}
}
#[test]
fn i0e_known_values_vs_torch() {
let input = t(&I_GRID, &[11]);
let r = i0e(&input).unwrap();
let d = r.data().unwrap();
let want = [
1.0,
0.645_035_270_449_150_1,
0.465_759_607_593_640_43,
0.308_508_322_553_671,
0.183_540_812_609_328_34,
0.143_431_781_856_850_3,
0.127_833_337_163_428_6,
0.089_780_311_884_826,
0.465_759_607_593_640_43,
0.308_508_322_553_671,
0.183_540_812_609_328_34,
];
for i in 0..11 {
assert!(
(d[i] - want[i]).abs() <= 1e-12 * (1.0 + want[i].abs()),
"i0e idx {i} x={}: got {} want {}",
I_GRID[i],
d[i],
want[i]
);
}
}
#[test]
fn i1_known_values_vs_torch() {
let input = t(&I_GRID, &[11]);
let r = i1(&input).unwrap();
let d = r.data().unwrap();
let want = [
0.0,
0.257_894_305_390_896_36,
0.565_159_103_992_485_1,
1.590_636_854_637_329_5,
24.335_642_142_450_524,
399.873_136_782_559_9,
2670.988_303_701_255,
42_454_973.385_127_775,
-0.565_159_103_992_485_1,
-1.590_636_854_637_329_5,
-24.335_642_142_450_524,
];
for i in 0..11 {
assert!(
(d[i] - want[i]).abs() <= 1e-9 * (1.0 + want[i].abs()),
"i1 idx {i} x={}: got {} want {}",
I_GRID[i],
d[i],
want[i]
);
}
}
#[test]
fn i1e_known_values_vs_torch() {
let input = t(&I_GRID, &[11]);
let r = i1e(&input).unwrap();
let d = r.data().unwrap();
let want = [
0.0,
0.156_420_803_184_871_73,
0.207_910_415_349_708_5,
0.215_269_289_248_937_7,
0.163_972_266_944_542_34,
0.134_142_493_292_698_12,
0.121_262_681_384_455_5,
0.087_506_222_183_288_67,
-0.207_910_415_349_708_5,
-0.215_269_289_248_937_7,
-0.163_972_266_944_542_34,
];
for i in 0..11 {
assert!(
(d[i] - want[i]).abs() <= 1e-12 * (1.0 + want[i].abs()),
"i1e idx {i} x={}: got {} want {}",
I_GRID[i],
d[i],
want[i]
);
}
}
#[test]
#[allow(
clippy::float_cmp,
reason = "i0(0)=1, i1(0)=0 are exact Cephes branch returns (chbevl at x=0 with the limit constants); torch returns the literal endpoint"
)]
fn i_family_edges_vs_torch() {
let input = t(&[0.0, f64::NAN, f64::INFINITY, f64::NEG_INFINITY], &[4]);
let r0 = i0(&input).unwrap();
let d0 = r0.data().unwrap();
assert_eq!(d0[0], 1.0, "i0(0) == 1");
assert!(d0[1].is_nan(), "i0(NaN) == NaN");
assert!(d0[2].is_nan(), "i0(+inf) == NaN (torch parity)");
assert!(d0[3].is_nan(), "i0(-inf) == NaN (torch parity, even)");
let r0e = i0e(&input).unwrap();
let d0e = r0e.data().unwrap();
assert_eq!(d0e[0], 1.0, "i0e(0) == 1");
assert!(d0e[1].is_nan(), "i0e(NaN) == NaN");
assert_eq!(d0e[2], 0.0, "i0e(+inf) == 0");
assert_eq!(d0e[3], 0.0, "i0e(-inf) == 0");
let r1 = i1(&input).unwrap();
let d1 = r1.data().unwrap();
assert_eq!(d1[0], 0.0, "i1(0) == 0");
assert!(d1[1].is_nan(), "i1(NaN) == NaN");
assert!(d1[2].is_nan(), "i1(+inf) == NaN (torch parity)");
assert!(d1[3].is_nan(), "i1(-inf) == NaN (torch parity, odd)");
let r1e = i1e(&input).unwrap();
let d1e = r1e.data().unwrap();
assert_eq!(d1e[0], 0.0, "i1e(0) == 0");
assert!(d1e[1].is_nan(), "i1e(NaN) == NaN");
assert_eq!(d1e[2], 0.0, "i1e(+inf) == 0");
assert_eq!(d1e[3], 0.0, "i1e(-inf) == 0");
}
#[test]
fn i_family_boundary_at_8_vs_torch() {
let input = t(&[8.0, 8.5, 12.0], &[3]);
let r0 = i0(&input).unwrap();
let d0 = r0.data().unwrap();
let w0 = [
427.564_115_721_804_74,
683.161_926_990_115_5,
18948.925_349_296_31,
];
let r1 = i1(&input).unwrap();
let d1 = r1.data().unwrap();
let w1 = [
399.873_136_782_559_9,
641.619_902_540_066_7,
18141.348_781_638_833,
];
for i in 0..3 {
assert!(
(d0[i] - w0[i]).abs() <= 1e-9 * (1.0 + w0[i].abs()),
"i0 boundary idx {i}: got {} want {}",
d0[i],
w0[i]
);
assert!(
(d1[i] - w1[i]).abs() <= 1e-9 * (1.0 + w1[i].abs()),
"i1 boundary idx {i}: got {} want {}",
d1[i],
w1[i]
);
}
}
#[test]
fn i_family_large_x_scaled_finite_vs_torch() {
let input = t(&[700.0], &[1]);
let r0e = i0e(&input).unwrap();
let d0e = r0e.data().unwrap();
let r1e = i1e(&input).unwrap();
let d1e = r1e.data().unwrap();
assert!(
d0e[0].is_finite() && (d0e[0] - 0.015_081_295_651_531_355).abs() <= 1e-12,
"i0e(700) finite & matches torch: got {}",
d0e[0]
);
assert!(
d1e[0].is_finite() && (d1e[0] - 0.015_070_519_444_716_846).abs() <= 1e-12,
"i1e(700) finite & matches torch: got {}",
d1e[0]
);
let r0 = i0(&input).unwrap();
let d0 = r0.data().unwrap();
assert!(d0[0] > 1e300, "i0(700) is huge (>1e300): got {}", d0[0]);
}
#[test]
fn i_family_f32_vs_torch() {
let xs = vec![-1.5f32, -0.7, 0.0, 0.3, 2.0, 5.0, 9.0];
let input =
Tensor::from_storage(TensorStorage::cpu(xs.clone()), vec![xs.len()], false).unwrap();
let cases: [(
&str,
fn(&Tensor<f32>) -> FerrotorchResult<Tensor<f32>>,
[f32; 7],
); 4] = [
(
"i0",
i0,
[
1.646_723_3,
1.126_303_1,
1.0,
1.022_626_9,
2.279_585_1,
27.239_874,
1_093.588_4,
],
),
(
"i0e",
i0e,
[
0.367_433_64,
0.559_305_55,
1.0,
0.757_580_6,
0.308_508_3,
0.183_540_82,
0.134_959_53,
],
),
(
"i1",
i1,
[
-0.981_666_45,
-0.371_879_67,
0.0,
0.151_693_87,
1.590_636_8,
24.335_642,
1_030.914_8,
],
),
(
"i1e",
i1e,
[
-0.219_039_41,
-0.184_669_99,
0.0,
0.112_377_57,
0.215_269_28,
0.163_972_26,
0.127_225,
],
),
];
for (name, f, want) in cases {
let r = f(&input).unwrap();
let d = r.data().unwrap();
for i in 0..7 {
assert!(
(d[i] - want[i]).abs() <= 1e-4 * (1.0 + want[i].abs()),
"{name} f32 idx {i} x={}: got {} want {}",
xs[i],
d[i],
want[i]
);
}
}
}
const SBJ0_GRID: [f64; 11] = [
0.0,
0.25,
0.49,
0.5,
1.0,
2.0,
3.141_592_653_589_79,
5.0,
10.0,
-1.0,
-3.0,
];
#[test]
fn spherical_bessel_j0_known_values_vs_torch() {
let input = t(&SBJ0_GRID, &[11]);
let r = spherical_bessel_j0(&input).unwrap();
let d = r.data().unwrap();
let want = [
1.0,
0.989_615_837_018_091_7,
0.960_460_996_267_669_5,
0.958_851_077_208_406,
0.841_470_984_807_896_5,
0.454_648_713_412_840_85,
1.028_487_619_224_955_5e-15,
-0.191_784_854_932_627_7,
-0.054_402_111_088_936_98,
0.841_470_984_807_896_5,
0.047_040_002_686_622_4,
];
for i in 0..11 {
assert!(
(d[i] - want[i]).abs() <= 1e-12 * (1.0 + want[i].abs()),
"spherical_bessel_j0 idx {i} x={}: got {} want {}",
SBJ0_GRID[i],
d[i],
want[i]
);
}
}
#[test]
#[allow(
clippy::float_cmp,
reason = "j0(0)=1 is the exact Taylor branch return (x2=0); j0(+/-inf)=0 the explicit isinf branch — torch returns the literal endpoints"
)]
fn spherical_bessel_j0_edges_vs_torch() {
let input = t(&[0.0, f64::INFINITY, f64::NEG_INFINITY, f64::NAN], &[4]);
let r = spherical_bessel_j0(&input).unwrap();
let d = r.data().unwrap();
assert_eq!(d[0], 1.0, "j0(0) == 1 (Taylor branch)");
assert_eq!(d[1], 0.0, "j0(+inf) == 0");
assert_eq!(d[2], 0.0, "j0(-inf) == 0");
assert!(d[3].is_nan(), "j0(NaN) == NaN");
}
const K_GRID: [f64; 9] = [0.1, 0.5, 1.0, 2.0, 2.0001, 3.0, 5.0, 10.0, 50.0];
#[test]
fn modified_bessel_k0_known_values_vs_torch() {
let input = t(&K_GRID, &[9]);
let r = modified_bessel_k0(&input).unwrap();
let d = r.data().unwrap();
let want = [
2.427_069_024_702_017,
0.924_419_071_227_666,
0.421_024_438_240_708_2,
0.113_893_872_749_533_4,
0.113_879_887_080_441_4,
0.034_739_504_386_279_25,
0.003_691_098_334_042_594_2,
1.778_006_231_616_765e-5,
3.410_167_749_789_495e-23,
];
for i in 0..9 {
assert!(
(d[i] - want[i]).abs() <= 1e-12 * (1.0 + want[i].abs()),
"k0 idx {i} x={}: got {} want {}",
K_GRID[i],
d[i],
want[i]
);
}
}
#[test]
fn scaled_modified_bessel_k0_known_values_vs_torch() {
let input = t(&K_GRID, &[9]);
let r = scaled_modified_bessel_k0(&input).unwrap();
let d = r.data().unwrap();
let want = [
2.682_326_102_262_895,
1.524_109_385_773_909_9,
1.144_463_079_806_894_4,
0.841_568_215_070_771_2,
0.841_549_024_872_151_7,
0.697_761_598_043_851_7,
0.547_807_564_313_519,
0.391_631_934_436_598_66,
0.176_807_155_857_429_32,
];
for i in 0..9 {
assert!(
(d[i] - want[i]).abs() <= 1e-12 * (1.0 + want[i].abs()),
"scaled_k0 idx {i} x={}: got {} want {}",
K_GRID[i],
d[i],
want[i]
);
}
}
#[test]
fn modified_bessel_k1_known_values_vs_torch() {
let input = t(&K_GRID, &[9]);
let r = modified_bessel_k1(&input).unwrap();
let d = r.data().unwrap();
let want = [
9.853_844_780_870_606,
1.656_441_120_003_300_7,
0.601_907_230_197_234_6,
0.139_865_881_816_522_46,
0.139_847_500_468_811_42,
0.040_156_431_128_194_19,
0.004_044_613_445_452_163,
1.864_877_345_382_558_5e-5,
3.444_102_226_717_555_5e-23,
];
for i in 0..9 {
assert!(
(d[i] - want[i]).abs() <= 1e-12 * (1.0 + want[i].abs()),
"k1 idx {i} x={}: got {} want {}",
K_GRID[i],
d[i],
want[i]
);
}
}
#[test]
fn scaled_modified_bessel_k1_known_values_vs_torch() {
let input = t(&K_GRID, &[9]);
let r = scaled_modified_bessel_k1(&input).unwrap();
let d = r.data().unwrap();
let want = [
10.890_182_683_049_698,
2.731_009_708_211_785_5,
1.636_153_486_263_258,
1.033_476_847_068_688_8,
1.033_444_365_528_781_5,
0.806_563_480_128_787,
0.600_273_858_788_312_5,
0.410_766_570_595_788_7,
0.178_566_558_558_815_56,
];
for i in 0..9 {
assert!(
(d[i] - want[i]).abs() <= 1e-12 * (1.0 + want[i].abs()),
"scaled_k1 idx {i} x={}: got {} want {}",
K_GRID[i],
d[i],
want[i]
);
}
}
#[test]
fn k_family_domain_edges_vs_torch() {
let input = t(&[0.0, -1.0, f64::NAN, 700.0], &[4]);
let fns: [(&str, fn(&Tensor<f64>) -> FerrotorchResult<Tensor<f64>>, f64); 4] = [
("k0", modified_bessel_k0, 0.047_362_369_454_613_57),
(
"scaled_k0",
scaled_modified_bessel_k0,
0.047_362_369_454_613_57,
),
("k1", modified_bessel_k1, 0.047_396_187_653_494_55),
(
"scaled_k1",
scaled_modified_bessel_k1,
0.047_396_187_653_494_55,
),
];
for (name, f, scaled_at_700) in fns {
let r = f(&input).unwrap();
let d = r.data().unwrap();
assert!(
d[0].is_infinite() && d[0] > 0.0,
"{name}(0) == +inf: got {}",
d[0]
);
assert!(d[1].is_nan(), "{name}(-1) == NaN: got {}", d[1]);
assert!(d[2].is_nan(), "{name}(NaN) == NaN: got {}", d[2]);
if name.starts_with("scaled") {
assert!(
(d[3] - scaled_at_700).abs() <= 1e-12 * (1.0 + scaled_at_700.abs()),
"{name}(700) ~ sqrt(pi/2x): got {} want {}",
d[3],
scaled_at_700
);
} else {
assert!(
d[3].is_finite() && d[3] >= 0.0 && d[3] < 1e-300,
"{name}(700) underflows finite-nonneg: got {}",
d[3]
);
}
}
}
#[test]
fn spherical_and_k_family_f32_vs_torch() {
let xs = vec![0.0f32, 0.25, 0.5, 1.0, 2.0, 5.0, -3.0];
let input =
Tensor::from_storage(TensorStorage::cpu(xs.clone()), vec![xs.len()], false).unwrap();
let r = spherical_bessel_j0(&input).unwrap();
let d = r.data().unwrap();
let want = [
1.0f32,
0.989_615_86,
0.958_851_1,
0.841_470_96,
0.454_648_7,
-0.191_784_86,
0.047_04,
];
for i in 0..7 {
assert!(
(d[i] - want[i]).abs() <= 1e-4 * (1.0 + want[i].abs()),
"spherical_bessel_j0 f32 idx {i} x={}: got {} want {}",
xs[i],
d[i],
want[i]
);
}
let kx = vec![1.0f32, 3.0];
let kin =
Tensor::from_storage(TensorStorage::cpu(kx.clone()), vec![kx.len()], false).unwrap();
let kcases: [(
&str,
fn(&Tensor<f32>) -> FerrotorchResult<Tensor<f32>>,
[f32; 2],
); 4] = [
("k0", modified_bessel_k0, [0.421_024_44, 0.034_739_504]),
(
"scaled_k0",
scaled_modified_bessel_k0,
[1.144_463_1, 0.697_761_6],
),
("k1", modified_bessel_k1, [0.601_907_23, 0.040_156_43]),
(
"scaled_k1",
scaled_modified_bessel_k1,
[1.636_153_5, 0.806_563_5],
),
];
for (name, f, want) in kcases {
let r = f(&kin).unwrap();
let d = r.data().unwrap();
for i in 0..2 {
assert!(
(d[i] - want[i]).abs() <= 1e-4 * (1.0 + want[i].abs()),
"{name} f32 idx {i} x={}: got {} want {}",
kx[i],
d[i],
want[i]
);
}
}
}
#[test]
fn zeta_known_values_vs_torch() {
let xs = [2.0, 2.0, 3.0, 4.0, 1.0001, 1.5, 10.0, 2.5, 5.0];
let qs = [1.0, 2.0, 1.0, 0.5, 1.0, 2.0, 0.25, 3.0, 1.0];
let x = t(&xs, &[9]);
let q = t(&qs, &[9]);
let r = zeta(&x, &q).unwrap();
let d = r.data().unwrap();
let want = [
1.6449340668482266,
0.6449340668482266,
1.202056903159594,
16.23484850566707,
10000.57722294754,
1.6123753486854886,
1048576.107683115,
0.1647105619542803,
1.0369277551433704,
];
for i in 0..9 {
assert!(
(d[i] - want[i]).abs() <= 1e-10 * (1.0 + want[i].abs()),
"zeta idx {i} x={} q={}: got {} want {}",
xs[i],
qs[i],
d[i],
want[i]
);
}
}
#[test]
fn zeta_2_1_is_pi_squared_over_six() {
let r = zeta(&t(&[2.0], &[1]), &t(&[1.0], &[1])).unwrap();
let got = r.data().unwrap()[0];
let want = std::f64::consts::PI * std::f64::consts::PI / 6.0;
assert!(
(got - want).abs() <= 1e-12 * (1.0 + want.abs()),
"zeta(2,1) got {got} want pi^2/6 = {want}"
);
}
#[test]
fn zeta_edge_ladder_vs_torch() {
let xs = [1.0, 0.5, 2.0, 2.0, 3.0, 2.5];
let qs = [2.0, 1.0, 0.0, -1.0, -2.0, -1.5];
let r = zeta(&t(&xs, &[6]), &t(&qs, &[6])).unwrap();
let d = r.data().unwrap();
assert!(
d[0].is_infinite() && d[0] > 0.0,
"zeta(1,q) == +inf: {}",
d[0]
);
assert!(d[1].is_nan(), "zeta(0.5,q) == NaN: {}", d[1]);
assert!(
d[2].is_infinite() && d[2] > 0.0,
"zeta(2, q=0 integer) == +inf: {}",
d[2]
);
assert!(
d[3].is_infinite() && d[3] > 0.0,
"zeta(2, q=-1 integer) == +inf: {}",
d[3]
);
assert!(
d[4].is_infinite() && d[4] > 0.0,
"zeta(3, q=-2 integer) == +inf: {}",
d[4]
);
assert!(
d[5].is_nan(),
"zeta(2.5, q=-1.5 non-integer) == NaN: {}",
d[5]
);
}
#[test]
fn zeta_f32_vs_torch() {
let xs = vec![2.0f32, 3.0, 1.5, 4.0];
let qs = vec![1.0f32, 2.0, 2.0, 0.5];
let x = Tensor::from_storage(TensorStorage::cpu(xs.clone()), vec![4], false).unwrap();
let q = Tensor::from_storage(TensorStorage::cpu(qs.clone()), vec![4], false).unwrap();
let r = zeta(&x, &q).unwrap();
let d = r.data().unwrap();
let want = [1.6449341f32, 0.20205691, 1.6123753, 16.234848];
for i in 0..4 {
assert!(
(d[i] - want[i]).abs() <= 1e-4 * (1.0 + want[i].abs()),
"zeta f32 idx {i} x={} q={}: got {} want {}",
xs[i],
qs[i],
d[i],
want[i]
);
}
}
#[test]
fn zeta_cuda_not_implemented() {
let r = zeta(&t(&[2.0], &[1]), &t(&[1.0], &[1])).unwrap();
assert!(r.data().unwrap()[0].is_finite());
}
#[test]
fn airy_ai_known_values_vs_torch() {
let xs = [
-5.0, -2.5, -2.09, -2.0, -1.0, 0.0, 1.0, 2.0, 2.09, 5.0, 8.0, 10.0, 100.0,
];
let r = airy_ai(&t(&xs, &[13])).unwrap();
let d = r.data().unwrap();
let want = [
0.35076100902415286,
-0.11232483666261353,
0.17005055173203007,
0.22740742820168564,
0.5355608832923521,
0.3550280538878172,
0.13529241631288144,
0.03492413042327433,
0.03042031836319837,
0.00010834442813607433,
4.692207616099224e-08,
1.1047532552898654e-10,
2.6344821520882847e-291,
];
for i in 0..13 {
assert!(
(d[i] - want[i]).abs() <= 1e-10 * (1.0 + want[i].abs()),
"airy_ai idx {i} x={}: got {} want {}",
xs[i],
d[i],
want[i]
);
}
}
#[test]
fn airy_ai_zero_vs_torch() {
let r = airy_ai(&t(&[0.0], &[1])).unwrap();
let got = r.data().unwrap()[0];
let want = 0.3550280538878172;
assert!(
(got - want).abs() <= 1e-12,
"airy_ai(0) got {got} want {want}"
);
}
#[test]
fn airy_ai_edges_vs_torch() {
let r = airy_ai(&t(
&[f64::INFINITY, f64::NEG_INFINITY, f64::NAN, 200.0],
&[4],
))
.unwrap();
let d = r.data().unwrap();
assert!(d[0].is_nan(), "airy_ai(+inf) == NaN: {}", d[0]);
assert!(d[1].is_nan(), "airy_ai(-inf) == NaN: {}", d[1]);
assert!(d[2].is_nan(), "airy_ai(NaN) == NaN: {}", d[2]);
assert_eq!(d[3], 0.0, "airy_ai(200) == 0 (x>103.892 branch)");
}
#[test]
fn airy_ai_f32_vs_torch() {
let xs = vec![-5.0f32, -2.0, -1.0, 0.0, 1.0, 2.0, 5.0];
let input = Tensor::from_storage(TensorStorage::cpu(xs.clone()), vec![7], false).unwrap();
let r = airy_ai(&input).unwrap();
let d = r.data().unwrap();
let want = [
0.35076096653938293f32,
0.22740741074085236,
0.5355609059333801,
0.35502806305885315,
0.13529238104820251,
0.03492411598563194,
0.00010834442946361378,
];
for i in 0..7 {
assert!(
(d[i] - want[i]).abs() <= 1e-4 * (1.0 + want[i].abs()),
"airy_ai f32 idx {i} x={}: got {} want {}",
xs[i],
d[i],
want[i]
);
}
}
}