1use crate::array::Array;
11use crate::autograd::{ComputeNode, NodeId};
12use std::rc::Rc;
13use std::cell::RefCell;
14use anyhow::Result;
15
16#[derive(Clone)]
18pub struct Tensor {
19 pub data: Array,
21
22 pub grad: Option<Rc<RefCell<Array>>>,
24
25 pub requires_grad: bool,
27
28 pub compute_node: Option<Rc<ComputeNode>>,
30
31 pub is_leaf: bool,
33}
34
35impl Tensor {
36 pub fn new(data: Array, requires_grad: bool) -> Self {
42 let grad = if requires_grad {
43 Some(Rc::new(RefCell::new(Array::new(
45 data.shape.clone(),
46 vec![0.0; data.data.len()],
47 ))))
48 } else {
49 None
50 };
51
52 Self {
53 data,
54 grad,
55 requires_grad,
56 compute_node: None,
57 is_leaf: true, }
59 }
60
61 pub fn from_operation(
63 data: Array,
64 compute_node: ComputeNode,
65 requires_grad: bool,
66 ) -> Self {
67 let grad = if requires_grad {
68 Some(Rc::new(RefCell::new(Array::new(
69 data.shape.clone(),
70 vec![0.0; data.data.len()],
71 ))))
72 } else {
73 None
74 };
75
76 Self {
77 data,
78 grad,
79 requires_grad,
80 compute_node: Some(Rc::new(compute_node)),
81 is_leaf: false, }
83 }
84
85 pub fn shape(&self) -> &[usize] {
87 &self.data.shape
88 }
89
90 pub fn values(&self) -> &[f32] {
92 &self.data.data
93 }
94
95 pub fn gradient(&self) -> Option<Array> {
97 self.grad.as_ref().map(|g| g.borrow().clone())
98 }
99
100 pub fn item(&self) -> f32 {
102 if self.data.data.is_empty() {
103 0.0
104 } else {
105 self.data.data[0]
106 }
107 }
108
109 pub fn set_gradient(&self, grad: Array) -> Result<()> {
111 if let Some(grad_cell) = &self.grad {
112 *grad_cell.borrow_mut() = grad;
113 Ok(())
114 } else {
115 Err(anyhow::anyhow!("Tensor does not require gradients"))
116 }
117 }
118
119 pub fn accumulate_gradient(&self, incoming_grad: &Array) -> Result<()> {
121 if let Some(grad_cell) = &self.grad {
122 let mut grad = grad_cell.borrow_mut();
123
124 for (g, &ig) in grad.data.iter_mut().zip(incoming_grad.data.iter()) {
126 *g += ig;
127 }
128 Ok(())
129 } else {
130 Err(anyhow::anyhow!("Tensor does not require gradients"))
131 }
132 }
133
134 pub fn zero_grad(&self) {
136 if let Some(grad_cell) = &self.grad {
137 let mut grad = grad_cell.borrow_mut();
138 for g in grad.data.iter_mut() {
139 *g = 0.0;
140 }
141 }
142 }
143
144 pub fn detach(&self) -> Self {
147 Self {
148 data: self.data.clone(),
149 grad: None,
150 requires_grad: false,
151 compute_node: None,
152 is_leaf: true,
153 }
154 }
155
156 pub fn to_array(&self) -> Array {
158 self.data.clone()
159 }
160
161 pub fn node_id(&self) -> Option<NodeId> {
163 self.compute_node.as_ref().map(|node| node.id)
164 }
165
166 pub fn backward(&self) -> Result<()> {
174 if !self.requires_grad {
175 return Err(anyhow::anyhow!("Tensor does not require gradients"));
176 }
177
178 let initial_grad = Array::new(
180 self.data.shape.clone(),
181 vec![1.0; self.data.data.len()],
182 );
183 self.set_gradient(initial_grad)?;
184
185 let mut topo_order = Vec::new();
187 let mut visited = std::collections::HashSet::new();
188 self.build_topo(self, &mut topo_order, &mut visited);
189
190 topo_order.reverse();
192
193 for tensor in topo_order {
195 if let Some(node) = &tensor.compute_node {
196 if let Some(backward_fn) = &node.backward_fn {
197 let grad_output = tensor.gradient()
199 .ok_or_else(|| anyhow::anyhow!("Missing gradient"))?;
200
201 let grad_inputs = backward_fn(&grad_output, &node.inputs, &tensor)?;
203
204 for (input, grad_input) in node.inputs.iter().zip(grad_inputs.iter()) {
206 if input.requires_grad {
207 input.accumulate_gradient(grad_input)?;
208 }
209 }
210 }
211 }
212 }
213
214 Ok(())
215 }
216
217 fn build_topo<'a>(
219 &'a self,
220 node: &'a Tensor,
221 topo_order: &mut Vec<&'a Tensor>,
222 visited: &mut std::collections::HashSet<NodeId>,
223 ) {
224 if let Some(id) = node.node_id() {
225 if visited.contains(&id) {
226 return;
227 }
228 visited.insert(id);
229 }
230
231 if let Some(compute_node) = &node.compute_node {
233 for input in &compute_node.inputs {
234 self.build_topo(input, topo_order, visited);
235 }
236 }
237
238 topo_order.push(node);
239 }
240
241 pub fn transpose(&self) -> Result<Tensor> {
243 use crate::ops::transpose as op_transpose;
244 let transposed_data = op_transpose(&self.data, None)?;
245
246 if self.requires_grad {
248 use crate::autograd::{backward::transpose_backward, OpKind};
249 let backward_fn = Box::new(transpose_backward);
250 let node = ComputeNode::new(
251 OpKind::Transpose,
252 vec![self.clone()],
253 Some(backward_fn),
254 );
255 Ok(Tensor::from_operation(transposed_data, node, true))
256 } else {
257 Ok(Tensor::new(transposed_data, false))
258 }
259 }
260}
261
262impl std::fmt::Debug for Tensor {
263 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 f.debug_struct("Tensor")
265 .field("shape", &self.data.shape)
266 .field("requires_grad", &self.requires_grad)
267 .field("is_leaf", &self.is_leaf)
268 .field("has_grad", &self.grad.is_some())
269 .field("node_id", &self.node_id())
270 .finish()
271 }
272}
273
274impl std::fmt::Display for Tensor {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 write!(f, "Tensor(shape={:?}, requires_grad={}, data=[",
277 self.data.shape, self.requires_grad)?;
278
279 let n = self.data.data.len().min(5);
280 for i in 0..n {
281 write!(f, "{:.4}", self.data.data[i])?;
282 if i < n - 1 {
283 write!(f, ", ")?;
284 }
285 }
286 if self.data.data.len() > 5 {
287 write!(f, ", ...")?;
288 }
289 write!(f, "])")
290 }
291}