1use super::context::WorkflowContext;
6use super::def::{ExecutionMode, FailureStrategy, NodeDef, NodeType, WorkflowDef};
7use super::executors::{ExecutorFactory, NodeExecutor};
8use super::rule_engine::evaluate_expression;
9use super::template::TemplateRenderer;
10use crate::tools::toolproxy::{ProxyToolDef, ProxyToolExecutor};
11use anyhow::{Context, Result};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::time::timeout;
16
17#[async_trait::async_trait]
19pub trait TaskExecutor: Send + Sync {
20 async fn execute(
22 &self,
23 task_name: &str,
24 params: &HashMap<String, serde_json::Value>,
25 context: &WorkflowContext,
26 ) -> Result<serde_json::Value>;
27}
28
29#[derive(Debug, Clone)]
31pub enum WorkflowEvent {
32 Started,
34 NodeStarted { node_id: String },
36 NodeCompleted {
38 node_id: String,
39 output: Option<serde_json::Value>,
40 },
41 NodeFailed { node_id: String, error: String },
43 NodeSkipped { node_id: String, reason: String },
45 Completed,
47 Failed { error: String },
49 Paused,
51 Resumed,
53}
54
55pub trait EventListener: Send + Sync {
57 fn on_event(&self, event: WorkflowEvent);
58}
59
60pub struct WorkflowEngine {
62 definition: WorkflowDef,
64 executor: Option<Arc<dyn TaskExecutor>>,
66 node_executors: HashMap<String, Arc<dyn NodeExecutor>>,
68 executor_factory: Option<ExecutorFactory>,
70 proxy_executor: Option<Arc<dyn ProxyToolExecutor>>,
72 proxy_tool_defs: Vec<ProxyToolDef>,
74 listeners: Vec<Box<dyn EventListener>>,
76 template_renderer: TemplateRenderer,
78}
79
80impl WorkflowEngine {
81 pub fn new(definition: WorkflowDef) -> Result<Self> {
83 definition
84 .validate()
85 .with_context(|| "Invalid workflow definition")?;
86
87 Ok(Self {
88 definition,
89 executor: None,
90 node_executors: HashMap::new(),
91 executor_factory: None,
92 proxy_executor: None,
93 proxy_tool_defs: Vec::new(),
94 listeners: Vec::new(),
95 template_renderer: TemplateRenderer::new(),
96 })
97 }
98
99 pub fn with_executor(mut self, executor: Arc<dyn TaskExecutor>) -> Self {
101 self.executor = Some(executor);
102 self
103 }
104
105 pub fn with_executor_factory(mut self, factory: ExecutorFactory) -> Self {
107 self.executor_factory = Some(factory);
108 self
109 }
110
111 pub fn with_proxy_executor(
113 mut self,
114 executor: Arc<dyn ProxyToolExecutor>,
115 tool_defs: Vec<ProxyToolDef>,
116 ) -> Self {
117 self.proxy_executor = Some(executor);
118 self.proxy_tool_defs = tool_defs;
119 self
120 }
121
122 pub fn register_node_executor(
124 mut self,
125 task_type: &str,
126 executor: Arc<dyn NodeExecutor>,
127 ) -> Self {
128 self.node_executors.insert(task_type.to_string(), executor);
129 self
130 }
131
132 pub fn add_listener(&mut self, listener: Box<dyn EventListener>) {
134 self.listeners.push(listener);
135 }
136
137 fn emit_event(&self, event: WorkflowEvent) {
139 for listener in &self.listeners {
140 listener.on_event(event.clone());
141 }
142 }
143
144 fn get_node_executor(&self, node: &NodeDef) -> Option<Arc<dyn NodeExecutor>> {
146 if let Some(task) = &node.task
148 && let Some(executor) = self.node_executors.get(task)
149 {
150 return Some(executor.clone());
151 }
152
153 if let Some(task) = &node.task
155 && self
156 .proxy_tool_defs
157 .iter()
158 .any(|t| t.definition.name == *task)
159 && let Some(executor) = &self.proxy_executor
160 {
161 return Some(Arc::new(super::executors::ProxyExecutor::new(
162 executor.clone(),
163 self.proxy_tool_defs.clone(),
164 )));
165 }
166
167 match node.node_type {
169 NodeType::Task => {
170 if let Some(factory) = &self.executor_factory
172 && let Some(task) = &node.task
173 {
174 let task_lower = task.to_lowercase();
177 if task_lower == "ai"
178 || task_lower.starts_with("ai_")
179 || task_lower.starts_with("claude")
180 || task_lower.starts_with("gpt")
181 {
182 return factory.create_ai_executor().ok();
183 }
184 return Some(factory.create_tool_executor());
186 }
187 }
188 NodeType::Condition => {
189 if let Some(factory) = &self.executor_factory {
190 return Some(factory.create_condition_executor());
191 }
192 }
193 NodeType::Approval => {
194 if let Some(factory) = &self.executor_factory {
196 return Some(factory.create_validate_executor());
197 }
198 }
199 _ => {}
200 }
201
202 None
203 }
204
205 pub async fn run(&self, inputs: HashMap<String, serde_json::Value>) -> Result<WorkflowContext> {
207 let mut context = WorkflowContext::new(self.definition.id.clone(), inputs.clone());
209
210 self.validate_inputs(&context)?;
212
213 for (key, value) in inputs {
215 context.set_variable(key.clone(), value.clone());
216 }
217
218 let renderer = crate::workflow::template::TemplateRenderer::new();
220 for (key, value) in &self.definition.variables {
221 let rendered_value = if let serde_json::Value::String(s) = value {
223 match renderer.render(s, &context.variables) {
224 Ok(rendered) => serde_json::Value::String(rendered),
225 Err(_) => value.clone(), }
227 } else {
228 value.clone()
229 };
230 context.set_variable(key.clone(), rendered_value);
231 }
232
233 context.start();
235 self.emit_event(WorkflowEvent::Started);
236
237 let start_node = self
239 .definition
240 .get_start_node()
241 .ok_or_else(|| anyhow::anyhow!("No start node found"))?;
242
243 match self.execute_from_node(start_node, &mut context).await {
245 Ok(()) => {
246 context.complete();
247 self.emit_event(WorkflowEvent::Completed);
248 }
249 Err(e) => {
250 context.fail(e.to_string());
251 self.emit_event(WorkflowEvent::Failed {
252 error: e.to_string(),
253 });
254 }
255 }
256
257 Ok(context)
258 }
259
260 async fn execute_from_node(&self, node: &NodeDef, context: &mut WorkflowContext) -> Result<()> {
262 let mut current_node = Some(node);
263
264 while let Some(node) = current_node {
265 if !context.can_continue() {
267 break;
268 }
269
270 match self.execute_node(node, context).await {
272 Ok(next_node_id) => {
273 current_node = next_node_id
274 .as_ref()
275 .and_then(|id| self.definition.get_node(id));
276 }
277 Err(e) => {
278 match &node.on_failure {
280 FailureStrategy::Retry {
281 max_attempts,
282 interval_ms,
283 } => {
284 let exec = context.get_or_create_node_execution(&node.id);
285 if exec.retry_count < *max_attempts {
286 exec.increment_retry();
287 if let Some(interval) = interval_ms {
288 tokio::time::sleep(Duration::from_millis(*interval)).await;
289 }
290 continue; } else {
292 return Err(e);
293 }
294 }
295 FailureStrategy::Ignore => {
296 let exec = context.get_or_create_node_execution(&node.id);
298 exec.skip();
299 self.emit_event(WorkflowEvent::NodeSkipped {
300 node_id: node.id.clone(),
301 reason: e.to_string(),
302 });
303 let next = self.get_next_node(node, context)?;
304 current_node =
305 next.as_ref().and_then(|id| self.definition.get_node(id));
306 }
307 FailureStrategy::Abort => {
308 return Err(e);
309 }
310 FailureStrategy::Goto { target } => {
311 current_node = self.definition.get_node(target);
312 }
313 }
314 }
315 }
316 }
317
318 Ok(())
319 }
320
321 async fn execute_node(
323 &self,
324 node: &NodeDef,
325 context: &mut WorkflowContext,
326 ) -> Result<Option<String>> {
327 let execution = context.get_or_create_node_execution(&node.id);
329 execution.start();
330 self.emit_event(WorkflowEvent::NodeStarted {
331 node_id: node.id.clone(),
332 });
333
334 context.set_current_node(node.id.clone());
336
337 let result = if let Some(timeout_ms) = node.timeout_ms {
339 timeout(
340 Duration::from_millis(timeout_ms),
341 self.execute_node_inner(node, context),
342 )
343 .await
344 .with_context(|| format!("Node '{}' timed out after {}ms", node.id, timeout_ms))?
345 } else {
346 self.execute_node_inner(node, context).await
347 };
348
349 match result {
350 Ok(output) => {
351 let exec = context.get_or_create_node_execution(&node.id);
352 exec.complete(output.clone());
353 self.emit_event(WorkflowEvent::NodeCompleted {
354 node_id: node.id.clone(),
355 output,
356 });
357
358 self.get_next_node(node, context)
360 }
361 Err(e) => {
362 let exec = context.get_or_create_node_execution(&node.id);
363 exec.fail(e.to_string());
364 self.emit_event(WorkflowEvent::NodeFailed {
365 node_id: node.id.clone(),
366 error: e.to_string(),
367 });
368 Err(e)
369 }
370 }
371 }
372
373 async fn execute_node_inner(
375 &self,
376 node: &NodeDef,
377 context: &mut WorkflowContext,
378 ) -> Result<Option<serde_json::Value>> {
379 match &node.node_type {
380 NodeType::Start => Ok(None),
381 NodeType::End => Ok(None),
382 NodeType::Task => self.execute_task(node, context).await,
383 NodeType::Condition => self.execute_condition(node, context).await,
384 NodeType::Parallel => self.execute_parallel(node, context).await,
385 NodeType::Pipeline => self.execute_pipeline(node, context).await,
386 NodeType::SubWorkflow => self.execute_subworkflow(node, context).await,
387 NodeType::Wait => self.execute_wait(node, context).await,
388 NodeType::Approval => self.execute_approval(node, context).await,
389 }
390 }
391
392 async fn execute_task(
394 &self,
395 node: &NodeDef,
396 context: &mut WorkflowContext,
397 ) -> Result<Option<serde_json::Value>> {
398 let task_name = node
399 .task
400 .as_ref()
401 .ok_or_else(|| anyhow::anyhow!("Task node '{}' has no task name", node.id))?;
402
403 let mut rendered_params = HashMap::new();
405 for (key, value) in &node.params {
406 if let serde_json::Value::String(s) = value {
407 let rendered = self.template_renderer.render(s, &context.variables)?;
408 rendered_params.insert(key.clone(), serde_json::Value::String(rendered));
409 } else {
410 rendered_params.insert(key.clone(), value.clone());
411 }
412 }
413
414 if let Some(node_executor) = self.get_node_executor(node) {
416 let output = node_executor.execute(node, context).await?;
417 return Ok(Some(output));
418 }
419
420 if let Some(executor) = &self.executor {
422 let output = executor
423 .execute(task_name, &rendered_params, context)
424 .await?;
425 Ok(Some(output))
426 } else {
427 Ok(Some(
429 serde_json::json!({ "task": task_name, "status": "completed" }),
430 ))
431 }
432 }
433
434 async fn execute_condition(
436 &self,
437 node: &NodeDef,
438 context: &mut WorkflowContext,
439 ) -> Result<Option<serde_json::Value>> {
440 let branches = node
441 .branches
442 .as_ref()
443 .ok_or_else(|| anyhow::anyhow!("Condition node '{}' has no branches", node.id))?;
444
445 for branch in branches {
446 if evaluate_expression(&branch.condition, &context.variables)? {
447 return Ok(Some(serde_json::Value::String(branch.target.clone())));
449 }
450 }
451
452 Ok(None)
454 }
455
456 async fn execute_parallel(
458 &self,
459 node: &NodeDef,
460 _context: &mut WorkflowContext,
461 ) -> Result<Option<serde_json::Value>> {
462 let branches = node
463 .parallel_branches
464 .as_ref()
465 .ok_or_else(|| anyhow::anyhow!("Parallel node '{}' has no branches", node.id))?;
466
467 let mut outputs = Vec::new();
469 for branch in branches {
470 outputs.push(serde_json::json!({
472 "branch": branch.name,
473 "status": "completed"
474 }));
475 }
476
477 Ok(Some(serde_json::Value::Array(outputs)))
478 }
479
480 async fn execute_pipeline(
487 &self,
488 node: &NodeDef,
489 _context: &mut WorkflowContext,
490 ) -> Result<Option<serde_json::Value>> {
491 let branches = node
492 .parallel_branches
493 .as_ref()
494 .ok_or_else(|| anyhow::anyhow!("Pipeline node '{}' has no branches", node.id))?;
495
496 let mut outputs = Vec::new();
499 for branch in branches {
500 if branch.mode != ExecutionMode::Pipeline {
502 log::warn!(
503 "Pipeline node '{}' branch '{}' has mode '{}', expected Pipeline",
504 node.id, branch.name, branch.mode.display_name()
505 );
506 }
507
508 outputs.push(serde_json::json!({
510 "branch": branch.name,
511 "mode": "pipeline",
512 "status": "streaming",
513 "has_barrier": false
514 }));
515 }
516
517 log::info!(
518 "Pipeline node '{}' executing {} branches in streaming mode (no barrier)",
519 node.id,
520 branches.len()
521 );
522
523 Ok(Some(serde_json::Value::Array(outputs)))
524 }
525
526 async fn execute_subworkflow(
528 &self,
529 node: &NodeDef,
530 _context: &mut WorkflowContext,
531 ) -> Result<Option<serde_json::Value>> {
532 let workflow_name = node.workflow.as_ref().ok_or_else(|| {
533 anyhow::anyhow!("SubWorkflow node '{}' has no workflow name", node.id)
534 })?;
535
536 Ok(Some(serde_json::json!({
538 "workflow": workflow_name,
539 "status": "completed"
540 })))
541 }
542
543 async fn execute_wait(
545 &self,
546 node: &NodeDef,
547 _context: &mut WorkflowContext,
548 ) -> Result<Option<serde_json::Value>> {
549 let wait_ms = node.wait_ms.unwrap_or(0);
550 if wait_ms > 0 {
551 tokio::time::sleep(Duration::from_millis(wait_ms)).await;
552 }
553 Ok(None)
554 }
555
556 async fn execute_approval(
558 &self,
559 node: &NodeDef,
560 _context: &mut WorkflowContext,
561 ) -> Result<Option<serde_json::Value>> {
562 let approvers = node
563 .approvers
564 .as_ref()
565 .ok_or_else(|| anyhow::anyhow!("Approval node '{}' has no approvers", node.id))?;
566
567 Ok(Some(serde_json::json!({
569 "approvers": approvers,
570 "status": "pending_approval"
571 })))
572 }
573
574 fn get_next_node(&self, node: &NodeDef, context: &WorkflowContext) -> Result<Option<String>> {
576 if node.node_type == NodeType::End {
578 return Ok(None);
579 }
580
581 let edges = self.definition.get_outgoing_edges(&node.id);
583
584 if edges.is_empty() {
585 return Ok(None);
586 }
587
588 if node.node_type == NodeType::Condition {
590 let exec = context.get_node_execution(&node.id);
591 if let Some(exec) = exec
592 && let Some(serde_json::Value::String(target)) = &exec.output
593 {
594 return Ok(Some(target.clone()));
595 }
596 }
597
598 for edge in edges {
600 if let Some(condition) = &edge.condition {
601 if evaluate_expression(condition, &context.variables)? {
602 return Ok(Some(edge.to.clone()));
603 }
604 } else {
605 return Ok(Some(edge.to.clone()));
607 }
608 }
609
610 Ok(None)
612 }
613
614 fn validate_inputs(&self, context: &WorkflowContext) -> Result<()> {
616 for input_def in &self.definition.inputs {
617 if input_def.required
618 && context.get_input(&input_def.name).is_none()
619 && input_def.default.is_none()
620 {
621 anyhow::bail!("Required input '{}' is missing", input_def.name);
622 }
623 }
624 Ok(())
625 }
626
627 pub fn definition(&self) -> &WorkflowDef {
629 &self.definition
630 }
631}
632
633pub struct DefaultTaskExecutor;
635
636#[async_trait::async_trait]
637impl TaskExecutor for DefaultTaskExecutor {
638 async fn execute(
639 &self,
640 task_name: &str,
641 _params: &HashMap<String, serde_json::Value>,
642 _context: &WorkflowContext,
643 ) -> Result<serde_json::Value> {
644 Ok(serde_json::json!({
645 "task": task_name,
646 "status": "completed",
647 "output": null
648 }))
649 }
650}
651
652#[cfg(test)]
653mod tests {
654 use super::super::context::WorkflowStatus;
655 use super::super::def::EdgeDef;
656 use super::*;
657
658 fn create_simple_workflow() -> WorkflowDef {
659 WorkflowDef {
660 id: "test-workflow".to_string(),
661 name: "Test Workflow".to_string(),
662 version: "1.0.0".to_string(),
663 description: None,
664 inputs: vec![],
665 outputs: vec![],
666 nodes: vec![
667 NodeDef {
668 id: "start".to_string(),
669 node_type: NodeType::Start,
670 name: "Start".to_string(),
671 description: None,
672 task: None,
673 params: HashMap::new(),
674 on_failure: FailureStrategy::Abort,
675 timeout_ms: None,
676 branches: None,
677 parallel_branches: None,
678 execution_mode: None,
679 workflow: None,
680 wait_ms: None,
681 approvers: None,
682 },
683 NodeDef {
684 id: "task1".to_string(),
685 node_type: NodeType::Task,
686 name: "Task 1".to_string(),
687 description: None,
688 task: Some("do_something".to_string()),
689 params: HashMap::new(),
690 on_failure: FailureStrategy::Abort,
691 timeout_ms: None,
692 branches: None,
693 parallel_branches: None,
694 execution_mode: None,
695 workflow: None,
696 wait_ms: None,
697 approvers: None,
698 },
699 NodeDef {
700 id: "end".to_string(),
701 node_type: NodeType::End,
702 name: "End".to_string(),
703 description: None,
704 task: None,
705 params: HashMap::new(),
706 on_failure: FailureStrategy::Abort,
707 timeout_ms: None,
708 branches: None,
709 parallel_branches: None,
710 execution_mode: None,
711 workflow: None,
712 wait_ms: None,
713 approvers: None,
714 },
715 ],
716 edges: vec![
717 EdgeDef {
718 id: "e1".to_string(),
719 from: "start".to_string(),
720 to: "task1".to_string(),
721 condition: None,
722 label: None,
723 },
724 EdgeDef {
725 id: "e2".to_string(),
726 from: "task1".to_string(),
727 to: "end".to_string(),
728 condition: None,
729 label: None,
730 },
731 ],
732 variables: HashMap::new(),
733 default_failure_strategy: FailureStrategy::Abort,
734 timeout_ms: None,
735 }
736 }
737
738 #[tokio::test]
739 async fn test_engine_run() {
740 let workflow = create_simple_workflow();
741 let engine = WorkflowEngine::new(workflow).unwrap();
742
743 let inputs = HashMap::new();
744 let context = engine.run(inputs).await.unwrap();
745
746 assert_eq!(context.status, WorkflowStatus::Completed);
747 assert_eq!(context.execution_path.len(), 3);
748 }
749
750 #[tokio::test]
751 async fn test_engine_with_executor() {
752 let workflow = create_simple_workflow();
753 let executor = Arc::new(DefaultTaskExecutor);
754 let engine = WorkflowEngine::new(workflow)
755 .unwrap()
756 .with_executor(executor);
757
758 let inputs = HashMap::new();
759 let context = engine.run(inputs).await.unwrap();
760
761 assert_eq!(context.status, WorkflowStatus::Completed);
762 }
763}