use ferray_core::Array;
use ferray_core::dimension::broadcast::broadcast_shapes;
use ferray_core::dimension::{Dimension, IxDyn};
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
#[inline]
pub fn unary_float_op<T, D>(input: &Array<T, D>, f: impl Fn(T) -> T) -> FerrayResult<Array<T, D>>
where
T: Element + Copy,
D: Dimension,
{
if let Some(slice) = input.as_slice() {
let n = slice.len();
let mut data = Vec::with_capacity(n);
#[allow(clippy::uninit_vec)]
unsafe {
data.set_len(n);
}
for (o, &x) in data.iter_mut().zip(slice.iter()) {
*o = f(x);
}
Array::from_vec(input.dim().clone(), data)
} else {
let data: Vec<T> = input.iter().map(|&x| f(x)).collect();
Array::from_vec(input.dim().clone(), data)
}
}
#[inline]
pub fn unary_slice_op_f64<D>(
input: &Array<f64, D>,
kernel: fn(&[f64], &mut [f64]),
scalar_fallback: fn(f64) -> f64,
) -> FerrayResult<Array<f64, D>>
where
D: Dimension,
{
let n = input.size();
if let Some(slice) = input.as_slice() {
let mut data = Vec::with_capacity(n);
#[allow(clippy::uninit_vec)]
unsafe {
data.set_len(n);
}
kernel(slice, &mut data);
Array::from_vec(input.dim().clone(), data)
} else {
let data: Vec<f64> = input.iter().map(|&x| scalar_fallback(x)).collect();
Array::from_vec(input.dim().clone(), data)
}
}
#[inline]
pub fn unary_slice_op_f32<D>(
input: &Array<f32, D>,
kernel: fn(&[f32], &mut [f32]),
scalar_fallback: fn(f32) -> f32,
) -> FerrayResult<Array<f32, D>>
where
D: Dimension,
{
let n = input.size();
if let Some(slice) = input.as_slice() {
let mut data = Vec::with_capacity(n);
#[allow(clippy::uninit_vec)]
unsafe {
data.set_len(n);
}
kernel(slice, &mut data);
Array::from_vec(input.dim().clone(), data)
} else {
let data: Vec<f32> = input.iter().map(|&x| scalar_fallback(x)).collect();
Array::from_vec(input.dim().clone(), data)
}
}
#[inline]
pub fn try_simd_f64_unary<T, D>(
input: &Array<T, D>,
kernel: fn(&[f64], &mut [f64]),
) -> Option<FerrayResult<Array<T, D>>>
where
T: Element + Copy,
D: Dimension,
{
use std::any::TypeId;
if TypeId::of::<T>() != TypeId::of::<f64>() {
return None;
}
let slice = input.as_slice()?;
let n = slice.len();
let f64_slice: &[f64] = unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const f64, n) };
let mut output = Vec::with_capacity(n);
#[allow(clippy::uninit_vec)]
unsafe {
output.set_len(n);
}
kernel(f64_slice, &mut output);
let t_vec: Vec<T> = unsafe {
let mut md = std::mem::ManuallyDrop::new(output);
Vec::from_raw_parts(md.as_mut_ptr() as *mut T, n, n)
};
Some(Array::from_vec(input.dim().clone(), t_vec))
}
#[inline]
pub fn unary_map_op<T, U, D>(input: &Array<T, D>, f: impl Fn(T) -> U) -> FerrayResult<Array<U, D>>
where
T: Element + Copy,
U: Element,
D: Dimension,
{
let data: Vec<U> = input.iter().map(|&x| f(x)).collect();
Array::from_vec(input.dim().clone(), data)
}
#[inline]
pub fn binary_float_op<T, D>(
a: &Array<T, D>,
b: &Array<T, D>,
f: impl Fn(T, T) -> T,
) -> FerrayResult<Array<T, D>>
where
T: Element + Copy,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"binary op: shapes {:?} and {:?} do not match",
a.shape(),
b.shape()
)));
}
let data: Vec<T> = a.iter().zip(b.iter()).map(|(&x, &y)| f(x, y)).collect();
Array::from_vec(a.dim().clone(), data)
}
#[inline]
pub fn binary_map_op<T, U, D>(
a: &Array<T, D>,
b: &Array<T, D>,
f: impl Fn(T, T) -> U,
) -> FerrayResult<Array<U, D>>
where
T: Element + Copy,
U: Element,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"binary op: shapes {:?} and {:?} do not match",
a.shape(),
b.shape()
)));
}
let data: Vec<U> = a.iter().zip(b.iter()).map(|(&x, &y)| f(x, y)).collect();
Array::from_vec(a.dim().clone(), data)
}
pub fn binary_broadcast_op<T, D1, D2>(
a: &Array<T, D1>,
b: &Array<T, D2>,
f: impl Fn(T, T) -> T,
) -> FerrayResult<Array<T, IxDyn>>
where
T: Element + Copy,
D1: Dimension,
D2: Dimension,
{
let shape = broadcast_shapes(a.shape(), b.shape())?;
let a_view = a.broadcast_to(&shape)?;
let b_view = b.broadcast_to(&shape)?;
let data: Vec<T> = a_view
.iter()
.zip(b_view.iter())
.map(|(&x, &y)| f(x, y))
.collect();
Array::from_vec(IxDyn::from(&shape[..]), data)
}
pub fn binary_broadcast_map_op<T, U, D1, D2>(
a: &Array<T, D1>,
b: &Array<T, D2>,
f: impl Fn(T, T) -> U,
) -> FerrayResult<Array<U, IxDyn>>
where
T: Element + Copy,
U: Element,
D1: Dimension,
D2: Dimension,
{
let shape = broadcast_shapes(a.shape(), b.shape())?;
let a_view = a.broadcast_to(&shape)?;
let b_view = b.broadcast_to(&shape)?;
let data: Vec<U> = a_view
.iter()
.zip(b_view.iter())
.map(|(&x, &y)| f(x, y))
.collect();
Array::from_vec(IxDyn::from(&shape[..]), data)
}
#[cfg(feature = "f16")]
#[inline]
pub fn unary_f16_op<D>(
input: &Array<half::f16, D>,
f: impl Fn(f32) -> f32,
) -> FerrayResult<Array<half::f16, D>>
where
D: Dimension,
{
let data: Vec<half::f16> = input
.iter()
.map(|&x| half::f16::from_f32(f(x.to_f32())))
.collect();
Array::from_vec(input.dim().clone(), data)
}
#[cfg(feature = "f16")]
#[inline]
pub fn unary_f16_to_bool_op<D>(
input: &Array<half::f16, D>,
f: impl Fn(f32) -> bool,
) -> FerrayResult<Array<bool, D>>
where
D: Dimension,
{
let data: Vec<bool> = input.iter().map(|&x| f(x.to_f32())).collect();
Array::from_vec(input.dim().clone(), data)
}
#[cfg(feature = "f16")]
#[inline]
pub fn binary_f16_op<D>(
a: &Array<half::f16, D>,
b: &Array<half::f16, D>,
f: impl Fn(f32, f32) -> f32,
) -> FerrayResult<Array<half::f16, D>>
where
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"binary op: shapes {:?} and {:?} do not match",
a.shape(),
b.shape()
)));
}
let data: Vec<half::f16> = a
.iter()
.zip(b.iter())
.map(|(&x, &y)| half::f16::from_f32(f(x.to_f32(), y.to_f32())))
.collect();
Array::from_vec(a.dim().clone(), data)
}
#[cfg(feature = "f16")]
#[inline]
pub fn binary_f16_to_bool_op<D>(
a: &Array<half::f16, D>,
b: &Array<half::f16, D>,
f: impl Fn(f32, f32) -> bool,
) -> FerrayResult<Array<bool, D>>
where
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"binary op: shapes {:?} and {:?} do not match",
a.shape(),
b.shape()
)));
}
let data: Vec<bool> = a
.iter()
.zip(b.iter())
.map(|(&x, &y)| f(x.to_f32(), y.to_f32()))
.collect();
Array::from_vec(a.dim().clone(), data)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::{Ix1, Ix2};
#[test]
fn unary_op_works() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 4.0, 9.0]).unwrap();
let r = unary_float_op(&a, f64::sqrt).unwrap();
assert_eq!(r.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn binary_op_same_shape() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![4.0, 5.0, 6.0]).unwrap();
let r = binary_float_op(&a, &b, |x, y| x + y).unwrap();
assert_eq!(r.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
}
#[test]
fn binary_op_shape_mismatch() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![4.0, 5.0]).unwrap();
assert!(binary_float_op(&a, &b, |x, y| x + y).is_err());
}
#[test]
fn binary_broadcast_works() {
let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![1.0, 2.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
let r = binary_broadcast_op(&a, &b, |x, y| x + y).unwrap();
assert_eq!(r.shape(), &[2, 3]);
let s: Vec<f64> = r.iter().copied().collect();
assert_eq!(s, vec![11.0, 21.0, 31.0, 12.0, 22.0, 32.0]);
}
}