use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use crate::error::CoreError;
use crate::{Float, Scalar};
use super::Tensor;
#[cfg(feature = "simd")]
fn simd_binop<T: Scalar>(
a: &[T],
b: &[T],
f64_kernel: fn(&[f64], &[f64], &mut [f64]),
f32_kernel: fn(&[f32], &[f32], &mut [f32]),
scalar_op: fn(T, T) -> T,
) -> Vec<T> {
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
let a_f64 = unsafe { crate::simd::slice_as_f64(a) };
let b_f64 = unsafe { crate::simd::slice_as_f64(b) };
let mut out = Vec::with_capacity(a.len());
unsafe {
out.set_len(a.len());
f64_kernel(a_f64, b_f64, &mut out);
}
let mut out = core::mem::ManuallyDrop::new(out);
unsafe { Vec::from_raw_parts(out.as_mut_ptr().cast::<T>(), out.len(), out.capacity()) }
} else if TypeId::of::<T>() == TypeId::of::<f32>() {
let a_f32 = unsafe { crate::simd::slice_as_f32(a) };
let b_f32 = unsafe { crate::simd::slice_as_f32(b) };
let mut out = Vec::with_capacity(a.len());
unsafe {
out.set_len(a.len());
f32_kernel(a_f32, b_f32, &mut out);
}
let mut out = core::mem::ManuallyDrop::new(out);
unsafe { Vec::from_raw_parts(out.as_mut_ptr().cast::<T>(), out.len(), out.capacity()) }
} else {
a.iter().zip(b.iter()).map(|(&x, &y)| scalar_op(x, y)).collect()
}
}
#[cfg(feature = "simd")]
fn simd_binop_inplace<T: Scalar>(
a: &mut [T],
b: &[T],
f64_kernel: fn(&[f64], &[f64], &mut [f64]),
f32_kernel: fn(&[f32], &[f32], &mut [f32]),
scalar_op: fn(T, T) -> T,
) {
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
let b_f64 = unsafe { crate::simd::slice_as_f64(b) };
let a_f64 = unsafe { crate::simd::slice_as_f64_mut(a) };
let a_input = unsafe { core::slice::from_raw_parts(a_f64.as_ptr(), a_f64.len()) };
f64_kernel(a_input, b_f64, a_f64);
} else if TypeId::of::<T>() == TypeId::of::<f32>() {
let b_f32 = unsafe { crate::simd::slice_as_f32(b) };
let a_f32 = unsafe { crate::simd::slice_as_f32_mut(a) };
let a_input = unsafe { core::slice::from_raw_parts(a_f32.as_ptr(), a_f32.len()) };
f32_kernel(a_input, b_f32, a_f32);
} else {
for (x, &y) in a.iter_mut().zip(b.iter()) {
*x = scalar_op(*x, y);
}
}
}
macro_rules! impl_tensor_binop {
($trait:ident, $method:ident, $op:tt, $f64_kern:path, $f32_kern:path) => {
impl<T: Scalar> $trait for Tensor<T> {
type Output = Tensor<T>;
fn $method(self, rhs: Tensor<T>) -> Tensor<T> {
assert_eq!(
self.shape, rhs.shape,
"shape mismatch in element-wise {}: {:?} vs {:?}",
stringify!($method), self.shape, rhs.shape,
);
#[cfg(feature = "simd")]
let data = simd_binop(
&self.data, &rhs.data,
$f64_kern, $f32_kern,
|a, b| a $op b,
);
#[cfg(not(feature = "simd"))]
let data: Vec<T> = self.data.iter()
.zip(rhs.data.iter())
.map(|(&a, &b)| a $op b)
.collect();
Tensor {
data,
shape: self.shape,
strides: self.strides,
}
}
}
impl<T: Scalar> $trait for &Tensor<T> {
type Output = Tensor<T>;
fn $method(self, rhs: &Tensor<T>) -> Tensor<T> {
assert_eq!(
self.shape, rhs.shape,
"shape mismatch in element-wise {}: {:?} vs {:?}",
stringify!($method), self.shape, rhs.shape,
);
#[cfg(feature = "simd")]
let data = simd_binop(
&self.data, &rhs.data,
$f64_kern, $f32_kern,
|a, b| a $op b,
);
#[cfg(not(feature = "simd"))]
let data: Vec<T> = self.data.iter()
.zip(rhs.data.iter())
.map(|(&a, &b)| a $op b)
.collect();
Tensor {
data,
shape: self.shape.clone(),
strides: self.strides.clone(),
}
}
}
};
}
impl_tensor_binop!(Add, add, +, crate::simd::f64_ops::add_f64, crate::simd::f32_ops::add_f32);
impl_tensor_binop!(Sub, sub, -, crate::simd::f64_ops::sub_f64, crate::simd::f32_ops::sub_f32);
impl_tensor_binop!(Mul, mul, *, crate::simd::f64_ops::mul_f64, crate::simd::f32_ops::mul_f32);
impl_tensor_binop!(Div, div, /, crate::simd::f64_ops::div_f64, crate::simd::f32_ops::div_f32);
macro_rules! impl_tensor_assign_op {
($trait:ident, $method:ident, $op:tt, $f64_kern:path, $f32_kern:path) => {
impl<T: Scalar> $trait<&Tensor<T>> for Tensor<T> {
fn $method(&mut self, rhs: &Tensor<T>) {
assert_eq!(
self.shape, rhs.shape,
"shape mismatch in element-wise {}: {:?} vs {:?}",
stringify!($method), self.shape, rhs.shape,
);
#[cfg(feature = "simd")]
{
simd_binop_inplace(
&mut self.data, &rhs.data,
$f64_kern, $f32_kern,
|a, b| a $op b,
);
return;
}
#[cfg(not(feature = "simd"))]
for (a, &b) in self.data.iter_mut().zip(rhs.data.iter()) {
*a = *a $op b;
}
}
}
impl<T: Scalar> $trait<Tensor<T>> for Tensor<T> {
fn $method(&mut self, rhs: Tensor<T>) {
$trait::$method(self, &rhs);
}
}
};
}
impl_tensor_assign_op!(AddAssign, add_assign, +, crate::simd::f64_ops::add_f64, crate::simd::f32_ops::add_f32);
impl_tensor_assign_op!(SubAssign, sub_assign, -, crate::simd::f64_ops::sub_f64, crate::simd::f32_ops::sub_f32);
impl_tensor_assign_op!(MulAssign, mul_assign, *, crate::simd::f64_ops::mul_f64, crate::simd::f32_ops::mul_f32);
impl_tensor_assign_op!(DivAssign, div_assign, /, crate::simd::f64_ops::div_f64, crate::simd::f32_ops::div_f32);
#[cfg(feature = "simd")]
impl Tensor<f64> {
pub fn add_simd(&self, other: &Tensor<f64>) -> Tensor<f64> {
assert_eq!(self.shape, other.shape, "shape mismatch in simd add");
let mut out = vec![0.0_f64; self.data.len()];
crate::simd::f64_ops::add_f64(&self.data, &other.data, &mut out);
Tensor {
data: out,
shape: self.shape.clone(),
strides: self.strides.clone(),
}
}
pub fn mul_simd(&self, other: &Tensor<f64>) -> Tensor<f64> {
assert_eq!(self.shape, other.shape, "shape mismatch in simd mul");
let mut out = vec![0.0_f64; self.data.len()];
crate::simd::f64_ops::mul_f64(&self.data, &other.data, &mut out);
Tensor {
data: out,
shape: self.shape.clone(),
strides: self.strides.clone(),
}
}
}
#[cfg(feature = "simd")]
impl Tensor<f32> {
pub fn add_simd(&self, other: &Tensor<f32>) -> Tensor<f32> {
assert_eq!(self.shape, other.shape, "shape mismatch in simd add");
let mut out = vec![0.0_f32; self.data.len()];
crate::simd::f32_ops::add_f32(&self.data, &other.data, &mut out);
Tensor {
data: out,
shape: self.shape.clone(),
strides: self.strides.clone(),
}
}
pub fn mul_simd(&self, other: &Tensor<f32>) -> Tensor<f32> {
assert_eq!(self.shape, other.shape, "shape mismatch in simd mul");
let mut out = vec![0.0_f32; self.data.len()];
crate::simd::f32_ops::mul_f32(&self.data, &other.data, &mut out);
Tensor {
data: out,
shape: self.shape.clone(),
strides: self.strides.clone(),
}
}
}
macro_rules! impl_scalar_binop {
($trait:ident, $method:ident, $op:tt) => {
impl<T: Scalar> $trait<T> for Tensor<T> {
type Output = Tensor<T>;
fn $method(self, rhs: T) -> Tensor<T> {
let data = self.data.iter().map(|&a| a $op rhs).collect();
Tensor {
data,
shape: self.shape,
strides: self.strides,
}
}
}
impl<T: Scalar> $trait<T> for &Tensor<T> {
type Output = Tensor<T>;
fn $method(self, rhs: T) -> Tensor<T> {
let data = self.data.iter().map(|&a| a $op rhs).collect();
Tensor {
data,
shape: self.shape.clone(),
strides: self.strides.clone(),
}
}
}
};
}
impl_scalar_binop!(Add, add, +);
impl_scalar_binop!(Sub, sub, -);
impl_scalar_binop!(Mul, mul, *);
impl_scalar_binop!(Div, div, /);
impl<T: Float> Neg for Tensor<T> {
type Output = Tensor<T>;
fn neg(self) -> Tensor<T> {
let data = self.data.iter().map(|&a| -a).collect();
Tensor {
data,
shape: self.shape,
strides: self.strides,
}
}
}
impl<T: Float> Neg for &Tensor<T> {
type Output = Tensor<T>;
fn neg(self) -> Tensor<T> {
let data = self.data.iter().map(|&a| -a).collect();
Tensor {
data,
shape: self.shape.clone(),
strides: self.strides.clone(),
}
}
}
impl<T: Scalar> Tensor<T> {
pub fn add_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
self.zip_map(other, |a, b| a + b)
}
pub fn sub_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
self.zip_map(other, |a, b| a - b)
}
pub fn mul_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
self.zip_map(other, |a, b| a * b)
}
pub fn div_checked(&self, other: &Tensor<T>) -> crate::Result<Tensor<T>> {
self.zip_map(other, |a, b| a / b)
}
}
impl<T: Scalar> Tensor<T> {
pub fn sum(&self) -> T {
#[cfg(feature = "simd")]
{
use crate::simd;
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
let result =
unsafe { simd::f64_ops::sum_f64(simd::slice_as_f64(self.data.as_slice())) };
return unsafe { simd::f64_to_t(result) };
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
let result =
unsafe { simd::f32_ops::sum_f32(simd::slice_as_f32(self.data.as_slice())) };
return unsafe { simd::f32_to_t(result) };
}
}
self.data.iter().copied().sum()
}
pub fn product(&self) -> T {
self.data.iter().copied().fold(T::one(), |acc, x| acc * x)
}
pub fn min_element(&self) -> Option<T> {
if self.data.is_empty() {
return None;
}
#[cfg(feature = "simd")]
{
use crate::simd;
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
let result =
unsafe { simd::f64_ops::min_f64(simd::slice_as_f64(self.data.as_slice())) };
return Some(unsafe { simd::f64_to_t(result) });
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
let result =
unsafe { simd::f32_ops::min_f32(simd::slice_as_f32(self.data.as_slice())) };
return Some(unsafe { simd::f32_to_t(result) });
}
}
self.data
.iter()
.copied()
.reduce(|a, b| if b < a { b } else { a })
}
pub fn max_element(&self) -> Option<T> {
if self.data.is_empty() {
return None;
}
#[cfg(feature = "simd")]
{
use crate::simd;
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
let result =
unsafe { simd::f64_ops::max_f64(simd::slice_as_f64(self.data.as_slice())) };
return Some(unsafe { simd::f64_to_t(result) });
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
let result =
unsafe { simd::f32_ops::max_f32(simd::slice_as_f32(self.data.as_slice())) };
return Some(unsafe { simd::f32_to_t(result) });
}
}
self.data
.iter()
.copied()
.reduce(|a, b| if b > a { b } else { a })
}
pub fn sum_axis(&self, axis: usize) -> crate::Result<Tensor<T>> {
if axis >= self.ndim() {
return Err(CoreError::AxisOutOfBounds {
axis,
ndim: self.ndim(),
});
}
let mut new_shape: Vec<usize> = self.shape.clone();
let axis_len = new_shape.remove(axis);
if new_shape.is_empty() {
return Ok(Tensor::scalar(self.sum()));
}
let new_numel: usize = new_shape.iter().product();
let mut result_data = vec![T::zero(); new_numel];
let outer: usize = self.shape[..axis].iter().product();
let inner: usize = self.shape[axis + 1..].iter().product();
for o in 0..outer {
for k in 0..axis_len {
let src_offset = (o * axis_len + k) * inner;
let dst_offset = o * inner;
for i in 0..inner {
result_data[dst_offset + i] += self.data[src_offset + i];
}
}
}
Tensor::from_vec(result_data, new_shape)
}
}
impl<T: Float> Tensor<T> {
pub fn mean(&self) -> T {
self.sum() / T::from_usize(self.numel())
}
pub fn relu(&self) -> Tensor<T> {
#[cfg(feature = "simd")]
{
use crate::simd;
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
let a = unsafe { simd::slice_as_f64(self.data.as_slice()) };
let mut out = Vec::with_capacity(a.len());
unsafe { out.set_len(a.len()) };
simd::f64_ops::relu_f64(a, &mut out);
let data = unsafe { std::mem::transmute::<Vec<f64>, Vec<T>>(out) };
return Tensor {
data,
shape: self.shape.clone(),
strides: self.strides.clone(),
};
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
let a = unsafe { simd::slice_as_f32(self.data.as_slice()) };
let mut out = Vec::with_capacity(a.len());
unsafe { out.set_len(a.len()) };
simd::f32_ops::relu_f32(a, &mut out);
let data = unsafe { std::mem::transmute::<Vec<f32>, Vec<T>>(out) };
return Tensor {
data,
shape: self.shape.clone(),
strides: self.strides.clone(),
};
}
}
let zero = T::zero();
let data = self.data.iter().map(|&v| if v > zero { v } else { zero }).collect();
Tensor {
data,
shape: self.shape.clone(),
strides: self.strides.clone(),
}
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
#[test]
fn test_add_tensors() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let b = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
let c = a + b;
assert_eq!(c.as_slice(), &[11.0, 22.0, 33.0]);
}
#[test]
fn test_sub_tensors() {
let a = Tensor::from_vec(vec![10.0, 20.0], vec![2]).unwrap();
let b = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
let c = &a - &b;
assert_eq!(c.as_slice(), &[9.0, 18.0]);
}
#[test]
fn test_mul_scalar() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let c = a * 10.0;
assert_eq!(c.as_slice(), &[10.0, 20.0, 30.0]);
}
#[test]
fn test_div_scalar() {
let a = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
let c = &a / 10.0;
assert_eq!(c.as_slice(), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_neg() {
let a = Tensor::from_vec(vec![1.0_f64, -2.0, 3.0], vec![3]).unwrap();
let b = -a;
assert_eq!(b.as_slice(), &[-1.0, 2.0, -3.0]);
}
#[test]
fn test_checked_add_mismatch() {
let a = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
assert!(a.add_checked(&b).is_err());
}
#[test]
fn test_sum() {
let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
assert_eq!(t.sum(), 10);
}
#[test]
fn test_product() {
let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
assert_eq!(t.product(), 24);
}
#[test]
fn test_min_max() {
let t = Tensor::from_vec(vec![3, 1, 4, 1, 5, 9], vec![6]).unwrap();
assert_eq!(t.min_element(), Some(1));
assert_eq!(t.max_element(), Some(9));
}
#[test]
fn test_mean() {
let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
assert_eq!(t.mean(), 2.5);
}
#[test]
fn test_sum_axis() {
let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
let s0 = t.sum_axis(0).unwrap();
assert_eq!(s0.shape(), &[3]);
assert_eq!(s0.as_slice(), &[5, 7, 9]);
let s1 = t.sum_axis(1).unwrap();
assert_eq!(s1.shape(), &[2]);
assert_eq!(s1.as_slice(), &[6, 15]);
}
#[test]
fn test_sum_axis_out_of_bounds() {
let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
assert!(t.sum_axis(1).is_err());
}
#[test]
fn test_add_assign() {
let mut a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let b = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
a += &b;
assert_eq!(a.as_slice(), &[11.0, 22.0, 33.0]);
}
#[test]
fn test_sub_assign() {
let mut a = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
a -= &b;
assert_eq!(a.as_slice(), &[9.0, 18.0, 27.0]);
}
#[test]
fn test_mul_assign() {
let mut a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let b = Tensor::from_vec(vec![10.0, 10.0, 10.0], vec![3]).unwrap();
a *= &b;
assert_eq!(a.as_slice(), &[10.0, 20.0, 30.0]);
}
#[test]
fn test_div_assign() {
let mut a = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
let b = Tensor::from_vec(vec![10.0, 10.0, 10.0], vec![3]).unwrap();
a /= &b;
assert_eq!(a.as_slice(), &[1.0, 2.0, 3.0]);
}
#[test]
#[should_panic(expected = "shape mismatch")]
fn test_add_panics_on_mismatch() {
let a = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let _ = a + b;
}
}