Skip to main content

do_memory_mcp/batch/
dependency_graph.rs

1//! Dependency graph for batch operations
2
3use super::types::{BatchOperation, BatchRequest};
4use std::collections::{HashMap, HashSet};
5
6/// Dependency graph for batch operations
7#[derive(Debug)]
8pub struct DependencyGraph {
9    /// Operations indexed by ID
10    operations: HashMap<String, BatchOperation>,
11    /// Operation IDs in insertion order
12    operation_order: Vec<String>,
13    /// Adjacency list (operation -> dependencies)
14    dependencies: HashMap<String, HashSet<String>>,
15    /// Reverse adjacency list (operation -> dependents)
16    dependents: HashMap<String, HashSet<String>>,
17}
18
19impl DependencyGraph {
20    /// Create a new dependency graph from operations
21    pub fn new(operations: Vec<BatchOperation>) -> Result<Self, String> {
22        let mut graph = Self {
23            operations: HashMap::new(),
24            operation_order: Vec::new(),
25            dependencies: HashMap::new(),
26            dependents: HashMap::new(),
27        };
28
29        // Build operation index and preserve order
30        for op in operations {
31            if graph.operations.contains_key(&op.id) {
32                return Err(format!("Duplicate operation ID: {}", op.id));
33            }
34            graph.operation_order.push(op.id.clone());
35            graph.operations.insert(op.id.clone(), op);
36        }
37
38        // Build dependency and dependent relationships
39        for (id, op) in &graph.operations {
40            for dep in &op.depends_on {
41                // Validate dependency exists
42                if !graph.operations.contains_key(dep) {
43                    return Err(format!(
44                        "Operation '{}' depends on unknown operation '{}'",
45                        id, dep
46                    ));
47                }
48
49                // Add to dependencies
50                graph
51                    .dependencies
52                    .entry(id.clone())
53                    .or_default()
54                    .insert(dep.clone());
55
56                // Add to dependents (reverse)
57                graph
58                    .dependents
59                    .entry(dep.clone())
60                    .or_default()
61                    .insert(id.clone());
62            }
63        }
64
65        // Validate no cycles
66        graph.validate_acyclic()?;
67
68        Ok(graph)
69    }
70
71    /// Validate that the graph is acyclic (no circular dependencies)
72    fn validate_acyclic(&self) -> Result<(), String> {
73        let mut visited = HashSet::new();
74        let mut stack = HashSet::new();
75
76        for id in self.operations.keys() {
77            if !visited.contains(id) {
78                self.detect_cycle(id, &mut visited, &mut stack)?;
79            }
80        }
81
82        Ok(())
83    }
84
85    /// Detect cycles using DFS
86    fn detect_cycle(
87        &self,
88        node: &str,
89        visited: &mut HashSet<String>,
90        stack: &mut HashSet<String>,
91    ) -> Result<(), String> {
92        visited.insert(node.to_string());
93        stack.insert(node.to_string());
94
95        if let Some(deps) = self.dependencies.get(node) {
96            for dep in deps {
97                if !visited.contains(dep) {
98                    self.detect_cycle(dep, visited, stack)?;
99                } else if stack.contains(dep) {
100                    return Err(format!("Circular dependency detected: {} -> {}", node, dep));
101                }
102            }
103        }
104
105        stack.remove(node);
106        Ok(())
107    }
108
109    /// Get operations that have no pending dependencies (ready to execute)
110    pub fn get_ready_operations(&self, completed: &HashSet<String>) -> Vec<BatchOperation> {
111        self.operations
112            .values()
113            .filter(|op| {
114                // Check if all dependencies are completed
115                self.dependencies
116                    .get(&op.id)
117                    .map(|deps| deps.iter().all(|dep| completed.contains(dep)))
118                    .unwrap_or(true) // No dependencies means ready
119                    && !completed.contains(&op.id) // Not already completed
120            })
121            .cloned()
122            .collect()
123    }
124
125    /// Get total number of operations
126    pub fn len(&self) -> usize {
127        self.operations.len()
128    }
129
130    /// Check if graph is empty
131    pub fn is_empty(&self) -> bool {
132        self.operations.is_empty()
133    }
134
135    /// Get operations in insertion order
136    pub fn operations_in_order(&self) -> Vec<BatchOperation> {
137        self.operation_order
138            .iter()
139            .filter_map(|id| self.operations.get(id).cloned())
140            .collect()
141    }
142}
143
144impl From<BatchRequest> for DependencyGraph {
145    fn from(request: BatchRequest) -> Self {
146        DependencyGraph::new(request.operations).unwrap_or_else(|_e| {
147            // In case of error, create an empty graph
148            DependencyGraph {
149                operations: HashMap::new(),
150                operation_order: Vec::new(),
151                dependencies: HashMap::new(),
152                dependents: HashMap::new(),
153            }
154        })
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use serde_json::Value;
162
163    #[test]
164    fn test_dependency_graph_simple() {
165        let ops = vec![
166            BatchOperation {
167                id: "op1".to_string(),
168                tool: "tool1".to_string(),
169                arguments: Value::Null,
170                depends_on: vec![],
171            },
172            BatchOperation {
173                id: "op2".to_string(),
174                tool: "tool2".to_string(),
175                arguments: Value::Null,
176                depends_on: vec![],
177            },
178        ];
179
180        let graph = DependencyGraph::new(ops).unwrap();
181        assert_eq!(graph.len(), 2);
182        assert!(!graph.is_empty());
183    }
184
185    #[test]
186    fn test_dependency_graph_with_dependencies() {
187        let ops = vec![
188            BatchOperation {
189                id: "op1".to_string(),
190                tool: "tool1".to_string(),
191                arguments: Value::Null,
192                depends_on: vec![],
193            },
194            BatchOperation {
195                id: "op2".to_string(),
196                tool: "tool2".to_string(),
197                arguments: Value::Null,
198                depends_on: vec!["op1".to_string()],
199            },
200        ];
201
202        let graph = DependencyGraph::new(ops).unwrap();
203        assert_eq!(graph.len(), 2);
204    }
205
206    #[test]
207    fn test_dependency_graph_cycle() {
208        let ops = vec![
209            BatchOperation {
210                id: "op1".to_string(),
211                tool: "tool1".to_string(),
212                arguments: Value::Null,
213                depends_on: vec!["op2".to_string()],
214            },
215            BatchOperation {
216                id: "op2".to_string(),
217                tool: "tool2".to_string(),
218                arguments: Value::Null,
219                depends_on: vec!["op1".to_string()],
220            },
221        ];
222
223        let result = DependencyGraph::new(ops);
224        assert!(result.is_err());
225        assert!(result.unwrap_err().contains("Circular dependency"));
226    }
227
228    #[test]
229    fn test_dependency_graph_unknown_dependency() {
230        let ops = vec![BatchOperation {
231            id: "op1".to_string(),
232            tool: "tool1".to_string(),
233            arguments: Value::Null,
234            depends_on: vec!["unknown".to_string()],
235        }];
236
237        let result = DependencyGraph::new(ops);
238        assert!(result.is_err());
239        assert!(result.unwrap_err().contains("unknown operation"));
240    }
241
242    #[test]
243    fn test_dependency_graph_duplicate_id() {
244        let ops = vec![
245            BatchOperation {
246                id: "op1".to_string(),
247                tool: "tool1".to_string(),
248                arguments: Value::Null,
249                depends_on: vec![],
250            },
251            BatchOperation {
252                id: "op1".to_string(),
253                tool: "tool2".to_string(),
254                arguments: Value::Null,
255                depends_on: vec![],
256            },
257        ];
258
259        let result = DependencyGraph::new(ops);
260        assert!(result.is_err());
261        assert!(result.unwrap_err().contains("Duplicate operation ID"));
262    }
263
264    #[test]
265    fn test_get_ready_operations() {
266        let ops = vec![
267            BatchOperation {
268                id: "op1".to_string(),
269                tool: "tool1".to_string(),
270                arguments: Value::Null,
271                depends_on: vec![],
272            },
273            BatchOperation {
274                id: "op2".to_string(),
275                tool: "tool2".to_string(),
276                arguments: Value::Null,
277                depends_on: vec!["op1".to_string()],
278            },
279        ];
280
281        let graph = DependencyGraph::new(ops).unwrap();
282
283        // Initially, only op1 should be ready
284        let completed = HashSet::new();
285        let ready = graph.get_ready_operations(&completed);
286        assert_eq!(ready.len(), 1);
287        assert_eq!(ready[0].id, "op1");
288
289        // After op1 completes, op2 should be ready
290        let mut completed = HashSet::new();
291        completed.insert("op1".to_string());
292        let ready = graph.get_ready_operations(&completed);
293        assert_eq!(ready.len(), 1);
294        assert_eq!(ready[0].id, "op2");
295    }
296}