redstone_ml/tensor/autograd.rs
1use crate::accumulate_grad::AccumulateGrad;
2use crate::gradient_function::GradientFunction;
3use crate::ndarray::flags::NdArrayFlags;
4use crate::{Constructors, NdArray, StridedMemory, Tensor, TensorDataType};
5use crate::none_backwards::NoneBackwards;
6
7impl<'a, T: TensorDataType> Tensor<'a, T> {
8 /// Checks if the tensor is a leaf.
9 ///
10 /// A tensor is considered a leaf node if `requires_grad = true`
11 /// and it was explicitly created by the user, or if `requires_grad = false`.
12 ///
13 /// # Examples
14 ///
15 /// ```
16 /// # use redstone_ml::*;
17 ///
18 /// let mut tensor = Tensor::new([1.0, 2.0, 3.0]);
19 /// tensor.set_requires_grad(true);
20 /// assert!(tensor.is_leaf());
21 ///
22 /// let tensor2 = -tensor;
23 /// assert!(!tensor2.is_leaf());
24 /// ```
25 #[inline]
26 pub fn is_leaf(&self) -> bool {
27 if self.requires_grad() {
28 self.flags.contains(NdArrayFlags::UserCreated)
29 } else {
30 true
31 }
32 }
33
34 /// Returns whether gradients must be computed for this tensor.
35 ///
36 /// A tensor is marked with the `requires_grad` flag if it was explicitly specified by the user
37 /// through the `set_requires_grad()` method or if the tensor was created using operations
38 /// on other tensors which were marked `requires_grad`.
39 ///
40 /// # Examples
41 ///
42 /// ```
43 /// # use redstone_ml::*;
44 ///
45 /// let mut tensor = Tensor::new([1.0, 2.0, 3.0]);
46 /// tensor.set_requires_grad(true);
47 ///
48 /// let tensor2 = -tensor;
49 /// assert!(tensor2.requires_grad());
50 /// ```
51 #[inline]
52 pub fn requires_grad(&self) -> bool {
53 self.flags.contains(NdArrayFlags::RequiresGrad)
54 }
55
56 /// Sets whether gradients must be computed for this tensor.
57 pub fn set_requires_grad(&mut self, requires_grad: bool) -> &mut Self {
58 let required_grad = self.requires_grad();
59
60 if requires_grad {
61 self.flags |= NdArrayFlags::RequiresGrad;
62 } else {
63 self.flags -= NdArrayFlags::RequiresGrad;
64 }
65
66 if !required_grad && requires_grad {
67 self.grad_fn = AccumulateGrad::new(self.shape().to_vec());
68 }
69 if required_grad && !requires_grad {
70 self.grad_fn = NoneBackwards::new();
71 }
72
73 self
74 }
75
76 /// Retrieves the gradient function associated with the current object.
77 ///
78 /// This is `NoneBackwards` if the tensor has `requires_grad = false`
79 /// or `AccumulateBackwards` if the tensor is a leaf node.
80 pub(crate) fn grad_fn(&self) -> GradientFunction<T> {
81 self.grad_fn.clone()
82 }
83
84 /// Returns the gradient of the differentiated tensor with respect to `self`.
85 ///
86 /// This method returns a view into the gradient.
87 ///
88 /// # Examples
89 ///
90 /// ```
91 /// # use redstone_ml::*;
92 ///
93 /// let mut a = Tensor::scalar(2.0f32);
94 /// let b = Tensor::scalar(3.0);
95 ///
96 /// a.set_requires_grad(true);
97 ///
98 /// let c = &a * &b;
99 /// c.backward();
100 ///
101 /// // dc/da = b
102 /// assert_eq!(a.gradient().unwrap(), b);
103 /// ```
104 pub fn gradient(&'a self) -> Option<NdArray<'a, T>> {
105 unsafe { (*self.grad_fn.as_ptr()).gradient() }
106 }
107
108 /// Sets the gradient of this tensor to zero.
109 ///
110 /// # Examples
111 ///
112 /// ```
113 /// # use redstone_ml::*;
114 ///
115 /// let mut a = Tensor::scalar(2.0f32);
116 /// let b = Tensor::scalar(3.0);
117 ///
118 /// a.set_requires_grad(true);
119 ///
120 /// let c = &a * &b;
121 /// c.backward();
122 ///
123 /// a.zero_gradient();
124 /// assert_eq!(a.gradient().unwrap(), Tensor::scalar(0.0));
125 /// ```
126 pub fn zero_gradient(&self) {
127 self.grad_fn.borrow_mut().zero_gradient();
128 }
129
130 /// Computes the gradient of the `self` with respect to its leaf tensors.
131 ///
132 /// # Parameters
133 ///
134 /// - `gradient`: the gradient of the tensor being differentiated with respect to `self`.
135 ///
136 /// # Examples
137 ///
138 /// ```
139 /// # use redstone_ml::*;
140 ///
141 /// let mut a = Tensor::full(2.0, [3]); // [2, 2, 2]
142 /// let b = Tensor::new([3.0, 1.0, -1.0]);
143 ///
144 /// a.set_requires_grad(true);
145 ///
146 /// let c = &a * &b;
147 /// c.backward_with(NdArray::new([2.0, 1.0, 1.0]));
148 ///
149 /// // dc/da = b
150 /// assert_eq!(a.gradient().unwrap(), Tensor::new([6.0, 1.0, -1.0]));
151 /// ```
152 pub fn backward_with(&self, gradient: impl AsRef<NdArray<'a, T>>) {
153 let gradient = gradient.as_ref();
154 assert_eq!(gradient.shape(), self.shape());
155
156 self.grad_fn.borrow_mut().backward(gradient);
157 }
158
159 /// Computes the gradient of the `self` with respect to its leaf tensors.
160 ///
161 /// # Examples
162 ///
163 /// ```
164 /// # use redstone_ml::*;
165 ///
166 /// let mut a = Tensor::full(2.0, [3]); // [2, 2, 2]
167 /// let b = Tensor::new([3.0, 1.0, -1.0]);
168 ///
169 /// a.set_requires_grad(true);
170 ///
171 /// let c = &a * &b;
172 /// c.backward();
173 ///
174 /// // dc/da = b
175 /// assert_eq!(a.gradient().unwrap(), Tensor::new([3.0, 1.0, -1.0]));
176 /// ```
177 pub fn backward(&self) {
178 self.backward_with(NdArray::ones(self.shape()))
179 }
180
181 /// Detaches the tensor from the computation graph and returns an `NdArray`.
182 ///
183 /// # Examples
184 ///
185 /// ```
186 /// # use redstone_ml::*;
187 ///
188 /// let mut a = Tensor::full(2.0, [3]);
189 /// a.set_requires_grad(true);
190 ///
191 /// let c = &a * 5.0;
192 ///
193 /// let d = c.detach();
194 /// assert_eq!(d, NdArray::new([10.0, 10.0, 10.0]));
195 /// ```
196 pub fn detach(&self) -> NdArray<'static, T> {
197 self.array.as_ref().clone()
198 }
199}