ricecoder_workflows/
resolver.rs

1//! Dependency resolution for workflow steps
2
3use crate::error::{WorkflowError, WorkflowResult};
4use crate::models::Workflow;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7/// Resolves step dependencies and builds execution order
8///
9/// Handles:
10/// - Building execution order from dependency graph
11/// - Detecting and reporting circular dependencies
12/// - Waiting for dependencies before executing step
13pub struct DependencyResolver;
14
15impl DependencyResolver {
16    /// Build execution order from dependency graph
17    ///
18    /// Uses topological sort to determine the order in which steps should execute
19    /// based on their dependencies. Returns error if circular dependencies are detected.
20    pub fn resolve_execution_order(workflow: &Workflow) -> WorkflowResult<Vec<String>> {
21        Self::topological_sort(workflow)
22    }
23
24    /// Perform topological sort on workflow steps
25    ///
26    /// Builds a valid execution order where all dependencies are satisfied
27    /// before a step is executed.
28    fn topological_sort(workflow: &Workflow) -> WorkflowResult<Vec<String>> {
29        let mut order = Vec::new();
30        let mut completed = HashSet::new();
31        let mut queue = VecDeque::new();
32
33        // Find all steps with no dependencies
34        for step in &workflow.steps {
35            if step.dependencies.is_empty() {
36                queue.push_back(step.id.clone());
37            }
38        }
39
40        // Build step map for quick lookup
41        let step_map: HashMap<_, _> = workflow.steps.iter().map(|s| (&s.id, s)).collect();
42
43        // Process queue
44        while let Some(step_id) = queue.pop_front() {
45            if completed.contains(&step_id) {
46                continue;
47            }
48
49            // Check if all dependencies are completed
50            if let Some(step) = step_map.get(&step_id) {
51                let all_deps_completed =
52                    step.dependencies.iter().all(|dep| completed.contains(dep));
53
54                if all_deps_completed {
55                    order.push(step_id.clone());
56                    completed.insert(step_id.clone());
57
58                    // Add steps that depend on this one
59                    for other_step in &workflow.steps {
60                        if other_step.dependencies.contains(&step_id)
61                            && !completed.contains(&other_step.id)
62                        {
63                            queue.push_back(other_step.id.clone());
64                        }
65                    }
66                } else {
67                    // Re-queue if dependencies not met
68                    queue.push_back(step_id);
69                }
70            }
71        }
72
73        if order.len() != workflow.steps.len() {
74            return Err(WorkflowError::Invalid(
75                "Could not determine execution order for all steps".to_string(),
76            ));
77        }
78
79        Ok(order)
80    }
81
82    /// Detect circular dependencies in workflow
83    ///
84    /// Uses depth-first search to detect cycles in the dependency graph.
85    /// Returns error if any circular dependency is found.
86    pub fn detect_circular_dependencies(workflow: &Workflow) -> WorkflowResult<()> {
87        let step_map: HashMap<&String, &crate::models::WorkflowStep> =
88            workflow.steps.iter().map(|s| (&s.id, s)).collect();
89
90        // For each step, perform DFS to detect cycles
91        for start_step in &workflow.steps {
92            let mut visited = HashSet::new();
93            let mut rec_stack = HashSet::new();
94
95            Self::dfs_detect_cycle(&step_map, &start_step.id, &mut visited, &mut rec_stack)?;
96        }
97
98        Ok(())
99    }
100
101    /// Depth-first search to detect cycles
102    fn dfs_detect_cycle(
103        step_map: &HashMap<&String, &crate::models::WorkflowStep>,
104        step_id: &String,
105        visited: &mut HashSet<String>,
106        rec_stack: &mut HashSet<String>,
107    ) -> WorkflowResult<()> {
108        visited.insert(step_id.clone());
109        rec_stack.insert(step_id.clone());
110
111        if let Some(step) = step_map.get(step_id) {
112            for dep in &step.dependencies {
113                if !visited.contains(dep) {
114                    Self::dfs_detect_cycle(step_map, dep, visited, rec_stack)?;
115                } else if rec_stack.contains(dep) {
116                    return Err(WorkflowError::Invalid(format!(
117                        "Circular dependency detected: {} -> {}",
118                        step_id, dep
119                    )));
120                }
121            }
122        }
123
124        rec_stack.remove(step_id);
125        Ok(())
126    }
127
128    /// Get all dependencies for a step (transitive closure)
129    ///
130    /// Returns all steps that must be completed before the given step can execute,
131    /// including transitive dependencies.
132    pub fn get_all_dependencies(
133        workflow: &Workflow,
134        step_id: &str,
135    ) -> WorkflowResult<HashSet<String>> {
136        let mut all_deps = HashSet::new();
137        let mut queue = VecDeque::new();
138
139        // Find the step
140        let step = workflow
141            .steps
142            .iter()
143            .find(|s| s.id == step_id)
144            .ok_or_else(|| WorkflowError::NotFound(format!("Step not found: {}", step_id)))?;
145
146        // Add direct dependencies to queue
147        for dep in &step.dependencies {
148            queue.push_back(dep.clone());
149        }
150
151        // Build step map
152        let step_map: HashMap<_, _> = workflow.steps.iter().map(|s| (&s.id, s)).collect();
153
154        // Process queue to find transitive dependencies
155        while let Some(dep_id) = queue.pop_front() {
156            if all_deps.contains(&dep_id) {
157                continue;
158            }
159
160            all_deps.insert(dep_id.clone());
161
162            // Add dependencies of this dependency
163            if let Some(dep_step) = step_map.get(&dep_id) {
164                for transitive_dep in &dep_step.dependencies {
165                    if !all_deps.contains(transitive_dep) {
166                        queue.push_back(transitive_dep.clone());
167                    }
168                }
169            }
170        }
171
172        Ok(all_deps)
173    }
174
175    /// Get all steps that depend on a given step (reverse dependencies)
176    ///
177    /// Returns all steps that have the given step as a dependency (direct or transitive).
178    pub fn get_dependent_steps(
179        workflow: &Workflow,
180        step_id: &str,
181    ) -> WorkflowResult<HashSet<String>> {
182        let mut dependents = HashSet::new();
183
184        // Find all steps that directly depend on this step
185        for step in &workflow.steps {
186            if step.dependencies.contains(&step_id.to_string()) {
187                dependents.insert(step.id.clone());
188
189                // Recursively find steps that depend on these steps
190                if let Ok(transitive) = Self::get_dependent_steps(workflow, &step.id) {
191                    dependents.extend(transitive);
192                }
193            }
194        }
195
196        Ok(dependents)
197    }
198
199    /// Check if a step can be executed
200    ///
201    /// A step can be executed if all its dependencies have been completed.
202    pub fn can_execute_step(
203        workflow: &Workflow,
204        completed_steps: &[String],
205        step_id: &str,
206    ) -> WorkflowResult<bool> {
207        let step = workflow
208            .steps
209            .iter()
210            .find(|s| s.id == step_id)
211            .ok_or_else(|| WorkflowError::NotFound(format!("Step not found: {}", step_id)))?;
212
213        // Check if all dependencies are completed
214        for dep in &step.dependencies {
215            if !completed_steps.contains(dep) {
216                return Ok(false);
217            }
218        }
219
220        Ok(true)
221    }
222
223    /// Get steps that are ready to execute
224    ///
225    /// Returns all steps whose dependencies are satisfied and haven't been executed yet.
226    pub fn get_ready_steps(
227        workflow: &Workflow,
228        completed_steps: &[String],
229        in_progress_steps: &[String],
230    ) -> WorkflowResult<Vec<String>> {
231        let mut ready = Vec::new();
232
233        for step in &workflow.steps {
234            // Skip if already completed or in progress
235            if completed_steps.contains(&step.id) || in_progress_steps.contains(&step.id) {
236                continue;
237            }
238
239            // Check if all dependencies are completed
240            if Self::can_execute_step(workflow, completed_steps, &step.id)? {
241                ready.push(step.id.clone());
242            }
243        }
244
245        Ok(ready)
246    }
247
248    /// Validate dependency graph
249    ///
250    /// Checks for:
251    /// - Missing dependencies (references to non-existent steps)
252    /// - Circular dependencies
253    /// - Duplicate step IDs
254    pub fn validate_dependencies(workflow: &Workflow) -> WorkflowResult<()> {
255        // Check for duplicate step IDs
256        let mut step_ids = HashSet::new();
257        for step in &workflow.steps {
258            if !step_ids.insert(&step.id) {
259                return Err(WorkflowError::Invalid(format!(
260                    "Duplicate step id: {}",
261                    step.id
262                )));
263            }
264        }
265
266        // Check for missing dependencies
267        for step in &workflow.steps {
268            for dep in &step.dependencies {
269                if !step_ids.contains(dep) {
270                    return Err(WorkflowError::Invalid(format!(
271                        "Step {} depends on non-existent step {}",
272                        step.id, dep
273                    )));
274                }
275            }
276        }
277
278        // Check for circular dependencies
279        Self::detect_circular_dependencies(workflow)?;
280
281        Ok(())
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use crate::models::{
289        AgentStep, ErrorAction, RiskFactors, StepConfig, StepType, WorkflowConfig, WorkflowStep,
290    };
291
292    fn create_workflow_with_deps() -> Workflow {
293        Workflow {
294            id: "test-workflow".to_string(),
295            name: "Test Workflow".to_string(),
296            description: "A test workflow".to_string(),
297            parameters: vec![],
298            steps: vec![
299                WorkflowStep {
300                    id: "step1".to_string(),
301                    name: "Step 1".to_string(),
302                    step_type: StepType::Agent(AgentStep {
303                        agent_id: "test-agent".to_string(),
304                        task: "test-task".to_string(),
305                    }),
306                    config: StepConfig {
307                        config: serde_json::json!({}),
308                    },
309                    dependencies: vec![],
310                    approval_required: false,
311                    on_error: ErrorAction::Fail,
312                    risk_score: None,
313                    risk_factors: RiskFactors::default(),
314                },
315                WorkflowStep {
316                    id: "step2".to_string(),
317                    name: "Step 2".to_string(),
318                    step_type: StepType::Agent(AgentStep {
319                        agent_id: "test-agent".to_string(),
320                        task: "test-task".to_string(),
321                    }),
322                    config: StepConfig {
323                        config: serde_json::json!({}),
324                    },
325                    dependencies: vec!["step1".to_string()],
326                    approval_required: false,
327                    on_error: ErrorAction::Fail,
328                    risk_score: None,
329                    risk_factors: RiskFactors::default(),
330                },
331                WorkflowStep {
332                    id: "step3".to_string(),
333                    name: "Step 3".to_string(),
334                    step_type: StepType::Agent(AgentStep {
335                        agent_id: "test-agent".to_string(),
336                        task: "test-task".to_string(),
337                    }),
338                    config: StepConfig {
339                        config: serde_json::json!({}),
340                    },
341                    dependencies: vec!["step1".to_string(), "step2".to_string()],
342                    approval_required: false,
343                    on_error: ErrorAction::Fail,
344                    risk_score: None,
345                    risk_factors: RiskFactors::default(),
346                },
347            ],
348            config: WorkflowConfig {
349                timeout_ms: None,
350                max_parallel: None,
351            },
352        }
353    }
354
355    #[test]
356    fn test_resolve_execution_order() {
357        let workflow = create_workflow_with_deps();
358        let order = DependencyResolver::resolve_execution_order(&workflow).unwrap();
359
360        assert_eq!(order.len(), 3);
361        assert_eq!(order[0], "step1");
362        assert_eq!(order[1], "step2");
363        assert_eq!(order[2], "step3");
364    }
365
366    #[test]
367    fn test_detect_circular_dependency() {
368        let mut workflow = create_workflow_with_deps();
369        // Create a circular dependency: step1 -> step2 -> step1
370        workflow.steps[0].dependencies.push("step2".to_string());
371
372        let result = DependencyResolver::detect_circular_dependencies(&workflow);
373        assert!(result.is_err());
374    }
375
376    #[test]
377    fn test_get_all_dependencies() {
378        let workflow = create_workflow_with_deps();
379        let deps = DependencyResolver::get_all_dependencies(&workflow, "step3").unwrap();
380
381        assert_eq!(deps.len(), 2);
382        assert!(deps.contains("step1"));
383        assert!(deps.contains("step2"));
384    }
385
386    #[test]
387    fn test_get_dependent_steps() {
388        let workflow = create_workflow_with_deps();
389        let dependents = DependencyResolver::get_dependent_steps(&workflow, "step1").unwrap();
390
391        assert!(dependents.contains("step2"));
392        assert!(dependents.contains("step3"));
393    }
394
395    #[test]
396    fn test_can_execute_step() {
397        let workflow = create_workflow_with_deps();
398
399        // step1 can execute (no dependencies)
400        assert!(DependencyResolver::can_execute_step(&workflow, &[], "step1").unwrap());
401
402        // step2 cannot execute (depends on step1)
403        assert!(!DependencyResolver::can_execute_step(&workflow, &[], "step2").unwrap());
404
405        // step2 can execute after step1 is completed
406        assert!(
407            DependencyResolver::can_execute_step(&workflow, &["step1".to_string()], "step2")
408                .unwrap()
409        );
410    }
411
412    #[test]
413    fn test_get_ready_steps() {
414        let workflow = create_workflow_with_deps();
415
416        // Initially, only step1 is ready
417        let ready = DependencyResolver::get_ready_steps(&workflow, &[], &[]).unwrap();
418        assert_eq!(ready.len(), 1);
419        assert_eq!(ready[0], "step1");
420
421        // After step1 completes, step2 is ready
422        let ready =
423            DependencyResolver::get_ready_steps(&workflow, &["step1".to_string()], &[]).unwrap();
424        assert_eq!(ready.len(), 1);
425        assert_eq!(ready[0], "step2");
426
427        // After step1 and step2 complete, step3 is ready
428        let ready = DependencyResolver::get_ready_steps(
429            &workflow,
430            &["step1".to_string(), "step2".to_string()],
431            &[],
432        )
433        .unwrap();
434        assert_eq!(ready.len(), 1);
435        assert_eq!(ready[0], "step3");
436    }
437
438    #[test]
439    fn test_validate_dependencies() {
440        let workflow = create_workflow_with_deps();
441        let result = DependencyResolver::validate_dependencies(&workflow);
442        assert!(result.is_ok());
443    }
444
445    #[test]
446    fn test_validate_missing_dependency() {
447        let mut workflow = create_workflow_with_deps();
448        workflow.steps[1]
449            .dependencies
450            .push("non-existent".to_string());
451
452        let result = DependencyResolver::validate_dependencies(&workflow);
453        assert!(result.is_err());
454    }
455}