Skip to main content

god_graph/transformer/autograd/
tensor.rs

1//! Differentiable tensor wrapper with gradient support
2
3use std::sync::Arc;
4use crate::tensor::DenseTensor;
5use crate::tensor::traits::{TensorOps, TensorBase};
6use super::compute_graph::{ComputeGraph, OpId, TensorId, OpType};
7
8/// A differentiable tensor that tracks gradients and compute graph information
9#[derive(Debug, Clone)]
10pub struct DifferentiableTensor {
11    /// The underlying tensor data
12    data: DenseTensor,
13    /// Gradient of this tensor (computed during backward pass)
14    grad: Option<DenseTensor>,
15    /// ID of the operation that produced this tensor
16    op_id: Option<OpId>,
17    /// Tensor ID in the compute graph
18    tensor_id: TensorId,
19    /// Whether this tensor requires gradient
20    requires_grad: bool,
21    /// Reference to the compute graph (weak reference to avoid cycles)
22    #[allow(dead_code)]
23    graph: Option<Arc<ComputeGraph>>,
24}
25
26impl DifferentiableTensor {
27    /// Create a new differentiable tensor
28    pub fn new(data: DenseTensor, requires_grad: bool) -> Self {
29        Self {
30            data,
31            grad: None,
32            op_id: None,
33            tensor_id: TensorId(0),
34            requires_grad,
35            graph: None,
36        }
37    }
38
39    /// Create a differentiable tensor with compute graph tracking
40    pub fn with_graph(data: DenseTensor, requires_grad: bool, graph: &mut ComputeGraph) -> Self {
41        let tensor_id = graph.next_tensor_id();
42        graph.store_value(tensor_id, data.clone());
43        
44        Self {
45            data,
46            grad: None,
47            op_id: None,
48            tensor_id,
49            requires_grad,
50            graph: Some(Arc::new(graph.clone())),
51        }
52    }
53
54    /// Get the underlying data
55    pub fn data(&self) -> &DenseTensor {
56        &self.data
57    }
58
59    /// Get mutable reference to data
60    pub fn data_mut(&mut self) -> &mut DenseTensor {
61        &mut self.data
62    }
63
64    /// Get the gradient (if computed)
65    pub fn grad(&self) -> Option<&DenseTensor> {
66        self.grad.as_ref()
67    }
68
69    /// Get mutable reference to gradient
70    pub fn grad_mut(&mut self) -> Option<&mut DenseTensor> {
71        self.grad.as_mut()
72    }
73
74    /// Set the gradient
75    pub fn set_grad(&mut self, grad: DenseTensor) {
76        self.grad = Some(grad);
77    }
78
79    /// Clear the gradient
80    pub fn zero_grad(&mut self) {
81        self.grad = None;
82    }
83
84    /// Check if this tensor requires gradient
85    pub fn requires_grad(&self) -> bool {
86        self.requires_grad
87    }
88
89    /// Get the tensor ID
90    pub fn tensor_id(&self) -> TensorId {
91        self.tensor_id
92    }
93
94    /// Get the operation ID that produced this tensor
95    pub fn op_id(&self) -> Option<OpId> {
96        self.op_id
97    }
98
99    /// Set operation ID (called by compute graph)
100    #[allow(dead_code)]
101    pub(crate) fn set_op_id(&mut self, op_id: OpId) {
102        self.op_id = Some(op_id);
103    }
104
105    /// Get shape of the tensor
106    pub fn shape(&self) -> &[usize] {
107        self.data.shape()
108    }
109
110    /// Matrix multiplication
111    pub fn matmul(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
112        let output_id = graph.next_tensor_id();
113        let result_data = self.data.matmul(&other.data);
114        
115        graph.record_op(OpType::MatMul, &[self.tensor_id, other.tensor_id], output_id);
116        graph.store_value(output_id, result_data.clone());
117        
118        DifferentiableTensor {
119            data: result_data,
120            grad: None,
121            op_id: None,
122            tensor_id: output_id,
123            requires_grad: self.requires_grad || other.requires_grad,
124            graph: Some(Arc::new(graph.clone())),
125        }
126    }
127
128    /// Element-wise addition
129    pub fn add(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
130        let output_id = graph.next_tensor_id();
131        let result_data = self.data.add(&other.data);
132        
133        graph.record_op(OpType::Add, &[self.tensor_id, other.tensor_id], output_id);
134        graph.store_value(output_id, result_data.clone());
135        
136        DifferentiableTensor {
137            data: result_data,
138            grad: None,
139            op_id: None,
140            tensor_id: output_id,
141            requires_grad: self.requires_grad || other.requires_grad,
142            graph: Some(Arc::new(graph.clone())),
143        }
144    }
145
146    /// Element-wise subtraction
147    pub fn sub(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
148        let output_id = graph.next_tensor_id();
149        let result_data = self.data.sub(&other.data);
150        
151        graph.record_op(OpType::Sub, &[self.tensor_id, other.tensor_id], output_id);
152        graph.store_value(output_id, result_data.clone());
153        
154        DifferentiableTensor {
155            data: result_data,
156            grad: None,
157            op_id: None,
158            tensor_id: output_id,
159            requires_grad: self.requires_grad || other.requires_grad,
160            graph: Some(Arc::new(graph.clone())),
161        }
162    }
163
164    /// Element-wise multiplication
165    pub fn mul(&self, other: &DifferentiableTensor, graph: &mut ComputeGraph) -> DifferentiableTensor {
166        let output_id = graph.next_tensor_id();
167        let result_data = self.data.mul(&other.data);
168        
169        graph.record_op(OpType::Mul, &[self.tensor_id, other.tensor_id], output_id);
170        graph.store_value(output_id, result_data.clone());
171        
172        DifferentiableTensor {
173            data: result_data,
174            grad: None,
175            op_id: None,
176            tensor_id: output_id,
177            requires_grad: self.requires_grad || other.requires_grad,
178            graph: Some(Arc::new(graph.clone())),
179        }
180    }
181
182    /// ReLU activation
183    pub fn relu(&self, graph: &mut ComputeGraph) -> DifferentiableTensor {
184        let output_id = graph.next_tensor_id();
185        let result_data = self.data.relu();
186        
187        graph.record_op(OpType::ReLU, &[self.tensor_id], output_id);
188        graph.store_value(output_id, result_data.clone());
189        
190        DifferentiableTensor {
191            data: result_data,
192            grad: None,
193            op_id: None,
194            tensor_id: output_id,
195            requires_grad: self.requires_grad,
196            graph: Some(Arc::new(graph.clone())),
197        }
198    }
199
200    /// GELU activation
201    pub fn gelu(&self, graph: &mut ComputeGraph) -> DifferentiableTensor {
202        let output_id = graph.next_tensor_id();
203        let result_data = self.data.gelu();
204        
205        graph.record_op(OpType::GELU, &[self.tensor_id], output_id);
206        graph.store_value(output_id, result_data.clone());
207        
208        DifferentiableTensor {
209            data: result_data,
210            grad: None,
211            op_id: None,
212            tensor_id: output_id,
213            requires_grad: self.requires_grad,
214            graph: Some(Arc::new(graph.clone())),
215        }
216    }
217
218    /// Softmax activation
219    pub fn softmax(&self, dim: isize, graph: &mut ComputeGraph) -> DifferentiableTensor {
220        let output_id = graph.next_tensor_id();
221        let result_data = self.data.softmax(dim);
222        
223        graph.record_op(OpType::Softmax, &[self.tensor_id], output_id);
224        graph.store_value(output_id, result_data.clone());
225        
226        DifferentiableTensor {
227            data: result_data,
228            grad: None,
229            op_id: None,
230            tensor_id: output_id,
231            requires_grad: self.requires_grad,
232            graph: Some(Arc::new(graph.clone())),
233        }
234    }
235
236    /// Transpose
237    pub fn transpose(&self, graph: &mut ComputeGraph) -> DifferentiableTensor {
238        let output_id = graph.next_tensor_id();
239        let result_data = self.data.transpose(None);
240
241        graph.record_op(OpType::Transpose, &[self.tensor_id], output_id);
242        graph.store_value(output_id, result_data.clone());
243        
244        DifferentiableTensor {
245            data: result_data,
246            grad: None,
247            op_id: None,
248            tensor_id: output_id,
249            requires_grad: self.requires_grad,
250            graph: Some(Arc::new(graph.clone())),
251        }
252    }
253
254    /// Detach from compute graph (for inference)
255    pub fn detach(&self) -> DifferentiableTensor {
256        DifferentiableTensor {
257            data: self.data.clone(),
258            grad: None,
259            op_id: None,
260            tensor_id: TensorId(0),
261            requires_grad: false,
262            graph: None,
263        }
264    }
265
266    /// Perform backward pass
267    pub fn backward(&mut self, graph: &mut ComputeGraph) {
268        if self.requires_grad {
269            graph.backward(self.tensor_id);
270
271            // Collect gradients
272            if let Some(grad) = graph.get_gradient(self.tensor_id).cloned() {
273                self.grad = Some(grad);
274            }
275        }
276    }
277}
278
279/// Convert from DenseTensor
280impl From<DenseTensor> for DifferentiableTensor {
281    fn from(tensor: DenseTensor) -> Self {
282        Self::new(tensor, false)
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::tensor::DenseTensor;
290
291    #[test]
292    fn test_differentiable_tensor_creation() {
293        let data = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]);
294        let tensor = DifferentiableTensor::new(data.clone(), true);
295        
296        assert!(tensor.requires_grad());
297        assert_eq!(tensor.data(), &data);
298        assert!(tensor.grad().is_none());
299    }
300
301    #[test]
302    fn test_differentiable_matmul() {
303        let mut graph = ComputeGraph::new();
304        
305        let x = DenseTensor::new(vec![1.0, 2.0], vec![1, 2]);
306        let w = DenseTensor::new(vec![0.1, 0.2, 0.3, 0.4], vec![2, 2]);
307        
308        let x_diff = DifferentiableTensor::with_graph(x, true, &mut graph);
309        let w_diff = DifferentiableTensor::with_graph(w, true, &mut graph);
310        
311        let out = x_diff.matmul(&w_diff, &mut graph);
312        
313        assert!(out.requires_grad());
314        assert_eq!(out.shape(), &[1, 2]);
315    }
316
317    #[test]
318    fn test_differentiable_relu() {
319        let mut graph = ComputeGraph::new();
320        
321        let data = DenseTensor::new(vec![-1.0, 2.0, -3.0, 4.0], vec![1, 4]);
322        let tensor = DifferentiableTensor::with_graph(data, true, &mut graph);
323        
324        let out = tensor.relu(&mut graph);
325        
326        // ReLU should zero out negative values
327        assert_eq!(out.shape(), &[1, 4]);
328    }
329}