1use crate::error::{WorkflowError, WorkflowResult};
4use crate::models::Workflow;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7pub struct DependencyResolver;
14
15impl DependencyResolver {
16 pub fn resolve_execution_order(workflow: &Workflow) -> WorkflowResult<Vec<String>> {
21 Self::topological_sort(workflow)
22 }
23
24 fn topological_sort(workflow: &Workflow) -> WorkflowResult<Vec<String>> {
29 let mut order = Vec::new();
30 let mut completed = HashSet::new();
31 let mut queue = VecDeque::new();
32
33 for step in &workflow.steps {
35 if step.dependencies.is_empty() {
36 queue.push_back(step.id.clone());
37 }
38 }
39
40 let step_map: HashMap<_, _> = workflow.steps.iter().map(|s| (&s.id, s)).collect();
42
43 while let Some(step_id) = queue.pop_front() {
45 if completed.contains(&step_id) {
46 continue;
47 }
48
49 if let Some(step) = step_map.get(&step_id) {
51 let all_deps_completed =
52 step.dependencies.iter().all(|dep| completed.contains(dep));
53
54 if all_deps_completed {
55 order.push(step_id.clone());
56 completed.insert(step_id.clone());
57
58 for other_step in &workflow.steps {
60 if other_step.dependencies.contains(&step_id)
61 && !completed.contains(&other_step.id)
62 {
63 queue.push_back(other_step.id.clone());
64 }
65 }
66 } else {
67 queue.push_back(step_id);
69 }
70 }
71 }
72
73 if order.len() != workflow.steps.len() {
74 return Err(WorkflowError::Invalid(
75 "Could not determine execution order for all steps".to_string(),
76 ));
77 }
78
79 Ok(order)
80 }
81
82 pub fn detect_circular_dependencies(workflow: &Workflow) -> WorkflowResult<()> {
87 let step_map: HashMap<&String, &crate::models::WorkflowStep> =
88 workflow.steps.iter().map(|s| (&s.id, s)).collect();
89
90 for start_step in &workflow.steps {
92 let mut visited = HashSet::new();
93 let mut rec_stack = HashSet::new();
94
95 Self::dfs_detect_cycle(&step_map, &start_step.id, &mut visited, &mut rec_stack)?;
96 }
97
98 Ok(())
99 }
100
101 fn dfs_detect_cycle(
103 step_map: &HashMap<&String, &crate::models::WorkflowStep>,
104 step_id: &String,
105 visited: &mut HashSet<String>,
106 rec_stack: &mut HashSet<String>,
107 ) -> WorkflowResult<()> {
108 visited.insert(step_id.clone());
109 rec_stack.insert(step_id.clone());
110
111 if let Some(step) = step_map.get(step_id) {
112 for dep in &step.dependencies {
113 if !visited.contains(dep) {
114 Self::dfs_detect_cycle(step_map, dep, visited, rec_stack)?;
115 } else if rec_stack.contains(dep) {
116 return Err(WorkflowError::Invalid(format!(
117 "Circular dependency detected: {} -> {}",
118 step_id, dep
119 )));
120 }
121 }
122 }
123
124 rec_stack.remove(step_id);
125 Ok(())
126 }
127
128 pub fn get_all_dependencies(
133 workflow: &Workflow,
134 step_id: &str,
135 ) -> WorkflowResult<HashSet<String>> {
136 let mut all_deps = HashSet::new();
137 let mut queue = VecDeque::new();
138
139 let step = workflow
141 .steps
142 .iter()
143 .find(|s| s.id == step_id)
144 .ok_or_else(|| WorkflowError::NotFound(format!("Step not found: {}", step_id)))?;
145
146 for dep in &step.dependencies {
148 queue.push_back(dep.clone());
149 }
150
151 let step_map: HashMap<_, _> = workflow.steps.iter().map(|s| (&s.id, s)).collect();
153
154 while let Some(dep_id) = queue.pop_front() {
156 if all_deps.contains(&dep_id) {
157 continue;
158 }
159
160 all_deps.insert(dep_id.clone());
161
162 if let Some(dep_step) = step_map.get(&dep_id) {
164 for transitive_dep in &dep_step.dependencies {
165 if !all_deps.contains(transitive_dep) {
166 queue.push_back(transitive_dep.clone());
167 }
168 }
169 }
170 }
171
172 Ok(all_deps)
173 }
174
175 pub fn get_dependent_steps(
179 workflow: &Workflow,
180 step_id: &str,
181 ) -> WorkflowResult<HashSet<String>> {
182 let mut dependents = HashSet::new();
183
184 for step in &workflow.steps {
186 if step.dependencies.contains(&step_id.to_string()) {
187 dependents.insert(step.id.clone());
188
189 if let Ok(transitive) = Self::get_dependent_steps(workflow, &step.id) {
191 dependents.extend(transitive);
192 }
193 }
194 }
195
196 Ok(dependents)
197 }
198
199 pub fn can_execute_step(
203 workflow: &Workflow,
204 completed_steps: &[String],
205 step_id: &str,
206 ) -> WorkflowResult<bool> {
207 let step = workflow
208 .steps
209 .iter()
210 .find(|s| s.id == step_id)
211 .ok_or_else(|| WorkflowError::NotFound(format!("Step not found: {}", step_id)))?;
212
213 for dep in &step.dependencies {
215 if !completed_steps.contains(dep) {
216 return Ok(false);
217 }
218 }
219
220 Ok(true)
221 }
222
223 pub fn get_ready_steps(
227 workflow: &Workflow,
228 completed_steps: &[String],
229 in_progress_steps: &[String],
230 ) -> WorkflowResult<Vec<String>> {
231 let mut ready = Vec::new();
232
233 for step in &workflow.steps {
234 if completed_steps.contains(&step.id) || in_progress_steps.contains(&step.id) {
236 continue;
237 }
238
239 if Self::can_execute_step(workflow, completed_steps, &step.id)? {
241 ready.push(step.id.clone());
242 }
243 }
244
245 Ok(ready)
246 }
247
248 pub fn validate_dependencies(workflow: &Workflow) -> WorkflowResult<()> {
255 let mut step_ids = HashSet::new();
257 for step in &workflow.steps {
258 if !step_ids.insert(&step.id) {
259 return Err(WorkflowError::Invalid(format!(
260 "Duplicate step id: {}",
261 step.id
262 )));
263 }
264 }
265
266 for step in &workflow.steps {
268 for dep in &step.dependencies {
269 if !step_ids.contains(dep) {
270 return Err(WorkflowError::Invalid(format!(
271 "Step {} depends on non-existent step {}",
272 step.id, dep
273 )));
274 }
275 }
276 }
277
278 Self::detect_circular_dependencies(workflow)?;
280
281 Ok(())
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use crate::models::{
289 AgentStep, ErrorAction, RiskFactors, StepConfig, StepType, WorkflowConfig, WorkflowStep,
290 };
291
292 fn create_workflow_with_deps() -> Workflow {
293 Workflow {
294 id: "test-workflow".to_string(),
295 name: "Test Workflow".to_string(),
296 description: "A test workflow".to_string(),
297 parameters: vec![],
298 steps: vec![
299 WorkflowStep {
300 id: "step1".to_string(),
301 name: "Step 1".to_string(),
302 step_type: StepType::Agent(AgentStep {
303 agent_id: "test-agent".to_string(),
304 task: "test-task".to_string(),
305 }),
306 config: StepConfig {
307 config: serde_json::json!({}),
308 },
309 dependencies: vec![],
310 approval_required: false,
311 on_error: ErrorAction::Fail,
312 risk_score: None,
313 risk_factors: RiskFactors::default(),
314 },
315 WorkflowStep {
316 id: "step2".to_string(),
317 name: "Step 2".to_string(),
318 step_type: StepType::Agent(AgentStep {
319 agent_id: "test-agent".to_string(),
320 task: "test-task".to_string(),
321 }),
322 config: StepConfig {
323 config: serde_json::json!({}),
324 },
325 dependencies: vec!["step1".to_string()],
326 approval_required: false,
327 on_error: ErrorAction::Fail,
328 risk_score: None,
329 risk_factors: RiskFactors::default(),
330 },
331 WorkflowStep {
332 id: "step3".to_string(),
333 name: "Step 3".to_string(),
334 step_type: StepType::Agent(AgentStep {
335 agent_id: "test-agent".to_string(),
336 task: "test-task".to_string(),
337 }),
338 config: StepConfig {
339 config: serde_json::json!({}),
340 },
341 dependencies: vec!["step1".to_string(), "step2".to_string()],
342 approval_required: false,
343 on_error: ErrorAction::Fail,
344 risk_score: None,
345 risk_factors: RiskFactors::default(),
346 },
347 ],
348 config: WorkflowConfig {
349 timeout_ms: None,
350 max_parallel: None,
351 },
352 }
353 }
354
355 #[test]
356 fn test_resolve_execution_order() {
357 let workflow = create_workflow_with_deps();
358 let order = DependencyResolver::resolve_execution_order(&workflow).unwrap();
359
360 assert_eq!(order.len(), 3);
361 assert_eq!(order[0], "step1");
362 assert_eq!(order[1], "step2");
363 assert_eq!(order[2], "step3");
364 }
365
366 #[test]
367 fn test_detect_circular_dependency() {
368 let mut workflow = create_workflow_with_deps();
369 workflow.steps[0].dependencies.push("step2".to_string());
371
372 let result = DependencyResolver::detect_circular_dependencies(&workflow);
373 assert!(result.is_err());
374 }
375
376 #[test]
377 fn test_get_all_dependencies() {
378 let workflow = create_workflow_with_deps();
379 let deps = DependencyResolver::get_all_dependencies(&workflow, "step3").unwrap();
380
381 assert_eq!(deps.len(), 2);
382 assert!(deps.contains("step1"));
383 assert!(deps.contains("step2"));
384 }
385
386 #[test]
387 fn test_get_dependent_steps() {
388 let workflow = create_workflow_with_deps();
389 let dependents = DependencyResolver::get_dependent_steps(&workflow, "step1").unwrap();
390
391 assert!(dependents.contains("step2"));
392 assert!(dependents.contains("step3"));
393 }
394
395 #[test]
396 fn test_can_execute_step() {
397 let workflow = create_workflow_with_deps();
398
399 assert!(DependencyResolver::can_execute_step(&workflow, &[], "step1").unwrap());
401
402 assert!(!DependencyResolver::can_execute_step(&workflow, &[], "step2").unwrap());
404
405 assert!(
407 DependencyResolver::can_execute_step(&workflow, &["step1".to_string()], "step2")
408 .unwrap()
409 );
410 }
411
412 #[test]
413 fn test_get_ready_steps() {
414 let workflow = create_workflow_with_deps();
415
416 let ready = DependencyResolver::get_ready_steps(&workflow, &[], &[]).unwrap();
418 assert_eq!(ready.len(), 1);
419 assert_eq!(ready[0], "step1");
420
421 let ready =
423 DependencyResolver::get_ready_steps(&workflow, &["step1".to_string()], &[]).unwrap();
424 assert_eq!(ready.len(), 1);
425 assert_eq!(ready[0], "step2");
426
427 let ready = DependencyResolver::get_ready_steps(
429 &workflow,
430 &["step1".to_string(), "step2".to_string()],
431 &[],
432 )
433 .unwrap();
434 assert_eq!(ready.len(), 1);
435 assert_eq!(ready[0], "step3");
436 }
437
438 #[test]
439 fn test_validate_dependencies() {
440 let workflow = create_workflow_with_deps();
441 let result = DependencyResolver::validate_dependencies(&workflow);
442 assert!(result.is_ok());
443 }
444
445 #[test]
446 fn test_validate_missing_dependency() {
447 let mut workflow = create_workflow_with_deps();
448 workflow.steps[1]
449 .dependencies
450 .push("non-existent".to_string());
451
452 let result = DependencyResolver::validate_dependencies(&workflow);
453 assert!(result.is_err());
454 }
455}