Skip to main content

god_graph/transformer/autograd/
compute_graph.rs

1//! Compute graph for tracking operations during forward pass
2
3use std::collections::HashMap;
4use crate::tensor::DenseTensor;
5use crate::tensor::traits::{TensorBase, TensorOps};
6
7/// Unique identifier for an operation node
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct OpId(pub usize);
10
11/// Unique identifier for a tensor
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct TensorId(pub usize);
14
15/// Reference to an operation
16#[derive(Debug, Clone)]
17pub struct OpRef {
18    /// Operation ID
19    pub id: OpId,
20    /// Operation type
21    pub op_type: OpType,
22    /// Input tensor IDs
23    pub inputs: Vec<TensorId>,
24    /// Output tensor ID
25    pub output: TensorId,
26}
27
28/// Operation type
29#[derive(Debug, Clone)]
30pub enum OpType {
31    /// Element-wise addition
32    Add,
33    /// Element-wise subtraction
34    Sub,
35    /// Element-wise multiplication
36    Mul,
37    /// Element-wise division
38    Div,
39    /// Matrix multiplication
40    MatMul,
41    /// Transpose operation
42    Transpose,
43    /// Sum reduction
44    Sum,
45    /// Mean reduction
46    Mean,
47    /// ReLU activation
48    ReLU,
49    /// GELU activation
50    GELU,
51    /// Sigmoid activation
52    Sigmoid,
53    /// Tanh activation
54    Tanh,
55    /// SiLU/Swish activation
56    SiLU,
57    /// Softmax activation
58    Softmax,
59    /// Layer normalization
60    LayerNorm,
61    /// RMS normalization
62    RMSNorm,
63    /// Linear/fully connected layer
64    Linear,
65    /// Embedding lookup
66    Embedding,
67    /// Rotary position embedding
68    RoPE,
69    /// Scaled dot-product attention
70    ScaledDotProduct,
71}
72
73/// Operation node in the compute graph
74#[derive(Debug, Clone)]
75pub struct OpNode {
76    /// Operation ID
77    pub id: OpId,
78    /// Operation type
79    pub op_type: OpType,
80    /// Input tensor IDs
81    pub inputs: Vec<TensorId>,
82    /// Output tensor ID
83    pub output: TensorId,
84}
85
86/// Data edge representing tensor dependencies
87#[derive(Debug, Clone)]
88pub struct DataEdge {
89    /// Source operation ID
90    pub from: OpId,
91    /// Destination operation ID
92    pub to: OpId,
93    /// Tensor ID
94    pub tensor_id: TensorId,
95}
96
97/// Checkpoint for gradient checkpointing
98#[derive(Debug, Clone)]
99pub struct Checkpoint {
100    /// Tensor values
101    pub tensors: HashMap<TensorId, DenseTensor>,
102}
103
104/// Compute graph for tracking operations
105#[derive(Debug, Default, Clone)]
106pub struct ComputeGraph {
107    /// Operation nodes
108    nodes: Vec<OpNode>,
109    /// Data flow edges
110    edges: Vec<DataEdge>,
111    /// Gradient storage
112    gradients: HashMap<TensorId, DenseTensor>,
113    /// Tensor values (for forward pass)
114    values: HashMap<TensorId, DenseTensor>,
115    /// Checkpoint for memory optimization
116    checkpoint: Option<Checkpoint>,
117    /// Next operation ID
118    next_op_id: usize,
119    /// Next tensor ID
120    next_tensor_id: usize,
121    /// Whether to record operations (disable during eval mode)
122    recording: bool,
123}
124
125impl ComputeGraph {
126    /// Create a new compute graph
127    pub fn new() -> Self {
128        Self {
129            nodes: Vec::new(),
130            edges: Vec::new(),
131            gradients: HashMap::new(),
132            values: HashMap::new(),
133            checkpoint: None,
134            next_op_id: 0,
135            next_tensor_id: 0,
136            recording: true,
137        }
138    }
139
140    /// Generate a new operation ID
141    pub fn next_op_id(&mut self) -> OpId {
142        let id = OpId(self.next_op_id);
143        self.next_op_id += 1;
144        id
145    }
146
147    /// Generate a new tensor ID
148    pub fn next_tensor_id(&mut self) -> TensorId {
149        let id = TensorId(self.next_tensor_id);
150        self.next_tensor_id += 1;
151        id
152    }
153
154    /// Record an operation in the compute graph
155    pub fn record_op(&mut self, op_type: OpType, inputs: &[TensorId], output: TensorId) {
156        if !self.recording {
157            return;
158        }
159
160        let op_id = self.next_op_id();
161        let node = OpNode {
162            id: op_id,
163            op_type: op_type.clone(),
164            inputs: inputs.to_vec(),
165            output,
166        };
167        self.nodes.push(node);
168
169        // Create edges from input producers to this operation
170        for &input_id in inputs {
171            // Find the operation that produced this input
172            if let Some(producer_op) = self.nodes.iter().rev().find(|n| {
173                n.output == input_id
174            }) {
175                let edge = DataEdge {
176                    from: producer_op.id,
177                    to: op_id,
178                    tensor_id: input_id,
179                };
180                self.edges.push(edge);
181            }
182        }
183    }
184
185    /// Store a tensor value
186    pub fn store_value(&mut self, tensor_id: TensorId, value: DenseTensor) {
187        self.values.insert(tensor_id, value);
188    }
189
190    /// Get a tensor value
191    pub fn get_value(&self, tensor_id: TensorId) -> Option<&DenseTensor> {
192        self.values.get(&tensor_id)
193    }
194
195    /// Get a mutable reference to a tensor value
196    pub fn get_value_mut(&mut self, tensor_id: TensorId) -> Option<&mut DenseTensor> {
197        self.values.get_mut(&tensor_id)
198    }
199
200    /// Store a gradient
201    pub fn store_gradient(&mut self, tensor_id: TensorId, gradient: DenseTensor) {
202        self.gradients.insert(tensor_id, gradient);
203    }
204
205    /// Get a gradient
206    pub fn get_gradient(&self, tensor_id: TensorId) -> Option<&DenseTensor> {
207        self.gradients.get(&tensor_id)
208    }
209
210    /// Perform backward pass to compute gradients
211    ///
212    /// # Arguments
213    /// * `loss` - The tensor ID of the loss value
214    ///
215    /// # Returns
216    /// A HashMap mapping tensor IDs to their computed gradients
217    pub fn backward(&mut self, loss: TensorId) -> HashMap<TensorId, DenseTensor> {
218        // Initialize gradient of loss as 1.0
219        if let Some(loss_tensor) = self.values.get(&loss) {
220            let shape = loss_tensor.shape().to_vec();
221            let ones = DenseTensor::ones(shape);
222            self.gradients.insert(loss, ones);
223        }
224
225        // Get topological order of operations (reverse)
226        let topo_order = self.topological_sort();
227
228        // Backpropagate in reverse topological order
229        for op_id in topo_order.into_iter().rev() {
230            // Clone node info to avoid borrow checker issues
231            let (node_op_type, node_inputs, node_output) = if let Some(node) = self.nodes.iter().find(|n| n.id == op_id) {
232                (node.op_type.clone(), node.inputs.clone(), node.output)
233            } else {
234                continue;
235            };
236
237            let grad_output = self.gradients.get(&node_output).cloned();
238
239            if let Some(grad) = grad_output {
240                // Compute gradients for inputs based on operation type
241                let input_grads = self.compute_gradients(&node_op_type, &node_inputs, &grad);
242
243                // Accumulate gradients for inputs
244                for (i, &input_id) in node_inputs.iter().enumerate() {
245                    if let Some(input_grad) = input_grads.get(&i) {
246                        self.accumulate_gradient(input_id, input_grad.clone());
247                    }
248                }
249            }
250        }
251
252        self.gradients.clone()
253    }
254
255    /// Compute gradients for a specific operation
256    fn compute_gradients(
257        &self,
258        op_type: &OpType,
259        inputs: &[TensorId],
260        grad_output: &DenseTensor,
261    ) -> HashMap<usize, DenseTensor> {
262        let mut grads = HashMap::new();
263
264        match op_type {
265            OpType::Add => {
266                // d(x+y)/dx = 1, d(x+y)/dy = 1
267                for (i, _) in inputs.iter().enumerate() {
268                    grads.insert(i, grad_output.clone());
269                }
270            }
271            OpType::Sub => {
272                // d(x-y)/dx = 1, d(x-y)/dy = -1
273                for (i, _) in inputs.iter().enumerate() {
274                    if i == 0 {
275                        grads.insert(i, grad_output.clone());
276                    } else {
277                        grads.insert(i, grad_output.neg());
278                    }
279                }
280            }
281            OpType::Mul => {
282                // Element-wise multiplication: d(x*y)/dx = y, d(x*y)/dy = x
283                if inputs.len() >= 2 {
284                    if let (Some(x), Some(y)) = (
285                        self.values.get(&inputs[0]),
286                        self.values.get(&inputs[1]),
287                    ) {
288                        grads.insert(0, grad_output.mul(y));
289                        grads.insert(1, grad_output.mul(x));
290                    }
291                }
292            }
293            OpType::MatMul => {
294                // Matrix multiplication: d(X@W)/dX = d_out @ W.T, d(X@W)/dW = X.T @ d_out
295                if inputs.len() >= 2 {
296                    if let (Some(x), Some(w)) = (
297                        self.values.get(&inputs[0]),
298                        self.values.get(&inputs[1]),
299                    ) {
300                        // Gradient w.r.t. input
301                        let w_t = w.transpose(None);
302                        let grad_x = grad_output.matmul(&w_t);
303                        grads.insert(0, grad_x);
304
305                        // Gradient w.r.t. weights
306                        let x_t = x.transpose(None);
307                        let grad_w = x_t.matmul(grad_output);
308                        grads.insert(1, grad_w);
309                    }
310                }
311            }
312            OpType::ReLU => {
313                // ReLU gradient: 1 if x > 0, else 0
314                if let Some(x) = inputs.first().and_then(|id| self.values.get(id)) {
315                    let mask = x.gt(0.0);
316                    let grad = grad_output.mul(&mask);
317                    grads.insert(0, grad);
318                }
319            }
320            OpType::GELU => {
321                // GELU gradient approximation
322                if let Some(x) = inputs.first().and_then(|id| self.values.get(id)) {
323                    let gelu_grad = x.gelu_derivative();
324                    let grad = grad_output.mul(&gelu_grad);
325                    grads.insert(0, grad);
326                }
327            }
328            OpType::Softmax => {
329                // Softmax gradient: s * (grad - sum(grad * s))
330                if let Some(softmax_out) = inputs.first().and_then(|id| self.values.get(id)) {
331                    let sum_grad_dot_s = grad_output.mul(softmax_out).sum(None);
332                    let ones = DenseTensor::ones(softmax_out.shape().to_vec());
333                    let ones_scaled = ones.scale(sum_grad_dot_s.data()[0]);
334                    let diff = grad_output.sub(&ones_scaled);
335                    let grad = softmax_out.mul(&diff);
336                    grads.insert(0, grad);
337                }
338            }
339            OpType::Transpose => {
340                // Transpose gradient is just transpose of gradient
341                if !inputs.is_empty() {
342                    grads.insert(0, grad_output.transpose(None));
343                }
344            }
345            OpType::LayerNorm | OpType::RMSNorm => {
346                // Normalization gradients (simplified)
347                if inputs.first().and_then(|id| self.values.get(id)).is_some() {
348                    // Placeholder - actual implementation needs more careful handling
349                    grads.insert(0, grad_output.clone());
350                }
351            }
352            _ => {
353                // For unimplemented operations, pass gradient through
354                for (i, _) in inputs.iter().enumerate() {
355                    grads.insert(i, grad_output.clone());
356                }
357            }
358        }
359
360        grads
361    }
362
363    /// Accumulate gradient for a tensor
364    pub fn accumulate_gradient(&mut self, tensor_id: TensorId, gradient: DenseTensor) {
365        self.gradients
366            .entry(tensor_id)
367            .and_modify(|existing| {
368                *existing = existing.add(&gradient);
369            })
370            .or_insert(gradient);
371    }
372
373    /// Perform topological sort of operations
374    pub fn topological_sort(&self) -> Vec<OpId> {
375        let mut result = Vec::new();
376        let mut visited = std::collections::HashSet::new();
377
378        fn visit(
379            node: &OpNode,
380            nodes: &[OpNode],
381            visited: &mut std::collections::HashSet<OpId>,
382            result: &mut Vec<OpId>,
383        ) {
384            if visited.contains(&node.id) {
385                return;
386            }
387            visited.insert(node.id);
388
389            // Visit producers first
390            for &input_id in &node.inputs {
391                if let Some(producer) = nodes.iter().find(|n| n.output == input_id) {
392                    visit(producer, nodes, visited, result);
393                }
394            }
395
396            result.push(node.id);
397        }
398
399        for node in &self.nodes {
400            visit(node, &self.nodes, &mut visited, &mut result);
401        }
402
403        result
404    }
405
406    /// Clear the compute graph (call after backward pass)
407    pub fn clear(&mut self) {
408        self.nodes.clear();
409        self.edges.clear();
410        self.gradients.clear();
411        self.values.clear();
412        self.checkpoint = None;
413    }
414
415    /// Enable/disable operation recording
416    pub fn set_recording(&mut self, recording: bool) {
417        self.recording = recording;
418    }
419
420    /// Check if recording is enabled
421    pub fn is_recording(&self) -> bool {
422        self.recording
423    }
424
425    /// Get number of recorded operations
426    pub fn num_ops(&self) -> usize {
427        self.nodes.len()
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_compute_graph_basic() {
437        let mut graph = ComputeGraph::new();
438        
439        // Create some tensors
440        let x_id = graph.next_tensor_id();
441        let w_id = graph.next_tensor_id();
442        
443        let x = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]);
444        let w = DenseTensor::new(vec![0.1, 0.2, 0.3], vec![3, 1]);
445        
446        graph.store_value(x_id, x);
447        graph.store_value(w_id, w);
448        
449        // Record MatMul operation
450        let out_id = graph.next_tensor_id();
451        graph.record_op(OpType::MatMul, &[x_id, w_id], out_id);
452        
453        // Compute output
454        if let (Some(x), Some(w)) = (graph.get_value(x_id), graph.get_value(w_id)) {
455            let out = x.matmul(w);
456            graph.store_value(out_id, out);
457        }
458        
459        assert_eq!(graph.num_ops(), 1);
460    }
461
462    #[test]
463    fn test_topological_sort() {
464        let mut graph = ComputeGraph::new();
465        
466        // Create a simple chain: x -> MatMul -> ReLU -> output
467        let x_id = graph.next_tensor_id();
468        let w_id = graph.next_tensor_id();
469        let matmul_out = graph.next_tensor_id();
470        let relu_out = graph.next_tensor_id();
471        
472        graph.store_value(x_id, DenseTensor::new(vec![1.0, 2.0], vec![1, 2]));
473        graph.store_value(w_id, DenseTensor::new(vec![0.1, 0.2], vec![2, 1]));
474        
475        graph.record_op(OpType::MatMul, &[x_id, w_id], matmul_out);
476        graph.record_op(OpType::ReLU, &[matmul_out], relu_out);
477        
478        let order = graph.topological_sort();
479        assert_eq!(order.len(), 2);
480        // MatMul should come before ReLU
481        assert!(order.iter().position(|&id| {
482            graph.nodes.iter().any(|n| n.id == id && matches!(n.op_type, OpType::MatMul))
483        }).unwrap() < order.iter().position(|&id| {
484            graph.nodes.iter().any(|n| n.id == id && matches!(n.op_type, OpType::ReLU))
485        }).unwrap());
486    }
487}