Skip to main content

agentic_workflow/engine/
dag.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2
3use chrono::Utc;
4use uuid::Uuid;
5
6use crate::types::{
7    Edge, EdgeType, ExecutionContext, ExecutionEvent, ExecutionEventType,
8    ExecutionProgress, ExecutionStatus, StepLifecycle, StepState, Workflow,
9    WorkflowError, WorkflowResult,
10};
11
12/// DAG execution engine — validates and runs workflow graphs.
13pub struct DagEngine {
14    workflows: HashMap<String, Workflow>,
15    executions: HashMap<String, ExecutionContext>,
16}
17
18impl DagEngine {
19    pub fn new() -> Self {
20        Self {
21            workflows: HashMap::new(),
22            executions: HashMap::new(),
23        }
24    }
25
26    /// Register a workflow definition.
27    pub fn register_workflow(&mut self, workflow: Workflow) -> WorkflowResult<()> {
28        self.validate_dag(&workflow)?;
29        self.workflows.insert(workflow.id.clone(), workflow);
30        Ok(())
31    }
32
33    /// Get a workflow by ID.
34    pub fn get_workflow(&self, id: &str) -> WorkflowResult<&Workflow> {
35        self.workflows
36            .get(id)
37            .ok_or_else(|| WorkflowError::WorkflowNotFound(id.to_string()))
38    }
39
40    /// Remove a workflow.
41    pub fn remove_workflow(&mut self, id: &str) -> WorkflowResult<Workflow> {
42        self.workflows
43            .remove(id)
44            .ok_or_else(|| WorkflowError::WorkflowNotFound(id.to_string()))
45    }
46
47    /// List all registered workflows.
48    pub fn list_workflows(&self) -> Vec<&Workflow> {
49        self.workflows.values().collect()
50    }
51
52    /// Validate the DAG — check for cycles and unsatisfied dependencies.
53    pub fn validate_dag(&self, workflow: &Workflow) -> WorkflowResult<()> {
54        let step_ids: HashSet<&str> = workflow.steps.iter().map(|s| s.id.as_str()).collect();
55
56        // Check all edges reference valid steps
57        for edge in &workflow.edges {
58            if !step_ids.contains(edge.from.as_str()) {
59                return Err(WorkflowError::StepNotFound(edge.from.clone()));
60            }
61            if !step_ids.contains(edge.to.as_str()) {
62                return Err(WorkflowError::StepNotFound(edge.to.clone()));
63            }
64        }
65
66        // Topological sort to detect cycles
67        self.topological_sort(workflow)?;
68        Ok(())
69    }
70
71    /// Topological sort of steps — returns execution order or error if cycle.
72    pub fn topological_sort(&self, workflow: &Workflow) -> WorkflowResult<Vec<String>> {
73        let mut in_degree: HashMap<&str, usize> = HashMap::new();
74        let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
75
76        for step in &workflow.steps {
77            in_degree.entry(step.id.as_str()).or_insert(0);
78            adjacency.entry(step.id.as_str()).or_default();
79        }
80
81        for edge in &workflow.edges {
82            *in_degree.entry(edge.to.as_str()).or_insert(0) += 1;
83            adjacency
84                .entry(edge.from.as_str())
85                .or_default()
86                .push(edge.to.as_str());
87        }
88
89        let mut queue: VecDeque<&str> = in_degree
90            .iter()
91            .filter(|(_, &deg)| deg == 0)
92            .map(|(&id, _)| id)
93            .collect();
94
95        let mut order = Vec::new();
96
97        while let Some(node) = queue.pop_front() {
98            order.push(node.to_string());
99            if let Some(neighbors) = adjacency.get(node) {
100                for &neighbor in neighbors {
101                    if let Some(deg) = in_degree.get_mut(neighbor) {
102                        *deg -= 1;
103                        if *deg == 0 {
104                            queue.push_back(neighbor);
105                        }
106                    }
107                }
108            }
109        }
110
111        if order.len() != workflow.steps.len() {
112            return Err(WorkflowError::CycleDetected(
113                "DAG contains a cycle".to_string(),
114            ));
115        }
116
117        Ok(order)
118    }
119
120    /// Start a new execution of a workflow.
121    pub fn start_execution(&mut self, workflow_id: &str) -> WorkflowResult<String> {
122        let workflow = self
123            .workflows
124            .get(workflow_id)
125            .ok_or_else(|| WorkflowError::WorkflowNotFound(workflow_id.to_string()))?
126            .clone();
127
128        let execution_id = Uuid::new_v4().to_string();
129        let now = Utc::now();
130
131        let mut step_states = HashMap::new();
132        for step in &workflow.steps {
133            step_states.insert(
134                step.id.clone(),
135                StepState {
136                    step_id: step.id.clone(),
137                    lifecycle: StepLifecycle::Pending,
138                    attempt: 0,
139                    started_at: None,
140                    completed_at: None,
141                    duration_ms: None,
142                    output: None,
143                    error: None,
144                },
145            );
146        }
147
148        let ctx = ExecutionContext {
149            execution_id: execution_id.clone(),
150            workflow_id: workflow_id.to_string(),
151            status: ExecutionStatus::Running,
152            step_states,
153            variables: HashMap::new(),
154            started_at: now,
155            completed_at: None,
156            trigger_info: None,
157            metadata: HashMap::new(),
158        };
159
160        self.executions.insert(execution_id.clone(), ctx);
161        Ok(execution_id)
162    }
163
164    /// Get execution progress.
165    pub fn get_progress(&self, execution_id: &str) -> WorkflowResult<ExecutionProgress> {
166        let ctx = self
167            .executions
168            .get(execution_id)
169            .ok_or_else(|| WorkflowError::ExecutionNotFound(execution_id.to_string()))?;
170
171        let total = ctx.step_states.len();
172        let completed = ctx.step_states.values().filter(|s| s.lifecycle == StepLifecycle::Success).count();
173        let failed = ctx.step_states.values().filter(|s| s.lifecycle == StepLifecycle::Failed).count();
174        let skipped = ctx.step_states.values().filter(|s| s.lifecycle == StepLifecycle::Skipped).count();
175        let running = ctx.step_states.values().filter(|s| s.lifecycle == StepLifecycle::Running).count();
176        let pending = ctx.step_states.values().filter(|s| s.lifecycle == StepLifecycle::Pending || s.lifecycle == StepLifecycle::Queued).count();
177
178        let percent = if total > 0 {
179            (completed as f64 / total as f64) * 100.0
180        } else {
181            0.0
182        };
183
184        Ok(ExecutionProgress {
185            execution_id: execution_id.to_string(),
186            total_steps: total,
187            completed_steps: completed,
188            failed_steps: failed,
189            skipped_steps: skipped,
190            running_steps: running,
191            pending_steps: pending,
192            estimated_remaining_ms: None,
193            percent_complete: percent,
194        })
195    }
196
197    /// Pause a running execution.
198    pub fn pause_execution(&mut self, execution_id: &str) -> WorkflowResult<()> {
199        let ctx = self
200            .executions
201            .get_mut(execution_id)
202            .ok_or_else(|| WorkflowError::ExecutionNotFound(execution_id.to_string()))?;
203
204        if ctx.status != ExecutionStatus::Running {
205            return Err(WorkflowError::Internal(format!(
206                "Cannot pause execution in state {:?}",
207                ctx.status
208            )));
209        }
210
211        ctx.status = ExecutionStatus::Paused;
212        Ok(())
213    }
214
215    /// Resume a paused execution.
216    pub fn resume_execution(&mut self, execution_id: &str) -> WorkflowResult<()> {
217        let ctx = self
218            .executions
219            .get_mut(execution_id)
220            .ok_or_else(|| WorkflowError::ExecutionNotFound(execution_id.to_string()))?;
221
222        if ctx.status != ExecutionStatus::Paused {
223            return Err(WorkflowError::ExecutionNotPaused(execution_id.to_string()));
224        }
225
226        ctx.status = ExecutionStatus::Running;
227        Ok(())
228    }
229
230    /// Cancel a running execution.
231    pub fn cancel_execution(&mut self, execution_id: &str) -> WorkflowResult<()> {
232        let ctx = self
233            .executions
234            .get_mut(execution_id)
235            .ok_or_else(|| WorkflowError::ExecutionNotFound(execution_id.to_string()))?;
236
237        ctx.status = ExecutionStatus::Cancelled;
238        ctx.completed_at = Some(Utc::now());
239        Ok(())
240    }
241
242    /// Get execution context.
243    pub fn get_execution(&self, execution_id: &str) -> WorkflowResult<&ExecutionContext> {
244        self.executions
245            .get(execution_id)
246            .ok_or_else(|| WorkflowError::ExecutionNotFound(execution_id.to_string()))
247    }
248
249    /// Generate a Mermaid diagram for a workflow.
250    pub fn visualize_mermaid(&self, workflow_id: &str) -> WorkflowResult<String> {
251        let wf = self.get_workflow(workflow_id)?;
252        let mut lines = vec!["graph TD".to_string()];
253
254        for step in &wf.steps {
255            lines.push(format!("    {}[{}]", step.id, step.name));
256        }
257
258        for edge in &wf.edges {
259            let label = match &edge.edge_type {
260                EdgeType::Sequence => "".to_string(),
261                EdgeType::Parallel => "|parallel|".to_string(),
262                EdgeType::Conditional { expression } => format!("|{}|", expression),
263                EdgeType::Loop { .. } => "|loop|".to_string(),
264            };
265            lines.push(format!("    {} -->{}  {}", edge.from, label, edge.to));
266        }
267
268        Ok(lines.join("\n"))
269    }
270}
271
272impl Default for DagEngine {
273    fn default() -> Self {
274        Self::new()
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::types::{StepNode, StepType};
282
283    #[test]
284    fn test_create_and_validate_workflow() {
285        let mut engine = DagEngine::new();
286        let mut wf = Workflow::new("test-wf", "A test workflow");
287
288        let step1 = StepNode::new("Step 1", StepType::Noop);
289        let step2 = StepNode::new("Step 2", StepType::Noop);
290        let s1_id = step1.id.clone();
291        let s2_id = step2.id.clone();
292
293        wf.add_step(step1);
294        wf.add_step(step2);
295        wf.add_edge(Edge {
296            from: s1_id,
297            to: s2_id,
298            edge_type: EdgeType::Sequence,
299        });
300
301        assert!(engine.register_workflow(wf).is_ok());
302    }
303
304    #[test]
305    fn test_cycle_detection() {
306        let engine = DagEngine::new();
307        let mut wf = Workflow::new("cyclic", "Cyclic workflow");
308
309        let s1 = StepNode::new("A", StepType::Noop);
310        let s2 = StepNode::new("B", StepType::Noop);
311        let s1_id = s1.id.clone();
312        let s2_id = s2.id.clone();
313
314        wf.add_step(s1);
315        wf.add_step(s2);
316        wf.add_edge(Edge {
317            from: s1_id.clone(),
318            to: s2_id.clone(),
319            edge_type: EdgeType::Sequence,
320        });
321        wf.add_edge(Edge {
322            from: s2_id,
323            to: s1_id,
324            edge_type: EdgeType::Sequence,
325        });
326
327        assert!(engine.validate_dag(&wf).is_err());
328    }
329
330    #[test]
331    fn test_execution_lifecycle() {
332        let mut engine = DagEngine::new();
333        let wf = Workflow::new("lifecycle", "Test lifecycle");
334        let wf_id = wf.id.clone();
335        engine.register_workflow(wf).unwrap();
336
337        let exec_id = engine.start_execution(&wf_id).unwrap();
338        assert!(engine.pause_execution(&exec_id).is_ok());
339        assert!(engine.resume_execution(&exec_id).is_ok());
340        assert!(engine.cancel_execution(&exec_id).is_ok());
341    }
342}