1use anyhow::{Context, Result};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::time::timeout;
10use super::context::{WorkflowContext};
11use super::def::{FailureStrategy, NodeDef, NodeType, WorkflowDef};
12use super::rule_engine::evaluate_expression;
13use super::template::TemplateRenderer;
14use super::executors::{NodeExecutor, ExecutorFactory};
15use crate::tools::toolproxy::{ProxyToolExecutor, ProxyToolDef};
16
17#[async_trait::async_trait]
19pub trait TaskExecutor: Send + Sync {
20 async fn execute(
22 &self,
23 task_name: &str,
24 params: &HashMap<String, serde_json::Value>,
25 context: &WorkflowContext,
26 ) -> Result<serde_json::Value>;
27}
28
29#[derive(Debug, Clone)]
31pub enum WorkflowEvent {
32 Started,
34 NodeStarted { node_id: String },
36 NodeCompleted { node_id: String, output: Option<serde_json::Value> },
38 NodeFailed { node_id: String, error: String },
40 NodeSkipped { node_id: String, reason: String },
42 Completed,
44 Failed { error: String },
46 Paused,
48 Resumed,
50}
51
52pub trait EventListener: Send + Sync {
54 fn on_event(&self, event: WorkflowEvent);
55}
56
57pub struct WorkflowEngine {
59 definition: WorkflowDef,
61 executor: Option<Arc<dyn TaskExecutor>>,
63 node_executors: HashMap<String, Arc<dyn NodeExecutor>>,
65 executor_factory: Option<ExecutorFactory>,
67 proxy_executor: Option<Arc<dyn ProxyToolExecutor>>,
69 proxy_tool_defs: Vec<ProxyToolDef>,
71 listeners: Vec<Box<dyn EventListener>>,
73 template_renderer: TemplateRenderer,
75}
76
77impl WorkflowEngine {
78 pub fn new(definition: WorkflowDef) -> Result<Self> {
80 definition.validate()
81 .with_context(|| "Invalid workflow definition")?;
82
83 Ok(Self {
84 definition,
85 executor: None,
86 node_executors: HashMap::new(),
87 executor_factory: None,
88 proxy_executor: None,
89 proxy_tool_defs: Vec::new(),
90 listeners: Vec::new(),
91 template_renderer: TemplateRenderer::new(),
92 })
93 }
94
95 pub fn with_executor(mut self, executor: Arc<dyn TaskExecutor>) -> Self {
97 self.executor = Some(executor);
98 self
99 }
100
101 pub fn with_executor_factory(mut self, factory: ExecutorFactory) -> Self {
103 self.executor_factory = Some(factory);
104 self
105 }
106
107 pub fn with_proxy_executor(mut self, executor: Arc<dyn ProxyToolExecutor>, tool_defs: Vec<ProxyToolDef>) -> Self {
109 self.proxy_executor = Some(executor);
110 self.proxy_tool_defs = tool_defs;
111 self
112 }
113
114 pub fn register_node_executor(mut self, task_type: &str, executor: Arc<dyn NodeExecutor>) -> Self {
116 self.node_executors.insert(task_type.to_string(), executor);
117 self
118 }
119
120 pub fn add_listener(&mut self, listener: Box<dyn EventListener>) {
122 self.listeners.push(listener);
123 }
124
125 fn emit_event(&self, event: WorkflowEvent) {
127 for listener in &self.listeners {
128 listener.on_event(event.clone());
129 }
130 }
131
132 fn get_node_executor(&self, node: &NodeDef) -> Option<Arc<dyn NodeExecutor>> {
134 if let Some(task) = &node.task
136 && let Some(executor) = self.node_executors.get(task) {
137 return Some(executor.clone());
138 }
139
140 if let Some(task) = &node.task
142 && self.proxy_tool_defs.iter().any(|t| t.definition.name == *task)
143 && let Some(executor) = &self.proxy_executor {
144 return Some(Arc::new(super::executors::ProxyExecutor::new(
145 executor.clone(),
146 self.proxy_tool_defs.clone(),
147 )));
148 }
149
150 match node.node_type {
152 NodeType::Task => {
153 if let Some(factory) = &self.executor_factory
155 && let Some(task) = &node.task {
156 let task_lower = task.to_lowercase();
159 if task_lower == "ai" || task_lower.starts_with("ai_") || task_lower.starts_with("claude") || task_lower.starts_with("gpt") {
160 return factory.create_ai_executor().ok();
161 }
162 return Some(factory.create_tool_executor());
164 }
165 }
166 NodeType::Condition => {
167 if let Some(factory) = &self.executor_factory {
168 return Some(factory.create_condition_executor());
169 }
170 }
171 NodeType::Approval => {
172 if let Some(factory) = &self.executor_factory {
174 return Some(factory.create_validate_executor());
175 }
176 }
177 _ => {}
178 }
179
180 None
181 }
182
183 pub async fn run(&self, inputs: HashMap<String, serde_json::Value>) -> Result<WorkflowContext> {
185 let mut context = WorkflowContext::new(self.definition.id.clone(), inputs.clone());
187
188 self.validate_inputs(&context)?;
190
191 for (key, value) in inputs {
193 context.set_variable(key.clone(), value.clone());
194 }
195
196 let renderer = crate::workflow::template::TemplateRenderer::new();
198 for (key, value) in &self.definition.variables {
199 let rendered_value = if let serde_json::Value::String(s) = value {
201 match renderer.render(s, &context.variables) {
202 Ok(rendered) => serde_json::Value::String(rendered),
203 Err(_) => value.clone(), }
205 } else {
206 value.clone()
207 };
208 context.set_variable(key.clone(), rendered_value);
209 }
210
211 context.start();
213 self.emit_event(WorkflowEvent::Started);
214
215 let start_node = self.definition.get_start_node()
217 .ok_or_else(|| anyhow::anyhow!("No start node found"))?;
218
219 match self.execute_from_node(start_node, &mut context).await {
221 Ok(()) => {
222 context.complete();
223 self.emit_event(WorkflowEvent::Completed);
224 }
225 Err(e) => {
226 context.fail(e.to_string());
227 self.emit_event(WorkflowEvent::Failed { error: e.to_string() });
228 }
229 }
230
231 Ok(context)
232 }
233
234 async fn execute_from_node(
236 &self,
237 node: &NodeDef,
238 context: &mut WorkflowContext,
239 ) -> Result<()> {
240 let mut current_node = Some(node);
241
242 while let Some(node) = current_node {
243 if !context.can_continue() {
245 break;
246 }
247
248 match self.execute_node(node, context).await {
250 Ok(next_node_id) => {
251 current_node = next_node_id
252 .as_ref()
253 .and_then(|id| self.definition.get_node(id));
254 }
255 Err(e) => {
256 match &node.on_failure {
258 FailureStrategy::Retry { max_attempts, interval_ms } => {
259 let exec = context.get_or_create_node_execution(&node.id);
260 if exec.retry_count < *max_attempts {
261 exec.increment_retry();
262 if let Some(interval) = interval_ms {
263 tokio::time::sleep(Duration::from_millis(*interval)).await;
264 }
265 continue; } else {
267 return Err(e);
268 }
269 }
270 FailureStrategy::Ignore => {
271 let exec = context.get_or_create_node_execution(&node.id);
273 exec.skip();
274 self.emit_event(WorkflowEvent::NodeSkipped {
275 node_id: node.id.clone(),
276 reason: e.to_string(),
277 });
278 let next = self.get_next_node(node, context)?;
279 current_node = next
280 .as_ref()
281 .and_then(|id| self.definition.get_node(id));
282 }
283 FailureStrategy::Abort => {
284 return Err(e);
285 }
286 FailureStrategy::Goto { target } => {
287 current_node = self.definition.get_node(target);
288 }
289 }
290 }
291 }
292 }
293
294 Ok(())
295 }
296
297 async fn execute_node(
299 &self,
300 node: &NodeDef,
301 context: &mut WorkflowContext,
302 ) -> Result<Option<String>> {
303 let execution = context.get_or_create_node_execution(&node.id);
305 execution.start();
306 self.emit_event(WorkflowEvent::NodeStarted { node_id: node.id.clone() });
307
308 context.set_current_node(node.id.clone());
310
311 let result = if let Some(timeout_ms) = node.timeout_ms {
313 timeout(
314 Duration::from_millis(timeout_ms),
315 self.execute_node_inner(node, context),
316 )
317 .await
318 .with_context(|| format!("Node '{}' timed out after {}ms", node.id, timeout_ms))?
319 } else {
320 self.execute_node_inner(node, context).await
321 };
322
323 match result {
324 Ok(output) => {
325 let exec = context.get_or_create_node_execution(&node.id);
326 exec.complete(output.clone());
327 self.emit_event(WorkflowEvent::NodeCompleted {
328 node_id: node.id.clone(),
329 output,
330 });
331
332 self.get_next_node(node, context)
334 }
335 Err(e) => {
336 let exec = context.get_or_create_node_execution(&node.id);
337 exec.fail(e.to_string());
338 self.emit_event(WorkflowEvent::NodeFailed {
339 node_id: node.id.clone(),
340 error: e.to_string(),
341 });
342 Err(e)
343 }
344 }
345 }
346
347 async fn execute_node_inner(
349 &self,
350 node: &NodeDef,
351 context: &mut WorkflowContext,
352 ) -> Result<Option<serde_json::Value>> {
353 match &node.node_type {
354 NodeType::Start => {
355 Ok(None)
356 }
357 NodeType::End => {
358 Ok(None)
359 }
360 NodeType::Task => {
361 self.execute_task(node, context).await
362 }
363 NodeType::Condition => {
364 self.execute_condition(node, context).await
365 }
366 NodeType::Parallel => {
367 self.execute_parallel(node, context).await
368 }
369 NodeType::SubWorkflow => {
370 self.execute_subworkflow(node, context).await
371 }
372 NodeType::Wait => {
373 self.execute_wait(node, context).await
374 }
375 NodeType::Approval => {
376 self.execute_approval(node, context).await
377 }
378 }
379 }
380
381 async fn execute_task(
383 &self,
384 node: &NodeDef,
385 context: &mut WorkflowContext,
386 ) -> Result<Option<serde_json::Value>> {
387 let task_name = node.task.as_ref()
388 .ok_or_else(|| anyhow::anyhow!("Task node '{}' has no task name", node.id))?;
389
390 let mut rendered_params = HashMap::new();
392 for (key, value) in &node.params {
393 if let serde_json::Value::String(s) = value {
394 let rendered = self.template_renderer.render(s, &context.variables)?;
395 rendered_params.insert(key.clone(), serde_json::Value::String(rendered));
396 } else {
397 rendered_params.insert(key.clone(), value.clone());
398 }
399 }
400
401 if let Some(node_executor) = self.get_node_executor(node) {
403 let output = node_executor.execute(node, context).await?;
404 return Ok(Some(output));
405 }
406
407 if let Some(executor) = &self.executor {
409 let output = executor.execute(task_name, &rendered_params, context).await?;
410 Ok(Some(output))
411 } else {
412 Ok(Some(serde_json::json!({ "task": task_name, "status": "completed" })))
414 }
415 }
416
417 async fn execute_condition(
419 &self,
420 node: &NodeDef,
421 context: &mut WorkflowContext,
422 ) -> Result<Option<serde_json::Value>> {
423 let branches = node.branches.as_ref()
424 .ok_or_else(|| anyhow::anyhow!("Condition node '{}' has no branches", node.id))?;
425
426 for branch in branches {
427 if evaluate_expression(&branch.condition, &context.variables)? {
428 return Ok(Some(serde_json::Value::String(branch.target.clone())));
430 }
431 }
432
433 Ok(None)
435 }
436
437 async fn execute_parallel(
439 &self,
440 node: &NodeDef,
441 _context: &mut WorkflowContext,
442 ) -> Result<Option<serde_json::Value>> {
443 let branches = node.parallel_branches.as_ref()
444 .ok_or_else(|| anyhow::anyhow!("Parallel node '{}' has no branches", node.id))?;
445
446 let mut outputs = Vec::new();
448 for branch in branches {
449 outputs.push(serde_json::json!({
451 "branch": branch.name,
452 "status": "completed"
453 }));
454 }
455
456 Ok(Some(serde_json::Value::Array(outputs)))
457 }
458
459 async fn execute_subworkflow(
461 &self,
462 node: &NodeDef,
463 _context: &mut WorkflowContext,
464 ) -> Result<Option<serde_json::Value>> {
465 let workflow_name = node.workflow.as_ref()
466 .ok_or_else(|| anyhow::anyhow!("SubWorkflow node '{}' has no workflow name", node.id))?;
467
468 Ok(Some(serde_json::json!({
470 "workflow": workflow_name,
471 "status": "completed"
472 })))
473 }
474
475 async fn execute_wait(
477 &self,
478 node: &NodeDef,
479 _context: &mut WorkflowContext,
480 ) -> Result<Option<serde_json::Value>> {
481 let wait_ms = node.wait_ms.unwrap_or(0);
482 if wait_ms > 0 {
483 tokio::time::sleep(Duration::from_millis(wait_ms)).await;
484 }
485 Ok(None)
486 }
487
488 async fn execute_approval(
490 &self,
491 node: &NodeDef,
492 _context: &mut WorkflowContext,
493 ) -> Result<Option<serde_json::Value>> {
494 let approvers = node.approvers.as_ref()
495 .ok_or_else(|| anyhow::anyhow!("Approval node '{}' has no approvers", node.id))?;
496
497 Ok(Some(serde_json::json!({
499 "approvers": approvers,
500 "status": "pending_approval"
501 })))
502 }
503
504 fn get_next_node(
506 &self,
507 node: &NodeDef,
508 context: &WorkflowContext,
509 ) -> Result<Option<String>> {
510 if node.node_type == NodeType::End {
512 return Ok(None);
513 }
514
515 let edges = self.definition.get_outgoing_edges(&node.id);
517
518 if edges.is_empty() {
519 return Ok(None);
520 }
521
522 if node.node_type == NodeType::Condition {
524 let exec = context.get_node_execution(&node.id);
525 if let Some(exec) = exec
526 && let Some(serde_json::Value::String(target)) = &exec.output {
527 return Ok(Some(target.clone()));
528 }
529 }
530
531 for edge in edges {
533 if let Some(condition) = &edge.condition {
534 if evaluate_expression(condition, &context.variables)? {
535 return Ok(Some(edge.to.clone()));
536 }
537 } else {
538 return Ok(Some(edge.to.clone()));
540 }
541 }
542
543 Ok(None)
545 }
546
547 fn validate_inputs(&self, context: &WorkflowContext) -> Result<()> {
549 for input_def in &self.definition.inputs {
550 if input_def.required
551 && context.get_input(&input_def.name).is_none()
552 && input_def.default.is_none() {
553 anyhow::bail!("Required input '{}' is missing", input_def.name);
554 }
555 }
556 Ok(())
557 }
558
559 pub fn definition(&self) -> &WorkflowDef {
561 &self.definition
562 }
563}
564
565pub struct DefaultTaskExecutor;
567
568#[async_trait::async_trait]
569impl TaskExecutor for DefaultTaskExecutor {
570 async fn execute(
571 &self,
572 task_name: &str,
573 _params: &HashMap<String, serde_json::Value>,
574 _context: &WorkflowContext,
575 ) -> Result<serde_json::Value> {
576 Ok(serde_json::json!({
577 "task": task_name,
578 "status": "completed",
579 "output": null
580 }))
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use super::super::def::EdgeDef;
588 use super::super::context::WorkflowStatus;
589
590 fn create_simple_workflow() -> WorkflowDef {
591 WorkflowDef {
592 id: "test-workflow".to_string(),
593 name: "Test Workflow".to_string(),
594 version: "1.0.0".to_string(),
595 description: None,
596 inputs: vec![],
597 outputs: vec![],
598 nodes: vec![
599 NodeDef {
600 id: "start".to_string(),
601 node_type: NodeType::Start,
602 name: "Start".to_string(),
603 description: None,
604 task: None,
605 params: HashMap::new(),
606 on_failure: FailureStrategy::Abort,
607 timeout_ms: None,
608 branches: None,
609 parallel_branches: None,
610 workflow: None,
611 wait_ms: None,
612 approvers: None,
613 },
614 NodeDef {
615 id: "task1".to_string(),
616 node_type: NodeType::Task,
617 name: "Task 1".to_string(),
618 description: None,
619 task: Some("do_something".to_string()),
620 params: HashMap::new(),
621 on_failure: FailureStrategy::Abort,
622 timeout_ms: None,
623 branches: None,
624 parallel_branches: None,
625 workflow: None,
626 wait_ms: None,
627 approvers: None,
628 },
629 NodeDef {
630 id: "end".to_string(),
631 node_type: NodeType::End,
632 name: "End".to_string(),
633 description: None,
634 task: None,
635 params: HashMap::new(),
636 on_failure: FailureStrategy::Abort,
637 timeout_ms: None,
638 branches: None,
639 parallel_branches: None,
640 workflow: None,
641 wait_ms: None,
642 approvers: None,
643 },
644 ],
645 edges: vec![
646 EdgeDef {
647 id: "e1".to_string(),
648 from: "start".to_string(),
649 to: "task1".to_string(),
650 condition: None,
651 label: None,
652 },
653 EdgeDef {
654 id: "e2".to_string(),
655 from: "task1".to_string(),
656 to: "end".to_string(),
657 condition: None,
658 label: None,
659 },
660 ],
661 variables: HashMap::new(),
662 default_failure_strategy: FailureStrategy::Abort,
663 timeout_ms: None,
664 }
665 }
666
667 #[tokio::test]
668 async fn test_engine_run() {
669 let workflow = create_simple_workflow();
670 let engine = WorkflowEngine::new(workflow).unwrap();
671
672 let inputs = HashMap::new();
673 let context = engine.run(inputs).await.unwrap();
674
675 assert_eq!(context.status, WorkflowStatus::Completed);
676 assert_eq!(context.execution_path.len(), 3);
677 }
678
679 #[tokio::test]
680 async fn test_engine_with_executor() {
681 let workflow = create_simple_workflow();
682 let executor = Arc::new(DefaultTaskExecutor);
683 let engine = WorkflowEngine::new(workflow)
684 .unwrap()
685 .with_executor(executor);
686
687 let inputs = HashMap::new();
688 let context = engine.run(inputs).await.unwrap();
689
690 assert_eq!(context.status, WorkflowStatus::Completed);
691 }
692}