1use super::graph::WorkflowGraph;
6use super::node::{NodeType, WorkflowNode};
7use super::state::{
8 ExecutionRecord, NodeExecutionRecord, NodeResult, NodeStatus, WorkflowContext, WorkflowStatus,
9 WorkflowValue,
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Instant;
14use tokio::sync::{RwLock, Semaphore, mpsc, oneshot};
15use tracing::{error, info, warn};
16
17#[derive(Debug, Clone)]
19pub struct ExecutorConfig {
20 pub max_parallelism: usize,
22 pub stop_on_failure: bool,
24 pub enable_checkpoints: bool,
26 pub checkpoint_interval: usize,
28 pub execution_timeout_ms: Option<u64>,
30}
31
32impl Default for ExecutorConfig {
33 fn default() -> Self {
34 Self {
35 max_parallelism: 10,
36 stop_on_failure: true,
37 enable_checkpoints: true,
38 checkpoint_interval: 5,
39 execution_timeout_ms: None,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub enum ExecutionEvent {
47 WorkflowStarted {
49 workflow_id: String,
50 execution_id: String,
51 },
52 WorkflowCompleted {
54 workflow_id: String,
55 execution_id: String,
56 status: WorkflowStatus,
57 },
58 NodeStarted { node_id: String },
60 NodeCompleted { node_id: String, result: NodeResult },
62 NodeFailed { node_id: String, error: String },
64 CheckpointCreated { label: String },
66 ExternalEvent {
68 event_type: String,
69 data: WorkflowValue,
70 },
71}
72
73pub struct WorkflowExecutor {
75 config: ExecutorConfig,
77 event_tx: Option<mpsc::Sender<ExecutionEvent>>,
79 sub_workflows: Arc<RwLock<HashMap<String, Arc<WorkflowGraph>>>>,
81 event_waiters: Arc<RwLock<HashMap<String, Vec<oneshot::Sender<WorkflowValue>>>>>,
83 semaphore: Arc<Semaphore>,
85}
86
87impl WorkflowExecutor {
88 pub fn new(config: ExecutorConfig) -> Self {
89 let semaphore = Arc::new(Semaphore::new(config.max_parallelism));
90 Self {
91 config,
92 event_tx: None,
93 sub_workflows: Arc::new(RwLock::new(HashMap::new())),
94 event_waiters: Arc::new(RwLock::new(HashMap::new())),
95 semaphore,
96 }
97 }
98
99 pub fn with_event_sender(mut self, tx: mpsc::Sender<ExecutionEvent>) -> Self {
101 self.event_tx = Some(tx);
102 self
103 }
104
105 pub async fn register_sub_workflow(&self, id: &str, graph: WorkflowGraph) {
107 let mut workflows = self.sub_workflows.write().await;
108 workflows.insert(id.to_string(), Arc::new(graph));
109 }
110
111 async fn emit_event(&self, event: ExecutionEvent) {
113 if let Some(ref tx) = self.event_tx {
114 let _ = tx.send(event).await;
115 }
116 }
117
118 pub async fn send_external_event(&self, event_type: &str, data: WorkflowValue) {
120 let mut waiters = self.event_waiters.write().await;
121 if let Some(senders) = waiters.remove(event_type) {
122 for sender in senders {
123 let _ = sender.send(data.clone());
124 }
125 }
126 }
127
128 pub async fn execute(
130 &self,
131 graph: &WorkflowGraph,
132 input: WorkflowValue,
133 ) -> Result<ExecutionRecord, String> {
134 let start_time = Instant::now();
135 let ctx = WorkflowContext::new(&graph.id);
136 ctx.set_input(input.clone()).await;
137
138 self.emit_event(ExecutionEvent::WorkflowStarted {
140 workflow_id: graph.id.clone(),
141 execution_id: ctx.execution_id.clone(),
142 })
143 .await;
144
145 info!(
146 "Starting workflow execution: {} ({})",
147 graph.name, ctx.execution_id
148 );
149
150 if let Err(errors) = graph.validate() {
152 let error_msg = errors.join("; ");
153 error!("Workflow validation failed: {}", error_msg);
154 return Err(error_msg);
155 }
156
157 let start_node_id = graph
159 .start_node()
160 .ok_or_else(|| "No start node".to_string())?;
161
162 let mut execution_record = ExecutionRecord {
164 execution_id: ctx.execution_id.clone(),
165 workflow_id: graph.id.clone(),
166 started_at: std::time::SystemTime::now()
167 .duration_since(std::time::UNIX_EPOCH)
168 .unwrap_or_default()
169 .as_millis() as u64,
170 ended_at: None,
171 status: WorkflowStatus::Running,
172 node_records: Vec::new(),
173 };
174
175 let result = self
177 .execute_from_node(graph, &ctx, start_node_id, input, &mut execution_record)
178 .await;
179
180 let duration = start_time.elapsed();
181 execution_record.ended_at = Some(
182 std::time::SystemTime::now()
183 .duration_since(std::time::UNIX_EPOCH)
184 .unwrap_or_default()
185 .as_millis() as u64,
186 );
187
188 match result {
189 Ok(_) => {
190 execution_record.status = WorkflowStatus::Completed;
191 info!("Workflow {} completed in {:?}", graph.name, duration);
192 }
193 Err(ref e) => {
194 execution_record.status = WorkflowStatus::Failed(e.clone());
195 error!("Workflow {} failed: {}", graph.name, e);
196 }
197 }
198
199 self.emit_event(ExecutionEvent::WorkflowCompleted {
201 workflow_id: graph.id.clone(),
202 execution_id: ctx.execution_id.clone(),
203 status: execution_record.status.clone(),
204 })
205 .await;
206
207 Ok(execution_record)
208 }
209
210 async fn execute_from_node(
212 &self,
213 graph: &WorkflowGraph,
214 ctx: &WorkflowContext,
215 start_node_id: &str,
216 initial_input: WorkflowValue,
217 record: &mut ExecutionRecord,
218 ) -> Result<WorkflowValue, String> {
219 let mut current_node_id = start_node_id.to_string();
220 let mut current_input = initial_input;
221
222 loop {
223 let node = graph
224 .get_node(¤t_node_id)
225 .ok_or_else(|| format!("Node {} not found", current_node_id))?;
226
227 ctx.set_node_status(¤t_node_id, NodeStatus::Running)
229 .await;
230 self.emit_event(ExecutionEvent::NodeStarted {
231 node_id: current_node_id.clone(),
232 })
233 .await;
234
235 let start_time = std::time::SystemTime::now()
236 .duration_since(std::time::UNIX_EPOCH)
237 .unwrap_or_default()
238 .as_millis() as u64;
239
240 let result = match node.node_type() {
242 NodeType::Parallel => {
243 self.execute_parallel(graph, ctx, node, current_input.clone(), record)
244 .await
245 }
246 NodeType::Join => self.execute_join(graph, ctx, node, record).await,
247 NodeType::SubWorkflow => {
248 self.execute_sub_workflow(graph, ctx, node, current_input.clone(), record)
249 .await
250 }
251 NodeType::Wait => self.execute_wait(ctx, node, current_input.clone()).await,
252 _ => {
253 let result = node.execute(ctx, current_input.clone()).await;
255 ctx.set_node_output(¤t_node_id, result.output.clone())
256 .await;
257 ctx.set_node_status(¤t_node_id, result.status.clone())
258 .await;
259
260 self.emit_event(ExecutionEvent::NodeCompleted {
262 node_id: current_node_id.clone(),
263 result: result.clone(),
264 })
265 .await;
266
267 if result.status.is_success() {
268 Ok(result.output)
269 } else {
270 Err(result.error.unwrap_or_else(|| "Unknown error".to_string()))
271 }
272 }
273 };
274
275 let end_time = std::time::SystemTime::now()
276 .duration_since(std::time::UNIX_EPOCH)
277 .unwrap_or_default()
278 .as_millis() as u64;
279
280 record.node_records.push(NodeExecutionRecord {
282 node_id: current_node_id.clone(),
283 started_at: start_time,
284 ended_at: end_time,
285 status: ctx
286 .get_node_status(¤t_node_id)
287 .await
288 .unwrap_or(NodeStatus::Pending),
289 retry_count: 0,
290 });
291
292 if self.config.enable_checkpoints
294 && record
295 .node_records
296 .len()
297 .is_multiple_of(self.config.checkpoint_interval)
298 {
299 let label = format!("auto_checkpoint_{}", record.node_records.len());
300 ctx.create_checkpoint(&label).await;
301 self.emit_event(ExecutionEvent::CheckpointCreated { label })
302 .await;
303 }
304
305 match result {
307 Ok(output) => {
308 let next = self.determine_next_node(graph, node, &output).await;
310
311 match next {
312 Some(next_node_id) => {
313 current_node_id = next_node_id;
315 current_input = output;
316 }
318 None => {
319 return Ok(output);
321 }
322 }
323 }
324 Err(e) => {
325 if let Some(error_handler) = graph.get_error_handler(¤t_node_id) {
327 warn!(
328 "Node {} failed, executing error handler: {}",
329 current_node_id, error_handler
330 );
331 let error_input = WorkflowValue::Map({
332 let mut m = HashMap::new();
333 m.insert("error".to_string(), WorkflowValue::String(e.clone()));
334 m.insert(
335 "node_id".to_string(),
336 WorkflowValue::String(current_node_id.clone()),
337 );
338 m
339 });
340 current_node_id = error_handler.to_string();
341 current_input = error_input;
342 } else if self.config.stop_on_failure {
344 return Err(e);
345 } else {
346 warn!("Node {} failed but continuing: {}", current_node_id, e);
347 if let Some(next_node_id) = graph.get_next_node(¤t_node_id, None) {
349 current_node_id = next_node_id.to_string();
350 current_input = WorkflowValue::Null;
351 } else {
353 return Err(e);
354 }
355 }
356 }
357 }
358 }
359 }
360
361 async fn determine_next_node(
363 &self,
364 graph: &WorkflowGraph,
365 node: &WorkflowNode,
366 output: &WorkflowValue,
367 ) -> Option<String> {
368 let node_id = node.id();
369
370 match node.node_type() {
371 NodeType::Condition => {
372 let condition = output.as_str().unwrap_or("false");
374 graph
375 .get_next_node(node_id, Some(condition))
376 .map(|s| s.to_string())
377 }
378 NodeType::End => {
379 None
381 }
382 _ => {
383 graph.get_next_node(node_id, None).map(|s| s.to_string())
385 }
386 }
387 }
388
389 async fn execute_parallel(
391 &self,
392 graph: &WorkflowGraph,
393 ctx: &WorkflowContext,
394 node: &WorkflowNode,
395 input: WorkflowValue,
396 record: &mut ExecutionRecord,
397 ) -> Result<WorkflowValue, String> {
398 let branches = node.parallel_branches();
399
400 if branches.is_empty() {
401 let edges = graph.get_outgoing_edges(node.id());
403 let branch_ids: Vec<String> = edges.iter().map(|e| e.to.clone()).collect();
404
405 if branch_ids.is_empty() {
406 return Ok(input);
407 }
408
409 return self
410 .execute_branches_parallel(graph, ctx, &branch_ids, input, record)
411 .await;
412 }
413
414 self.execute_branches_parallel(graph, ctx, branches, input, record)
415 .await
416 }
417
418 async fn execute_branches_parallel(
421 &self,
422 graph: &WorkflowGraph,
423 ctx: &WorkflowContext,
424 branches: &[String],
425 input: WorkflowValue,
426 _record: &mut ExecutionRecord,
427 ) -> Result<WorkflowValue, String> {
428 let mut results = HashMap::new();
429 let mut errors = Vec::new();
430
431 for branch_id in branches {
433 if let Some(node) = graph.get_node(branch_id) {
434 let result = node.execute(ctx, input.clone()).await;
435 ctx.set_node_output(branch_id, result.output.clone()).await;
436 ctx.set_node_status(branch_id, result.status.clone()).await;
437
438 if result.status.is_success() {
439 results.insert(branch_id.clone(), result.output);
440 } else {
441 errors.push(format!(
442 "{}: {}",
443 branch_id,
444 result.error.unwrap_or_else(|| "Unknown error".to_string())
445 ));
446 }
447 } else {
448 errors.push(format!("Node {} not found", branch_id));
449 }
450 }
451
452 if !errors.is_empty() && self.config.stop_on_failure {
453 return Err(errors.join("; "));
454 }
455
456 Ok(WorkflowValue::Map(results))
457 }
458
459 async fn execute_join(
461 &self,
462 _graph: &WorkflowGraph,
463 ctx: &WorkflowContext,
464 node: &WorkflowNode,
465 _record: &mut ExecutionRecord,
466 ) -> Result<WorkflowValue, String> {
467 let wait_for = node.join_nodes();
468
469 let mut all_completed = false;
471 let mut attempts = 0;
472 const MAX_ATTEMPTS: u32 = 1000;
473
474 while !all_completed && attempts < MAX_ATTEMPTS {
475 all_completed = true;
476 for node_id in wait_for {
477 match ctx.get_node_status(node_id).await {
478 Some(status) if status.is_terminal() => {}
479 _ => {
480 all_completed = false;
481 break;
482 }
483 }
484 }
485 if !all_completed {
486 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
487 attempts += 1;
488 }
489 }
490
491 if !all_completed {
492 return Err("Join timeout waiting for nodes".to_string());
493 }
494
495 let outputs = ctx
497 .get_node_outputs(&wait_for.iter().map(|s| s.as_str()).collect::<Vec<_>>())
498 .await;
499
500 let result = node.execute(ctx, WorkflowValue::Map(outputs)).await;
502
503 ctx.set_node_output(node.id(), result.output.clone()).await;
504 ctx.set_node_status(node.id(), result.status.clone()).await;
505
506 if result.status.is_success() {
507 Ok(result.output)
508 } else {
509 Err(result.error.unwrap_or_else(|| "Join failed".to_string()))
510 }
511 }
512
513 async fn execute_sub_workflow(
516 &self,
517 _graph: &WorkflowGraph,
518 ctx: &WorkflowContext,
519 node: &WorkflowNode,
520 input: WorkflowValue,
521 _record: &mut ExecutionRecord,
522 ) -> Result<WorkflowValue, String> {
523 let sub_workflow_id = node
524 .sub_workflow_id()
525 .ok_or_else(|| "No sub-workflow specified".to_string())?;
526
527 let workflows = self.sub_workflows.read().await;
528 let sub_graph = workflows
529 .get(sub_workflow_id)
530 .ok_or_else(|| format!("Sub-workflow {} not found", sub_workflow_id))?
531 .clone();
532 drop(workflows);
533
534 info!("Executing sub-workflow: {}", sub_workflow_id);
535
536 let _sub_record = self.execute_parallel_workflow(&sub_graph, input).await?;
539
540 let output = if let Some(end_node) = sub_graph.end_nodes().first() {
542 ctx.get_node_output(end_node)
543 .await
544 .unwrap_or(WorkflowValue::Null)
545 } else {
546 WorkflowValue::Null
547 };
548
549 ctx.set_node_output(node.id(), output.clone()).await;
550 ctx.set_node_status(node.id(), NodeStatus::Completed).await;
551
552 Ok(output)
553 }
554
555 async fn execute_wait(
557 &self,
558 ctx: &WorkflowContext,
559 node: &WorkflowNode,
560 _input: WorkflowValue,
561 ) -> Result<WorkflowValue, String> {
562 let event_type = node
563 .wait_event_type()
564 .ok_or_else(|| "No event type specified".to_string())?;
565
566 info!("Waiting for event: {}", event_type);
567
568 let (tx, rx) = oneshot::channel();
570
571 {
572 let mut waiters = self.event_waiters.write().await;
573 waiters.entry(event_type.to_string()).or_default().push(tx);
574 }
575
576 let timeout = node.config.timeout.execution_timeout_ms;
578 let result = if timeout > 0 {
579 tokio::time::timeout(std::time::Duration::from_millis(timeout), rx)
580 .await
581 .map_err(|_| "Wait timeout".to_string())?
582 .map_err(|_| "Wait cancelled".to_string())?
583 } else {
584 rx.await.map_err(|_| "Wait cancelled".to_string())?
585 };
586
587 ctx.set_node_output(node.id(), result.clone()).await;
588 ctx.set_node_status(node.id(), NodeStatus::Completed).await;
589
590 Ok(result)
591 }
592
593 pub async fn execute_parallel_workflow(
597 &self,
598 graph: &WorkflowGraph,
599 input: WorkflowValue,
600 ) -> Result<ExecutionRecord, String> {
601 let ctx = WorkflowContext::new(&graph.id);
602 ctx.set_input(input.clone()).await;
603
604 let start_time = Instant::now();
605
606 info!(
607 "Starting layered workflow execution: {} ({})",
608 graph.name, ctx.execution_id
609 );
610
611 let groups = graph.get_parallel_groups();
613
614 let mut execution_record = ExecutionRecord {
615 execution_id: ctx.execution_id.clone(),
616 workflow_id: graph.id.clone(),
617 started_at: std::time::SystemTime::now()
618 .duration_since(std::time::UNIX_EPOCH)
619 .unwrap_or_default()
620 .as_millis() as u64,
621 ended_at: None,
622 status: WorkflowStatus::Running,
623 node_records: Vec::new(),
624 };
625
626 for group in groups {
628 for node_id in group {
629 let node_start_time = std::time::SystemTime::now()
630 .duration_since(std::time::UNIX_EPOCH)
631 .unwrap_or_default()
632 .as_millis() as u64;
633
634 let result = if let Some(node) = graph.get_node(&node_id) {
635 let predecessors = graph.get_predecessors(&node_id);
637 let node_input = if predecessors.is_empty() {
638 ctx.get_input().await
639 } else if predecessors.len() == 1 {
640 ctx.get_node_output(predecessors[0])
641 .await
642 .unwrap_or(WorkflowValue::Null)
643 } else {
644 let outputs = ctx.get_node_outputs(&predecessors).await;
645 WorkflowValue::Map(outputs)
646 };
647
648 let result = node.execute(&ctx, node_input).await;
649 ctx.set_node_output(&node_id, result.output.clone()).await;
650 ctx.set_node_status(&node_id, result.status.clone()).await;
651 result
652 } else {
653 NodeResult::failed(&node_id, "Node not found", 0)
654 };
655
656 let node_end_time = std::time::SystemTime::now()
657 .duration_since(std::time::UNIX_EPOCH)
658 .unwrap_or_default()
659 .as_millis() as u64;
660
661 let record_entry = NodeExecutionRecord {
662 node_id: node_id.clone(),
663 started_at: node_start_time,
664 ended_at: node_end_time,
665 status: result.status.clone(),
666 retry_count: result.retry_count,
667 };
668 execution_record.node_records.push(record_entry);
669
670 if !result.status.is_success() && self.config.stop_on_failure {
671 execution_record.status = WorkflowStatus::Failed(
672 result.error.unwrap_or_else(|| "Unknown error".to_string()),
673 );
674 execution_record.ended_at = Some(
675 std::time::SystemTime::now()
676 .duration_since(std::time::UNIX_EPOCH)
677 .unwrap_or_default()
678 .as_millis() as u64,
679 );
680 return Ok(execution_record);
681 }
682 }
683 }
684
685 let duration = start_time.elapsed();
686 execution_record.status = WorkflowStatus::Completed;
687 execution_record.ended_at = Some(
688 std::time::SystemTime::now()
689 .duration_since(std::time::UNIX_EPOCH)
690 .unwrap_or_default()
691 .as_millis() as u64,
692 );
693
694 info!(
695 "Layered workflow {} completed in {:?}",
696 graph.name, duration
697 );
698
699 Ok(execution_record)
700 }
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706
707 #[tokio::test]
708 async fn test_simple_workflow_execution() {
709 let mut graph = WorkflowGraph::new("test", "Simple Workflow");
710
711 graph.add_node(WorkflowNode::start("start"));
712 graph.add_node(WorkflowNode::task(
713 "double",
714 "Double",
715 |_ctx, input| async move {
716 let value = input.as_i64().unwrap_or(0);
717 Ok(WorkflowValue::Int(value * 2))
718 },
719 ));
720 graph.add_node(WorkflowNode::task(
721 "add_ten",
722 "Add Ten",
723 |_ctx, input| async move {
724 let value = input.as_i64().unwrap_or(0);
725 Ok(WorkflowValue::Int(value + 10))
726 },
727 ));
728 graph.add_node(WorkflowNode::end("end"));
729
730 graph.connect("start", "double");
731 graph.connect("double", "add_ten");
732 graph.connect("add_ten", "end");
733
734 let executor = WorkflowExecutor::new(ExecutorConfig::default());
735 let result = executor
736 .execute(&graph, WorkflowValue::Int(5))
737 .await
738 .unwrap();
739
740 assert!(matches!(result.status, WorkflowStatus::Completed));
741 }
742
743 #[tokio::test]
744 async fn test_conditional_workflow() {
745 let mut graph = WorkflowGraph::new("test", "Conditional Workflow");
746
747 graph.add_node(WorkflowNode::start("start"));
748 graph.add_node(WorkflowNode::condition(
749 "check",
750 "Check Value",
751 |_ctx, input| async move { input.as_i64().unwrap_or(0) > 10 },
752 ));
753 graph.add_node(WorkflowNode::task(
754 "high",
755 "High Path",
756 |_ctx, _input| async move { Ok(WorkflowValue::String("high".to_string())) },
757 ));
758 graph.add_node(WorkflowNode::task(
759 "low",
760 "Low Path",
761 |_ctx, _input| async move { Ok(WorkflowValue::String("low".to_string())) },
762 ));
763 graph.add_node(WorkflowNode::end("end"));
764
765 graph.connect("start", "check");
766 graph.connect_conditional("check", "high", "true");
767 graph.connect_conditional("check", "low", "false");
768 graph.connect("high", "end");
769 graph.connect("low", "end");
770
771 let executor = WorkflowExecutor::new(ExecutorConfig::default());
772
773 let result = executor
775 .execute(&graph, WorkflowValue::Int(20))
776 .await
777 .unwrap();
778 assert!(matches!(result.status, WorkflowStatus::Completed));
779 }
780}