use anyhow::{Context, Result};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
use super::context::{WorkflowContext};
use super::def::{FailureStrategy, NodeDef, NodeType, WorkflowDef};
use super::rule_engine::evaluate_expression;
use super::template::TemplateRenderer;
use super::executors::{NodeExecutor, ExecutorFactory};
use crate::tools::toolproxy::{ProxyToolExecutor, ProxyToolDef};
#[async_trait::async_trait]
pub trait TaskExecutor: Send + Sync {
async fn execute(
&self,
task_name: &str,
params: &HashMap<String, serde_json::Value>,
context: &WorkflowContext,
) -> Result<serde_json::Value>;
}
#[derive(Debug, Clone)]
pub enum WorkflowEvent {
Started,
NodeStarted { node_id: String },
NodeCompleted { node_id: String, output: Option<serde_json::Value> },
NodeFailed { node_id: String, error: String },
NodeSkipped { node_id: String, reason: String },
Completed,
Failed { error: String },
Paused,
Resumed,
}
pub trait EventListener: Send + Sync {
fn on_event(&self, event: WorkflowEvent);
}
pub struct WorkflowEngine {
definition: WorkflowDef,
executor: Option<Arc<dyn TaskExecutor>>,
node_executors: HashMap<String, Arc<dyn NodeExecutor>>,
executor_factory: Option<ExecutorFactory>,
proxy_executor: Option<Arc<dyn ProxyToolExecutor>>,
proxy_tool_defs: Vec<ProxyToolDef>,
listeners: Vec<Box<dyn EventListener>>,
template_renderer: TemplateRenderer,
}
impl WorkflowEngine {
pub fn new(definition: WorkflowDef) -> Result<Self> {
definition.validate()
.with_context(|| "Invalid workflow definition")?;
Ok(Self {
definition,
executor: None,
node_executors: HashMap::new(),
executor_factory: None,
proxy_executor: None,
proxy_tool_defs: Vec::new(),
listeners: Vec::new(),
template_renderer: TemplateRenderer::new(),
})
}
pub fn with_executor(mut self, executor: Arc<dyn TaskExecutor>) -> Self {
self.executor = Some(executor);
self
}
pub fn with_executor_factory(mut self, factory: ExecutorFactory) -> Self {
self.executor_factory = Some(factory);
self
}
pub fn with_proxy_executor(mut self, executor: Arc<dyn ProxyToolExecutor>, tool_defs: Vec<ProxyToolDef>) -> Self {
self.proxy_executor = Some(executor);
self.proxy_tool_defs = tool_defs;
self
}
pub fn register_node_executor(mut self, task_type: &str, executor: Arc<dyn NodeExecutor>) -> Self {
self.node_executors.insert(task_type.to_string(), executor);
self
}
pub fn add_listener(&mut self, listener: Box<dyn EventListener>) {
self.listeners.push(listener);
}
fn emit_event(&self, event: WorkflowEvent) {
for listener in &self.listeners {
listener.on_event(event.clone());
}
}
fn get_node_executor(&self, node: &NodeDef) -> Option<Arc<dyn NodeExecutor>> {
if let Some(task) = &node.task
&& let Some(executor) = self.node_executors.get(task) {
return Some(executor.clone());
}
if let Some(task) = &node.task
&& self.proxy_tool_defs.iter().any(|t| t.definition.name == *task)
&& let Some(executor) = &self.proxy_executor {
return Some(Arc::new(super::executors::ProxyExecutor::new(
executor.clone(),
self.proxy_tool_defs.clone(),
)));
}
match node.node_type {
NodeType::Task => {
if let Some(factory) = &self.executor_factory
&& let Some(task) = &node.task {
let task_lower = task.to_lowercase();
if task_lower == "ai" || task_lower.starts_with("ai_") || task_lower.starts_with("claude") || task_lower.starts_with("gpt") {
return factory.create_ai_executor().ok();
}
return Some(factory.create_tool_executor());
}
}
NodeType::Condition => {
if let Some(factory) = &self.executor_factory {
return Some(factory.create_condition_executor());
}
}
NodeType::Approval => {
if let Some(factory) = &self.executor_factory {
return Some(factory.create_validate_executor());
}
}
_ => {}
}
None
}
pub async fn run(&self, inputs: HashMap<String, serde_json::Value>) -> Result<WorkflowContext> {
let mut context = WorkflowContext::new(self.definition.id.clone(), inputs.clone());
self.validate_inputs(&context)?;
for (key, value) in inputs {
context.set_variable(key.clone(), value.clone());
}
let renderer = crate::workflow::template::TemplateRenderer::new();
for (key, value) in &self.definition.variables {
let rendered_value = if let serde_json::Value::String(s) = value {
match renderer.render(s, &context.variables) {
Ok(rendered) => serde_json::Value::String(rendered),
Err(_) => value.clone(), }
} else {
value.clone()
};
context.set_variable(key.clone(), rendered_value);
}
context.start();
self.emit_event(WorkflowEvent::Started);
let start_node = self.definition.get_start_node()
.ok_or_else(|| anyhow::anyhow!("No start node found"))?;
match self.execute_from_node(start_node, &mut context).await {
Ok(()) => {
context.complete();
self.emit_event(WorkflowEvent::Completed);
}
Err(e) => {
context.fail(e.to_string());
self.emit_event(WorkflowEvent::Failed { error: e.to_string() });
}
}
Ok(context)
}
async fn execute_from_node(
&self,
node: &NodeDef,
context: &mut WorkflowContext,
) -> Result<()> {
let mut current_node = Some(node);
while let Some(node) = current_node {
if !context.can_continue() {
break;
}
match self.execute_node(node, context).await {
Ok(next_node_id) => {
current_node = next_node_id
.as_ref()
.and_then(|id| self.definition.get_node(id));
}
Err(e) => {
match &node.on_failure {
FailureStrategy::Retry { max_attempts, interval_ms } => {
let exec = context.get_or_create_node_execution(&node.id);
if exec.retry_count < *max_attempts {
exec.increment_retry();
if let Some(interval) = interval_ms {
tokio::time::sleep(Duration::from_millis(*interval)).await;
}
continue; } else {
return Err(e);
}
}
FailureStrategy::Ignore => {
let exec = context.get_or_create_node_execution(&node.id);
exec.skip();
self.emit_event(WorkflowEvent::NodeSkipped {
node_id: node.id.clone(),
reason: e.to_string(),
});
let next = self.get_next_node(node, context)?;
current_node = next
.as_ref()
.and_then(|id| self.definition.get_node(id));
}
FailureStrategy::Abort => {
return Err(e);
}
FailureStrategy::Goto { target } => {
current_node = self.definition.get_node(target);
}
}
}
}
}
Ok(())
}
async fn execute_node(
&self,
node: &NodeDef,
context: &mut WorkflowContext,
) -> Result<Option<String>> {
let execution = context.get_or_create_node_execution(&node.id);
execution.start();
self.emit_event(WorkflowEvent::NodeStarted { node_id: node.id.clone() });
context.set_current_node(node.id.clone());
let result = if let Some(timeout_ms) = node.timeout_ms {
timeout(
Duration::from_millis(timeout_ms),
self.execute_node_inner(node, context),
)
.await
.with_context(|| format!("Node '{}' timed out after {}ms", node.id, timeout_ms))?
} else {
self.execute_node_inner(node, context).await
};
match result {
Ok(output) => {
let exec = context.get_or_create_node_execution(&node.id);
exec.complete(output.clone());
self.emit_event(WorkflowEvent::NodeCompleted {
node_id: node.id.clone(),
output,
});
self.get_next_node(node, context)
}
Err(e) => {
let exec = context.get_or_create_node_execution(&node.id);
exec.fail(e.to_string());
self.emit_event(WorkflowEvent::NodeFailed {
node_id: node.id.clone(),
error: e.to_string(),
});
Err(e)
}
}
}
async fn execute_node_inner(
&self,
node: &NodeDef,
context: &mut WorkflowContext,
) -> Result<Option<serde_json::Value>> {
match &node.node_type {
NodeType::Start => {
Ok(None)
}
NodeType::End => {
Ok(None)
}
NodeType::Task => {
self.execute_task(node, context).await
}
NodeType::Condition => {
self.execute_condition(node, context).await
}
NodeType::Parallel => {
self.execute_parallel(node, context).await
}
NodeType::SubWorkflow => {
self.execute_subworkflow(node, context).await
}
NodeType::Wait => {
self.execute_wait(node, context).await
}
NodeType::Approval => {
self.execute_approval(node, context).await
}
}
}
async fn execute_task(
&self,
node: &NodeDef,
context: &mut WorkflowContext,
) -> Result<Option<serde_json::Value>> {
let task_name = node.task.as_ref()
.ok_or_else(|| anyhow::anyhow!("Task node '{}' has no task name", node.id))?;
let mut rendered_params = HashMap::new();
for (key, value) in &node.params {
if let serde_json::Value::String(s) = value {
let rendered = self.template_renderer.render(s, &context.variables)?;
rendered_params.insert(key.clone(), serde_json::Value::String(rendered));
} else {
rendered_params.insert(key.clone(), value.clone());
}
}
if let Some(node_executor) = self.get_node_executor(node) {
let output = node_executor.execute(node, context).await?;
return Ok(Some(output));
}
if let Some(executor) = &self.executor {
let output = executor.execute(task_name, &rendered_params, context).await?;
Ok(Some(output))
} else {
Ok(Some(serde_json::json!({ "task": task_name, "status": "completed" })))
}
}
async fn execute_condition(
&self,
node: &NodeDef,
context: &mut WorkflowContext,
) -> Result<Option<serde_json::Value>> {
let branches = node.branches.as_ref()
.ok_or_else(|| anyhow::anyhow!("Condition node '{}' has no branches", node.id))?;
for branch in branches {
if evaluate_expression(&branch.condition, &context.variables)? {
return Ok(Some(serde_json::Value::String(branch.target.clone())));
}
}
Ok(None)
}
async fn execute_parallel(
&self,
node: &NodeDef,
_context: &mut WorkflowContext,
) -> Result<Option<serde_json::Value>> {
let branches = node.parallel_branches.as_ref()
.ok_or_else(|| anyhow::anyhow!("Parallel node '{}' has no branches", node.id))?;
let mut outputs = Vec::new();
for branch in branches {
outputs.push(serde_json::json!({
"branch": branch.name,
"status": "completed"
}));
}
Ok(Some(serde_json::Value::Array(outputs)))
}
async fn execute_subworkflow(
&self,
node: &NodeDef,
_context: &mut WorkflowContext,
) -> Result<Option<serde_json::Value>> {
let workflow_name = node.workflow.as_ref()
.ok_or_else(|| anyhow::anyhow!("SubWorkflow node '{}' has no workflow name", node.id))?;
Ok(Some(serde_json::json!({
"workflow": workflow_name,
"status": "completed"
})))
}
async fn execute_wait(
&self,
node: &NodeDef,
_context: &mut WorkflowContext,
) -> Result<Option<serde_json::Value>> {
let wait_ms = node.wait_ms.unwrap_or(0);
if wait_ms > 0 {
tokio::time::sleep(Duration::from_millis(wait_ms)).await;
}
Ok(None)
}
async fn execute_approval(
&self,
node: &NodeDef,
_context: &mut WorkflowContext,
) -> Result<Option<serde_json::Value>> {
let approvers = node.approvers.as_ref()
.ok_or_else(|| anyhow::anyhow!("Approval node '{}' has no approvers", node.id))?;
Ok(Some(serde_json::json!({
"approvers": approvers,
"status": "pending_approval"
})))
}
fn get_next_node(
&self,
node: &NodeDef,
context: &WorkflowContext,
) -> Result<Option<String>> {
if node.node_type == NodeType::End {
return Ok(None);
}
let edges = self.definition.get_outgoing_edges(&node.id);
if edges.is_empty() {
return Ok(None);
}
if node.node_type == NodeType::Condition {
let exec = context.get_node_execution(&node.id);
if let Some(exec) = exec
&& let Some(serde_json::Value::String(target)) = &exec.output {
return Ok(Some(target.clone()));
}
}
for edge in edges {
if let Some(condition) = &edge.condition {
if evaluate_expression(condition, &context.variables)? {
return Ok(Some(edge.to.clone()));
}
} else {
return Ok(Some(edge.to.clone()));
}
}
Ok(None)
}
fn validate_inputs(&self, context: &WorkflowContext) -> Result<()> {
for input_def in &self.definition.inputs {
if input_def.required
&& context.get_input(&input_def.name).is_none()
&& input_def.default.is_none() {
anyhow::bail!("Required input '{}' is missing", input_def.name);
}
}
Ok(())
}
pub fn definition(&self) -> &WorkflowDef {
&self.definition
}
}
pub struct DefaultTaskExecutor;
#[async_trait::async_trait]
impl TaskExecutor for DefaultTaskExecutor {
async fn execute(
&self,
task_name: &str,
_params: &HashMap<String, serde_json::Value>,
_context: &WorkflowContext,
) -> Result<serde_json::Value> {
Ok(serde_json::json!({
"task": task_name,
"status": "completed",
"output": null
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::def::EdgeDef;
use super::super::context::WorkflowStatus;
fn create_simple_workflow() -> WorkflowDef {
WorkflowDef {
id: "test-workflow".to_string(),
name: "Test Workflow".to_string(),
version: "1.0.0".to_string(),
description: None,
inputs: vec![],
outputs: vec![],
nodes: vec![
NodeDef {
id: "start".to_string(),
node_type: NodeType::Start,
name: "Start".to_string(),
description: None,
task: None,
params: HashMap::new(),
on_failure: FailureStrategy::Abort,
timeout_ms: None,
branches: None,
parallel_branches: None,
workflow: None,
wait_ms: None,
approvers: None,
},
NodeDef {
id: "task1".to_string(),
node_type: NodeType::Task,
name: "Task 1".to_string(),
description: None,
task: Some("do_something".to_string()),
params: HashMap::new(),
on_failure: FailureStrategy::Abort,
timeout_ms: None,
branches: None,
parallel_branches: None,
workflow: None,
wait_ms: None,
approvers: None,
},
NodeDef {
id: "end".to_string(),
node_type: NodeType::End,
name: "End".to_string(),
description: None,
task: None,
params: HashMap::new(),
on_failure: FailureStrategy::Abort,
timeout_ms: None,
branches: None,
parallel_branches: None,
workflow: None,
wait_ms: None,
approvers: None,
},
],
edges: vec![
EdgeDef {
id: "e1".to_string(),
from: "start".to_string(),
to: "task1".to_string(),
condition: None,
label: None,
},
EdgeDef {
id: "e2".to_string(),
from: "task1".to_string(),
to: "end".to_string(),
condition: None,
label: None,
},
],
variables: HashMap::new(),
default_failure_strategy: FailureStrategy::Abort,
timeout_ms: None,
}
}
#[tokio::test]
async fn test_engine_run() {
let workflow = create_simple_workflow();
let engine = WorkflowEngine::new(workflow).unwrap();
let inputs = HashMap::new();
let context = engine.run(inputs).await.unwrap();
assert_eq!(context.status, WorkflowStatus::Completed);
assert_eq!(context.execution_path.len(), 3);
}
#[tokio::test]
async fn test_engine_with_executor() {
let workflow = create_simple_workflow();
let executor = Arc::new(DefaultTaskExecutor);
let engine = WorkflowEngine::new(workflow)
.unwrap()
.with_executor(executor);
let inputs = HashMap::new();
let context = engine.run(inputs).await.unwrap();
assert_eq!(context.status, WorkflowStatus::Completed);
}
}