use torsh_core::Result as TorshResult;
use torsh_tensor::{creation::full_like, Tensor};
pub use torsh_special::{
acosh,
asinh,
atanh,
bessel_i0_scirs2 as bessel_i0,
bessel_i1_scirs2 as bessel_i1,
bessel_j0_scirs2 as bessel_j0,
bessel_j1_scirs2 as bessel_j1,
bessel_jn_scirs2 as bessel_jn,
bessel_k0_scirs2 as bessel_k0,
bessel_k1_scirs2 as bessel_k1,
bessel_y0_scirs2 as bessel_y0,
bessel_y1_scirs2 as bessel_y1,
bessel_yn_scirs2 as bessel_yn,
beta,
digamma,
erf,
erfc,
erfcx,
erfinv,
expm1,
fresnel,
fresnel_c,
fresnel_s,
gamma,
lgamma,
log1p,
polygamma,
sinc,
};
pub fn spherical_j0(input: &Tensor) -> TorshResult<Tensor> {
let sin_x = input.sin()?;
let result = sin_x.div(input)?;
let zeros_mask = input.eq_scalar(0.0)?;
let ones = input.ones_like()?;
let final_result = ones.where_tensor(&zeros_mask, &result)?;
Ok(final_result)
}
pub fn spherical_j1(input: &Tensor) -> TorshResult<Tensor> {
let sin_x = input.sin()?;
let cos_x = input.cos()?;
let x_squared = input.mul_op(input)?;
let term1 = sin_x.div(&x_squared)?;
let term2 = cos_x.div(input)?;
let result = term1.sub(&term2)?;
let zeros_mask = input.eq_scalar(0.0)?;
let zeros = input.zeros_like()?;
let final_result = zeros.where_tensor(&zeros_mask, &result)?;
Ok(final_result)
}
pub fn spherical_y0(input: &Tensor) -> TorshResult<Tensor> {
let cos_x = input.cos()?;
let result = cos_x.neg()?.div(input)?;
Ok(result)
}
pub fn spherical_y1(input: &Tensor) -> TorshResult<Tensor> {
let sin_x = input.sin()?;
let cos_x = input.cos()?;
let x_squared = input.mul_op(input)?;
let term1 = cos_x.neg()?.div(&x_squared)?;
let term2 = sin_x.div(input)?;
let result = term1.sub(&term2)?;
Ok(result)
}
pub fn spherical_jn(n: i32, input: &Tensor) -> TorshResult<Tensor> {
match n {
0 => spherical_j0(input),
1 => spherical_j1(input),
_ if n > 1 => {
let mut j_prev = spherical_j0(input)?;
let mut j_curr = spherical_j1(input)?;
for k in 1..n {
let factor_scalar = (2 * k + 1) as f32;
let factor_tensor = full_like(input, factor_scalar)?.div(&input)?;
let j_next = factor_tensor.mul_op(&j_curr)?.sub(&j_prev)?;
j_prev = j_curr;
j_curr = j_next;
}
Ok(j_curr)
}
_ => {
let positive_result = spherical_jn(-n, input)?;
if n % 2 == 0 {
Ok(positive_result)
} else {
Ok(positive_result.neg()?)
}
}
}
}
pub fn spherical_yn(n: i32, input: &Tensor) -> TorshResult<Tensor> {
match n {
0 => spherical_y0(input),
1 => spherical_y1(input),
_ if n > 1 => {
let mut y_prev = spherical_y0(input)?;
let mut y_curr = spherical_y1(input)?;
for k in 1..n {
let factor_scalar = (2 * k + 1) as f32;
let factor_tensor = full_like(input, factor_scalar)?.div(&input)?;
let y_next = factor_tensor.mul_op(&y_curr)?.sub(&y_prev)?;
y_prev = y_curr;
y_curr = y_next;
}
Ok(y_curr)
}
_ => {
let positive_result = spherical_yn(-n, input)?;
if n % 2 == 0 {
Ok(positive_result.neg()?)
} else {
Ok(positive_result)
}
}
}
}
pub fn logsumexp(input: &Tensor, dim: Option<i32>, keepdim: bool) -> TorshResult<Tensor> {
let max_vals = if let Some(d) = dim {
input.max_dim(d, true)?
} else {
input.max(None, false)?
};
let shifted = input.sub(&max_vals)?;
let exp_shifted = shifted.exp()?;
let sum_exp = if let Some(d) = dim {
exp_shifted.sum_dim(&[d], keepdim)?
} else {
exp_shifted.sum()?
};
let log_sum = sum_exp.log()?;
if keepdim || dim.is_none() {
max_vals.add_op(&log_sum)
} else {
let max_squeezed = max_vals
.squeeze(dim.expect("dim should be Some in else branch of dim.is_none() check"))?;
max_squeezed.add_op(&log_sum)
}
}
pub fn multigammaln(input: &Tensor, p: i32) -> TorshResult<Tensor> {
use std::f32::consts::PI;
let p_f = p as f32;
let log_pi_term = (p_f * (p_f - 1.0) / 4.0) * PI.ln();
let mut result = input.mul_scalar(0.0)?; let log_pi_tensor = result.add_scalar(log_pi_term)?;
result = log_pi_tensor;
for j in 0..p {
let offset = (j as f32) / 2.0;
let adjusted_input = input.sub(&full_like(input, offset)?)?;
let lgamma_term = lgamma(&adjusted_input)?;
result = result.add_op(&lgamma_term)?;
}
Ok(result)
}
pub fn erfcinv(input: &Tensor) -> TorshResult<Tensor> {
let one_minus_input = full_like(input, 1.0)?.sub(&input)?;
erfinv(&one_minus_input)
}
pub fn normal_cdf(input: &Tensor) -> TorshResult<Tensor> {
let sqrt_two = (2.0f32).sqrt();
let normalized = input.div_scalar(sqrt_two)?;
let erf_result = erf(&normalized)?;
let one_plus_erf = erf_result.add_scalar(1.0)?;
one_plus_erf.div_scalar(2.0)
}
pub fn normal_icdf(input: &Tensor) -> TorshResult<Tensor> {
let sqrt_two = (2.0f32).sqrt();
let two_p_minus_one = input.mul_scalar(2.0)?.sub(&full_like(input, 1.0)?)?;
let erf_inv_result = erfinv(&two_p_minus_one)?;
erf_inv_result.mul_scalar(sqrt_two)
}
pub fn betainc(x: &Tensor, a: f32, b: f32) -> TorshResult<Tensor> {
let beta_ab = crate::special::beta(&full_like(&x, a)?, &full_like(&x, b)?)?;
let beta_x_a_b = crate::special::beta(x, &full_like(&x, a)?)?;
beta_x_a_b.div(&beta_ab)
}
pub fn bessel_iv(v: f32, x: &Tensor) -> TorshResult<Tensor> {
match v as i32 {
0 => bessel_i0(x),
1 => bessel_i1(x),
_ => {
let v_tensor = full_like(&x, v)?;
let gamma_v_plus_1 = gamma(&v_tensor.add_scalar(1.0)?)?;
let x_over_2 = x.div_scalar(2.0)?;
let x_over_2_pow_v = x_over_2.pow_tensor(&v_tensor)?;
x_over_2_pow_v.div(&gamma_v_plus_1)
}
}
}
pub fn hypergeometric_1f1(a: f32, b: f32, x: &Tensor) -> TorshResult<Tensor> {
let mut result = x.ones_like()?; let mut term = x.ones_like()?;
for n in 1..20 {
let n_f = n as f32;
let a_rising = a + n_f - 1.0;
let b_rising = b + n_f - 1.0;
let coeff = (a_rising / b_rising) / n_f;
term = term.mul_op(x)?.mul_scalar(coeff)?;
result = result.add_op(&term)?;
if coeff.abs() < 1e-10 {
break;
}
}
Ok(result)
}
pub fn expint(x: &Tensor) -> TorshResult<Tensor> {
let gamma_const = 0.5772156649015329f32; let ln_x = x.log()?;
let mut result = ln_x.add_scalar(gamma_const)?;
let mut term = x.clone();
for n in 1..50 {
let n_f = n as f32;
let factorial = (1..=n).map(|i| i as f32).product::<f32>();
let coeff = 1.0 / (n_f * factorial);
term = term.mul_op(x)?.mul_scalar(coeff)?;
result = result.add_op(&term)?;
if coeff < 1e-15 {
break;
}
}
Ok(result)
}
pub fn voigt_profile(x: &Tensor, sigma: f32, gamma: f32) -> TorshResult<Tensor> {
let sigma_sqrt_2 = sigma * (2.0f32).sqrt();
let z_real = x.div_scalar(sigma_sqrt_2)?;
let z_imag = gamma / sigma_sqrt_2;
let exp_neg_x2 = x.pow_scalar(2.0)?.neg()?.exp()?;
let _erfcx_approx = erfcx(&z_real)?;
let gaussian_part =
exp_neg_x2.mul_scalar(1.0 / (sigma * (2.0 * std::f32::consts::PI).sqrt()))?;
let lorentzian_factor = 1.0 / (1.0 + z_imag * z_imag);
gaussian_part.mul_scalar(lorentzian_factor)
}
pub fn airy_ai(x: &Tensor) -> TorshResult<Tensor> {
let c1 = 1.0 / (3.0f32.powf(2.0 / 3.0) * 1.354117939426400); let _c2 = 1.0 / (3.0f32.powf(1.0 / 3.0) * 2.678938534707747);
let x_cubed = x.pow_scalar(3.0)?;
let mut term = x.ones_like()?;
let mut f_series = term.clone();
for n in 1..20 {
let n_f = n as f32;
let factorial_term = (1..=(3 * n)).map(|i| i as f32).product::<f32>();
let coeff_scalar = 1.0 / factorial_term;
let coeff = x_cubed.pow_scalar(n_f)?;
term = term.mul_op(&coeff)?.mul_scalar(coeff_scalar)?;
f_series = f_series.add_op(&term)?;
}
f_series.mul_scalar(c1 as f32)
}
pub fn kelvin_ber(x: &Tensor) -> TorshResult<Tensor> {
let mut result = x.ones_like()?;
let x_pow_4 = x.pow_scalar(4.0)?;
let mut term = x.ones_like()?;
for k in 1..15 {
let k_f = k as f32;
let factorial_2k = (1..=(2 * k)).map(|i| i as f32).product::<f32>();
let coeff = (-1.0f32).powf(k_f) / (factorial_2k * factorial_2k);
term = term.mul_op(&x_pow_4.pow_scalar(k_f)?)?;
let series_term = term.mul_scalar(coeff)?;
result = result.add_op(&series_term)?;
}
Ok(result)
}
pub fn dawson(x: &Tensor) -> TorshResult<Tensor> {
let abs_x = x.abs()?;
let _small_x_mask = abs_x.lt_scalar(2.0)?;
let x_squared = x.pow_scalar(2.0)?;
let mut small_result = x.clone();
let mut term = x.clone();
for n in 1..10 {
let n_f = n as f32;
let coeff =
(-2.0f32).powf(n_f) / ((2.0 * n_f + 1.0) * (1..=n).map(|i| i as f32).product::<f32>());
term = term.mul_op(&x_squared)?;
let series_term = term.mul_scalar(coeff)?;
small_result = small_result.add_op(&series_term)?;
}
let inv_2x = x.pow_scalar(-1.0)?.mul_scalar(0.5)?;
let _large_result = inv_2x.clone();
Ok(small_result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use torsh_core::DeviceType;
#[test]
fn test_spherical_bessel_j0() {
let input = Tensor::from_data(
vec![0.0f32, 1.0, std::f32::consts::PI],
vec![3],
DeviceType::Cpu,
)
.unwrap();
let result = spherical_j0(&input).unwrap();
let data = result.data().expect("tensor should have data");
assert_relative_eq!(data[0], 1.0, epsilon = 1e-6);
assert!(data[2].abs() < 1e-6);
}
#[test]
fn test_spherical_bessel_j1() {
let input = Tensor::from_data(vec![0.0f32, 1.0], vec![2], DeviceType::Cpu).unwrap();
let result = spherical_j1(&input).unwrap();
let data = result.data().expect("tensor should have data");
assert_relative_eq!(data[0], 0.0, epsilon = 1e-6);
assert_relative_eq!(data[1], 0.30116866, epsilon = 1e-6);
}
#[test]
fn test_logsumexp() {
let input = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).unwrap();
let result = logsumexp(&input, None, false).unwrap();
let expected = 3.0 + ((-2.0f32).exp() + (-1.0f32).exp() + 1.0).ln();
let data = result.data().expect("tensor should have data");
assert_relative_eq!(data[0], expected, epsilon = 1e-6);
}
#[test]
fn test_normal_cdf() {
let input = Tensor::from_data(vec![0.0f32, 1.0, -1.0], vec![3], DeviceType::Cpu).unwrap();
let result = normal_cdf(&input).unwrap();
let data = result.data().expect("tensor should have data");
assert_relative_eq!(data[0], 0.5, epsilon = 1e-6);
assert_relative_eq!(data[1], 0.8413447, epsilon = 1e-6);
assert_relative_eq!(data[2], 0.15865526, epsilon = 1e-6);
}
}