entrenar/autograd/
tensor.rs1use super::BackwardOp;
4use ndarray::Array1;
5use std::cell::RefCell;
6use std::rc::Rc;
7
8#[derive(Clone)]
19pub struct Tensor {
20 data: Rc<Array1<f32>>,
21 grad: Rc<RefCell<Option<Array1<f32>>>>,
22 backward_op: Option<Rc<dyn BackwardOp>>,
23 requires_grad: bool,
24}
25
26impl Tensor {
27 pub fn new(data: Array1<f32>, requires_grad: bool) -> Self {
29 Self {
30 data: Rc::new(data),
31 grad: Rc::new(RefCell::new(None)),
32 backward_op: None,
33 requires_grad,
34 }
35 }
36
37 pub fn from_vec(data: Vec<f32>, requires_grad: bool) -> Self {
39 Self::new(Array1::from(data), requires_grad)
40 }
41
42 pub fn zeros(size: usize, requires_grad: bool) -> Self {
44 Self::new(Array1::zeros(size), requires_grad)
45 }
46
47 pub fn ones(size: usize, requires_grad: bool) -> Self {
49 Self::new(Array1::ones(size), requires_grad)
50 }
51
52 pub fn data(&self) -> &Array1<f32> {
54 contract_pre_data_read!();
55 &self.data
56 }
57
58 pub fn data_mut(&mut self) -> &mut Array1<f32> {
65 contract_pre_data_mut!();
66 Rc::make_mut(&mut self.data)
67 }
68
69 pub fn grad(&self) -> Option<Array1<f32>> {
71 self.grad.borrow().clone()
72 }
73
74 pub fn set_grad(&self, grad: Array1<f32>) {
76 *self.grad.borrow_mut() = Some(grad);
77 }
78
79 pub fn accumulate_grad(&self, grad: Array1<f32>) {
81 let mut grad_ref = self.grad.borrow_mut();
82 if let Some(existing) = grad_ref.as_mut() {
83 *existing = &*existing + &grad;
84 } else {
85 *grad_ref = Some(grad);
86 }
87 }
88
89 pub fn zero_grad(&self) {
91 *self.grad.borrow_mut() = None;
92 }
93
94 pub fn scale_grad(&self, factor: f32) {
102 let mut grad_ref = self.grad.borrow_mut();
103 if let Some(existing) = grad_ref.as_mut() {
104 existing.mapv_inplace(|v| v * factor);
105 }
106 }
107
108 pub fn requires_grad(&self) -> bool {
110 self.requires_grad
111 }
112
113 pub fn set_requires_grad(&mut self, requires_grad: bool) {
117 self.requires_grad = requires_grad;
118 }
119
120 pub fn grad_cell(&self) -> Rc<RefCell<Option<Array1<f32>>>> {
122 self.grad.clone()
123 }
124
125 pub fn set_backward_op(&mut self, op: Rc<dyn BackwardOp>) {
127 self.backward_op = Some(op);
128 }
129
130 pub fn backward_op(&self) -> Option<Rc<dyn BackwardOp>> {
132 self.backward_op.clone()
133 }
134
135 pub fn len(&self) -> usize {
137 self.data.len()
138 }
139
140 pub fn is_empty(&self) -> bool {
142 self.data.is_empty()
143 }
144}
145
146impl std::fmt::Debug for Tensor {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 f.debug_struct("Tensor")
149 .field("data", &self.data)
150 .field("grad", &self.grad.borrow())
151 .field("requires_grad", &self.requires_grad)
152 .finish_non_exhaustive()
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn test_tensor_new() {
162 let data = Array1::from(vec![1.0, 2.0, 3.0]);
163 let t = Tensor::new(data.clone(), true);
164 assert_eq!(t.data(), &data);
165 assert!(t.requires_grad());
166 }
167
168 #[test]
169 fn test_tensor_from_vec() {
170 let t = Tensor::from_vec(vec![1.0, 2.0], false);
171 assert_eq!(t.len(), 2);
172 assert!(!t.requires_grad());
173 }
174
175 #[test]
176 fn test_tensor_zeros() {
177 let t = Tensor::zeros(5, true);
178 assert_eq!(t.len(), 5);
179 assert!(t.data().iter().all(|&x| x == 0.0));
180 }
181
182 #[test]
183 fn test_tensor_ones() {
184 let t = Tensor::ones(3, false);
185 assert_eq!(t.len(), 3);
186 assert!(t.data().iter().all(|&x| x == 1.0));
187 }
188
189 #[test]
190 fn test_tensor_data_mut() {
191 let mut t = Tensor::from_vec(vec![1.0, 2.0], true);
192 t.data_mut()[0] = 5.0;
193 assert_eq!(t.data()[0], 5.0);
194 }
195
196 #[test]
197 fn test_tensor_grad_operations() {
198 let t = Tensor::from_vec(vec![1.0, 2.0], true);
199 assert!(t.grad().is_none());
200
201 t.set_grad(Array1::from(vec![0.1, 0.2]));
202 assert!(t.grad().is_some());
203 assert_eq!(t.grad().expect("gradient should be available")[0], 0.1);
204
205 t.zero_grad();
206 assert!(t.grad().is_none());
207 }
208
209 #[test]
210 fn test_tensor_accumulate_grad() {
211 let t = Tensor::from_vec(vec![1.0, 2.0], true);
212
213 t.accumulate_grad(Array1::from(vec![0.1, 0.2]));
215 assert_eq!(t.grad().expect("gradient should be available")[0], 0.1);
216
217 t.accumulate_grad(Array1::from(vec![0.3, 0.4]));
219 let grad = t.grad().expect("gradient should be available");
220 assert!((grad[0] - 0.4).abs() < 1e-6);
221 assert!((grad[1] - 0.6).abs() < 1e-6);
222 }
223
224 #[test]
225 fn test_tensor_grad_cell() {
226 let t = Tensor::from_vec(vec![1.0], true);
227 let cell = t.grad_cell();
228 assert!(cell.borrow().is_none());
229 }
230
231 #[test]
232 fn test_tensor_backward_op() {
233 let t = Tensor::from_vec(vec![1.0], true);
234 assert!(t.backward_op().is_none());
235 }
237
238 #[test]
239 fn test_tensor_is_empty() {
240 let t = Tensor::from_vec(vec![], false);
241 assert!(t.is_empty());
242
243 let t2 = Tensor::from_vec(vec![1.0], false);
244 assert!(!t2.is_empty());
245 }
246
247 #[test]
248 fn test_tensor_set_requires_grad() {
249 let mut t = Tensor::from_vec(vec![1.0, 2.0], false);
250 assert!(!t.requires_grad());
251
252 t.set_requires_grad(true);
253 assert!(t.requires_grad());
254
255 t.set_requires_grad(false);
256 assert!(!t.requires_grad());
257 }
258
259 #[test]
260 fn test_tensor_debug() {
261 let t = Tensor::from_vec(vec![1.0, 2.0], true);
262 let debug_str = format!("{t:?}");
263 assert!(debug_str.contains("Tensor"));
264 assert!(debug_str.contains("data"));
265 }
266
267 #[test]
268 fn test_tensor_clone() {
269 let t1 = Tensor::from_vec(vec![1.0, 2.0], true);
270 t1.set_grad(Array1::from(vec![0.1, 0.2]));
271 let t2 = t1.clone();
272
273 assert_eq!(t2.data(), t1.data());
274 assert_eq!(t2.requires_grad(), t1.requires_grad());
275 }
276}