1use super::types::{BatchOperation, BatchRequest};
4use std::collections::{HashMap, HashSet};
5
6#[derive(Debug)]
8pub struct DependencyGraph {
9 operations: HashMap<String, BatchOperation>,
11 operation_order: Vec<String>,
13 dependencies: HashMap<String, HashSet<String>>,
15 dependents: HashMap<String, HashSet<String>>,
17}
18
19impl DependencyGraph {
20 pub fn new(operations: Vec<BatchOperation>) -> Result<Self, String> {
22 let mut graph = Self {
23 operations: HashMap::new(),
24 operation_order: Vec::new(),
25 dependencies: HashMap::new(),
26 dependents: HashMap::new(),
27 };
28
29 for op in operations {
31 if graph.operations.contains_key(&op.id) {
32 return Err(format!("Duplicate operation ID: {}", op.id));
33 }
34 graph.operation_order.push(op.id.clone());
35 graph.operations.insert(op.id.clone(), op);
36 }
37
38 for (id, op) in &graph.operations {
40 for dep in &op.depends_on {
41 if !graph.operations.contains_key(dep) {
43 return Err(format!(
44 "Operation '{}' depends on unknown operation '{}'",
45 id, dep
46 ));
47 }
48
49 graph
51 .dependencies
52 .entry(id.clone())
53 .or_default()
54 .insert(dep.clone());
55
56 graph
58 .dependents
59 .entry(dep.clone())
60 .or_default()
61 .insert(id.clone());
62 }
63 }
64
65 graph.validate_acyclic()?;
67
68 Ok(graph)
69 }
70
71 fn validate_acyclic(&self) -> Result<(), String> {
73 let mut visited = HashSet::new();
74 let mut stack = HashSet::new();
75
76 for id in self.operations.keys() {
77 if !visited.contains(id) {
78 self.detect_cycle(id, &mut visited, &mut stack)?;
79 }
80 }
81
82 Ok(())
83 }
84
85 fn detect_cycle(
87 &self,
88 node: &str,
89 visited: &mut HashSet<String>,
90 stack: &mut HashSet<String>,
91 ) -> Result<(), String> {
92 visited.insert(node.to_string());
93 stack.insert(node.to_string());
94
95 if let Some(deps) = self.dependencies.get(node) {
96 for dep in deps {
97 if !visited.contains(dep) {
98 self.detect_cycle(dep, visited, stack)?;
99 } else if stack.contains(dep) {
100 return Err(format!("Circular dependency detected: {} -> {}", node, dep));
101 }
102 }
103 }
104
105 stack.remove(node);
106 Ok(())
107 }
108
109 pub fn get_ready_operations(&self, completed: &HashSet<String>) -> Vec<BatchOperation> {
111 self.operations
112 .values()
113 .filter(|op| {
114 self.dependencies
116 .get(&op.id)
117 .map(|deps| deps.iter().all(|dep| completed.contains(dep)))
118 .unwrap_or(true) && !completed.contains(&op.id) })
121 .cloned()
122 .collect()
123 }
124
125 pub fn len(&self) -> usize {
127 self.operations.len()
128 }
129
130 pub fn is_empty(&self) -> bool {
132 self.operations.is_empty()
133 }
134
135 pub fn operations_in_order(&self) -> Vec<BatchOperation> {
137 self.operation_order
138 .iter()
139 .filter_map(|id| self.operations.get(id).cloned())
140 .collect()
141 }
142}
143
144impl From<BatchRequest> for DependencyGraph {
145 fn from(request: BatchRequest) -> Self {
146 DependencyGraph::new(request.operations).unwrap_or_else(|_e| {
147 DependencyGraph {
149 operations: HashMap::new(),
150 operation_order: Vec::new(),
151 dependencies: HashMap::new(),
152 dependents: HashMap::new(),
153 }
154 })
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use serde_json::Value;
162
163 #[test]
164 fn test_dependency_graph_simple() {
165 let ops = vec![
166 BatchOperation {
167 id: "op1".to_string(),
168 tool: "tool1".to_string(),
169 arguments: Value::Null,
170 depends_on: vec![],
171 },
172 BatchOperation {
173 id: "op2".to_string(),
174 tool: "tool2".to_string(),
175 arguments: Value::Null,
176 depends_on: vec![],
177 },
178 ];
179
180 let graph = DependencyGraph::new(ops).unwrap();
181 assert_eq!(graph.len(), 2);
182 assert!(!graph.is_empty());
183 }
184
185 #[test]
186 fn test_dependency_graph_with_dependencies() {
187 let ops = vec![
188 BatchOperation {
189 id: "op1".to_string(),
190 tool: "tool1".to_string(),
191 arguments: Value::Null,
192 depends_on: vec![],
193 },
194 BatchOperation {
195 id: "op2".to_string(),
196 tool: "tool2".to_string(),
197 arguments: Value::Null,
198 depends_on: vec!["op1".to_string()],
199 },
200 ];
201
202 let graph = DependencyGraph::new(ops).unwrap();
203 assert_eq!(graph.len(), 2);
204 }
205
206 #[test]
207 fn test_dependency_graph_cycle() {
208 let ops = vec![
209 BatchOperation {
210 id: "op1".to_string(),
211 tool: "tool1".to_string(),
212 arguments: Value::Null,
213 depends_on: vec!["op2".to_string()],
214 },
215 BatchOperation {
216 id: "op2".to_string(),
217 tool: "tool2".to_string(),
218 arguments: Value::Null,
219 depends_on: vec!["op1".to_string()],
220 },
221 ];
222
223 let result = DependencyGraph::new(ops);
224 assert!(result.is_err());
225 assert!(result.unwrap_err().contains("Circular dependency"));
226 }
227
228 #[test]
229 fn test_dependency_graph_unknown_dependency() {
230 let ops = vec![BatchOperation {
231 id: "op1".to_string(),
232 tool: "tool1".to_string(),
233 arguments: Value::Null,
234 depends_on: vec!["unknown".to_string()],
235 }];
236
237 let result = DependencyGraph::new(ops);
238 assert!(result.is_err());
239 assert!(result.unwrap_err().contains("unknown operation"));
240 }
241
242 #[test]
243 fn test_dependency_graph_duplicate_id() {
244 let ops = vec![
245 BatchOperation {
246 id: "op1".to_string(),
247 tool: "tool1".to_string(),
248 arguments: Value::Null,
249 depends_on: vec![],
250 },
251 BatchOperation {
252 id: "op1".to_string(),
253 tool: "tool2".to_string(),
254 arguments: Value::Null,
255 depends_on: vec![],
256 },
257 ];
258
259 let result = DependencyGraph::new(ops);
260 assert!(result.is_err());
261 assert!(result.unwrap_err().contains("Duplicate operation ID"));
262 }
263
264 #[test]
265 fn test_get_ready_operations() {
266 let ops = vec![
267 BatchOperation {
268 id: "op1".to_string(),
269 tool: "tool1".to_string(),
270 arguments: Value::Null,
271 depends_on: vec![],
272 },
273 BatchOperation {
274 id: "op2".to_string(),
275 tool: "tool2".to_string(),
276 arguments: Value::Null,
277 depends_on: vec!["op1".to_string()],
278 },
279 ];
280
281 let graph = DependencyGraph::new(ops).unwrap();
282
283 let completed = HashSet::new();
285 let ready = graph.get_ready_operations(&completed);
286 assert_eq!(ready.len(), 1);
287 assert_eq!(ready[0].id, "op1");
288
289 let mut completed = HashSet::new();
291 completed.insert("op1".to_string());
292 let ready = graph.get_ready_operations(&completed);
293 assert_eq!(ready.len(), 1);
294 assert_eq!(ready[0].id, "op2");
295 }
296}