Skip to main content

entrenar/autograd/
tensor.rs

1//! Tensor type with gradient tracking
2
3use super::BackwardOp;
4use ndarray::Array1;
5use std::cell::RefCell;
6use std::rc::Rc;
7
8/// Tensor with automatic differentiation support
9///
10/// Data is stored behind `Rc` for O(1) clone.  Backward ops clone input
11/// tensors to hold references for gradient computation — with `Rc`, this
12/// is a reference-count bump instead of a full memcpy.  For Qwen3-4B with
13/// batch_size=8, this eliminates ~16 GB of redundant frozen-weight copies
14/// per training step (KAIZEN-019).
15///
16/// `data_mut()` uses `Rc::make_mut` for copy-on-write: mutation only copies
17/// when there are multiple Rc references to the same data.
18#[derive(Clone)]
19pub struct Tensor {
20    data: Rc<Array1<f32>>,
21    grad: Rc<RefCell<Option<Array1<f32>>>>,
22    backward_op: Option<Rc<dyn BackwardOp>>,
23    requires_grad: bool,
24}
25
26impl Tensor {
27    /// Create a new tensor with data
28    pub fn new(data: Array1<f32>, requires_grad: bool) -> Self {
29        Self {
30            data: Rc::new(data),
31            grad: Rc::new(RefCell::new(None)),
32            backward_op: None,
33            requires_grad,
34        }
35    }
36
37    /// Create a tensor from a vector
38    pub fn from_vec(data: Vec<f32>, requires_grad: bool) -> Self {
39        Self::new(Array1::from(data), requires_grad)
40    }
41
42    /// Create a tensor filled with zeros
43    pub fn zeros(size: usize, requires_grad: bool) -> Self {
44        Self::new(Array1::zeros(size), requires_grad)
45    }
46
47    /// Create a tensor filled with ones
48    pub fn ones(size: usize, requires_grad: bool) -> Self {
49        Self::new(Array1::ones(size), requires_grad)
50    }
51
52    /// Get reference to data
53    pub fn data(&self) -> &Array1<f32> {
54        contract_pre_data_read!();
55        &self.data
56    }
57
58    /// Get mutable reference to data (copy-on-write via `Rc::make_mut`)
59    ///
60    /// If this is the only `Rc` reference, returns a mutable reference to
61    /// the existing data with no copy.  If there are other references
62    /// (e.g. backward ops holding clones), clones the data first so
63    /// mutations don't affect other holders.
64    pub fn data_mut(&mut self) -> &mut Array1<f32> {
65        contract_pre_data_mut!();
66        Rc::make_mut(&mut self.data)
67    }
68
69    /// Get gradient (if computed)
70    pub fn grad(&self) -> Option<Array1<f32>> {
71        self.grad.borrow().clone()
72    }
73
74    /// Set gradient
75    pub fn set_grad(&self, grad: Array1<f32>) {
76        *self.grad.borrow_mut() = Some(grad);
77    }
78
79    /// Accumulate gradient (for when tensor is used multiple times)
80    pub fn accumulate_grad(&self, grad: Array1<f32>) {
81        let mut grad_ref = self.grad.borrow_mut();
82        if let Some(existing) = grad_ref.as_mut() {
83            *existing = &*existing + &grad;
84        } else {
85            *grad_ref = Some(grad);
86        }
87    }
88
89    /// Zero out gradient
90    pub fn zero_grad(&self) {
91        *self.grad.borrow_mut() = None;
92    }
93
94    /// Scale gradient in-place (KAIZEN-037: zero-allocation gradient scaling)
95    ///
96    /// # Contract (C-SCALE-GRAD-001)
97    ///
98    /// - **Precondition**: factor is finite
99    /// - **Postcondition**: grad[i] *= factor for all i
100    /// - **Invariant**: No heap allocation (mutates existing Array1 in-place)
101    pub fn scale_grad(&self, factor: f32) {
102        let mut grad_ref = self.grad.borrow_mut();
103        if let Some(existing) = grad_ref.as_mut() {
104            existing.mapv_inplace(|v| v * factor);
105        }
106    }
107
108    /// Check if requires gradient
109    pub fn requires_grad(&self) -> bool {
110        self.requires_grad
111    }
112
113    /// Set requires gradient flag
114    ///
115    /// Use this to enable gradient tracking on loaded tensors before training.
116    pub fn set_requires_grad(&mut self, requires_grad: bool) {
117        self.requires_grad = requires_grad;
118    }
119
120    /// Get reference to gradient cell (for backward operations)
121    pub fn grad_cell(&self) -> Rc<RefCell<Option<Array1<f32>>>> {
122        self.grad.clone()
123    }
124
125    /// Set backward operation
126    pub fn set_backward_op(&mut self, op: Rc<dyn BackwardOp>) {
127        self.backward_op = Some(op);
128    }
129
130    /// Get backward operation
131    pub fn backward_op(&self) -> Option<Rc<dyn BackwardOp>> {
132        self.backward_op.clone()
133    }
134
135    /// Get size
136    pub fn len(&self) -> usize {
137        self.data.len()
138    }
139
140    /// Check if empty
141    pub fn is_empty(&self) -> bool {
142        self.data.is_empty()
143    }
144}
145
146impl std::fmt::Debug for Tensor {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        f.debug_struct("Tensor")
149            .field("data", &self.data)
150            .field("grad", &self.grad.borrow())
151            .field("requires_grad", &self.requires_grad)
152            .finish_non_exhaustive()
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn test_tensor_new() {
162        let data = Array1::from(vec![1.0, 2.0, 3.0]);
163        let t = Tensor::new(data.clone(), true);
164        assert_eq!(t.data(), &data);
165        assert!(t.requires_grad());
166    }
167
168    #[test]
169    fn test_tensor_from_vec() {
170        let t = Tensor::from_vec(vec![1.0, 2.0], false);
171        assert_eq!(t.len(), 2);
172        assert!(!t.requires_grad());
173    }
174
175    #[test]
176    fn test_tensor_zeros() {
177        let t = Tensor::zeros(5, true);
178        assert_eq!(t.len(), 5);
179        assert!(t.data().iter().all(|&x| x == 0.0));
180    }
181
182    #[test]
183    fn test_tensor_ones() {
184        let t = Tensor::ones(3, false);
185        assert_eq!(t.len(), 3);
186        assert!(t.data().iter().all(|&x| x == 1.0));
187    }
188
189    #[test]
190    fn test_tensor_data_mut() {
191        let mut t = Tensor::from_vec(vec![1.0, 2.0], true);
192        t.data_mut()[0] = 5.0;
193        assert_eq!(t.data()[0], 5.0);
194    }
195
196    #[test]
197    fn test_tensor_grad_operations() {
198        let t = Tensor::from_vec(vec![1.0, 2.0], true);
199        assert!(t.grad().is_none());
200
201        t.set_grad(Array1::from(vec![0.1, 0.2]));
202        assert!(t.grad().is_some());
203        assert_eq!(t.grad().expect("gradient should be available")[0], 0.1);
204
205        t.zero_grad();
206        assert!(t.grad().is_none());
207    }
208
209    #[test]
210    fn test_tensor_accumulate_grad() {
211        let t = Tensor::from_vec(vec![1.0, 2.0], true);
212
213        // First accumulation - should set grad
214        t.accumulate_grad(Array1::from(vec![0.1, 0.2]));
215        assert_eq!(t.grad().expect("gradient should be available")[0], 0.1);
216
217        // Second accumulation - should add
218        t.accumulate_grad(Array1::from(vec![0.3, 0.4]));
219        let grad = t.grad().expect("gradient should be available");
220        assert!((grad[0] - 0.4).abs() < 1e-6);
221        assert!((grad[1] - 0.6).abs() < 1e-6);
222    }
223
224    #[test]
225    fn test_tensor_grad_cell() {
226        let t = Tensor::from_vec(vec![1.0], true);
227        let cell = t.grad_cell();
228        assert!(cell.borrow().is_none());
229    }
230
231    #[test]
232    fn test_tensor_backward_op() {
233        let t = Tensor::from_vec(vec![1.0], true);
234        assert!(t.backward_op().is_none());
235        // Note: Setting backward op requires an actual BackwardOp implementation
236    }
237
238    #[test]
239    fn test_tensor_is_empty() {
240        let t = Tensor::from_vec(vec![], false);
241        assert!(t.is_empty());
242
243        let t2 = Tensor::from_vec(vec![1.0], false);
244        assert!(!t2.is_empty());
245    }
246
247    #[test]
248    fn test_tensor_set_requires_grad() {
249        let mut t = Tensor::from_vec(vec![1.0, 2.0], false);
250        assert!(!t.requires_grad());
251
252        t.set_requires_grad(true);
253        assert!(t.requires_grad());
254
255        t.set_requires_grad(false);
256        assert!(!t.requires_grad());
257    }
258
259    #[test]
260    fn test_tensor_debug() {
261        let t = Tensor::from_vec(vec![1.0, 2.0], true);
262        let debug_str = format!("{t:?}");
263        assert!(debug_str.contains("Tensor"));
264        assert!(debug_str.contains("data"));
265    }
266
267    #[test]
268    fn test_tensor_clone() {
269        let t1 = Tensor::from_vec(vec![1.0, 2.0], true);
270        t1.set_grad(Array1::from(vec![0.1, 0.2]));
271        let t2 = t1.clone();
272
273        assert_eq!(t2.data(), t1.data());
274        assert_eq!(t2.requires_grad(), t1.requires_grad());
275    }
276}