redstone_ml/tensor/
ops.rs

1use crate::{Tensor, TensorDataType};
2use std::ops::{Add, Div, Mul, Neg, Sub};
3
4use crate::add_backwards::*;
5use crate::div_backwards::*;
6use crate::mul_backwards::*;
7use crate::neg_backwards::*;
8use crate::none_backwards::*;
9use crate::sub_backwards::*;
10use paste::paste;
11
12impl<T: TensorDataType> Neg for Tensor<'_, T> {
13    type Output = Tensor<'static, T>;
14
15    fn neg(self) -> Self::Output { -&self }
16}
17
18impl<T: TensorDataType> Neg for &Tensor<'_, T> {
19    type Output = Tensor<'static, T>;
20
21    fn neg(self) -> Self::Output {
22        let requires_grad = self.requires_grad();
23        let grad_fn = if requires_grad { NegBackwards::new(self) } else { NoneBackwards::new() };
24
25        unsafe { Tensor::from_raw_parts(-self.array.as_ref(), requires_grad, grad_fn) }
26    }
27}
28
29macro_rules! implement_binary_ops {
30    ($($trait_: ident, $operator:tt, $method: ident, $backwards:ident, $backwards_scalar:ident;)* ) => { $(
31        impl<T: TensorDataType> $trait_<Tensor<'_, T>> for Tensor<'_, T> {
32            type Output = Tensor<'static, T>;
33
34            fn $method(self, rhs: Tensor<T>) -> Self::Output { &self $operator &rhs }
35        }
36
37        impl<T: TensorDataType> $trait_<&Tensor<'_, T>> for Tensor<'_, T> {
38            type Output = Tensor<'static, T>;
39
40            fn $method(self, rhs: &Tensor<T>) -> Self::Output { &self $operator rhs }
41        }
42        
43        impl<T: TensorDataType> $trait_<Tensor<'_, T>> for &Tensor<'_, T> {
44            type Output = Tensor<'static, T>;
45
46            fn $method(self, rhs: Tensor<T>) -> Self::Output { self $operator &rhs }
47        }
48
49        impl<T: TensorDataType> $trait_<&Tensor<'_, T>> for &Tensor<'_, T> {
50            type Output = Tensor<'static, T>;
51
52            fn $method(self, rhs: &Tensor<T>) -> Self::Output {
53                let requires_grad = self.requires_grad() || rhs.requires_grad();
54                let grad_fn = if requires_grad { $backwards::new(self, rhs) } else { NoneBackwards::new() };
55
56                unsafe { Tensor::from_raw_parts(self.array.as_ref() $operator rhs.array.as_ref(), requires_grad, grad_fn) }
57            }
58        }
59        
60        impl<T: TensorDataType> $trait_<T> for Tensor<'_, T> {
61            type Output = Tensor<'static, T>;
62
63            fn $method(self, rhs: T) -> Self::Output { paste! { &self $operator rhs } }
64        }
65
66        impl<T: TensorDataType> $trait_<T> for &Tensor<'_, T> {
67            type Output = Tensor<'static, T>;
68
69            fn $method(self, rhs: T) -> Self::Output {
70                let requires_grad = self.requires_grad();
71                let grad_fn = if requires_grad { $backwards_scalar::new(self, rhs) } else { NoneBackwards::new() };
72
73                unsafe { Tensor::from_raw_parts(self.array.as_ref() $operator rhs, requires_grad, grad_fn) }
74            }
75        }
76    )*};
77}
78
79implement_binary_ops!(
80    Add, +, add, AddBackwards, AddScalarBackwards;
81    Sub, -, sub, SubBackwards, AddScalarBackwards;
82    Mul, *, mul, MulBackwards, MulScalarBackwards;
83    Div, /, div, DivBackwards, DivScalarBackwards;
84);