use ferray_core::Array;
use ferray_core::dimension::{Dimension, Ix1, IxDyn};
use ferray_core::dtype::Element;
use ferray_core::error::FerrayResult;
use crate::helpers::binary_elementwise_op;
use crate::ufunc_methods::{accumulate_axis, at, outer, reduce_axis};
#[derive(Clone, Copy)]
pub struct Ufunc<T, F>
where
T: Element + Copy,
F: Fn(T, T) -> T,
{
name: &'static str,
identity: T,
op: F,
}
impl<T, F> Ufunc<T, F>
where
T: Element + Copy,
F: Fn(T, T) -> T,
{
#[inline]
pub const fn new(name: &'static str, identity: T, op: F) -> Self {
Self { name, identity, op }
}
#[inline]
pub const fn name(&self) -> &'static str {
self.name
}
#[inline]
pub const fn identity(&self) -> T {
self.identity
}
pub fn call<D: Dimension>(
&self,
a: &Array<T, D>,
b: &Array<T, D>,
) -> FerrayResult<Array<T, D>> {
binary_elementwise_op(a, b, &self.op)
}
pub fn reduce<D: Dimension>(
&self,
a: &Array<T, D>,
axis: usize,
) -> FerrayResult<Array<T, IxDyn>> {
reduce_axis(a, axis, self.identity, &self.op)
}
pub fn accumulate<D: Dimension>(
&self,
a: &Array<T, D>,
axis: usize,
) -> FerrayResult<Array<T, D>> {
accumulate_axis(a, axis, &self.op)
}
pub fn outer(&self, a: &Array<T, Ix1>, b: &Array<T, Ix1>) -> FerrayResult<Array<T, IxDyn>> {
outer(a, b, &self.op)
}
pub fn at(&self, arr: &mut Array<T, Ix1>, indices: &[usize], values: &[T]) -> FerrayResult<()> {
at(arr, indices, values, &self.op)
}
}
#[must_use]
pub fn add_ufunc<T>() -> Ufunc<T, fn(T, T) -> T>
where
T: Element + Copy + std::ops::Add<Output = T>,
{
fn add_kernel<T: std::ops::Add<Output = T>>(a: T, b: T) -> T {
a + b
}
Ufunc::new(
"add",
<T as Element>::zero(),
add_kernel::<T> as fn(T, T) -> T,
)
}
#[must_use]
pub fn subtract_ufunc<T>() -> Ufunc<T, fn(T, T) -> T>
where
T: Element + Copy + std::ops::Sub<Output = T>,
{
fn sub_kernel<T: std::ops::Sub<Output = T>>(a: T, b: T) -> T {
a - b
}
Ufunc::new(
"subtract",
<T as Element>::zero(),
sub_kernel::<T> as fn(T, T) -> T,
)
}
#[must_use]
pub fn multiply_ufunc<T>() -> Ufunc<T, fn(T, T) -> T>
where
T: Element + Copy + std::ops::Mul<Output = T>,
{
fn mul_kernel<T: std::ops::Mul<Output = T>>(a: T, b: T) -> T {
a * b
}
Ufunc::new(
"multiply",
<T as Element>::one(),
mul_kernel::<T> as fn(T, T) -> T,
)
}
#[must_use]
pub fn divide_ufunc<T>() -> Ufunc<T, fn(T, T) -> T>
where
T: Element + Copy + std::ops::Div<Output = T>,
{
fn div_kernel<T: std::ops::Div<Output = T>>(a: T, b: T) -> T {
a / b
}
Ufunc::new(
"divide",
<T as Element>::one(),
div_kernel::<T> as fn(T, T) -> T,
)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix2;
use crate::test_util::arr1;
#[test]
fn add_ufunc_roundtrip() {
let add = add_ufunc::<f64>();
assert_eq!(add.name(), "add");
assert_eq!(add.identity(), 0.0);
let a = arr1([1.0, 2.0, 3.0]);
let b = arr1([10.0, 20.0, 30.0]);
let c = add.call(&a, &b).unwrap();
assert_eq!(c.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
}
#[test]
fn add_ufunc_reduce() {
let add = add_ufunc::<f64>();
let a = arr1([1.0, 2.0, 3.0, 4.0]);
let r = add.reduce(&a, 0).unwrap();
assert_eq!(r.as_slice().unwrap(), &[10.0]);
}
#[test]
fn add_ufunc_accumulate() {
let add = add_ufunc::<f64>();
let a = arr1([1.0, 2.0, 3.0, 4.0]);
let r = add.accumulate(&a, 0).unwrap();
assert_eq!(r.as_slice().unwrap(), &[1.0, 3.0, 6.0, 10.0]);
}
#[test]
fn multiply_ufunc_reduce_is_product() {
let mul = multiply_ufunc::<f64>();
assert_eq!(mul.identity(), 1.0);
let a = arr1([1.0, 2.0, 3.0, 4.0]);
let r = mul.reduce(&a, 0).unwrap();
assert_eq!(r.as_slice().unwrap(), &[24.0]);
}
#[test]
fn multiply_ufunc_outer() {
let mul = multiply_ufunc::<f64>();
let a = arr1([1.0, 2.0, 3.0]);
let b = arr1([10.0, 20.0]);
let r = mul.outer(&a, &b).unwrap();
assert_eq!(r.shape(), &[3, 2]);
assert_eq!(r.as_slice().unwrap(), &[10.0, 20.0, 20.0, 40.0, 30.0, 60.0]);
}
#[test]
fn subtract_ufunc_accumulate_is_running_diff() {
let sub = subtract_ufunc::<f64>();
let a = arr1([10.0, 3.0, 2.0, 1.0]);
let r = sub.accumulate(&a, 0).unwrap();
assert_eq!(r.as_slice().unwrap(), &[10.0, 7.0, 5.0, 4.0]);
}
#[test]
fn user_defined_max_ufunc() {
let max = Ufunc::new("max", f64::NEG_INFINITY, f64::max);
let a = arr1([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]);
let m = max.reduce(&a, 0).unwrap();
assert_eq!(m.as_slice().unwrap(), &[9.0]);
}
#[test]
fn add_ufunc_at_unbuffered_duplicates() {
let add = add_ufunc::<f64>();
let mut a = arr1([0.0, 0.0, 0.0]);
add.at(&mut a, &[0, 0, 1, 2], &[1.0, 2.0, 5.0, 10.0])
.unwrap();
assert_eq!(a.as_slice().unwrap(), &[3.0, 5.0, 10.0]);
}
#[test]
fn add_ufunc_reduce_2d_row_sums() {
use ferray_core::dimension::Ix2;
let add = add_ufunc::<f64>();
let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
let r = add.reduce(&a, 1).unwrap();
assert_eq!(r.shape(), &[2]);
assert_eq!(r.as_slice().unwrap(), &[6.0, 15.0]);
}
#[test]
fn divide_ufunc_elementwise() {
let div = divide_ufunc::<f64>();
assert_eq!(div.identity(), 1.0);
let a = arr1([10.0, 20.0, 30.0]);
let b = arr1([2.0, 4.0, 5.0]);
let c = div.call(&a, &b).unwrap();
assert_eq!(c.as_slice().unwrap(), &[5.0, 5.0, 6.0]);
}
#[test]
fn ufunc_is_copy() {
let add = add_ufunc::<f64>();
let add_copy = add;
assert_eq!(add.name(), add_copy.name());
let _ = Ix2::new([1, 1]); }
}