use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::ops::elementwise::{binary_map, unary_map};
use crate::tensor::Tensor;
#[inline(always)]
fn nt_zero<T: num_traits::Zero>() -> T {
<T as num_traits::Zero>::zero()
}
#[inline(always)]
fn nt_one<T: num_traits::One>() -> T {
<T as num_traits::One>::one()
}
const ERF_A1: f64 = 0.254829592;
const ERF_A2: f64 = -0.284496736;
const ERF_A3: f64 = 1.421413741;
const ERF_A4: f64 = -1.453152027;
const ERF_A5: f64 = 1.061405429;
const ERF_P: f64 = 0.3275911;
const LANCZOS_G: f64 = 7.0;
const LANCZOS_COEFFICIENTS: [f64; 9] = [
0.999_999_999_999_809_9,
676.5203681218851,
-1259.1392167224028,
771.323_428_777_653_1,
-176.615_029_162_140_6,
12.507343278686905,
-0.13857109526572012,
9.984_369_578_019_572e-6,
1.5056327351493116e-7,
];
fn erf_scalar<T: Float>(x: T) -> T {
let zero = nt_zero::<T>();
let one = nt_one::<T>();
if x == zero {
return zero;
}
let sign = if x < zero { -one } else { one };
let ax = x.abs();
let p = T::from(ERF_P).unwrap();
let t = one / (one + p * ax);
let a1 = T::from(ERF_A1).unwrap();
let a2 = T::from(ERF_A2).unwrap();
let a3 = T::from(ERF_A3).unwrap();
let a4 = T::from(ERF_A4).unwrap();
let a5 = T::from(ERF_A5).unwrap();
let poly = a1 + t * (a2 + t * (a3 + t * (a4 + t * a5)));
sign * (one - poly * t * (-ax * ax).exp())
}
fn erfinv_scalar<T: Float>(x: T) -> T {
let zero = nt_zero::<T>();
let one = nt_one::<T>();
if x == zero {
return zero;
}
if x >= one {
return T::infinity();
}
if x <= -one {
return T::neg_infinity();
}
let sign = if x < zero { -one } else { one };
let ax = x.abs();
let a = T::from(0.147).unwrap();
let two = T::from(2.0).unwrap();
let pi = T::from(std::f64::consts::PI).unwrap();
let ln_term = (one - ax * ax).ln();
let b = two / (pi * a) + ln_term / two;
let c = ln_term / a;
sign * (-b + (b * b - c).sqrt()).sqrt()
}
fn lgamma_scalar<T: Float>(x: T) -> T {
let one = nt_one::<T>();
let half = T::from(0.5).unwrap();
let half_ln_2pi = T::from(0.9189385332046727).unwrap(); let g = T::from(LANCZOS_G).unwrap();
if x < half {
let pi = T::from(std::f64::consts::PI).unwrap();
let sin_pi_x = (pi * x).sin();
if sin_pi_x == nt_zero::<T>() {
return T::infinity();
}
return (pi / sin_pi_x.abs()).ln() - lgamma_scalar(one - x);
}
let z = x - one;
let mut sum = T::from(LANCZOS_COEFFICIENTS[0]).unwrap();
for (i, &coeff) in LANCZOS_COEFFICIENTS.iter().enumerate().skip(1) {
sum += T::from(coeff).unwrap() / (z + T::from(i as f64).unwrap());
}
let t = z + g + half;
half_ln_2pi + (t).ln() * (z + half) - t + sum.ln()
}
fn digamma_scalar<T: Float>(x: T) -> T {
let zero = nt_zero::<T>();
let one = nt_one::<T>();
let half = T::from(0.5).unwrap();
if x < half {
let pi = T::from(std::f64::consts::PI).unwrap();
let cot = (pi * x).cos() / (pi * x).sin();
return digamma_scalar(one - x) - pi * cot;
}
let mut result = zero;
let mut z = x;
let six = T::from(6.0).unwrap();
while z < six {
#[allow(clippy::assign_op_pattern)]
{
result = result - one / z;
}
#[allow(clippy::assign_op_pattern)]
{
z = z + one;
}
}
let z2 = z * z;
let z4 = z2 * z2;
let z6 = z4 * z2;
result =
result + z.ln() - one / (T::from(2.0).unwrap() * z) - one / (T::from(12.0).unwrap() * z2)
+ one / (T::from(120.0).unwrap() * z4)
- one / (T::from(252.0).unwrap() * z6);
result
}
fn sinc_scalar<T: Float>(x: T) -> T {
let zero = nt_zero::<T>();
let one = nt_one::<T>();
if x == zero {
return one;
}
let pi = T::from(std::f64::consts::PI).unwrap();
let pi_x = pi * x;
pi_x.sin() / pi_x
}
fn xlogy_scalar<T: Float>(x: T, y: T) -> T {
if x == nt_zero::<T>() {
nt_zero::<T>()
} else {
x * y.ln()
}
}
pub fn erf<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, erf_scalar)
}
pub fn erfc<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, |x| nt_one::<T>() - erf_scalar(x))
}
pub fn erfinv<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, erfinv_scalar)
}
pub fn lgamma<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, lgamma_scalar)
}
pub fn digamma<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, digamma_scalar)
}
pub fn log1p<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, |x| x.ln_1p())
}
pub fn expm1<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, |x| x.exp_m1())
}
pub fn sinc<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unary_map(input, sinc_scalar)
}
pub fn xlogy<T: Float>(x: &Tensor<T>, y: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
binary_map(x, y, xlogy_scalar)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn t(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
#[test]
fn erf_zero() {
let input = t(&[0.0], &[1]);
let result = erf(&input).unwrap();
assert!((result.data().unwrap()[0]).abs() < 1e-10);
}
#[test]
fn erf_symmetry() {
let input = t(&[0.5, 1.0, 2.0], &[3]);
let neg_input = t(&[-0.5, -1.0, -2.0], &[3]);
let pos = erf(&input).unwrap();
let neg = erf(&neg_input).unwrap();
let pd = pos.data().unwrap();
let nd = neg.data().unwrap();
for i in 0..3 {
assert!(
(pd[i] + nd[i]).abs() < 1e-6,
"erf({}) + erf({}) = {} (expected 0)",
input.data().unwrap()[i],
neg_input.data().unwrap()[i],
pd[i] + nd[i],
);
}
}
#[test]
fn erf_large_value() {
let input = t(&[f64::INFINITY], &[1]);
let result = erf(&input).unwrap();
assert!((result.data().unwrap()[0] - 1.0).abs() < 1e-6);
}
#[test]
fn erf_known_values() {
let input = t(&[1.0], &[1]);
let result = erf(&input).unwrap();
assert!(
(result.data().unwrap()[0] - 0.8427007929).abs() < 2e-7,
"erf(1) = {}",
result.data().unwrap()[0]
);
}
#[test]
fn erfc_is_one_minus_erf() {
let input = t(&[0.0, 0.5, 1.0, -0.5, 2.0], &[5]);
let erf_result = erf(&input).unwrap();
let erfc_result = erfc(&input).unwrap();
let ed = erf_result.data().unwrap();
let cd = erfc_result.data().unwrap();
for i in 0..5 {
assert!(
(ed[i] + cd[i] - 1.0).abs() < 1e-10,
"erf({0}) + erfc({0}) = {1} (expected 1.0)",
input.data().unwrap()[i],
ed[i] + cd[i],
);
}
}
#[test]
fn erfinv_zero() {
let input = t(&[0.0], &[1]);
let result = erfinv(&input).unwrap();
assert!(result.data().unwrap()[0].abs() < 1e-10);
}
#[test]
fn erfinv_roundtrip() {
let xs = t(&[0.1, 0.5, 1.0, -0.3, -1.5], &[5]);
let erf_xs = erf(&xs).unwrap();
let roundtrip = erfinv(&erf_xs).unwrap();
let orig = xs.data().unwrap();
let rt = roundtrip.data().unwrap();
for i in 0..5 {
assert!(
(orig[i] - rt[i]).abs() < 0.01,
"erfinv(erf({})) = {} (expected {})",
orig[i],
rt[i],
orig[i],
);
}
}
#[test]
fn erfinv_boundary() {
let input = t(&[1.0, -1.0], &[2]);
let result = erfinv(&input).unwrap();
let d = result.data().unwrap();
assert!(d[0].is_infinite() && d[0] > 0.0, "erfinv(1) should be +inf");
assert!(
d[1].is_infinite() && d[1] < 0.0,
"erfinv(-1) should be -inf"
);
}
#[test]
fn lgamma_at_one_and_two() {
let input = t(&[1.0, 2.0], &[2]);
let result = lgamma(&input).unwrap();
let d = result.data().unwrap();
assert!(d[0].abs() < 1e-10, "lgamma(1) = {} (expected 0)", d[0]);
assert!(d[1].abs() < 1e-10, "lgamma(2) = {} (expected 0)", d[1]);
}
#[test]
fn lgamma_known_values() {
let input = t(&[0.5], &[1]);
let result = lgamma(&input).unwrap();
let expected = 0.5723649429247001;
assert!(
(result.data().unwrap()[0] - expected).abs() < 1e-8,
"lgamma(0.5) = {} (expected {})",
result.data().unwrap()[0],
expected,
);
}
#[test]
fn lgamma_factorial() {
let input = t(&[6.0], &[1]);
let result = lgamma(&input).unwrap();
let expected = (120.0f64).ln();
assert!(
(result.data().unwrap()[0] - expected).abs() < 1e-8,
"lgamma(6) = {} (expected {})",
result.data().unwrap()[0],
expected,
);
}
#[test]
fn digamma_known_values() {
let input = t(&[1.0], &[1]);
let result = digamma(&input).unwrap();
let expected = -0.5772156649015329;
assert!(
(result.data().unwrap()[0] - expected).abs() < 1e-6,
"digamma(1) = {} (expected {})",
result.data().unwrap()[0],
expected,
);
}
#[test]
fn digamma_recurrence() {
let x_val = 2.5;
let input_x = t(&[x_val], &[1]);
let input_x1 = t(&[x_val + 1.0], &[1]);
let psi_x = digamma(&input_x).unwrap().data().unwrap()[0];
let psi_x1 = digamma(&input_x1).unwrap().data().unwrap()[0];
assert!(
(psi_x1 - psi_x - 1.0 / x_val).abs() < 1e-8,
"psi({}) - psi({}) = {} (expected {})",
x_val + 1.0,
x_val,
psi_x1 - psi_x,
1.0 / x_val,
);
}
#[test]
fn log1p_zero() {
let input = t(&[0.0], &[1]);
let result = log1p(&input).unwrap();
assert!(result.data().unwrap()[0].abs() < 1e-15);
}
#[test]
fn log1p_small() {
let small = 1e-10;
let input = t(&[small], &[1]);
let result = log1p(&input).unwrap();
assert!(
(result.data().unwrap()[0] - small).abs() < 1e-15,
"log1p({small}) = {} (expected ~{small})",
result.data().unwrap()[0],
);
}
#[test]
fn log1p_known() {
let input = t(&[1.0], &[1]);
let result = log1p(&input).unwrap();
assert!((result.data().unwrap()[0] - std::f64::consts::LN_2).abs() < 1e-15,);
}
#[test]
fn expm1_zero() {
let input = t(&[0.0], &[1]);
let result = expm1(&input).unwrap();
assert!(result.data().unwrap()[0].abs() < 1e-15);
}
#[test]
fn expm1_small() {
let small = 1e-10;
let input = t(&[small], &[1]);
let result = expm1(&input).unwrap();
assert!(
(result.data().unwrap()[0] - small).abs() < 1e-15,
"expm1({small}) = {} (expected ~{small})",
result.data().unwrap()[0],
);
}
#[test]
fn expm1_known() {
let input = t(&[1.0], &[1]);
let result = expm1(&input).unwrap();
let expected = std::f64::consts::E - 1.0;
assert!((result.data().unwrap()[0] - expected).abs() < 1e-14,);
}
#[test]
fn sinc_zero() {
let input = t(&[0.0], &[1]);
let result = sinc(&input).unwrap();
assert!(
(result.data().unwrap()[0] - 1.0).abs() < 1e-15,
"sinc(0) = {} (expected 1)",
result.data().unwrap()[0],
);
}
#[test]
fn sinc_integer() {
let input = t(&[1.0, 2.0, -1.0, -3.0], &[4]);
let result = sinc(&input).unwrap();
let d = result.data().unwrap();
for i in 0..4 {
assert!(
d[i].abs() < 1e-15,
"sinc({}) = {} (expected 0)",
input.data().unwrap()[i],
d[i],
);
}
}
#[test]
fn sinc_half() {
let input = t(&[0.5], &[1]);
let result = sinc(&input).unwrap();
let expected = 2.0 / std::f64::consts::PI;
assert!(
(result.data().unwrap()[0] - expected).abs() < 1e-15,
"sinc(0.5) = {} (expected {})",
result.data().unwrap()[0],
expected,
);
}
#[test]
fn xlogy_zero_x() {
let x = t(&[0.0, 0.0, 0.0], &[3]);
let y = t(&[1.0, 0.0, f64::INFINITY], &[3]);
let result = xlogy(&x, &y).unwrap();
let d = result.data().unwrap();
for i in 0..3 {
assert!(
d[i] == 0.0,
"xlogy(0, {}) = {} (expected 0)",
y.data().unwrap()[i],
d[i],
);
}
}
#[test]
fn xlogy_normal() {
let x = t(&[2.0], &[1]);
let y = t(&[std::f64::consts::E], &[1]);
let result = xlogy(&x, &y).unwrap();
assert!(
(result.data().unwrap()[0] - 2.0).abs() < 1e-14,
"xlogy(2, e) = {} (expected 2)",
result.data().unwrap()[0],
);
}
#[test]
fn xlogy_broadcast() {
let x = t(&[2.0, 3.0], &[2]);
let y = t(&[std::f64::consts::E, std::f64::consts::E], &[2]);
let result = xlogy(&x, &y).unwrap();
let d = result.data().unwrap();
assert!((d[0] - 2.0).abs() < 1e-14);
assert!((d[1] - 3.0).abs() < 1e-14);
}
#[test]
fn erf_f32() {
let input =
Tensor::from_storage(TensorStorage::cpu(vec![0.0f32, 1.0, -1.0]), vec![3], false)
.unwrap();
let result = erf(&input).unwrap();
let d = result.data().unwrap();
assert!(d[0].abs() < 1e-6);
assert!((d[1] - 0.8427008).abs() < 1e-5);
assert!((d[2] + 0.8427008).abs() < 1e-5);
}
#[test]
fn erf_2d() {
let input = t(&[0.0, 0.5, 1.0, -0.5, -1.0, 2.0], &[2, 3]);
let result = erf(&input).unwrap();
assert_eq!(result.shape(), &[2, 3]);
let d = result.data().unwrap();
assert!(d[0].abs() < 1e-10); assert!(d[2] > 0.8); assert!(d[3] < 0.0); }
}