use ::ndarray::{Array, ArrayView, Dimension, IxDyn, ShapeBuilder};
use crate::ufuncs::core::{UFunc, UFuncKind, apply_unary, register_ufunc};
use std::sync::Once;
static INIT: Once = Once::new();
#[allow(dead_code)]
fn init_math_ufuncs() {
INIT.call_once(|| {
let _ = register_ufunc(Box::new(SinUFunc));
let _ = register_ufunc(Box::new(CosUFunc));
let _ = register_ufunc(Box::new(TanUFunc));
let _ = register_ufunc(Box::new(ExpUFunc));
let _ = register_ufunc(Box::new(LogUFunc));
let _ = register_ufunc(Box::new(SqrtUFunc));
let _ = register_ufunc(Box::new(AbsUFunc));
});
}
pub struct SinUFunc;
impl UFunc for SinUFunc {
fn name(&self) -> &str {
"sin"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Unary
}
fn apply<D>(&self, inputs: &[&crate::ndarray::ArrayBase<crate::ndarray::Data, D>], output: &mut crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Result<(), &'static str>
where
D: Dimension,
{
if inputs.len() != 1 {
return Err("Sin requires exactly one input array");
}
apply_unary(inputs[0], output, |&x: &f64| x.sin())
}
}
pub struct CosUFunc;
impl UFunc for CosUFunc {
fn name(&self) -> &str {
"cos"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Unary
}
fn apply<D>(&self, inputs: &[&crate::ndarray::ArrayBase<crate::ndarray::Data, D>], output: &mut crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Result<(), &'static str>
where
D: Dimension,
{
if inputs.len() != 1 {
return Err("Cos requires exactly one input array");
}
apply_unary(inputs[0], output, |&x: &f64| x.cos())
}
}
pub struct TanUFunc;
impl UFunc for TanUFunc {
fn name(&self) -> &str {
"tan"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Unary
}
fn apply<D>(&self, inputs: &[&crate::ndarray::ArrayBase<crate::ndarray::Data, D>], output: &mut crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Result<(), &'static str>
where
D: Dimension,
{
if inputs.len() != 1 {
return Err("Tan requires exactly one input array");
}
apply_unary(inputs[0], output, |&x: &f64| x.tan())
}
}
pub struct ExpUFunc;
impl UFunc for ExpUFunc {
fn name(&self) -> &str {
"exp"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Unary
}
fn apply<D>(&self, inputs: &[&crate::ndarray::ArrayBase<crate::ndarray::Data, D>], output: &mut crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Result<(), &'static str>
where
D: Dimension,
{
if inputs.len() != 1 {
return Err("Exp requires exactly one input array");
}
apply_unary(inputs[0], output, |&x: &f64| x.exp())
}
}
pub struct LogUFunc;
impl UFunc for LogUFunc {
fn name(&self) -> &str {
"log"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Unary
}
fn apply<D>(&self, inputs: &[&crate::ndarray::ArrayBase<crate::ndarray::Data, D>], output: &mut crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Result<(), &'static str>
where
D: Dimension,
{
if inputs.len() != 1 {
return Err("Log requires exactly one input array");
}
apply_unary(inputs[0], output, |&x: &f64| x.ln())
}
}
pub struct SqrtUFunc;
impl UFunc for SqrtUFunc {
fn name(&self) -> &str {
"sqrt"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Unary
}
fn apply<D>(&self, inputs: &[&crate::ndarray::ArrayBase<crate::ndarray::Data, D>], output: &mut crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Result<(), &'static str>
where
D: Dimension,
{
if inputs.len() != 1 {
return Err("Sqrt requires exactly one input array");
}
apply_unary(inputs[0], output, |&x: &f64| x.sqrt())
}
}
pub struct AbsUFunc;
impl UFunc for AbsUFunc {
fn name(&self) -> &str {
"abs"
}
fn kind(&self) -> UFuncKind {
UFuncKind::Unary
}
fn apply<D>(&self, inputs: &[&crate::ndarray::ArrayBase<crate::ndarray::Data, D>], output: &mut crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Result<(), &'static str>
where
D: Dimension,
{
if inputs.len() != 1 {
return Err("Abs requires exactly one input array");
}
apply_unary(inputs[0], output, |&x: &f64| x.abs())
}
}
#[allow(dead_code)]
pub fn sin<D>(array: &crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Array<f64, D>
where
D: Dimension,
{
init_math_ufuncs();
let mut result = Array::zeros(_array.raw_dim());
let sin_ufunc = SinUFunc;
sin_ufunc.apply(&[_array], &mut result).expect("Operation failed");
result
}
#[allow(dead_code)]
pub fn cos<D>(array: &crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Array<f64, D>
where
D: Dimension,
{
init_math_ufuncs();
let mut result = Array::zeros(_array.raw_dim());
let cos_ufunc = CosUFunc;
cos_ufunc.apply(&[_array], &mut result).expect("Operation failed");
result
}
#[allow(dead_code)]
pub fn tan<D>(array: &crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Array<f64, D>
where
D: Dimension,
{
init_math_ufuncs();
let mut result = Array::zeros(_array.raw_dim());
let tan_ufunc = TanUFunc;
tan_ufunc.apply(&[_array], &mut result).expect("Operation failed");
result
}
#[allow(dead_code)]
pub fn exp<D>(array: &crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Array<f64, D>
where
D: Dimension,
{
init_math_ufuncs();
let mut result = Array::zeros(_array.raw_dim());
let exp_ufunc = ExpUFunc;
exp_ufunc.apply(&[_array], &mut result).expect("Operation failed");
result
}
#[allow(dead_code)]
pub fn log<D>(array: &crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Array<f64, D>
where
D: Dimension,
{
init_math_ufuncs();
let mut result = Array::zeros(_array.raw_dim());
let log_ufunc = LogUFunc;
log_ufunc.apply(&[_array], &mut result).expect("Operation failed");
result
}
#[allow(dead_code)]
pub fn sqrt<D>(array: &crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Array<f64, D>
where
D: Dimension,
{
init_math_ufuncs();
let mut result = Array::zeros(_array.raw_dim());
let sqrt_ufunc = SqrtUFunc;
sqrt_ufunc.apply(&[_array], &mut result).expect("Operation failed");
result
}
#[allow(dead_code)]
pub fn abs<D>(array: &crate::ndarray::ArrayBase<crate::ndarray::Data, D>) -> Array<f64, D>
where
D: Dimension,
{
init_math_ufuncs();
let mut result = Array::zeros(_array.raw_dim());
let abs_ufunc = AbsUFunc;
abs_ufunc.apply(&[_array], &mut result).expect("Operation failed");
result
}
#[cfg(test)]
mod tests {
use super::*;
use ::ndarray::array;
#[test]
fn test_sin() {
let a = array![0.0, PI/2.0, PI];
let result = sin(&a);
assert!((result[0] - 0.0).abs() < 1e-10);
assert!((result[1] - 1.0).abs() < 1e-10);
assert!((result[2] - 0.0).abs() < 1e-10);
}
#[test]
fn test_cos() {
let a = array![0.0, PI/2.0, PI];
let result = cos(&a);
assert!((result[0] - 1.0).abs() < 1e-10);
assert!((result[1] - 0.0).abs() < 1e-10);
assert!((result[2] + 1.0).abs() < 1e-10);
}
#[test]
fn test_sqrt() {
let a = array![1.0, 4.0, 9.0];
let result = sqrt(&a);
assert_eq!(result, array![1.0, 2.0, 3.0]);
}
#[test]
fn test_abs() {
let a = array![-1.0, 0.0, 1.0];
let result = abs(&a);
assert_eq!(result, array![1.0, 0.0, 1.0]);
}
}