god_graph/transformer/autograd/
tensor.rs1use std::sync::Arc;
4use crate::tensor::DenseTensor;
5use crate::tensor::traits::{TensorOps, TensorBase};
6use super::compute_graph::{ComputeGraph, OpId, TensorId, OpType};
7
8#[derive(Debug, Clone)]
10pub struct DifferentiableTensor {
11 data: DenseTensor,
13 grad: Option<DenseTensor>,
15 op_id: Option<OpId>,
17 tensor_id: TensorId,
19 requires_grad: bool,
21 #[allow(dead_code)]
23 graph: Option<Arc<ComputeGraph>>,
24}
25
26impl DifferentiableTensor {
27 pub fn new(data: DenseTensor, requires_grad: bool) -> Self {
29 Self {
30 data,
31 grad: None,
32 op_id: None,
33 tensor_id: TensorId(0),
34 requires_grad,
35 graph: None,
36 }
37 }
38
39 pub fn with_graph(data: DenseTensor, requires_grad: bool, graph: &mut ComputeGraph) -> Self {
41 let tensor_id = graph.next_tensor_id();
42 graph.store_value(tensor_id, data.clone());
43
44 Self {
45 data,
46 grad: None,
47 op_id: None,
48 tensor_id,
49 requires_grad,
50 graph: Some(Arc::new(graph.clone())),
51 }
52 }
53
54 pub fn data(&self) -> &DenseTensor {
56 &self.data
57 }
58
59 pub fn data_mut(&mut self) -> &mut DenseTensor {
61 &mut self.data
62 }
63
64 pub fn grad(&self) -> Option<&DenseTensor> {
66 self.grad.as_ref()
67 }
68
69 pub fn grad_mut(&mut self) -> Option<&mut DenseTensor> {
71 self.grad.as_mut()
72 }
73
74 pub fn set_grad(&mut self, grad: DenseTensor) {
76 self.grad = Some(grad);
77 }
78
79 pub fn zero_grad(&mut self) {
81 self.grad = None;
82 }
83
84 pub fn requires_grad(&self) -> bool {
86 self.requires_grad
87 }
88
89 pub fn tensor_id(&self) -> TensorId {
91 self.tensor_id
92 }
93
94 pub fn op_id(&self) -> Option<OpId> {
96 self.op_id
97 }
98
99 #[allow(dead_code)]
101 pub(crate) fn set_op_id(&mut self, op_id: OpId) {
102 self.op_id = Some(op_id);
103 }
104
105 pub fn shape(&self) -> &[usize] {
107 self.data.shape()
108 }
109
110 pub fn matmul(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
112 let output_id = graph.next_tensor_id();
113 let result_data = self.data.matmul(&other.data);
114
115 graph.record_op(OpType::MatMul, &[self.tensor_id, other.tensor_id], output_id);
116 graph.store_value(output_id, result_data.clone());
117
118 DifferentiableTensor {
119 data: result_data,
120 grad: None,
121 op_id: None,
122 tensor_id: output_id,
123 requires_grad: self.requires_grad || other.requires_grad,
124 graph: Some(Arc::new(graph.clone())),
125 }
126 }
127
128 pub fn add(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
130 let output_id = graph.next_tensor_id();
131 let result_data = self.data.add(&other.data);
132
133 graph.record_op(OpType::Add, &[self.tensor_id, other.tensor_id], output_id);
134 graph.store_value(output_id, result_data.clone());
135
136 DifferentiableTensor {
137 data: result_data,
138 grad: None,
139 op_id: None,
140 tensor_id: output_id,
141 requires_grad: self.requires_grad || other.requires_grad,
142 graph: Some(Arc::new(graph.clone())),
143 }
144 }
145
146 pub fn sub(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
148 let output_id = graph.next_tensor_id();
149 let result_data = self.data.sub(&other.data);
150
151 graph.record_op(OpType::Sub, &[self.tensor_id, other.tensor_id], output_id);
152 graph.store_value(output_id, result_data.clone());
153
154 DifferentiableTensor {
155 data: result_data,
156 grad: None,
157 op_id: None,
158 tensor_id: output_id,
159 requires_grad: self.requires_grad || other.requires_grad,
160 graph: Some(Arc::new(graph.clone())),
161 }
162 }
163
164 pub fn mul(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
166 let output_id = graph.next_tensor_id();
167 let result_data = self.data.mul(&other.data);
168
169 graph.record_op(OpType::Mul, &[self.tensor_id, other.tensor_id], output_id);
170 graph.store_value(output_id, result_data.clone());
171
172 DifferentiableTensor {
173 data: result_data,
174 grad: None,
175 op_id: None,
176 tensor_id: output_id,
177 requires_grad: self.requires_grad || other.requires_grad,
178 graph: Some(Arc::new(graph.clone())),
179 }
180 }
181
182 pub fn relu(&self, graph: &mut ComputeGraph) -> DifferentiableTensor {
184 let output_id = graph.next_tensor_id();
185 let result_data = self.data.relu();
186
187 graph.record_op(OpType::ReLU, &[self.tensor_id], output_id);
188 graph.store_value(output_id, result_data.clone());
189
190 DifferentiableTensor {
191 data: result_data,
192 grad: None,
193 op_id: None,
194 tensor_id: output_id,
195 requires_grad: self.requires_grad,
196 graph: Some(Arc::new(graph.clone())),
197 }
198 }
199
200 pub fn gelu(&self, graph: &mut ComputeGraph) -> DifferentiableTensor {
202 let output_id = graph.next_tensor_id();
203 let result_data = self.data.gelu();
204
205 graph.record_op(OpType::GELU, &[self.tensor_id], output_id);
206 graph.store_value(output_id, result_data.clone());
207
208 DifferentiableTensor {
209 data: result_data,
210 grad: None,
211 op_id: None,
212 tensor_id: output_id,
213 requires_grad: self.requires_grad,
214 graph: Some(Arc::new(graph.clone())),
215 }
216 }
217
218 pub fn softmax(&self, dim: isize, graph: &mut ComputeGraph) -> DifferentiableTensor {
220 let output_id = graph.next_tensor_id();
221 let result_data = self.data.softmax(dim);
222
223 graph.record_op(OpType::Softmax, &[self.tensor_id], output_id);
224 graph.store_value(output_id, result_data.clone());
225
226 DifferentiableTensor {
227 data: result_data,
228 grad: None,
229 op_id: None,
230 tensor_id: output_id,
231 requires_grad: self.requires_grad,
232 graph: Some(Arc::new(graph.clone())),
233 }
234 }
235
236 pub fn transpose(&self, graph: &mut ComputeGraph) -> DifferentiableTensor {
238 let output_id = graph.next_tensor_id();
239 let result_data = self.data.transpose(None);
240
241 graph.record_op(OpType::Transpose, &[self.tensor_id], output_id);
242 graph.store_value(output_id, result_data.clone());
243
244 DifferentiableTensor {
245 data: result_data,
246 grad: None,
247 op_id: None,
248 tensor_id: output_id,
249 requires_grad: self.requires_grad,
250 graph: Some(Arc::new(graph.clone())),
251 }
252 }
253
254 pub fn detach(&self) -> DifferentiableTensor {
256 DifferentiableTensor {
257 data: self.data.clone(),
258 grad: None,
259 op_id: None,
260 tensor_id: TensorId(0),
261 requires_grad: false,
262 graph: None,
263 }
264 }
265
266 pub fn backward(&mut self, graph: &mut ComputeGraph) {
268 if self.requires_grad {
269 graph.backward(self.tensor_id);
270
271 if let Some(grad) = graph.get_gradient(self.tensor_id).cloned() {
273 self.grad = Some(grad);
274 }
275 }
276 }
277}
278
279impl From<DenseTensor> for DifferentiableTensor {
281 fn from(tensor: DenseTensor) -> Self {
282 Self::new(tensor, false)
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use crate::tensor::DenseTensor;
290
291 #[test]
292 fn test_differentiable_tensor_creation() {
293 let data = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]);
294 let tensor = DifferentiableTensor::new(data.clone(), true);
295
296 assert!(tensor.requires_grad());
297 assert_eq!(tensor.data(), &data);
298 assert!(tensor.grad().is_none());
299 }
300
301 #[test]
302 fn test_differentiable_matmul() {
303 let mut graph = ComputeGraph::new();
304
305 let x = DenseTensor::new(vec![1.0, 2.0], vec![1, 2]);
306 let w = DenseTensor::new(vec![0.1, 0.2, 0.3, 0.4], vec![2, 2]);
307
308 let x_diff = DifferentiableTensor::with_graph(x, true, &mut graph);
309 let w_diff = DifferentiableTensor::with_graph(w, true, &mut graph);
310
311 let out = x_diff.matmul(&w_diff, &mut graph);
312
313 assert!(out.requires_grad());
314 assert_eq!(out.shape(), &[1, 2]);
315 }
316
317 #[test]
318 fn test_differentiable_relu() {
319 let mut graph = ComputeGraph::new();
320
321 let data = DenseTensor::new(vec![-1.0, 2.0, -3.0, 4.0], vec![1, 4]);
322 let tensor = DifferentiableTensor::with_graph(data, true, &mut graph);
323
324 let out = tensor.relu(&mut graph);
325
326 assert_eq!(out.shape(), &[1, 4]);
328 }
329}