use super::super::core::Tensor;
use crate::error::{RusTorchError, RusTorchResult};
use ndarray::ArrayD;
use num_traits::Float;
use std::ops::{Add, Div, Mul, Neg, Sub};
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Tensor<T> {
pub fn add(&self, other: &Tensor<T>) -> RusTorchResult<Self>
where
T: ndarray::ScalarOperand + Copy,
{
if self.shape() == other.shape() {
let result_data: Vec<T> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&a, &b)| a + b)
.collect();
return Ok(Tensor::from_vec(result_data, self.shape().to_vec()));
}
if !self.can_broadcast_with(other) {
return Err(RusTorchError::shape_mismatch(self.shape(), other.shape()));
}
let (broadcasted_self, broadcasted_other) = self.broadcast_with(other)?;
let result_data: Vec<T> = broadcasted_self
.data
.iter()
.zip(broadcasted_other.data.iter())
.map(|(&a, &b)| a + b)
.collect();
Ok(Tensor::from_vec(
result_data,
broadcasted_self.shape().to_vec(),
))
}
pub fn sub(&self, other: &Tensor<T>) -> RusTorchResult<Self>
where
T: ndarray::ScalarOperand + Copy,
{
if self.shape() == other.shape() {
let result_data: Vec<T> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&a, &b)| a - b)
.collect();
return Ok(Tensor::from_vec(result_data, self.shape().to_vec()));
}
if !self.can_broadcast_with(other) {
return Err(RusTorchError::shape_mismatch(self.shape(), other.shape()));
}
let (broadcasted_self, broadcasted_other) = self.broadcast_with(other)?;
let result_data: Vec<T> = broadcasted_self
.data
.iter()
.zip(broadcasted_other.data.iter())
.map(|(&a, &b)| a - b)
.collect();
Ok(Tensor::from_vec(
result_data,
broadcasted_self.shape().to_vec(),
))
}
pub fn mul(&self, other: &Tensor<T>) -> RusTorchResult<Self>
where
T: ndarray::ScalarOperand + Copy,
{
if self.shape() == other.shape() {
let result_data: Vec<T> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&a, &b)| a * b)
.collect();
return Ok(Tensor::from_vec(result_data, self.shape().to_vec()));
}
if !self.can_broadcast_with(other) {
return Err(RusTorchError::shape_mismatch(self.shape(), other.shape()));
}
let (broadcasted_self, broadcasted_other) = self.broadcast_with(other)?;
let result_data: Vec<T> = broadcasted_self
.data
.iter()
.zip(broadcasted_other.data.iter())
.map(|(&a, &b)| a * b)
.collect();
Ok(Tensor::from_vec(
result_data,
broadcasted_self.shape().to_vec(),
))
}
pub fn div(&self, other: &Tensor<T>) -> RusTorchResult<Self>
where
T: ndarray::ScalarOperand + Copy,
{
if self.shape() == other.shape() {
let result_data: Vec<T> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&a, &b)| {
if b == T::zero() {
T::infinity() } else {
a / b
}
})
.collect();
return Ok(Tensor::from_vec(result_data, self.shape().to_vec()));
}
if !self.can_broadcast_with(other) {
return Err(RusTorchError::shape_mismatch(self.shape(), other.shape()));
}
let (broadcasted_self, broadcasted_other) = self.broadcast_with(other)?;
let result_data: Vec<T> = broadcasted_self
.data
.iter()
.zip(broadcasted_other.data.iter())
.map(
|(&a, &b)| {
if b == T::zero() {
T::infinity()
} else {
a / b
}
},
)
.collect();
Ok(Tensor::from_vec(
result_data,
broadcasted_self.shape().to_vec(),
))
}
pub fn add_scalar(&self, scalar: T) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x + scalar).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn sub_scalar(&self, scalar: T) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x - scalar).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn mul_scalar(&self, scalar: T) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x * scalar).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn div_scalar(&self, scalar: T) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| x / scalar).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn neg(&self) -> Self {
let result_data: Vec<T> = self.data.iter().map(|&x| -x).collect();
Tensor::from_vec(result_data, self.shape().to_vec())
}
pub fn maximum(&self, other: &Tensor<T>) -> RusTorchResult<Self>
where
T: ndarray::ScalarOperand + Copy,
{
if self.shape() == other.shape() {
let result_data: Vec<T> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&a, &b)| if a > b { a } else { b })
.collect();
return Ok(Tensor::from_vec(result_data, self.shape().to_vec()));
}
if !self.can_broadcast_with(other) {
return Err(RusTorchError::shape_mismatch(self.shape(), other.shape()));
}
let (broadcasted_self, broadcasted_other) = self.broadcast_with(other)?;
let result_data: Vec<T> = broadcasted_self
.data
.iter()
.zip(broadcasted_other.data.iter())
.map(|(&a, &b)| if a > b { a } else { b })
.collect();
Ok(Tensor::from_vec(
result_data,
broadcasted_self.shape().to_vec(),
))
}
pub fn minimum(&self, other: &Tensor<T>) -> RusTorchResult<Self>
where
T: ndarray::ScalarOperand + Copy,
{
if self.shape() == other.shape() {
let result_data: Vec<T> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&a, &b)| if a < b { a } else { b })
.collect();
return Ok(Tensor::from_vec(result_data, self.shape().to_vec()));
}
if !self.can_broadcast_with(other) {
return Err(RusTorchError::shape_mismatch(self.shape(), other.shape()));
}
let (broadcasted_self, broadcasted_other) = self.broadcast_with(other)?;
let result_data: Vec<T> = broadcasted_self
.data
.iter()
.zip(broadcasted_other.data.iter())
.map(|(&a, &b)| if a < b { a } else { b })
.collect();
Ok(Tensor::from_vec(
result_data,
broadcasted_self.shape().to_vec(),
))
}
pub fn sum(&self) -> T {
self.data.iter().fold(T::zero(), |acc, &x| acc + x)
}
pub fn mean(&self) -> T {
let sum = self.sum();
let count = T::from(self.data.len()).unwrap_or(T::one());
sum / count
}
pub fn sum_axis(&self, axis: usize) -> RusTorchResult<Tensor<T>> {
use ndarray::Axis;
let result_array = self.data.sum_axis(Axis(axis));
Ok(Tensor::new(result_array))
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Add for &Tensor<T> {
type Output = Tensor<T>;
fn add(self, other: Self) -> Self::Output {
Tensor::add(self, other).expect("Addition failed")
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Sub for &Tensor<T> {
type Output = Tensor<T>;
fn sub(self, other: Self) -> Self::Output {
Tensor::sub(self, other).expect("Subtraction failed")
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Mul for &Tensor<T> {
type Output = Tensor<T>;
fn mul(self, other: Self) -> Self::Output {
Tensor::mul(self, other).expect("Multiplication failed")
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Div for &Tensor<T> {
type Output = Tensor<T>;
fn div(self, other: Self) -> Self::Output {
Tensor::div(self, other).expect("Division failed")
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Mul<T>
for &Tensor<T>
{
type Output = Tensor<T>;
fn mul(self, scalar: T) -> Self::Output {
self.mul_scalar(scalar)
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Div<T>
for &Tensor<T>
{
type Output = Tensor<T>;
fn div(self, scalar: T) -> Self::Output {
self.div_scalar(scalar)
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Add<T>
for &Tensor<T>
{
type Output = Tensor<T>;
fn add(self, scalar: T) -> Self::Output {
self.add_scalar(scalar)
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Sub<T>
for &Tensor<T>
{
type Output = Tensor<T>;
fn sub(self, scalar: T) -> Self::Output {
self.sub_scalar(scalar)
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Neg for &Tensor<T> {
type Output = Tensor<T>;
fn neg(self) -> Self::Output {
Tensor::neg(self)
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Add for Tensor<T> {
type Output = Tensor<T>;
fn add(self, other: Self) -> Self::Output {
&self + &other
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Sub for Tensor<T> {
type Output = Tensor<T>;
fn sub(self, other: Self) -> Self::Output {
&self - &other
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Mul for Tensor<T> {
type Output = Tensor<T>;
fn mul(self, other: Self) -> Self::Output {
&self * &other
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Div for Tensor<T> {
type Output = Tensor<T>;
fn div(self, other: Self) -> Self::Output {
&self / &other
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Mul<T> for Tensor<T> {
type Output = Tensor<T>;
fn mul(self, scalar: T) -> Self::Output {
&self * scalar
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Div<T> for Tensor<T> {
type Output = Tensor<T>;
fn div(self, scalar: T) -> Self::Output {
&self / scalar
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Add<T> for Tensor<T> {
type Output = Tensor<T>;
fn add(self, scalar: T) -> Self::Output {
&self + scalar
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Sub<T> for Tensor<T> {
type Output = Tensor<T>;
fn sub(self, scalar: T) -> Self::Output {
&self - scalar
}
}
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Neg for Tensor<T> {
type Output = Tensor<T>;
fn neg(self) -> Self::Output {
-&self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let b = Tensor::from_vec(vec![4.0, 5.0, 6.0], vec![3]);
let result = &a + &b;
assert_eq!(result.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
}
#[test]
fn test_add_scalar() {
let a = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let result = a.add_scalar(10.0);
assert_eq!(result.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
}
#[test]
fn test_shape_mismatch() {
let a = Tensor::from_vec(vec![1.0, 2.0], vec![2]);
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let result = (&a).add(&b);
assert!(result.is_err());
}
}