use super::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{One, Zero};
use scirs2_core::ndarray::Array1;
use scirs2_core::parallel_ops::*;
use scirs2_core::simd_ops::SimdUnifiedOps;
use std::ops::{Add, Div, Mul, Sub};
impl<T: Clone> Array<T> {
pub fn scalar_mul(&self, scalar: T) -> Self
where
T: Clone + Mul<Output = T>,
{
self.map(|x| x * scalar.clone())
}
pub fn scalar_div(&self, scalar: T) -> Self
where
T: Clone + Div<Output = T>,
{
self.map(|x| x / scalar.clone())
}
pub fn sum_all(&self) -> T
where
T: Clone + Add<Output = T> + Zero,
{
self.data.iter().fold(T::zero(), |acc, x| acc + x.clone())
}
pub fn sum_axis(&self, axis: usize) -> Result<Self>
where
T: Clone + Add<Output = T> + Zero,
{
let axis_val = axis;
if axis_val >= self.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis_val,
self.ndim()
)));
}
let shape = self.shape();
let axis_size = shape[axis_val];
let mut result_shape = shape.clone();
result_shape.remove(axis_val);
let mut result = Self::zeros(&result_shape);
let data = self.to_vec();
let mut indices = vec![0; shape.len()];
let mut result_indices = vec![0; result_shape.len()];
let result_size = result.size();
for i in 0..result_size {
let mut remainder = i;
for j in (0..result_shape.len()).rev() {
result_indices[j] = remainder % result_shape[j];
remainder /= result_shape[j];
}
let mut result_idx = 0;
for (j, idx) in indices.iter_mut().enumerate() {
if j == axis_val {
*idx = 0; } else {
*idx = result_indices[result_idx];
result_idx += 1;
}
}
let mut sum = T::zero();
for k in 0..axis_size {
indices[axis_val] = k;
let mut flat_idx = 0;
let mut stride = 1;
for j in (0..shape.len()).rev() {
flat_idx += indices[j] * stride;
stride *= shape[j];
}
sum = sum + data[flat_idx].clone();
}
result.set(&result_indices, sum)?;
}
Ok(result)
}
pub fn par_map<F, U>(&self, f: F) -> Array<U>
where
T: Send + Sync + Clone,
U: Send + Clone,
F: Fn(T) -> U + Send + Sync,
{
let vec_data = self.to_vec();
let result: Vec<U> = vec_data.par_iter().map(|x| f(x.clone())).collect();
Array::from_vec(result).reshape(&self.shape())
}
pub fn map<F, U>(&self, f: F) -> Array<U>
where
U: Clone,
F: Fn(T) -> U,
T: Clone,
{
let vec_data = self.to_vec();
let result: Vec<U> = vec_data.iter().map(|x| f(x.clone())).collect();
Array::from_vec(result).reshape(&self.shape())
}
pub fn zip_with<F, U, V>(&self, other: &Array<U>, f: F) -> Result<Array<V>>
where
T: Clone,
U: Clone,
V: Clone,
F: Fn(T, U) -> V,
{
let a_shape = self.shape();
let b_shape = other.shape();
if a_shape == b_shape {
let self_data = self.to_vec();
let other_data = other.to_vec();
let result: Vec<V> = self_data
.iter()
.zip(other_data.iter())
.map(|(a, b)| f(a.clone(), b.clone()))
.collect();
return Ok(Array::from_vec(result).reshape(&self.shape()));
}
let broadcast_shape = Self::broadcast_shape(&a_shape, &b_shape)?;
let self_broadcast = self.broadcast_to(&broadcast_shape)?;
let other_broadcast = other.broadcast_to(&broadcast_shape)?;
let self_data = self_broadcast.to_vec();
let other_data = other_broadcast.to_vec();
let result: Vec<V> = self_data
.iter()
.zip(other_data.iter())
.map(|(a, b)| f(a.clone(), b.clone()))
.collect();
Ok(Array::from_vec(result).reshape(&broadcast_shape))
}
pub fn broadcast_op<F, U, V>(&self, other: &Array<U>, op: F) -> Result<Array<V>>
where
T: Clone,
U: Clone,
V: Clone,
F: Fn(&Array<T>, &Array<U>) -> Array<V>,
{
let a_shape = self.shape();
let b_shape = other.shape();
if a_shape == b_shape {
return Ok(op(self, other));
}
let broadcast_shape = Self::broadcast_shape(&a_shape, &b_shape)?;
let self_broadcast = self.broadcast_to(&broadcast_shape)?;
let other_broadcast = other.broadcast_to(&broadcast_shape)?;
Ok(op(&self_broadcast, &other_broadcast))
}
}
impl<T> Array<T>
where
T: Clone + Add<Output = T> + Zero + Mul<Output = T> + One + 'static,
{
pub fn sum(&self) -> T {
if self.len() >= 64 && std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let data: Vec<f64> = self
.to_vec()
.iter()
.map(|x| {
let ptr = x as *const T as *const f64;
unsafe { *ptr }
})
.collect();
let nd_array = Array1::from_vec(data);
let result = f64::simd_sum(&nd_array.view());
return unsafe { std::mem::transmute_copy(&result) };
}
if self.len() >= 64 && std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let data: Vec<f32> = self
.to_vec()
.iter()
.map(|x| {
let ptr = x as *const T as *const f32;
unsafe { *ptr }
})
.collect();
let nd_array = Array1::from_vec(data);
let result = f32::simd_sum(&nd_array.view());
return unsafe { std::mem::transmute_copy(&result) };
}
let data = self.to_vec();
data.iter().fold(T::zero(), |acc, x| acc + x.clone())
}
pub fn product(&self) -> T {
let data = self.to_vec();
data.iter().fold(T::one(), |acc, x| acc * x.clone())
}
}
impl<T: Clone + Add<Output = T>> Array<T> {
pub fn add(&self, other: &Array<T>) -> Array<T> {
let result = &self.data + &other.data;
Array { data: result }
}
pub fn add_broadcast(&self, other: &Array<T>) -> Result<Array<T>> {
self.broadcast_op(other, |a, b| {
let result = &a.data + &b.data;
Array { data: result }
})
}
}
impl<T: Clone + Sub<Output = T>> Array<T> {
pub fn subtract(&self, other: &Array<T>) -> Array<T> {
let result = &self.data - &other.data;
Array { data: result }
}
pub fn subtract_broadcast(&self, other: &Array<T>) -> Result<Array<T>> {
self.broadcast_op(other, |a, b| {
let result = &a.data - &b.data;
Array { data: result }
})
}
}
impl<T: Clone + Mul<Output = T>> Array<T> {
pub fn multiply(&self, other: &Array<T>) -> Array<T> {
let result = &self.data * &other.data;
Array { data: result }
}
pub fn multiply_broadcast(&self, other: &Array<T>) -> Result<Array<T>> {
self.broadcast_op(other, |a, b| {
let result = &a.data * &b.data;
Array { data: result }
})
}
}
impl<T: Clone + Div<Output = T>> Array<T> {
pub fn divide(&self, other: &Array<T>) -> Array<T> {
let result = &self.data / &other.data;
Array { data: result }
}
pub fn divide_broadcast(&self, other: &Array<T>) -> Result<Array<T>> {
self.broadcast_op(other, |a, b| {
let result = &a.data / &b.data;
Array { data: result }
})
}
}
impl<T: Clone + Add<Output = T>> Array<T> {
pub fn add_scalar(&self, scalar: T) -> Self {
self.map(|x| x + scalar.clone())
}
}
impl<T: Clone + Sub<Output = T>> Array<T> {
pub fn subtract_scalar(&self, scalar: T) -> Self {
self.map(|x| x - scalar.clone())
}
}
impl<T: Clone + Mul<Output = T>> Array<T> {
pub fn multiply_scalar(&self, scalar: T) -> Self {
self.map(|x| x * scalar.clone())
}
}
impl<T: Clone + Div<Output = T>> Array<T> {
pub fn divide_scalar(&self, scalar: T) -> Self {
self.map(|x| x / scalar.clone())
}
}