use crate::tensor::Tensor;
use num_traits::Float;
use std::fmt::Debug;
pub mod bessel;
pub mod error;
pub mod gamma;
pub mod utils;
pub use bessel::{bessel_i, bessel_j, bessel_k, bessel_y};
pub use error::{erf, erfc, erfcinv, erfinv};
pub use gamma::{beta, digamma, gamma, lbeta, lgamma};
pub trait SpecialFunctions<T: Float> {
fn gamma(&self) -> crate::error::RusTorchResult<Tensor<T>>;
fn lgamma(&self) -> crate::error::RusTorchResult<Tensor<T>>;
fn digamma(&self) -> crate::error::RusTorchResult<Tensor<T>>;
fn erf(&self) -> crate::error::RusTorchResult<Tensor<T>>;
fn erfc(&self) -> crate::error::RusTorchResult<Tensor<T>>;
fn erfinv(&self) -> crate::error::RusTorchResult<Tensor<T>>;
fn bessel_j(&self, n: T) -> crate::error::RusTorchResult<Tensor<T>>;
fn bessel_y(&self, n: T) -> crate::error::RusTorchResult<Tensor<T>>;
fn bessel_i(&self, n: T) -> crate::error::RusTorchResult<Tensor<T>>;
fn bessel_k(&self, n: T) -> crate::error::RusTorchResult<Tensor<T>>;
}
impl<T> SpecialFunctions<T> for Tensor<T>
where
T: Float + Debug + 'static,
{
fn gamma(&self) -> crate::error::RusTorchResult<Tensor<T>> {
let mut result = vec![T::zero(); self.data.len()];
for (i, &x) in self.data.iter().enumerate() {
result[i] = gamma::gamma_scalar(x)?;
}
Ok(Tensor::from_vec(result, self.shape().to_vec()))
}
fn lgamma(&self) -> crate::error::RusTorchResult<Tensor<T>> {
let mut result = vec![T::zero(); self.data.len()];
for (i, &x) in self.data.iter().enumerate() {
result[i] = gamma::lgamma_scalar(x)?;
}
Ok(Tensor::from_vec(result, self.shape().to_vec()))
}
fn digamma(&self) -> crate::error::RusTorchResult<Tensor<T>> {
let mut result = vec![T::zero(); self.data.len()];
for (i, &x) in self.data.iter().enumerate() {
result[i] = gamma::digamma_scalar(x)?;
}
Ok(Tensor::from_vec(result, self.shape().to_vec()))
}
fn erf(&self) -> crate::error::RusTorchResult<Tensor<T>> {
let mut result = vec![T::zero(); self.data.len()];
for (i, &x) in self.data.iter().enumerate() {
result[i] = error::erf_scalar(x);
}
Ok(Tensor::from_vec(result, self.shape().to_vec()))
}
fn erfc(&self) -> crate::error::RusTorchResult<Tensor<T>> {
let mut result = vec![T::zero(); self.data.len()];
for (i, &x) in self.data.iter().enumerate() {
result[i] = error::erfc_scalar(x);
}
Ok(Tensor::from_vec(result, self.shape().to_vec()))
}
fn erfinv(&self) -> crate::error::RusTorchResult<Tensor<T>> {
let mut result = vec![T::zero(); self.data.len()];
for (i, &x) in self.data.iter().enumerate() {
result[i] = error::erfinv_scalar(x)?;
}
Ok(Tensor::from_vec(result, self.shape().to_vec()))
}
fn bessel_j(&self, n: T) -> crate::error::RusTorchResult<Tensor<T>> {
let mut result = vec![T::zero(); self.data.len()];
for (i, &x) in self.data.iter().enumerate() {
result[i] = bessel::bessel_j_scalar(n, x)?;
}
Ok(Tensor::from_vec(result, self.shape().to_vec()))
}
fn bessel_y(&self, n: T) -> crate::error::RusTorchResult<Tensor<T>> {
let mut result = vec![T::zero(); self.data.len()];
for (i, &x) in self.data.iter().enumerate() {
result[i] = bessel::bessel_y_scalar(n, x)?;
}
Ok(Tensor::from_vec(result, self.shape().to_vec()))
}
fn bessel_i(&self, n: T) -> crate::error::RusTorchResult<Tensor<T>> {
let mut result = vec![T::zero(); self.data.len()];
for (i, &x) in self.data.iter().enumerate() {
result[i] = bessel::bessel_i_scalar(n, x)?;
}
Ok(Tensor::from_vec(result, self.shape().to_vec()))
}
fn bessel_k(&self, n: T) -> crate::error::RusTorchResult<Tensor<T>> {
let mut result = vec![T::zero(); self.data.len()];
for (i, &x) in self.data.iter().enumerate() {
result[i] = bessel::bessel_k_scalar(n, x)?;
}
Ok(Tensor::from_vec(result, self.shape().to_vec()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_special_functions_module() {
let x = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]);
let gamma_result = x.gamma();
assert!(gamma_result.is_ok());
let erf_result = x.erf();
assert!(erf_result.is_ok());
}
}