redstone_ml/tensor/
ops.rs1use crate::{Tensor, TensorDataType};
2use std::ops::{Add, Div, Mul, Neg, Sub};
3
4use crate::add_backwards::*;
5use crate::div_backwards::*;
6use crate::mul_backwards::*;
7use crate::neg_backwards::*;
8use crate::none_backwards::*;
9use crate::sub_backwards::*;
10use paste::paste;
11
12impl<T: TensorDataType> Neg for Tensor<'_, T> {
13 type Output = Tensor<'static, T>;
14
15 fn neg(self) -> Self::Output { -&self }
16}
17
18impl<T: TensorDataType> Neg for &Tensor<'_, T> {
19 type Output = Tensor<'static, T>;
20
21 fn neg(self) -> Self::Output {
22 let requires_grad = self.requires_grad();
23 let grad_fn = if requires_grad { NegBackwards::new(self) } else { NoneBackwards::new() };
24
25 unsafe { Tensor::from_raw_parts(-self.array.as_ref(), requires_grad, grad_fn) }
26 }
27}
28
29macro_rules! implement_binary_ops {
30 ($($trait_: ident, $operator:tt, $method: ident, $backwards:ident, $backwards_scalar:ident;)* ) => { $(
31 impl<T: TensorDataType> $trait_<Tensor<'_, T>> for Tensor<'_, T> {
32 type Output = Tensor<'static, T>;
33
34 fn $method(self, rhs: Tensor<T>) -> Self::Output { &self $operator &rhs }
35 }
36
37 impl<T: TensorDataType> $trait_<&Tensor<'_, T>> for Tensor<'_, T> {
38 type Output = Tensor<'static, T>;
39
40 fn $method(self, rhs: &Tensor<T>) -> Self::Output { &self $operator rhs }
41 }
42
43 impl<T: TensorDataType> $trait_<Tensor<'_, T>> for &Tensor<'_, T> {
44 type Output = Tensor<'static, T>;
45
46 fn $method(self, rhs: Tensor<T>) -> Self::Output { self $operator &rhs }
47 }
48
49 impl<T: TensorDataType> $trait_<&Tensor<'_, T>> for &Tensor<'_, T> {
50 type Output = Tensor<'static, T>;
51
52 fn $method(self, rhs: &Tensor<T>) -> Self::Output {
53 let requires_grad = self.requires_grad() || rhs.requires_grad();
54 let grad_fn = if requires_grad { $backwards::new(self, rhs) } else { NoneBackwards::new() };
55
56 unsafe { Tensor::from_raw_parts(self.array.as_ref() $operator rhs.array.as_ref(), requires_grad, grad_fn) }
57 }
58 }
59
60 impl<T: TensorDataType> $trait_<T> for Tensor<'_, T> {
61 type Output = Tensor<'static, T>;
62
63 fn $method(self, rhs: T) -> Self::Output { paste! { &self $operator rhs } }
64 }
65
66 impl<T: TensorDataType> $trait_<T> for &Tensor<'_, T> {
67 type Output = Tensor<'static, T>;
68
69 fn $method(self, rhs: T) -> Self::Output {
70 let requires_grad = self.requires_grad();
71 let grad_fn = if requires_grad { $backwards_scalar::new(self, rhs) } else { NoneBackwards::new() };
72
73 unsafe { Tensor::from_raw_parts(self.array.as_ref() $operator rhs, requires_grad, grad_fn) }
74 }
75 }
76 )*};
77}
78
79implement_binary_ops!(
80 Add, +, add, AddBackwards, AddScalarBackwards;
81 Sub, -, sub, SubBackwards, AddScalarBackwards;
82 Mul, *, mul, MulBackwards, MulScalarBackwards;
83 Div, /, div, DivBackwards, DivScalarBackwards;
84);