redstone_ml/tensor/
autograd.rs

1use crate::accumulate_grad::AccumulateGrad;
2use crate::gradient_function::GradientFunction;
3use crate::ndarray::flags::NdArrayFlags;
4use crate::{Constructors, NdArray, StridedMemory, Tensor, TensorDataType};
5use crate::none_backwards::NoneBackwards;
6
7impl<'a, T: TensorDataType> Tensor<'a, T> {
8    /// Checks if the tensor is a leaf.
9    ///
10    /// A tensor is considered a leaf node if `requires_grad = true`
11    /// and it was explicitly created by the user, or if `requires_grad = false`.
12    ///
13    /// # Examples
14    ///
15    /// ```
16    /// # use redstone_ml::*;
17    ///
18    /// let mut tensor = Tensor::new([1.0, 2.0, 3.0]);
19    /// tensor.set_requires_grad(true);
20    /// assert!(tensor.is_leaf());
21    ///
22    /// let tensor2 = -tensor;
23    /// assert!(!tensor2.is_leaf());
24    /// ```
25    #[inline]
26    pub fn is_leaf(&self) -> bool {
27        if self.requires_grad() {
28            self.flags.contains(NdArrayFlags::UserCreated)
29        } else {
30            true
31        }
32    }
33
34    /// Returns whether gradients must be computed for this tensor.
35    ///
36    /// A tensor is marked with the `requires_grad` flag if it was explicitly specified by the user
37    /// through the `set_requires_grad()` method or if the tensor was created using operations
38    /// on other tensors which were marked `requires_grad`.
39    ///
40    /// # Examples
41    ///
42    /// ```
43    /// # use redstone_ml::*;
44    ///
45    /// let mut tensor = Tensor::new([1.0, 2.0, 3.0]);
46    /// tensor.set_requires_grad(true);
47    ///
48    /// let tensor2 = -tensor;
49    /// assert!(tensor2.requires_grad());
50    /// ```
51    #[inline]
52    pub fn requires_grad(&self) -> bool {
53        self.flags.contains(NdArrayFlags::RequiresGrad)
54    }
55
56    /// Sets whether gradients must be computed for this tensor.
57    pub fn set_requires_grad(&mut self, requires_grad: bool) -> &mut Self {
58        let required_grad = self.requires_grad();
59
60        if requires_grad {
61            self.flags |= NdArrayFlags::RequiresGrad;
62        } else {
63            self.flags -= NdArrayFlags::RequiresGrad;
64        }
65
66        if !required_grad && requires_grad {
67            self.grad_fn = AccumulateGrad::new(self.shape().to_vec());
68        }
69        if required_grad && !requires_grad {
70            self.grad_fn = NoneBackwards::new();
71        }
72
73        self
74    }
75
76    /// Retrieves the gradient function associated with the current object.
77    ///
78    /// This is `NoneBackwards` if the tensor has `requires_grad = false`
79    /// or `AccumulateBackwards` if the tensor is a leaf node.
80    pub(crate) fn grad_fn(&self) -> GradientFunction<T> {
81        self.grad_fn.clone()
82    }
83
84    /// Returns the gradient of the differentiated tensor with respect to `self`.
85    ///
86    /// This method returns a view into the gradient.
87    ///
88    /// # Examples
89    ///
90    /// ```
91    /// # use redstone_ml::*;
92    ///
93    /// let mut a = Tensor::scalar(2.0f32);
94    /// let b = Tensor::scalar(3.0);
95    ///
96    /// a.set_requires_grad(true);
97    ///
98    /// let c = &a * &b;
99    /// c.backward();
100    ///
101    /// // dc/da = b
102    /// assert_eq!(a.gradient().unwrap(), b);
103    /// ```
104    pub fn gradient(&'a self) -> Option<NdArray<'a, T>> {
105        unsafe { (*self.grad_fn.as_ptr()).gradient() }
106    }
107
108    /// Sets the gradient of this tensor to zero.
109    ///
110    /// # Examples
111    ///
112    /// ```
113    /// # use redstone_ml::*;
114    ///
115    /// let mut a = Tensor::scalar(2.0f32);
116    /// let b = Tensor::scalar(3.0);
117    ///
118    /// a.set_requires_grad(true);
119    ///
120    /// let c = &a * &b;
121    /// c.backward();
122    ///
123    /// a.zero_gradient();
124    /// assert_eq!(a.gradient().unwrap(), Tensor::scalar(0.0));
125    /// ```
126    pub fn zero_gradient(&self) {
127        self.grad_fn.borrow_mut().zero_gradient();
128    }
129
130    /// Computes the gradient of the `self` with respect to its leaf tensors.
131    ///
132    /// # Parameters
133    ///
134    /// - `gradient`: the gradient of the tensor being differentiated with respect to `self`.
135    ///
136    /// # Examples
137    ///
138    /// ```
139    /// # use redstone_ml::*;
140    ///
141    /// let mut a = Tensor::full(2.0, [3]);  // [2, 2, 2]
142    /// let b = Tensor::new([3.0, 1.0, -1.0]);
143    ///
144    /// a.set_requires_grad(true);
145    ///
146    /// let c = &a * &b;
147    /// c.backward_with(NdArray::new([2.0, 1.0, 1.0]));
148    ///
149    /// // dc/da = b
150    /// assert_eq!(a.gradient().unwrap(), Tensor::new([6.0, 1.0, -1.0]));
151    /// ```
152    pub fn backward_with(&self, gradient: impl AsRef<NdArray<'a, T>>) {
153        let gradient = gradient.as_ref();
154        assert_eq!(gradient.shape(), self.shape());
155
156        self.grad_fn.borrow_mut().backward(gradient);
157    }
158
159    /// Computes the gradient of the `self` with respect to its leaf tensors.
160    ///
161    /// # Examples
162    ///
163    /// ```
164    /// # use redstone_ml::*;
165    ///
166    /// let mut a = Tensor::full(2.0, [3]);  // [2, 2, 2]
167    /// let b = Tensor::new([3.0, 1.0, -1.0]);
168    ///
169    /// a.set_requires_grad(true);
170    ///
171    /// let c = &a * &b;
172    /// c.backward();
173    ///
174    /// // dc/da = b
175    /// assert_eq!(a.gradient().unwrap(), Tensor::new([3.0, 1.0, -1.0]));
176    /// ```
177    pub fn backward(&self) {
178        self.backward_with(NdArray::ones(self.shape()))
179    }
180
181    /// Detaches the tensor from the computation graph and returns an `NdArray`.
182    ///
183    /// # Examples
184    ///
185    /// ```
186    /// # use redstone_ml::*;
187    ///
188    /// let mut a = Tensor::full(2.0, [3]);
189    /// a.set_requires_grad(true);
190    ///
191    /// let c = &a * 5.0;
192    ///
193    /// let d = c.detach();
194    /// assert_eq!(d, NdArray::new([10.0, 10.0, 10.0]));
195    /// ```
196    pub fn detach(&self) -> NdArray<'static, T> {
197        self.array.as_ref().clone()
198    }
199}