1use ghostflow_core::Tensor;
7use std::sync::{Arc, Mutex};
8use std::collections::HashMap;
9
10#[derive(Clone)]
12pub struct GraphNode {
13 pub id: usize,
15 pub op: String,
17 pub inputs: Vec<usize>,
19 pub output: Tensor,
21 pub backward_fn: Option<Arc<dyn Fn(&[Tensor]) -> Vec<Tensor> + Send + Sync>>,
23}
24
25impl std::fmt::Debug for GraphNode {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("GraphNode")
28 .field("id", &self.id)
29 .field("op", &self.op)
30 .field("inputs", &self.inputs)
31 .field("output", &self.output)
32 .field("backward_fn", &self.backward_fn.is_some())
33 .finish()
34 }
35}
36
37#[derive(Debug)]
39pub struct DynamicGraph {
40 nodes: Arc<Mutex<HashMap<usize, GraphNode>>>,
42 next_id: Arc<Mutex<usize>>,
44 recording: Arc<Mutex<bool>>,
46}
47
48impl DynamicGraph {
49 pub fn new() -> Self {
51 DynamicGraph {
52 nodes: Arc::new(Mutex::new(HashMap::new())),
53 next_id: Arc::new(Mutex::new(0)),
54 recording: Arc::new(Mutex::new(true)),
55 }
56 }
57
58 pub fn start_recording(&self) {
60 *self.recording.lock().unwrap() = true;
61 }
62
63 pub fn stop_recording(&self) {
65 *self.recording.lock().unwrap() = false;
66 }
67
68 pub fn is_recording(&self) -> bool {
70 *self.recording.lock().unwrap()
71 }
72
73 pub fn add_node(&self, op: String, inputs: Vec<usize>, output: Tensor) -> usize {
75 if !self.is_recording() {
76 return 0;
77 }
78
79 let mut next_id = self.next_id.lock().unwrap();
80 let id = *next_id;
81 *next_id += 1;
82
83 let node = GraphNode {
84 id,
85 op,
86 inputs,
87 output,
88 backward_fn: None,
89 };
90
91 self.nodes.lock().unwrap().insert(id, node);
92 id
93 }
94
95 pub fn get_node(&self, id: usize) -> Option<GraphNode> {
97 self.nodes.lock().unwrap().get(&id).cloned()
98 }
99
100 pub fn clear(&self) {
102 self.nodes.lock().unwrap().clear();
103 *self.next_id.lock().unwrap() = 0;
104 }
105
106 pub fn num_nodes(&self) -> usize {
108 self.nodes.lock().unwrap().len()
109 }
110
111 pub fn backward(&self, node_id: usize, grad: Tensor) -> HashMap<usize, Tensor> {
113 let mut gradients: HashMap<usize, Tensor> = HashMap::new();
114 gradients.insert(node_id, grad);
115
116 let nodes = self.nodes.lock().unwrap();
118 let mut sorted_ids: Vec<usize> = nodes.keys().cloned().collect();
119 sorted_ids.sort_by(|a, b| b.cmp(a)); for &id in &sorted_ids {
122 if let Some(grad) = gradients.get(&id) {
123 if let Some(node) = nodes.get(&id) {
124 if let Some(ref backward_fn) = node.backward_fn {
126 let input_grads = backward_fn(&[grad.clone()]);
127
128 for (i, input_id) in node.inputs.iter().enumerate() {
129 if i < input_grads.len() {
130 let input_grad = &input_grads[i];
131 gradients.entry(*input_id)
132 .and_modify(|g| *g = g.add(input_grad).unwrap())
133 .or_insert_with(|| input_grad.clone());
134 }
135 }
136 }
137 }
138 }
139 }
140
141 gradients
142 }
143}
144
145impl Default for DynamicGraph {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151pub struct DynamicContext {
153 graph: Arc<DynamicGraph>,
154}
155
156impl DynamicContext {
157 pub fn new() -> Self {
159 DynamicContext {
160 graph: Arc::new(DynamicGraph::new()),
161 }
162 }
163
164 pub fn graph(&self) -> &Arc<DynamicGraph> {
166 &self.graph
167 }
168
169 pub fn with_grad<F, R>(&self, f: F) -> R
171 where
172 F: FnOnce() -> R,
173 {
174 self.graph.start_recording();
175 let result = f();
176 result
177 }
178
179 pub fn no_grad<F, R>(&self, f: F) -> R
181 where
182 F: FnOnce() -> R,
183 {
184 self.graph.stop_recording();
185 let result = f();
186 self.graph.start_recording();
187 result
188 }
189}
190
191impl Default for DynamicContext {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct DynamicTensor {
200 pub tensor: Tensor,
202 pub node_id: Option<usize>,
204 pub graph: Option<Arc<DynamicGraph>>,
206}
207
208impl DynamicTensor {
209 pub fn new(tensor: Tensor, graph: Arc<DynamicGraph>) -> Self {
211 let node_id = graph.add_node("input".to_string(), vec![], tensor.clone());
212
213 DynamicTensor {
214 tensor,
215 node_id: Some(node_id),
216 graph: Some(graph),
217 }
218 }
219
220 pub fn from_tensor(tensor: Tensor) -> Self {
222 DynamicTensor {
223 tensor,
224 node_id: None,
225 graph: None,
226 }
227 }
228
229 pub fn add(&self, other: &DynamicTensor) -> DynamicTensor {
231 let result = self.tensor.add(&other.tensor).unwrap();
232
233 if let (Some(graph), Some(id1), Some(id2)) = (&self.graph, self.node_id, other.node_id) {
234 let node_id = graph.add_node("add".to_string(), vec![id1, id2], result.clone());
235
236 DynamicTensor {
237 tensor: result,
238 node_id: Some(node_id),
239 graph: Some(graph.clone()),
240 }
241 } else {
242 DynamicTensor::from_tensor(result)
243 }
244 }
245
246 pub fn mul(&self, other: &DynamicTensor) -> DynamicTensor {
248 let result = self.tensor.mul(&other.tensor).unwrap();
249
250 if let (Some(graph), Some(id1), Some(id2)) = (&self.graph, self.node_id, other.node_id) {
251 let node_id = graph.add_node("mul".to_string(), vec![id1, id2], result.clone());
252
253 DynamicTensor {
254 tensor: result,
255 node_id: Some(node_id),
256 graph: Some(graph.clone()),
257 }
258 } else {
259 DynamicTensor::from_tensor(result)
260 }
261 }
262
263 pub fn matmul(&self, other: &DynamicTensor) -> DynamicTensor {
265 let result = self.tensor.matmul(&other.tensor).unwrap();
266
267 if let (Some(graph), Some(id1), Some(id2)) = (&self.graph, self.node_id, other.node_id) {
268 let node_id = graph.add_node("matmul".to_string(), vec![id1, id2], result.clone());
269
270 DynamicTensor {
271 tensor: result,
272 node_id: Some(node_id),
273 graph: Some(graph.clone()),
274 }
275 } else {
276 DynamicTensor::from_tensor(result)
277 }
278 }
279
280 pub fn relu(&self) -> DynamicTensor {
282 let result = self.tensor.relu();
283
284 if let (Some(graph), Some(id)) = (&self.graph, self.node_id) {
285 let node_id = graph.add_node("relu".to_string(), vec![id], result.clone());
286
287 DynamicTensor {
288 tensor: result,
289 node_id: Some(node_id),
290 graph: Some(graph.clone()),
291 }
292 } else {
293 DynamicTensor::from_tensor(result)
294 }
295 }
296
297 pub fn backward(&self) -> HashMap<usize, Tensor> {
299 if let (Some(graph), Some(node_id)) = (&self.graph, self.node_id) {
300 let grad = Tensor::ones(self.tensor.dims());
301 graph.backward(node_id, grad)
302 } else {
303 HashMap::new()
304 }
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_dynamic_graph() {
314 let graph = DynamicGraph::new();
315 assert_eq!(graph.num_nodes(), 0);
316
317 let t1 = Tensor::ones(&[2, 2]);
318 let id = graph.add_node("test".to_string(), vec![], t1);
319
320 assert_eq!(graph.num_nodes(), 1);
321 assert!(graph.get_node(id).is_some());
322 }
323
324 #[test]
325 fn test_dynamic_context() {
326 let ctx = DynamicContext::new();
327
328 ctx.with_grad(|| {
329 assert!(ctx.graph().is_recording());
330 });
331
332 ctx.no_grad(|| {
333 assert!(!ctx.graph().is_recording());
334 });
335 }
336
337 #[test]
338 fn test_dynamic_tensor() {
339 let graph = Arc::new(DynamicGraph::new());
340 let t1 = Tensor::ones(&[2, 2]);
341 let t2 = Tensor::ones(&[2, 2]);
342
343 let dt1 = DynamicTensor::new(t1, graph.clone());
344 let dt2 = DynamicTensor::new(t2, graph.clone());
345
346 let result = dt1.add(&dt2);
347 assert_eq!(result.tensor.data_f32()[0], 2.0);
348 assert_eq!(graph.num_nodes(), 3); }
350}