1use super::context::WorkflowContext;
6use super::def::{FailureStrategy, NodeDef, NodeType, WorkflowDef};
7use super::executors::{ExecutorFactory, NodeExecutor};
8use super::rule_engine::evaluate_expression;
9use super::template::TemplateRenderer;
10use crate::tools::toolproxy::{ProxyToolDef, ProxyToolExecutor};
11use anyhow::{Context, Result};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::time::timeout;
16
17#[async_trait::async_trait]
19pub trait TaskExecutor: Send + Sync {
20 async fn execute(
22 &self,
23 task_name: &str,
24 params: &HashMap<String, serde_json::Value>,
25 context: &WorkflowContext,
26 ) -> Result<serde_json::Value>;
27}
28
29#[derive(Debug, Clone)]
31pub enum WorkflowEvent {
32 Started,
34 NodeStarted { node_id: String },
36 NodeCompleted {
38 node_id: String,
39 output: Option<serde_json::Value>,
40 },
41 NodeFailed { node_id: String, error: String },
43 NodeSkipped { node_id: String, reason: String },
45 Completed,
47 Failed { error: String },
49 Paused,
51 Resumed,
53}
54
55pub trait EventListener: Send + Sync {
57 fn on_event(&self, event: WorkflowEvent);
58}
59
60pub struct WorkflowEngine {
62 definition: WorkflowDef,
64 executor: Option<Arc<dyn TaskExecutor>>,
66 node_executors: HashMap<String, Arc<dyn NodeExecutor>>,
68 executor_factory: Option<ExecutorFactory>,
70 proxy_executor: Option<Arc<dyn ProxyToolExecutor>>,
72 proxy_tool_defs: Vec<ProxyToolDef>,
74 listeners: Vec<Box<dyn EventListener>>,
76 template_renderer: TemplateRenderer,
78}
79
80impl WorkflowEngine {
81 pub fn new(definition: WorkflowDef) -> Result<Self> {
83 definition
84 .validate()
85 .with_context(|| "Invalid workflow definition")?;
86
87 Ok(Self {
88 definition,
89 executor: None,
90 node_executors: HashMap::new(),
91 executor_factory: None,
92 proxy_executor: None,
93 proxy_tool_defs: Vec::new(),
94 listeners: Vec::new(),
95 template_renderer: TemplateRenderer::new(),
96 })
97 }
98
99 pub fn with_executor(mut self, executor: Arc<dyn TaskExecutor>) -> Self {
101 self.executor = Some(executor);
102 self
103 }
104
105 pub fn with_executor_factory(mut self, factory: ExecutorFactory) -> Self {
107 self.executor_factory = Some(factory);
108 self
109 }
110
111 pub fn with_proxy_executor(
113 mut self,
114 executor: Arc<dyn ProxyToolExecutor>,
115 tool_defs: Vec<ProxyToolDef>,
116 ) -> Self {
117 self.proxy_executor = Some(executor);
118 self.proxy_tool_defs = tool_defs;
119 self
120 }
121
122 pub fn register_node_executor(
124 mut self,
125 task_type: &str,
126 executor: Arc<dyn NodeExecutor>,
127 ) -> Self {
128 self.node_executors.insert(task_type.to_string(), executor);
129 self
130 }
131
132 pub fn add_listener(&mut self, listener: Box<dyn EventListener>) {
134 self.listeners.push(listener);
135 }
136
137 fn emit_event(&self, event: WorkflowEvent) {
139 for listener in &self.listeners {
140 listener.on_event(event.clone());
141 }
142 }
143
144 fn get_node_executor(&self, node: &NodeDef) -> Option<Arc<dyn NodeExecutor>> {
146 if let Some(task) = &node.task
148 && let Some(executor) = self.node_executors.get(task)
149 {
150 return Some(executor.clone());
151 }
152
153 if let Some(task) = &node.task
155 && self
156 .proxy_tool_defs
157 .iter()
158 .any(|t| t.definition.name == *task)
159 && let Some(executor) = &self.proxy_executor
160 {
161 return Some(Arc::new(super::executors::ProxyExecutor::new(
162 executor.clone(),
163 self.proxy_tool_defs.clone(),
164 )));
165 }
166
167 match node.node_type {
169 NodeType::Task => {
170 if let Some(factory) = &self.executor_factory
172 && let Some(task) = &node.task
173 {
174 let task_lower = task.to_lowercase();
177 if task_lower == "ai"
178 || task_lower.starts_with("ai_")
179 || task_lower.starts_with("claude")
180 || task_lower.starts_with("gpt")
181 {
182 return factory.create_ai_executor().ok();
183 }
184 return Some(factory.create_tool_executor());
186 }
187 }
188 NodeType::Condition => {
189 if let Some(factory) = &self.executor_factory {
190 return Some(factory.create_condition_executor());
191 }
192 }
193 NodeType::Approval => {
194 if let Some(factory) = &self.executor_factory {
196 return Some(factory.create_validate_executor());
197 }
198 }
199 _ => {}
200 }
201
202 None
203 }
204
205 pub async fn run(&self, inputs: HashMap<String, serde_json::Value>) -> Result<WorkflowContext> {
207 let mut context = WorkflowContext::new(self.definition.id.clone(), inputs.clone());
209
210 self.validate_inputs(&context)?;
212
213 for (key, value) in inputs {
215 context.set_variable(key.clone(), value.clone());
216 }
217
218 let renderer = crate::workflow::template::TemplateRenderer::new();
220 for (key, value) in &self.definition.variables {
221 let rendered_value = if let serde_json::Value::String(s) = value {
223 match renderer.render(s, &context.variables) {
224 Ok(rendered) => serde_json::Value::String(rendered),
225 Err(_) => value.clone(), }
227 } else {
228 value.clone()
229 };
230 context.set_variable(key.clone(), rendered_value);
231 }
232
233 context.start();
235 self.emit_event(WorkflowEvent::Started);
236
237 let start_node = self
239 .definition
240 .get_start_node()
241 .ok_or_else(|| anyhow::anyhow!("No start node found"))?;
242
243 match self.execute_from_node(start_node, &mut context).await {
245 Ok(()) => {
246 context.complete();
247 self.emit_event(WorkflowEvent::Completed);
248 }
249 Err(e) => {
250 context.fail(e.to_string());
251 self.emit_event(WorkflowEvent::Failed {
252 error: e.to_string(),
253 });
254 }
255 }
256
257 Ok(context)
258 }
259
260 async fn execute_from_node(&self, node: &NodeDef, context: &mut WorkflowContext) -> Result<()> {
262 let mut current_node = Some(node);
263
264 while let Some(node) = current_node {
265 if !context.can_continue() {
267 break;
268 }
269
270 match self.execute_node(node, context).await {
272 Ok(next_node_id) => {
273 current_node = next_node_id
274 .as_ref()
275 .and_then(|id| self.definition.get_node(id));
276 }
277 Err(e) => {
278 match &node.on_failure {
280 FailureStrategy::Retry {
281 max_attempts,
282 interval_ms,
283 } => {
284 let exec = context.get_or_create_node_execution(&node.id);
285 if exec.retry_count < *max_attempts {
286 exec.increment_retry();
287 if let Some(interval) = interval_ms {
288 tokio::time::sleep(Duration::from_millis(*interval)).await;
289 }
290 continue; } else {
292 return Err(e);
293 }
294 }
295 FailureStrategy::Ignore => {
296 let exec = context.get_or_create_node_execution(&node.id);
298 exec.skip();
299 self.emit_event(WorkflowEvent::NodeSkipped {
300 node_id: node.id.clone(),
301 reason: e.to_string(),
302 });
303 let next = self.get_next_node(node, context)?;
304 current_node =
305 next.as_ref().and_then(|id| self.definition.get_node(id));
306 }
307 FailureStrategy::Abort => {
308 return Err(e);
309 }
310 FailureStrategy::Goto { target } => {
311 current_node = self.definition.get_node(target);
312 }
313 }
314 }
315 }
316 }
317
318 Ok(())
319 }
320
321 async fn execute_node(
323 &self,
324 node: &NodeDef,
325 context: &mut WorkflowContext,
326 ) -> Result<Option<String>> {
327 let execution = context.get_or_create_node_execution(&node.id);
329 execution.start();
330 self.emit_event(WorkflowEvent::NodeStarted {
331 node_id: node.id.clone(),
332 });
333
334 context.set_current_node(node.id.clone());
336
337 let result = if let Some(timeout_ms) = node.timeout_ms {
339 timeout(
340 Duration::from_millis(timeout_ms),
341 self.execute_node_inner(node, context),
342 )
343 .await
344 .with_context(|| format!("Node '{}' timed out after {}ms", node.id, timeout_ms))?
345 } else {
346 self.execute_node_inner(node, context).await
347 };
348
349 match result {
350 Ok(output) => {
351 let exec = context.get_or_create_node_execution(&node.id);
352 exec.complete(output.clone());
353 self.emit_event(WorkflowEvent::NodeCompleted {
354 node_id: node.id.clone(),
355 output,
356 });
357
358 self.get_next_node(node, context)
360 }
361 Err(e) => {
362 let exec = context.get_or_create_node_execution(&node.id);
363 exec.fail(e.to_string());
364 self.emit_event(WorkflowEvent::NodeFailed {
365 node_id: node.id.clone(),
366 error: e.to_string(),
367 });
368 Err(e)
369 }
370 }
371 }
372
373 async fn execute_node_inner(
375 &self,
376 node: &NodeDef,
377 context: &mut WorkflowContext,
378 ) -> Result<Option<serde_json::Value>> {
379 match &node.node_type {
380 NodeType::Start => Ok(None),
381 NodeType::End => Ok(None),
382 NodeType::Task => self.execute_task(node, context).await,
383 NodeType::Condition => self.execute_condition(node, context).await,
384 NodeType::Parallel => self.execute_parallel(node, context).await,
385 NodeType::SubWorkflow => self.execute_subworkflow(node, context).await,
386 NodeType::Wait => self.execute_wait(node, context).await,
387 NodeType::Approval => self.execute_approval(node, context).await,
388 }
389 }
390
391 async fn execute_task(
393 &self,
394 node: &NodeDef,
395 context: &mut WorkflowContext,
396 ) -> Result<Option<serde_json::Value>> {
397 let task_name = node
398 .task
399 .as_ref()
400 .ok_or_else(|| anyhow::anyhow!("Task node '{}' has no task name", node.id))?;
401
402 let mut rendered_params = HashMap::new();
404 for (key, value) in &node.params {
405 if let serde_json::Value::String(s) = value {
406 let rendered = self.template_renderer.render(s, &context.variables)?;
407 rendered_params.insert(key.clone(), serde_json::Value::String(rendered));
408 } else {
409 rendered_params.insert(key.clone(), value.clone());
410 }
411 }
412
413 if let Some(node_executor) = self.get_node_executor(node) {
415 let output = node_executor.execute(node, context).await?;
416 return Ok(Some(output));
417 }
418
419 if let Some(executor) = &self.executor {
421 let output = executor
422 .execute(task_name, &rendered_params, context)
423 .await?;
424 Ok(Some(output))
425 } else {
426 Ok(Some(
428 serde_json::json!({ "task": task_name, "status": "completed" }),
429 ))
430 }
431 }
432
433 async fn execute_condition(
435 &self,
436 node: &NodeDef,
437 context: &mut WorkflowContext,
438 ) -> Result<Option<serde_json::Value>> {
439 let branches = node
440 .branches
441 .as_ref()
442 .ok_or_else(|| anyhow::anyhow!("Condition node '{}' has no branches", node.id))?;
443
444 for branch in branches {
445 if evaluate_expression(&branch.condition, &context.variables)? {
446 return Ok(Some(serde_json::Value::String(branch.target.clone())));
448 }
449 }
450
451 Ok(None)
453 }
454
455 async fn execute_parallel(
457 &self,
458 node: &NodeDef,
459 _context: &mut WorkflowContext,
460 ) -> Result<Option<serde_json::Value>> {
461 let branches = node
462 .parallel_branches
463 .as_ref()
464 .ok_or_else(|| anyhow::anyhow!("Parallel node '{}' has no branches", node.id))?;
465
466 let mut outputs = Vec::new();
468 for branch in branches {
469 outputs.push(serde_json::json!({
471 "branch": branch.name,
472 "status": "completed"
473 }));
474 }
475
476 Ok(Some(serde_json::Value::Array(outputs)))
477 }
478
479 async fn execute_subworkflow(
481 &self,
482 node: &NodeDef,
483 _context: &mut WorkflowContext,
484 ) -> Result<Option<serde_json::Value>> {
485 let workflow_name = node.workflow.as_ref().ok_or_else(|| {
486 anyhow::anyhow!("SubWorkflow node '{}' has no workflow name", node.id)
487 })?;
488
489 Ok(Some(serde_json::json!({
491 "workflow": workflow_name,
492 "status": "completed"
493 })))
494 }
495
496 async fn execute_wait(
498 &self,
499 node: &NodeDef,
500 _context: &mut WorkflowContext,
501 ) -> Result<Option<serde_json::Value>> {
502 let wait_ms = node.wait_ms.unwrap_or(0);
503 if wait_ms > 0 {
504 tokio::time::sleep(Duration::from_millis(wait_ms)).await;
505 }
506 Ok(None)
507 }
508
509 async fn execute_approval(
511 &self,
512 node: &NodeDef,
513 _context: &mut WorkflowContext,
514 ) -> Result<Option<serde_json::Value>> {
515 let approvers = node
516 .approvers
517 .as_ref()
518 .ok_or_else(|| anyhow::anyhow!("Approval node '{}' has no approvers", node.id))?;
519
520 Ok(Some(serde_json::json!({
522 "approvers": approvers,
523 "status": "pending_approval"
524 })))
525 }
526
527 fn get_next_node(&self, node: &NodeDef, context: &WorkflowContext) -> Result<Option<String>> {
529 if node.node_type == NodeType::End {
531 return Ok(None);
532 }
533
534 let edges = self.definition.get_outgoing_edges(&node.id);
536
537 if edges.is_empty() {
538 return Ok(None);
539 }
540
541 if node.node_type == NodeType::Condition {
543 let exec = context.get_node_execution(&node.id);
544 if let Some(exec) = exec
545 && let Some(serde_json::Value::String(target)) = &exec.output
546 {
547 return Ok(Some(target.clone()));
548 }
549 }
550
551 for edge in edges {
553 if let Some(condition) = &edge.condition {
554 if evaluate_expression(condition, &context.variables)? {
555 return Ok(Some(edge.to.clone()));
556 }
557 } else {
558 return Ok(Some(edge.to.clone()));
560 }
561 }
562
563 Ok(None)
565 }
566
567 fn validate_inputs(&self, context: &WorkflowContext) -> Result<()> {
569 for input_def in &self.definition.inputs {
570 if input_def.required
571 && context.get_input(&input_def.name).is_none()
572 && input_def.default.is_none()
573 {
574 anyhow::bail!("Required input '{}' is missing", input_def.name);
575 }
576 }
577 Ok(())
578 }
579
580 pub fn definition(&self) -> &WorkflowDef {
582 &self.definition
583 }
584}
585
586pub struct DefaultTaskExecutor;
588
589#[async_trait::async_trait]
590impl TaskExecutor for DefaultTaskExecutor {
591 async fn execute(
592 &self,
593 task_name: &str,
594 _params: &HashMap<String, serde_json::Value>,
595 _context: &WorkflowContext,
596 ) -> Result<serde_json::Value> {
597 Ok(serde_json::json!({
598 "task": task_name,
599 "status": "completed",
600 "output": null
601 }))
602 }
603}
604
605#[cfg(test)]
606mod tests {
607 use super::super::context::WorkflowStatus;
608 use super::super::def::EdgeDef;
609 use super::*;
610
611 fn create_simple_workflow() -> WorkflowDef {
612 WorkflowDef {
613 id: "test-workflow".to_string(),
614 name: "Test Workflow".to_string(),
615 version: "1.0.0".to_string(),
616 description: None,
617 inputs: vec![],
618 outputs: vec![],
619 nodes: vec![
620 NodeDef {
621 id: "start".to_string(),
622 node_type: NodeType::Start,
623 name: "Start".to_string(),
624 description: None,
625 task: None,
626 params: HashMap::new(),
627 on_failure: FailureStrategy::Abort,
628 timeout_ms: None,
629 branches: None,
630 parallel_branches: None,
631 workflow: None,
632 wait_ms: None,
633 approvers: None,
634 },
635 NodeDef {
636 id: "task1".to_string(),
637 node_type: NodeType::Task,
638 name: "Task 1".to_string(),
639 description: None,
640 task: Some("do_something".to_string()),
641 params: HashMap::new(),
642 on_failure: FailureStrategy::Abort,
643 timeout_ms: None,
644 branches: None,
645 parallel_branches: None,
646 workflow: None,
647 wait_ms: None,
648 approvers: None,
649 },
650 NodeDef {
651 id: "end".to_string(),
652 node_type: NodeType::End,
653 name: "End".to_string(),
654 description: None,
655 task: None,
656 params: HashMap::new(),
657 on_failure: FailureStrategy::Abort,
658 timeout_ms: None,
659 branches: None,
660 parallel_branches: None,
661 workflow: None,
662 wait_ms: None,
663 approvers: None,
664 },
665 ],
666 edges: vec![
667 EdgeDef {
668 id: "e1".to_string(),
669 from: "start".to_string(),
670 to: "task1".to_string(),
671 condition: None,
672 label: None,
673 },
674 EdgeDef {
675 id: "e2".to_string(),
676 from: "task1".to_string(),
677 to: "end".to_string(),
678 condition: None,
679 label: None,
680 },
681 ],
682 variables: HashMap::new(),
683 default_failure_strategy: FailureStrategy::Abort,
684 timeout_ms: None,
685 }
686 }
687
688 #[tokio::test]
689 async fn test_engine_run() {
690 let workflow = create_simple_workflow();
691 let engine = WorkflowEngine::new(workflow).unwrap();
692
693 let inputs = HashMap::new();
694 let context = engine.run(inputs).await.unwrap();
695
696 assert_eq!(context.status, WorkflowStatus::Completed);
697 assert_eq!(context.execution_path.len(), 3);
698 }
699
700 #[tokio::test]
701 async fn test_engine_with_executor() {
702 let workflow = create_simple_workflow();
703 let executor = Arc::new(DefaultTaskExecutor);
704 let engine = WorkflowEngine::new(workflow)
705 .unwrap()
706 .with_executor(executor);
707
708 let inputs = HashMap::new();
709 let context = engine.run(inputs).await.unwrap();
710
711 assert_eq!(context.status, WorkflowStatus::Completed);
712 }
713}