ghostflow_autograd/
dynamic_graph.rs

1//! Dynamic Computation Graph
2//!
3//! Implements dynamic computation graphs (like PyTorch) where the graph
4//! is built on-the-fly during forward pass.
5
6use ghostflow_core::Tensor;
7use std::sync::{Arc, Mutex};
8use std::collections::HashMap;
9
10/// Node in the dynamic computation graph
11#[derive(Clone)]
12pub struct GraphNode {
13    /// Unique node ID
14    pub id: usize,
15    /// Operation name
16    pub op: String,
17    /// Input node IDs
18    pub inputs: Vec<usize>,
19    /// Output tensor
20    pub output: Tensor,
21    /// Gradient function
22    pub backward_fn: Option<Arc<dyn Fn(&[Tensor]) -> Vec<Tensor> + Send + Sync>>,
23}
24
25impl std::fmt::Debug for GraphNode {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("GraphNode")
28            .field("id", &self.id)
29            .field("op", &self.op)
30            .field("inputs", &self.inputs)
31            .field("output", &self.output)
32            .field("backward_fn", &self.backward_fn.is_some())
33            .finish()
34    }
35}
36
37/// Dynamic computation graph
38#[derive(Debug)]
39pub struct DynamicGraph {
40    /// All nodes in the graph
41    nodes: Arc<Mutex<HashMap<usize, GraphNode>>>,
42    /// Next node ID
43    next_id: Arc<Mutex<usize>>,
44    /// Whether to record operations
45    recording: Arc<Mutex<bool>>,
46}
47
48impl DynamicGraph {
49    /// Create a new dynamic graph
50    pub fn new() -> Self {
51        DynamicGraph {
52            nodes: Arc::new(Mutex::new(HashMap::new())),
53            next_id: Arc::new(Mutex::new(0)),
54            recording: Arc::new(Mutex::new(true)),
55        }
56    }
57    
58    /// Start recording operations
59    pub fn start_recording(&self) {
60        *self.recording.lock().unwrap() = true;
61    }
62    
63    /// Stop recording operations
64    pub fn stop_recording(&self) {
65        *self.recording.lock().unwrap() = false;
66    }
67    
68    /// Check if recording
69    pub fn is_recording(&self) -> bool {
70        *self.recording.lock().unwrap()
71    }
72    
73    /// Add a node to the graph
74    pub fn add_node(&self, op: String, inputs: Vec<usize>, output: Tensor) -> usize {
75        if !self.is_recording() {
76            return 0;
77        }
78        
79        let mut next_id = self.next_id.lock().unwrap();
80        let id = *next_id;
81        *next_id += 1;
82        
83        let node = GraphNode {
84            id,
85            op,
86            inputs,
87            output,
88            backward_fn: None,
89        };
90        
91        self.nodes.lock().unwrap().insert(id, node);
92        id
93    }
94    
95    /// Get a node by ID
96    pub fn get_node(&self, id: usize) -> Option<GraphNode> {
97        self.nodes.lock().unwrap().get(&id).cloned()
98    }
99    
100    /// Clear the graph
101    pub fn clear(&self) {
102        self.nodes.lock().unwrap().clear();
103        *self.next_id.lock().unwrap() = 0;
104    }
105    
106    /// Get number of nodes
107    pub fn num_nodes(&self) -> usize {
108        self.nodes.lock().unwrap().len()
109    }
110    
111    /// Perform backward pass from a node
112    pub fn backward(&self, node_id: usize, grad: Tensor) -> HashMap<usize, Tensor> {
113        let mut gradients: HashMap<usize, Tensor> = HashMap::new();
114        gradients.insert(node_id, grad);
115        
116        // Topological sort (simplified - assumes DAG)
117        let nodes = self.nodes.lock().unwrap();
118        let mut sorted_ids: Vec<usize> = nodes.keys().cloned().collect();
119        sorted_ids.sort_by(|a, b| b.cmp(a)); // Reverse order
120        
121        for &id in &sorted_ids {
122            if let Some(grad) = gradients.get(&id) {
123                if let Some(node) = nodes.get(&id) {
124                    // Compute gradients for inputs
125                    if let Some(ref backward_fn) = node.backward_fn {
126                        let input_grads = backward_fn(&[grad.clone()]);
127                        
128                        for (i, input_id) in node.inputs.iter().enumerate() {
129                            if i < input_grads.len() {
130                                let input_grad = &input_grads[i];
131                                gradients.entry(*input_id)
132                                    .and_modify(|g| *g = g.add(input_grad).unwrap())
133                                    .or_insert_with(|| input_grad.clone());
134                            }
135                        }
136                    }
137                }
138            }
139        }
140        
141        gradients
142    }
143}
144
145impl Default for DynamicGraph {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151/// Context for dynamic graph operations
152pub struct DynamicContext {
153    graph: Arc<DynamicGraph>,
154}
155
156impl DynamicContext {
157    /// Create a new context
158    pub fn new() -> Self {
159        DynamicContext {
160            graph: Arc::new(DynamicGraph::new()),
161        }
162    }
163    
164    /// Get the graph
165    pub fn graph(&self) -> &Arc<DynamicGraph> {
166        &self.graph
167    }
168    
169    /// Execute a function with gradient tracking
170    pub fn with_grad<F, R>(&self, f: F) -> R
171    where
172        F: FnOnce() -> R,
173    {
174        self.graph.start_recording();
175        let result = f();
176        result
177    }
178    
179    /// Execute a function without gradient tracking
180    pub fn no_grad<F, R>(&self, f: F) -> R
181    where
182        F: FnOnce() -> R,
183    {
184        self.graph.stop_recording();
185        let result = f();
186        self.graph.start_recording();
187        result
188    }
189}
190
191impl Default for DynamicContext {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197/// Wrapper for tensors in dynamic graph
198#[derive(Debug, Clone)]
199pub struct DynamicTensor {
200    /// The actual tensor
201    pub tensor: Tensor,
202    /// Node ID in the graph
203    pub node_id: Option<usize>,
204    /// Reference to the graph
205    pub graph: Option<Arc<DynamicGraph>>,
206}
207
208impl DynamicTensor {
209    /// Create a new dynamic tensor
210    pub fn new(tensor: Tensor, graph: Arc<DynamicGraph>) -> Self {
211        let node_id = graph.add_node("input".to_string(), vec![], tensor.clone());
212        
213        DynamicTensor {
214            tensor,
215            node_id: Some(node_id),
216            graph: Some(graph),
217        }
218    }
219    
220    /// Create from tensor without graph
221    pub fn from_tensor(tensor: Tensor) -> Self {
222        DynamicTensor {
223            tensor,
224            node_id: None,
225            graph: None,
226        }
227    }
228    
229    /// Add two dynamic tensors
230    pub fn add(&self, other: &DynamicTensor) -> DynamicTensor {
231        let result = self.tensor.add(&other.tensor).unwrap();
232        
233        if let (Some(graph), Some(id1), Some(id2)) = (&self.graph, self.node_id, other.node_id) {
234            let node_id = graph.add_node("add".to_string(), vec![id1, id2], result.clone());
235            
236            DynamicTensor {
237                tensor: result,
238                node_id: Some(node_id),
239                graph: Some(graph.clone()),
240            }
241        } else {
242            DynamicTensor::from_tensor(result)
243        }
244    }
245    
246    /// Multiply two dynamic tensors
247    pub fn mul(&self, other: &DynamicTensor) -> DynamicTensor {
248        let result = self.tensor.mul(&other.tensor).unwrap();
249        
250        if let (Some(graph), Some(id1), Some(id2)) = (&self.graph, self.node_id, other.node_id) {
251            let node_id = graph.add_node("mul".to_string(), vec![id1, id2], result.clone());
252            
253            DynamicTensor {
254                tensor: result,
255                node_id: Some(node_id),
256                graph: Some(graph.clone()),
257            }
258        } else {
259            DynamicTensor::from_tensor(result)
260        }
261    }
262    
263    /// Matrix multiplication
264    pub fn matmul(&self, other: &DynamicTensor) -> DynamicTensor {
265        let result = self.tensor.matmul(&other.tensor).unwrap();
266        
267        if let (Some(graph), Some(id1), Some(id2)) = (&self.graph, self.node_id, other.node_id) {
268            let node_id = graph.add_node("matmul".to_string(), vec![id1, id2], result.clone());
269            
270            DynamicTensor {
271                tensor: result,
272                node_id: Some(node_id),
273                graph: Some(graph.clone()),
274            }
275        } else {
276            DynamicTensor::from_tensor(result)
277        }
278    }
279    
280    /// ReLU activation
281    pub fn relu(&self) -> DynamicTensor {
282        let result = self.tensor.relu();
283        
284        if let (Some(graph), Some(id)) = (&self.graph, self.node_id) {
285            let node_id = graph.add_node("relu".to_string(), vec![id], result.clone());
286            
287            DynamicTensor {
288                tensor: result,
289                node_id: Some(node_id),
290                graph: Some(graph.clone()),
291            }
292        } else {
293            DynamicTensor::from_tensor(result)
294        }
295    }
296    
297    /// Compute gradients
298    pub fn backward(&self) -> HashMap<usize, Tensor> {
299        if let (Some(graph), Some(node_id)) = (&self.graph, self.node_id) {
300            let grad = Tensor::ones(self.tensor.dims());
301            graph.backward(node_id, grad)
302        } else {
303            HashMap::new()
304        }
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    
312    #[test]
313    fn test_dynamic_graph() {
314        let graph = DynamicGraph::new();
315        assert_eq!(graph.num_nodes(), 0);
316        
317        let t1 = Tensor::ones(&[2, 2]);
318        let id = graph.add_node("test".to_string(), vec![], t1);
319        
320        assert_eq!(graph.num_nodes(), 1);
321        assert!(graph.get_node(id).is_some());
322    }
323    
324    #[test]
325    fn test_dynamic_context() {
326        let ctx = DynamicContext::new();
327        
328        ctx.with_grad(|| {
329            assert!(ctx.graph().is_recording());
330        });
331        
332        ctx.no_grad(|| {
333            assert!(!ctx.graph().is_recording());
334        });
335    }
336    
337    #[test]
338    fn test_dynamic_tensor() {
339        let graph = Arc::new(DynamicGraph::new());
340        let t1 = Tensor::ones(&[2, 2]);
341        let t2 = Tensor::ones(&[2, 2]);
342        
343        let dt1 = DynamicTensor::new(t1, graph.clone());
344        let dt2 = DynamicTensor::new(t2, graph.clone());
345        
346        let result = dt1.add(&dt2);
347        assert_eq!(result.tensor.data_f32()[0], 2.0);
348        assert_eq!(graph.num_nodes(), 3); // 2 inputs + 1 add
349    }
350}