use crate::Numeric;
use crate::enums::error::KernelError;
use crate::kernels::routing::broadcast::maybe_broadcast_scalar_array;
use crate::{Array, ArrayV, Vec64};
pub fn binary_map<T, F>(
lhs: impl Into<ArrayV>,
rhs: impl Into<ArrayV>,
f: F,
) -> Result<Array, KernelError>
where
T: Numeric,
Vec64<T>: Into<Array>,
F: Fn(T, T) -> T,
{
let (lhs, rhs) = maybe_broadcast_scalar_array(lhs.into(), rhs.into())?;
let lhs_vec = lhs.to_typed_vec::<T>()?;
let rhs_vec = rhs.to_typed_vec::<T>()?;
let result: Vec64<T> = lhs_vec
.iter()
.zip(rhs_vec.iter())
.map(|(a, b)| f(*a, *b))
.collect();
Ok(result.into())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::NumericArray;
fn make_f64_array(values: &[f64]) -> Array {
Vec64::from(values).into()
}
fn make_i32_array(values: &[i32]) -> Array {
Vec64::from(values).into()
}
#[test]
fn test_binary_map_two_arrays() {
let arr1 = make_f64_array(&[1.0, 2.0, 3.0]);
let arr2 = make_f64_array(&[10.0, 20.0, 30.0]);
let result = binary_map::<f64, _>(arr1, arr2, |a, b| a + b).unwrap();
match result {
Array::NumericArray(NumericArray::Float64(a)) => {
assert_eq!(a.data.as_slice(), &[11.0, 22.0, 33.0]);
}
_ => panic!("Expected Float64 array"),
}
}
#[test]
fn test_binary_map_broadcast_scalar() {
let arr = make_f64_array(&[1.0, 2.0, 3.0]);
let scalar_arr = make_f64_array(&[10.0]);
let result = binary_map::<f64, _>(arr, scalar_arr, |a, b| a * b).unwrap();
match result {
Array::NumericArray(NumericArray::Float64(a)) => {
assert_eq!(a.data.as_slice(), &[10.0, 20.0, 30.0]);
}
_ => panic!("Expected Float64 array"),
}
}
#[test]
fn test_binary_map_type_cast() {
let arr1 = make_i32_array(&[1, 2, 3]);
let arr2 = make_i32_array(&[10, 20, 30]);
let result = binary_map::<f64, _>(arr1, arr2, |a, b| a + b).unwrap();
match result {
Array::NumericArray(NumericArray::Float64(a)) => {
assert_eq!(a.data.as_slice(), &[11.0, 22.0, 33.0]);
}
_ => panic!("Expected Float64 array"),
}
}
#[cfg(feature = "scalar_type")]
#[test]
fn test_binary_map_with_scalar() {
use crate::Scalar;
let arr = make_f64_array(&[1.0, 2.0, 3.0]);
let scalar = Scalar::Float64(10.0);
let result = binary_map::<f64, _>(arr, scalar, |a, b| a + b).unwrap();
match result {
Array::NumericArray(NumericArray::Float64(a)) => {
assert_eq!(a.data.as_slice(), &[11.0, 12.0, 13.0]);
}
_ => panic!("Expected Float64 array"),
}
}
}