ferrite/tensor/
base.rs

1use std::rc::Rc;
2use std::cell::RefCell;
3use super::storage::*;
4use crate::{grad_storage, GradientFunction, CpuStorage};
5use std::collections::HashSet;
6
7
8pub type GradientStorage = Rc<RefCell<Storage>>;
9
10#[derive(Clone)]
11pub struct Tensor {
12  pub storage: Storage,
13  device: Device,
14  requires_grad: bool,
15  grad_fn: Option<Rc<dyn GradientFunction>>,
16  grad: Option<GradientStorage>,
17}
18
19impl Tensor {
20  pub fn new(storage: Storage, device: Device, requires_grad: bool) -> Self {
21    let grad = if requires_grad {
22      Some(Rc::new(RefCell::new(Storage::zeros(storage.shape().clone(), Some(device), None))))
23    } else {
24      None
25    };
26    
27    Tensor {
28      storage: storage,
29      device: device,
30      requires_grad: requires_grad,
31      grad_fn: None,
32      grad: grad,
33    }
34  }
35
36  pub fn view(&self, tensor: Storage) -> Self {
37    Tensor {
38      storage: tensor,
39      device: self.device,
40      requires_grad: self.requires_grad,
41      grad_fn: self.grad_fn.clone(),
42      grad: self.grad.clone(),
43    }
44  }
45
46  pub fn tensor(&self) -> &Storage {
47    &self.storage
48  }
49
50  pub fn tensor_mut(&mut self) -> &mut Storage {
51    &mut self.storage
52  }
53
54  pub fn device(&self) -> Device {
55    self.device
56  }
57
58  pub fn requires_grad(&self) -> &bool {
59    &self.requires_grad
60  }
61
62  pub fn grad_fn(&self) -> Option<Rc<dyn GradientFunction>> {
63    self.grad_fn.clone()
64  }
65
66  pub fn set_grad_fn(&mut self, grad_fn: Option<Rc<dyn GradientFunction>>) {
67    self.grad_fn = grad_fn;
68  }
69
70  pub fn grad(&self) -> Option<GradientStorage> {
71    self.grad.clone()
72  }
73
74  pub fn grad_mut(&mut self) -> GradientStorage {
75    self.grad.clone().expect("Grad can't be empty")
76  }
77
78  pub fn shape(&self) -> &Vec<usize> {
79    &self.tensor().shape()
80  }
81
82  pub fn backward(&mut self) {
83    // Verify we're starting with a scalar
84    if self.tensor().shape().len() != 1 || self.tensor().shape()[0] != 1 {
85      panic!("backward() can only be called on scalar tensors");
86    }
87
88    // Initialize gradient for final output (always 1.0 for scalar outputs)
89    if let Some(grad) = &self.grad {
90      grad.borrow_mut().set_data(vec![1.0]);
91    } else {
92      panic!("Called backward on tensor that doesn't require grad");
93    }
94
95    // Build computation graph in topological order
96    let mut topo = Vec::new();
97    let mut visited = HashSet::new();
98
99    fn build_topo(
100      node: &Tensor, 
101      topo: &mut Vec<Rc<dyn GradientFunction>>, 
102      visited: &mut HashSet<*const dyn GradientFunction>
103    ) {
104      if let Some(grad_fn) = &node.grad_fn {
105        let ptr = Rc::as_ptr(grad_fn) as *const dyn GradientFunction;
106        if !visited.contains(&ptr) {
107          visited.insert(ptr);
108          for parent in grad_fn.prev() {
109            build_topo(parent, topo, visited);
110          }
111          topo.push(grad_fn.clone());
112        }
113      }
114    }
115
116    build_topo(self, &mut topo, &mut visited);
117
118    // Execute backward passes in reverse order
119    for grad_fn in topo.iter().rev() {
120      grad_fn.backward();
121    }
122  }
123}