1use std::rc::Rc;
2use std::cell::RefCell;
3use super::storage::*;
4use crate::{grad_storage, GradientFunction, CpuStorage};
5use std::collections::HashSet;
6
7
8pub type GradientStorage = Rc<RefCell<Storage>>;
9
10#[derive(Clone)]
11pub struct Tensor {
12 pub storage: Storage,
13 device: Device,
14 requires_grad: bool,
15 grad_fn: Option<Rc<dyn GradientFunction>>,
16 grad: Option<GradientStorage>,
17}
18
19impl Tensor {
20 pub fn new(storage: Storage, device: Device, requires_grad: bool) -> Self {
21 let grad = if requires_grad {
22 Some(Rc::new(RefCell::new(Storage::zeros(storage.shape().clone(), Some(device), None))))
23 } else {
24 None
25 };
26
27 Tensor {
28 storage: storage,
29 device: device,
30 requires_grad: requires_grad,
31 grad_fn: None,
32 grad: grad,
33 }
34 }
35
36 pub fn view(&self, tensor: Storage) -> Self {
37 Tensor {
38 storage: tensor,
39 device: self.device,
40 requires_grad: self.requires_grad,
41 grad_fn: self.grad_fn.clone(),
42 grad: self.grad.clone(),
43 }
44 }
45
46 pub fn tensor(&self) -> &Storage {
47 &self.storage
48 }
49
50 pub fn tensor_mut(&mut self) -> &mut Storage {
51 &mut self.storage
52 }
53
54 pub fn device(&self) -> Device {
55 self.device
56 }
57
58 pub fn requires_grad(&self) -> &bool {
59 &self.requires_grad
60 }
61
62 pub fn grad_fn(&self) -> Option<Rc<dyn GradientFunction>> {
63 self.grad_fn.clone()
64 }
65
66 pub fn set_grad_fn(&mut self, grad_fn: Option<Rc<dyn GradientFunction>>) {
67 self.grad_fn = grad_fn;
68 }
69
70 pub fn grad(&self) -> Option<GradientStorage> {
71 self.grad.clone()
72 }
73
74 pub fn grad_mut(&mut self) -> GradientStorage {
75 self.grad.clone().expect("Grad can't be empty")
76 }
77
78 pub fn shape(&self) -> &Vec<usize> {
79 &self.tensor().shape()
80 }
81
82 pub fn backward(&mut self) {
83 if self.tensor().shape().len() != 1 || self.tensor().shape()[0] != 1 {
85 panic!("backward() can only be called on scalar tensors");
86 }
87
88 if let Some(grad) = &self.grad {
90 grad.borrow_mut().set_data(vec![1.0]);
91 } else {
92 panic!("Called backward on tensor that doesn't require grad");
93 }
94
95 let mut topo = Vec::new();
97 let mut visited = HashSet::new();
98
99 fn build_topo(
100 node: &Tensor,
101 topo: &mut Vec<Rc<dyn GradientFunction>>,
102 visited: &mut HashSet<*const dyn GradientFunction>
103 ) {
104 if let Some(grad_fn) = &node.grad_fn {
105 let ptr = Rc::as_ptr(grad_fn) as *const dyn GradientFunction;
106 if !visited.contains(&ptr) {
107 visited.insert(ptr);
108 for parent in grad_fn.prev() {
109 build_topo(parent, topo, visited);
110 }
111 topo.push(grad_fn.clone());
112 }
113 }
114 }
115
116 build_topo(self, &mut topo, &mut visited);
117
118 for grad_fn in topo.iter().rev() {
120 grad_fn.backward();
121 }
122 }
123}