1use std::collections::{HashMap, HashSet, VecDeque};
2
3use chrono::Utc;
4use uuid::Uuid;
5
6use crate::types::{
7 Edge, EdgeType, ExecutionContext, ExecutionEvent, ExecutionEventType,
8 ExecutionProgress, ExecutionStatus, StepLifecycle, StepState, Workflow,
9 WorkflowError, WorkflowResult,
10};
11
12pub struct DagEngine {
14 workflows: HashMap<String, Workflow>,
15 executions: HashMap<String, ExecutionContext>,
16}
17
18impl DagEngine {
19 pub fn new() -> Self {
20 Self {
21 workflows: HashMap::new(),
22 executions: HashMap::new(),
23 }
24 }
25
26 pub fn register_workflow(&mut self, workflow: Workflow) -> WorkflowResult<()> {
28 self.validate_dag(&workflow)?;
29 self.workflows.insert(workflow.id.clone(), workflow);
30 Ok(())
31 }
32
33 pub fn get_workflow(&self, id: &str) -> WorkflowResult<&Workflow> {
35 self.workflows
36 .get(id)
37 .ok_or_else(|| WorkflowError::WorkflowNotFound(id.to_string()))
38 }
39
40 pub fn remove_workflow(&mut self, id: &str) -> WorkflowResult<Workflow> {
42 self.workflows
43 .remove(id)
44 .ok_or_else(|| WorkflowError::WorkflowNotFound(id.to_string()))
45 }
46
47 pub fn list_workflows(&self) -> Vec<&Workflow> {
49 self.workflows.values().collect()
50 }
51
52 pub fn validate_dag(&self, workflow: &Workflow) -> WorkflowResult<()> {
54 let step_ids: HashSet<&str> = workflow.steps.iter().map(|s| s.id.as_str()).collect();
55
56 for edge in &workflow.edges {
58 if !step_ids.contains(edge.from.as_str()) {
59 return Err(WorkflowError::StepNotFound(edge.from.clone()));
60 }
61 if !step_ids.contains(edge.to.as_str()) {
62 return Err(WorkflowError::StepNotFound(edge.to.clone()));
63 }
64 }
65
66 self.topological_sort(workflow)?;
68 Ok(())
69 }
70
71 pub fn topological_sort(&self, workflow: &Workflow) -> WorkflowResult<Vec<String>> {
73 let mut in_degree: HashMap<&str, usize> = HashMap::new();
74 let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
75
76 for step in &workflow.steps {
77 in_degree.entry(step.id.as_str()).or_insert(0);
78 adjacency.entry(step.id.as_str()).or_default();
79 }
80
81 for edge in &workflow.edges {
82 *in_degree.entry(edge.to.as_str()).or_insert(0) += 1;
83 adjacency
84 .entry(edge.from.as_str())
85 .or_default()
86 .push(edge.to.as_str());
87 }
88
89 let mut queue: VecDeque<&str> = in_degree
90 .iter()
91 .filter(|(_, °)| deg == 0)
92 .map(|(&id, _)| id)
93 .collect();
94
95 let mut order = Vec::new();
96
97 while let Some(node) = queue.pop_front() {
98 order.push(node.to_string());
99 if let Some(neighbors) = adjacency.get(node) {
100 for &neighbor in neighbors {
101 if let Some(deg) = in_degree.get_mut(neighbor) {
102 *deg -= 1;
103 if *deg == 0 {
104 queue.push_back(neighbor);
105 }
106 }
107 }
108 }
109 }
110
111 if order.len() != workflow.steps.len() {
112 return Err(WorkflowError::CycleDetected(
113 "DAG contains a cycle".to_string(),
114 ));
115 }
116
117 Ok(order)
118 }
119
120 pub fn start_execution(&mut self, workflow_id: &str) -> WorkflowResult<String> {
122 let workflow = self
123 .workflows
124 .get(workflow_id)
125 .ok_or_else(|| WorkflowError::WorkflowNotFound(workflow_id.to_string()))?
126 .clone();
127
128 let execution_id = Uuid::new_v4().to_string();
129 let now = Utc::now();
130
131 let mut step_states = HashMap::new();
132 for step in &workflow.steps {
133 step_states.insert(
134 step.id.clone(),
135 StepState {
136 step_id: step.id.clone(),
137 lifecycle: StepLifecycle::Pending,
138 attempt: 0,
139 started_at: None,
140 completed_at: None,
141 duration_ms: None,
142 output: None,
143 error: None,
144 },
145 );
146 }
147
148 let ctx = ExecutionContext {
149 execution_id: execution_id.clone(),
150 workflow_id: workflow_id.to_string(),
151 status: ExecutionStatus::Running,
152 step_states,
153 variables: HashMap::new(),
154 started_at: now,
155 completed_at: None,
156 trigger_info: None,
157 metadata: HashMap::new(),
158 };
159
160 self.executions.insert(execution_id.clone(), ctx);
161 Ok(execution_id)
162 }
163
164 pub fn get_progress(&self, execution_id: &str) -> WorkflowResult<ExecutionProgress> {
166 let ctx = self
167 .executions
168 .get(execution_id)
169 .ok_or_else(|| WorkflowError::ExecutionNotFound(execution_id.to_string()))?;
170
171 let total = ctx.step_states.len();
172 let completed = ctx.step_states.values().filter(|s| s.lifecycle == StepLifecycle::Success).count();
173 let failed = ctx.step_states.values().filter(|s| s.lifecycle == StepLifecycle::Failed).count();
174 let skipped = ctx.step_states.values().filter(|s| s.lifecycle == StepLifecycle::Skipped).count();
175 let running = ctx.step_states.values().filter(|s| s.lifecycle == StepLifecycle::Running).count();
176 let pending = ctx.step_states.values().filter(|s| s.lifecycle == StepLifecycle::Pending || s.lifecycle == StepLifecycle::Queued).count();
177
178 let percent = if total > 0 {
179 (completed as f64 / total as f64) * 100.0
180 } else {
181 0.0
182 };
183
184 Ok(ExecutionProgress {
185 execution_id: execution_id.to_string(),
186 total_steps: total,
187 completed_steps: completed,
188 failed_steps: failed,
189 skipped_steps: skipped,
190 running_steps: running,
191 pending_steps: pending,
192 estimated_remaining_ms: None,
193 percent_complete: percent,
194 })
195 }
196
197 pub fn pause_execution(&mut self, execution_id: &str) -> WorkflowResult<()> {
199 let ctx = self
200 .executions
201 .get_mut(execution_id)
202 .ok_or_else(|| WorkflowError::ExecutionNotFound(execution_id.to_string()))?;
203
204 if ctx.status != ExecutionStatus::Running {
205 return Err(WorkflowError::Internal(format!(
206 "Cannot pause execution in state {:?}",
207 ctx.status
208 )));
209 }
210
211 ctx.status = ExecutionStatus::Paused;
212 Ok(())
213 }
214
215 pub fn resume_execution(&mut self, execution_id: &str) -> WorkflowResult<()> {
217 let ctx = self
218 .executions
219 .get_mut(execution_id)
220 .ok_or_else(|| WorkflowError::ExecutionNotFound(execution_id.to_string()))?;
221
222 if ctx.status != ExecutionStatus::Paused {
223 return Err(WorkflowError::ExecutionNotPaused(execution_id.to_string()));
224 }
225
226 ctx.status = ExecutionStatus::Running;
227 Ok(())
228 }
229
230 pub fn cancel_execution(&mut self, execution_id: &str) -> WorkflowResult<()> {
232 let ctx = self
233 .executions
234 .get_mut(execution_id)
235 .ok_or_else(|| WorkflowError::ExecutionNotFound(execution_id.to_string()))?;
236
237 ctx.status = ExecutionStatus::Cancelled;
238 ctx.completed_at = Some(Utc::now());
239 Ok(())
240 }
241
242 pub fn get_execution(&self, execution_id: &str) -> WorkflowResult<&ExecutionContext> {
244 self.executions
245 .get(execution_id)
246 .ok_or_else(|| WorkflowError::ExecutionNotFound(execution_id.to_string()))
247 }
248
249 pub fn visualize_mermaid(&self, workflow_id: &str) -> WorkflowResult<String> {
251 let wf = self.get_workflow(workflow_id)?;
252 let mut lines = vec!["graph TD".to_string()];
253
254 for step in &wf.steps {
255 lines.push(format!(" {}[{}]", step.id, step.name));
256 }
257
258 for edge in &wf.edges {
259 let label = match &edge.edge_type {
260 EdgeType::Sequence => "".to_string(),
261 EdgeType::Parallel => "|parallel|".to_string(),
262 EdgeType::Conditional { expression } => format!("|{}|", expression),
263 EdgeType::Loop { .. } => "|loop|".to_string(),
264 };
265 lines.push(format!(" {} -->{} {}", edge.from, label, edge.to));
266 }
267
268 Ok(lines.join("\n"))
269 }
270}
271
272impl Default for DagEngine {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::types::{StepNode, StepType};
282
283 #[test]
284 fn test_create_and_validate_workflow() {
285 let mut engine = DagEngine::new();
286 let mut wf = Workflow::new("test-wf", "A test workflow");
287
288 let step1 = StepNode::new("Step 1", StepType::Noop);
289 let step2 = StepNode::new("Step 2", StepType::Noop);
290 let s1_id = step1.id.clone();
291 let s2_id = step2.id.clone();
292
293 wf.add_step(step1);
294 wf.add_step(step2);
295 wf.add_edge(Edge {
296 from: s1_id,
297 to: s2_id,
298 edge_type: EdgeType::Sequence,
299 });
300
301 assert!(engine.register_workflow(wf).is_ok());
302 }
303
304 #[test]
305 fn test_cycle_detection() {
306 let engine = DagEngine::new();
307 let mut wf = Workflow::new("cyclic", "Cyclic workflow");
308
309 let s1 = StepNode::new("A", StepType::Noop);
310 let s2 = StepNode::new("B", StepType::Noop);
311 let s1_id = s1.id.clone();
312 let s2_id = s2.id.clone();
313
314 wf.add_step(s1);
315 wf.add_step(s2);
316 wf.add_edge(Edge {
317 from: s1_id.clone(),
318 to: s2_id.clone(),
319 edge_type: EdgeType::Sequence,
320 });
321 wf.add_edge(Edge {
322 from: s2_id,
323 to: s1_id,
324 edge_type: EdgeType::Sequence,
325 });
326
327 assert!(engine.validate_dag(&wf).is_err());
328 }
329
330 #[test]
331 fn test_execution_lifecycle() {
332 let mut engine = DagEngine::new();
333 let wf = Workflow::new("lifecycle", "Test lifecycle");
334 let wf_id = wf.id.clone();
335 engine.register_workflow(wf).unwrap();
336
337 let exec_id = engine.start_execution(&wf_id).unwrap();
338 assert!(engine.pause_execution(&exec_id).is_ok());
339 assert!(engine.resume_execution(&exec_id).is_ok());
340 assert!(engine.cancel_execution(&exec_id).is_ok());
341 }
342}