aprender/autograd/
tensor.rs

1//! Tensor with automatic differentiation support.
2//!
3//! This module provides the core `Tensor` type that tracks gradients
4//! through computational operations.
5
6use std::fmt;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10use crate::primitives::Vector;
11
12use super::grad_fn::GradFn;
13use super::with_graph;
14
15/// Unique identifier for tensors in the computation graph.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub struct TensorId(u64);
18
19impl TensorId {
20    /// Generate a new unique tensor ID.
21    pub fn new() -> Self {
22        static COUNTER: AtomicU64 = AtomicU64::new(0);
23        TensorId(COUNTER.fetch_add(1, Ordering::SeqCst))
24    }
25}
26
27impl Default for TensorId {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33/// A tensor with optional gradient tracking for automatic differentiation.
34///
35/// # Design
36///
37/// The tensor stores:
38/// - `data`: The actual numerical values (backed by aprender's Vector)
39/// - `shape`: Dimensions of the tensor
40/// - `grad`: Accumulated gradient (populated after `backward()`)
41/// - `requires_grad`: Whether this tensor participates in gradient computation
42/// - `grad_fn`: The operation that created this tensor (for backprop)
43/// - `id`: Unique identifier for graph tracking
44///
45/// # Thread Safety
46///
47/// Tensors use `Arc` internally for shared ownership of gradient functions,
48/// making them safe to share across threads for inference (but not training).
49#[derive(Clone)]
50pub struct Tensor {
51    /// Underlying data storage
52    data: Vector<f32>,
53
54    /// Shape of the tensor
55    shape: Vec<usize>,
56
57    /// Gradient (populated after backward())
58    grad: Option<Box<Tensor>>,
59
60    /// Whether this tensor requires gradient computation
61    requires_grad: bool,
62
63    /// Whether this is a leaf tensor (created by user, not by operation)
64    is_leaf: bool,
65
66    /// Function that computes gradients during backward pass
67    grad_fn: Option<Arc<dyn GradFn>>,
68
69    /// Unique identifier for graph construction
70    id: TensorId,
71}
72
73impl Tensor {
74    /// Create a new tensor from a slice with the given shape.
75    ///
76    /// By default, gradient tracking is disabled.
77    ///
78    /// # Panics
79    ///
80    /// Panics if the data length doesn't match the product of shape dimensions.
81    pub fn new(data: &[f32], shape: &[usize]) -> Self {
82        let expected_len: usize = shape.iter().product();
83        assert_eq!(
84            data.len(),
85            expected_len,
86            "Data length {} doesn't match shape {:?} (expected {})",
87            data.len(),
88            shape,
89            expected_len
90        );
91
92        Self {
93            data: Vector::from_slice(data),
94            shape: shape.to_vec(),
95            grad: None,
96            requires_grad: false,
97            is_leaf: true,
98            grad_fn: None,
99            id: TensorId::new(),
100        }
101    }
102
103    /// Create a tensor from a 1D slice (vector).
104    pub fn from_slice(data: &[f32]) -> Self {
105        Self::new(data, &[data.len()])
106    }
107
108    /// Create a tensor filled with zeros.
109    pub fn zeros(shape: &[usize]) -> Self {
110        let len: usize = shape.iter().product();
111        Self::new(&vec![0.0; len], shape)
112    }
113
114    /// Create a tensor filled with ones.
115    pub fn ones(shape: &[usize]) -> Self {
116        let len: usize = shape.iter().product();
117        Self::new(&vec![1.0; len], shape)
118    }
119
120    /// Create a tensor with the same shape as another, filled with zeros.
121    pub fn zeros_like(other: &Tensor) -> Self {
122        Self::zeros(&other.shape)
123    }
124
125    /// Create a tensor with the same shape as another, filled with ones.
126    pub fn ones_like(other: &Tensor) -> Self {
127        Self::ones(&other.shape)
128    }
129
130    /// Enable gradient tracking for this tensor.
131    ///
132    /// Returns self for method chaining.
133    pub fn requires_grad(mut self) -> Self {
134        self.requires_grad = true;
135        self
136    }
137
138    /// Enable or disable gradient tracking (in-place).
139    pub fn requires_grad_(&mut self, requires: bool) -> &mut Self {
140        self.requires_grad = requires;
141        self
142    }
143
144    /// Check if this tensor requires gradient computation.
145    pub fn requires_grad_enabled(&self) -> bool {
146        self.requires_grad
147    }
148
149    /// Check if this is a leaf tensor (not created by an operation).
150    pub fn is_leaf(&self) -> bool {
151        self.is_leaf
152    }
153
154    /// Get the tensor's unique identifier.
155    pub fn id(&self) -> TensorId {
156        self.id
157    }
158
159    /// Get the shape of the tensor.
160    pub fn shape(&self) -> &[usize] {
161        &self.shape
162    }
163
164    /// Get the total number of elements.
165    pub fn numel(&self) -> usize {
166        self.shape.iter().product()
167    }
168
169    /// Get the number of dimensions.
170    pub fn ndim(&self) -> usize {
171        self.shape.len()
172    }
173
174    /// Get a reference to the underlying data.
175    pub fn data(&self) -> &[f32] {
176        self.data.as_slice()
177    }
178
179    /// Get a mutable reference to the underlying data.
180    ///
181    /// # Warning
182    ///
183    /// Modifying data directly may invalidate gradients.
184    pub fn data_mut(&mut self) -> &mut [f32] {
185        self.data.as_mut_slice()
186    }
187
188    /// Get the gradient tensor (if computed).
189    pub fn grad(&self) -> Option<&Tensor> {
190        self.grad.as_deref()
191    }
192
193    /// Zero out the gradient.
194    pub fn zero_grad_(&mut self) {
195        self.grad = None;
196    }
197
198    /// Clear the gradient (alias for zero_grad_).
199    pub fn clear_grad(&mut self) {
200        self.grad = None;
201    }
202
203    /// Accumulate gradient (used during backward pass).
204    pub(crate) fn accumulate_grad(&mut self, grad: Tensor) {
205        match &mut self.grad {
206            Some(existing) => {
207                // Accumulate gradients
208                let new_data: Vec<f32> = existing
209                    .data()
210                    .iter()
211                    .zip(grad.data().iter())
212                    .map(|(a, b)| a + b)
213                    .collect();
214                **existing = Tensor::new(&new_data, &self.shape);
215            }
216            None => {
217                self.grad = Some(Box::new(grad));
218            }
219        }
220    }
221
222    /// Set the gradient function (used internally by operations).
223    pub(crate) fn set_grad_fn(&mut self, grad_fn: Arc<dyn GradFn>) {
224        self.grad_fn = Some(grad_fn);
225        self.is_leaf = false;
226    }
227
228    /// Get the gradient function.
229    #[allow(dead_code)]
230    pub(crate) fn grad_fn(&self) -> Option<&Arc<dyn GradFn>> {
231        self.grad_fn.as_ref()
232    }
233
234    /// Detach tensor from computation graph.
235    ///
236    /// Returns a new tensor with the same data but no gradient tracking.
237    pub fn detach(&self) -> Tensor {
238        Tensor {
239            data: self.data.clone(),
240            shape: self.shape.clone(),
241            grad: None,
242            requires_grad: false,
243            is_leaf: true,
244            grad_fn: None,
245            id: TensorId::new(),
246        }
247    }
248
249    /// Get a scalar value (for 0-d or 1-element tensors).
250    ///
251    /// # Panics
252    ///
253    /// Panics if the tensor has more than one element.
254    pub fn item(&self) -> f32 {
255        assert_eq!(
256            self.numel(),
257            1,
258            "item() only works on tensors with exactly 1 element, got {}",
259            self.numel()
260        );
261        self.data[0]
262    }
263
264    /// Compute gradients via backpropagation.
265    ///
266    /// This implements the reverse-mode automatic differentiation algorithm
267    /// described in Rumelhart et al. (1986).
268    ///
269    /// # Panics
270    ///
271    /// Panics if called on a tensor with more than one element
272    /// (use `backward_with_grad` for non-scalar outputs).
273    pub fn backward(&self) {
274        assert_eq!(
275            self.numel(),
276            1,
277            "backward() requires scalar output, got shape {:?}. Use backward_with_grad() instead.",
278            self.shape
279        );
280
281        self.backward_with_grad(Tensor::ones(&self.shape));
282    }
283
284    /// Compute gradients with a specified output gradient.
285    ///
286    /// # Arguments
287    ///
288    /// * `grad_output` - Gradient of the loss with respect to this tensor
289    pub fn backward_with_grad(&self, grad_output: Tensor) {
290        with_graph(|graph| {
291            graph.backward(self.id, grad_output);
292        });
293    }
294}
295
296impl fmt::Debug for Tensor {
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        f.debug_struct("Tensor")
299            .field("shape", &self.shape)
300            .field("requires_grad", &self.requires_grad)
301            .field("is_leaf", &self.is_leaf)
302            .field("has_grad", &self.grad.is_some())
303            .field("id", &self.id)
304            .finish_non_exhaustive()
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_tensor_creation() {
314        let t = Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
315        assert_eq!(t.shape(), &[2, 2]);
316        assert_eq!(t.numel(), 4);
317        assert_eq!(t.ndim(), 2);
318    }
319
320    #[test]
321    fn test_tensor_from_slice() {
322        let t = Tensor::from_slice(&[1.0, 2.0, 3.0]);
323        assert_eq!(t.shape(), &[3]);
324        assert_eq!(t.numel(), 3);
325    }
326
327    #[test]
328    fn test_tensor_zeros_ones() {
329        let z = Tensor::zeros(&[2, 3]);
330        assert!(z.data().iter().all(|&x| x == 0.0));
331
332        let o = Tensor::ones(&[2, 3]);
333        assert!(o.data().iter().all(|&x| x == 1.0));
334    }
335
336    #[test]
337    fn test_requires_grad() {
338        let t = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
339        assert!(t.requires_grad_enabled());
340
341        let t2 = Tensor::from_slice(&[1.0, 2.0]);
342        assert!(!t2.requires_grad_enabled());
343    }
344
345    #[test]
346    fn test_detach() {
347        let t = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
348        let d = t.detach();
349
350        assert!(t.requires_grad_enabled());
351        assert!(!d.requires_grad_enabled());
352        assert!(d.is_leaf());
353    }
354
355    #[test]
356    fn test_item() {
357        let t = Tensor::new(&[42.0], &[1]);
358        assert_eq!(t.item(), 42.0);
359
360        let t2 = Tensor::new(&[42.0], &[]);
361        assert_eq!(t2.item(), 42.0);
362    }
363
364    #[test]
365    #[should_panic(expected = "item() only works on tensors with exactly 1 element")]
366    fn test_item_panics_multi_element() {
367        let t = Tensor::from_slice(&[1.0, 2.0]);
368        t.item();
369    }
370
371    #[test]
372    fn test_tensor_id_unique() {
373        let t1 = Tensor::from_slice(&[1.0]);
374        let t2 = Tensor::from_slice(&[1.0]);
375        assert_ne!(t1.id(), t2.id());
376    }
377
378    #[test]
379    fn test_gradient_accumulation() {
380        let mut t = Tensor::from_slice(&[1.0, 2.0, 3.0]).requires_grad();
381
382        t.accumulate_grad(Tensor::from_slice(&[0.1, 0.2, 0.3]));
383
384        let grad1 = t
385            .grad()
386            .expect("grad should exist after accumulate")
387            .data()
388            .to_vec();
389        assert_eq!(grad1, vec![0.1, 0.2, 0.3]);
390
391        t.accumulate_grad(Tensor::from_slice(&[0.1, 0.2, 0.3]));
392        let grad2 = t
393            .grad()
394            .expect("grad should exist after second accumulate")
395            .data()
396            .to_vec();
397        assert_eq!(grad2, vec![0.2, 0.4, 0.6]);
398    }
399}