use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use anyhow::Result;
use crate::tools::{ToolCall, ToolRegistry};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowStep {
pub id: String,
pub step_type: WorkflowStepType,
pub depends_on: Vec<String>,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum WorkflowStepType {
#[serde(rename = "llm")]
LLMGeneration {
prompt: String,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
},
#[serde(rename = "tool")]
ToolCall {
tool_name: String,
arguments: serde_json::Value,
},
#[serde(rename = "data_transform")]
DataTransform {
operation: String, expression: String, },
#[serde(rename = "conditional")]
Conditional {
condition: String,
if_true: Box<WorkflowStep>,
if_false: Option<Box<WorkflowStep>>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Workflow {
pub id: String,
pub name: String,
pub description: String,
pub steps: Vec<WorkflowStep>,
pub inputs: HashMap<String, serde_json::Value>,
pub outputs: Vec<String>, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowRequest {
pub workflow: Workflow,
pub context: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowResult {
pub workflow_id: String,
pub success: bool,
pub step_results: HashMap<String, StepResult>,
pub outputs: HashMap<String, serde_json::Value>,
pub error: Option<String>,
pub execution_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepResult {
pub step_id: String,
pub success: bool,
pub result: serde_json::Value,
pub error: Option<String>,
pub execution_time_ms: u64,
}
pub struct WorkflowEngine {
tool_registry: ToolRegistry,
}
impl WorkflowEngine {
pub fn new(tool_registry: ToolRegistry) -> Self {
Self { tool_registry }
}
pub async fn execute_workflow(&self, request: WorkflowRequest) -> Result<WorkflowResult> {
let start_time = std::time::Instant::now();
let mut step_results = HashMap::new();
let mut context = request.context;
for (key, value) in request.workflow.inputs {
context.insert(key, value);
}
let execution_order = self.calculate_execution_order(&request.workflow.steps)?;
for step_id in execution_order {
let step = request.workflow.steps.iter()
.find(|s| s.id == step_id)
.ok_or_else(|| anyhow::anyhow!("Step {} not found", step_id))?;
let step_start = std::time::Instant::now();
let step_result = match self.execute_step(step, &context, &step_results).await {
Ok(result) => {
context.insert(format!("step_{}", step.id), result.clone());
StepResult {
step_id: step.id.clone(),
success: true,
result,
error: None,
execution_time_ms: step_start.elapsed().as_millis() as u64,
}
}
Err(e) => StepResult {
step_id: step.id.clone(),
success: false,
result: serde_json::Value::Null,
error: Some(e.to_string()),
execution_time_ms: step_start.elapsed().as_millis() as u64,
}
};
step_results.insert(step.id.clone(), step_result);
}
let mut outputs = HashMap::new();
for output_step_id in &request.workflow.outputs {
if let Some(step_result) = step_results.get(output_step_id) {
outputs.insert(output_step_id.clone(), step_result.result.clone());
}
}
let success = step_results.values().all(|result| result.success);
Ok(WorkflowResult {
workflow_id: request.workflow.id,
success,
step_results,
outputs,
error: if success { None } else { Some("One or more steps failed".to_string()) },
execution_time_ms: start_time.elapsed().as_millis() as u64,
})
}
fn execute_step<'a>(
&'a self,
step: &'a WorkflowStep,
context: &'a HashMap<String, serde_json::Value>,
step_results: &'a HashMap<String, StepResult>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + 'a>> {
Box::pin(async move {
match &step.step_type {
WorkflowStepType::LLMGeneration { prompt, model, max_tokens, temperature } => {
let resolved_prompt = self.substitute_variables(prompt, context)?;
let result = self.call_llm(
&resolved_prompt,
model.as_deref().unwrap_or("default"),
max_tokens.unwrap_or(512),
temperature.unwrap_or(0.7),
).await?;
Ok(serde_json::json!({
"text": result,
"type": "llm_generation"
}))
}
WorkflowStepType::ToolCall { tool_name, arguments } => {
let resolved_args = self.substitute_variables_in_json(arguments, context)?;
let tool_call = ToolCall {
name: tool_name.clone(),
arguments: resolved_args,
};
let tool_result = self.tool_registry.execute_tool(&tool_call)?;
if tool_result.success {
Ok(tool_result.result)
} else {
Err(anyhow::anyhow!("Tool execution failed: {:?}", tool_result.error))
}
}
WorkflowStepType::DataTransform { operation, expression } => {
self.execute_data_transform(operation, expression, context, step_results)
}
WorkflowStepType::Conditional { condition, if_true, if_false } => {
let condition_result = self.evaluate_condition(condition, context)?;
if condition_result {
self.execute_step(if_true, context, step_results).await
} else if let Some(false_step) = if_false {
self.execute_step(false_step, context, step_results).await
} else {
Ok(serde_json::json!({ "skipped": true }))
}
}
}
})
}
fn calculate_execution_order(&self, steps: &[WorkflowStep]) -> Result<Vec<String>> {
let mut order = Vec::new();
let mut visited = std::collections::HashSet::new();
let mut temp_visited = std::collections::HashSet::new();
for step in steps {
if !visited.contains(&step.id) {
self.visit_step(&step.id, steps, &mut visited, &mut temp_visited, &mut order)?;
}
}
Ok(order)
}
#[allow(clippy::only_used_in_recursion)] fn visit_step(
&self,
step_id: &str,
steps: &[WorkflowStep],
visited: &mut std::collections::HashSet<String>,
temp_visited: &mut std::collections::HashSet<String>,
order: &mut Vec<String>,
) -> Result<()> {
if temp_visited.contains(step_id) {
return Err(anyhow::anyhow!("Circular dependency detected involving step {}", step_id));
}
if visited.contains(step_id) {
return Ok(());
}
temp_visited.insert(step_id.to_string());
let step = steps.iter()
.find(|s| s.id == step_id)
.ok_or_else(|| anyhow::anyhow!("Step {} not found", step_id))?;
for dep in &step.depends_on {
self.visit_step(dep, steps, visited, temp_visited, order)?;
}
temp_visited.remove(step_id);
visited.insert(step_id.to_string());
order.push(step_id.to_string());
Ok(())
}
fn substitute_variables(&self, text: &str, context: &HashMap<String, serde_json::Value>) -> Result<String> {
let mut result = text.to_string();
for (key, value) in context {
let placeholder = format!("{{{{{}}}}}", key);
let replacement = match value {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
result = result.replace(&placeholder, &replacement);
}
Ok(result)
}
fn substitute_variables_in_json(
&self,
json: &serde_json::Value,
context: &HashMap<String, serde_json::Value>,
) -> Result<serde_json::Value> {
match json {
serde_json::Value::String(s) => {
Ok(serde_json::Value::String(self.substitute_variables(s, context)?))
}
serde_json::Value::Object(obj) => {
let mut new_obj = serde_json::Map::new();
for (key, value) in obj {
new_obj.insert(key.clone(), self.substitute_variables_in_json(value, context)?);
}
Ok(serde_json::Value::Object(new_obj))
}
serde_json::Value::Array(arr) => {
let new_arr: Result<Vec<_>> = arr.iter()
.map(|item| self.substitute_variables_in_json(item, context))
.collect();
Ok(serde_json::Value::Array(new_arr?))
}
other => Ok(other.clone()),
}
}
fn execute_data_transform(
&self,
operation: &str,
expression: &str,
context: &HashMap<String, serde_json::Value>,
step_results: &HashMap<String, StepResult>,
) -> Result<serde_json::Value> {
match operation {
"extract" => {
if let Some(value) = context.get(expression) {
Ok(value.clone())
} else if expression.starts_with("step_") {
let step_id = expression.strip_prefix("step_").unwrap();
if let Some(step_result) = step_results.get(step_id) {
Ok(step_result.result.clone())
} else {
Err(anyhow::anyhow!("Step {} not found", step_id))
}
} else {
Err(anyhow::anyhow!("Variable {} not found", expression))
}
}
"filter" => {
Ok(serde_json::json!({ "filtered": true, "expression": expression }))
}
_ => Err(anyhow::anyhow!("Unsupported data transform operation: {}", operation)),
}
}
fn evaluate_condition(
&self,
condition: &str,
context: &HashMap<String, serde_json::Value>,
) -> Result<bool> {
if condition.contains("==") {
let parts: Vec<&str> = condition.split("==").collect();
if parts.len() == 2 {
let left = parts[0].trim();
let right = parts[1].trim();
let left_value = context.get(left);
let right_str = right.trim_matches('"');
match left_value {
Some(serde_json::Value::String(s)) => Ok(s == right_str),
Some(serde_json::Value::Bool(b)) => Ok(b.to_string() == right_str),
_ => Ok(false),
}
} else {
Ok(false)
}
} else {
Ok(true)
}
}
async fn call_llm(
&self,
prompt: &str,
model: &str,
max_tokens: u32,
temperature: f32,
) -> Result<String> {
Ok(format!("LLM response to: {} (model: {}, max_tokens: {}, temp: {})",
prompt, model, max_tokens, temperature))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::ToolRegistry;
fn create_test_tool_registry() -> ToolRegistry {
ToolRegistry::new()
}
#[test]
fn test_workflow_step_creation() {
let step = WorkflowStep {
id: "test".to_string(),
step_type: WorkflowStepType::LLMGeneration {
prompt: "hello".to_string(),
model: None,
max_tokens: None,
temperature: None,
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
assert_eq!(step.id, "test");
}
#[test]
fn test_workflow_creation() {
let workflow = Workflow {
id: "test-workflow".to_string(),
name: "Test".to_string(),
description: "Test workflow".to_string(),
steps: vec![],
inputs: HashMap::new(),
outputs: vec![],
};
assert_eq!(workflow.id, "test-workflow");
}
#[test]
fn test_tool_call_step_type() {
let step_type = WorkflowStepType::ToolCall {
tool_name: "calculator".to_string(),
arguments: serde_json::json!({"x": 5}),
};
match step_type {
WorkflowStepType::ToolCall { tool_name, .. } => {
assert_eq!(tool_name, "calculator");
}
_ => panic!("Expected ToolCall"),
}
}
#[tokio::test]
async fn test_workflow_engine_creation() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
assert!(std::ptr::addr_of!(engine.tool_registry) as usize > 0);
}
#[tokio::test]
async fn test_execute_workflow_empty_steps() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let workflow = Workflow {
id: "empty-workflow".to_string(),
name: "Empty Test".to_string(),
description: "Empty workflow for testing".to_string(),
steps: vec![],
inputs: HashMap::new(),
outputs: vec![],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(result.success);
assert_eq!(result.step_results.len(), 0);
assert_eq!(result.outputs.len(), 0);
}
#[tokio::test]
async fn test_execute_workflow_with_inputs() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let mut inputs = HashMap::new();
inputs.insert("input_value".to_string(), serde_json::json!("test_input"));
let workflow = Workflow {
id: "input-workflow".to_string(),
name: "Input Test".to_string(),
description: "Workflow with inputs".to_string(),
steps: vec![],
inputs,
outputs: vec![],
};
let mut context = HashMap::new();
context.insert("context_value".to_string(), serde_json::json!("test_context"));
let request = WorkflowRequest {
workflow,
context,
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(result.success);
assert_eq!(result.workflow_id, "input-workflow");
assert!(result.execution_time_ms == result.execution_time_ms); }
#[tokio::test]
async fn test_execute_workflow_with_llm_step() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let step = WorkflowStep {
id: "llm_step".to_string(),
step_type: WorkflowStepType::LLMGeneration {
prompt: "Generate text: {{input_value}}".to_string(),
model: Some("test-model".to_string()),
max_tokens: Some(100),
temperature: Some(0.5),
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let mut inputs = HashMap::new();
inputs.insert("input_value".to_string(), serde_json::json!("hello world"));
let workflow = Workflow {
id: "llm-workflow".to_string(),
name: "LLM Test".to_string(),
description: "Workflow with LLM step".to_string(),
steps: vec![step],
inputs,
outputs: vec!["llm_step".to_string()],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(result.success);
assert_eq!(result.step_results.len(), 1);
assert!(result.step_results.contains_key("llm_step"));
assert_eq!(result.outputs.len(), 1);
assert!(result.outputs.contains_key("llm_step"));
let step_result = result.step_results.get("llm_step").unwrap();
assert!(step_result.success);
assert!(step_result.execution_time_ms == step_result.execution_time_ms); }
#[tokio::test]
async fn test_execute_workflow_with_tool_step() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let step = WorkflowStep {
id: "tool_step".to_string(),
step_type: WorkflowStepType::ToolCall {
tool_name: "calculator".to_string(),
arguments: serde_json::json!({
"expression": "2 + 3"
}),
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let workflow = Workflow {
id: "tool-workflow".to_string(),
name: "Tool Test".to_string(),
description: "Workflow with tool step".to_string(),
steps: vec![step],
inputs: HashMap::new(),
outputs: vec!["tool_step".to_string()],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(result.success);
assert_eq!(result.step_results.len(), 1);
let step_result = result.step_results.get("tool_step").unwrap();
assert!(step_result.success);
assert_eq!(step_result.result, serde_json::json!(5.0));
}
#[tokio::test]
async fn test_execute_workflow_with_failed_tool_step() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let step = WorkflowStep {
id: "failed_tool_step".to_string(),
step_type: WorkflowStepType::ToolCall {
tool_name: "nonexistent_tool".to_string(),
arguments: serde_json::json!({}),
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let workflow = Workflow {
id: "failed-tool-workflow".to_string(),
name: "Failed Tool Test".to_string(),
description: "Workflow with failing tool step".to_string(),
steps: vec![step],
inputs: HashMap::new(),
outputs: vec!["failed_tool_step".to_string()],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(!result.success);
assert!(result.error.is_some());
assert_eq!(result.error.unwrap(), "One or more steps failed");
let step_result = result.step_results.get("failed_tool_step").unwrap();
assert!(!step_result.success);
assert!(step_result.error.is_some());
}
#[tokio::test]
async fn test_execute_workflow_with_data_transform_step() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let step = WorkflowStep {
id: "transform_step".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input_data".to_string(),
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let mut inputs = HashMap::new();
inputs.insert("input_data".to_string(), serde_json::json!("extracted_value"));
let workflow = Workflow {
id: "transform-workflow".to_string(),
name: "Transform Test".to_string(),
description: "Workflow with data transform step".to_string(),
steps: vec![step],
inputs,
outputs: vec!["transform_step".to_string()],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(result.success);
let step_result = result.step_results.get("transform_step").unwrap();
assert!(step_result.success);
assert_eq!(step_result.result, serde_json::json!("extracted_value"));
}
#[tokio::test]
async fn test_execute_workflow_with_conditional_step_true() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let if_true_step = WorkflowStep {
id: "true_branch".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "true_value".to_string(),
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let step = WorkflowStep {
id: "conditional_step".to_string(),
step_type: WorkflowStepType::Conditional {
condition: "test_condition == \"true\"".to_string(),
if_true: Box::new(if_true_step),
if_false: None,
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let mut inputs = HashMap::new();
inputs.insert("test_condition".to_string(), serde_json::json!("true"));
inputs.insert("true_value".to_string(), serde_json::json!("condition_met"));
let workflow = Workflow {
id: "conditional-workflow".to_string(),
name: "Conditional Test".to_string(),
description: "Workflow with conditional step".to_string(),
steps: vec![step],
inputs,
outputs: vec!["conditional_step".to_string()],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(result.success);
let step_result = result.step_results.get("conditional_step").unwrap();
assert!(step_result.success);
assert_eq!(step_result.result, serde_json::json!("condition_met"));
}
#[tokio::test]
async fn test_execute_workflow_with_conditional_step_false() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let if_true_step = WorkflowStep {
id: "true_branch".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "true_value".to_string(),
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let step = WorkflowStep {
id: "conditional_step".to_string(),
step_type: WorkflowStepType::Conditional {
condition: "test_condition == \"true\"".to_string(),
if_true: Box::new(if_true_step),
if_false: None,
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let mut inputs = HashMap::new();
inputs.insert("test_condition".to_string(), serde_json::json!("false"));
let workflow = Workflow {
id: "conditional-false-workflow".to_string(),
name: "Conditional False Test".to_string(),
description: "Workflow with conditional step (false branch)".to_string(),
steps: vec![step],
inputs,
outputs: vec!["conditional_step".to_string()],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(result.success);
let step_result = result.step_results.get("conditional_step").unwrap();
assert!(step_result.success);
assert_eq!(step_result.result, serde_json::json!({ "skipped": true }));
}
#[tokio::test]
async fn test_execute_workflow_with_dependencies() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let step1 = WorkflowStep {
id: "step1".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input_value".to_string(),
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let step2 = WorkflowStep {
id: "step2".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "step_step1".to_string(),
},
depends_on: vec!["step1".to_string()],
parameters: serde_json::json!({}),
};
let mut inputs = HashMap::new();
inputs.insert("input_value".to_string(), serde_json::json!("first_step_output"));
let workflow = Workflow {
id: "dependency-workflow".to_string(),
name: "Dependency Test".to_string(),
description: "Workflow with step dependencies".to_string(),
steps: vec![step2, step1], inputs,
outputs: vec!["step2".to_string()],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(result.success);
assert_eq!(result.step_results.len(), 2);
let step1_result = result.step_results.get("step1").unwrap();
assert!(step1_result.success);
assert_eq!(step1_result.result, serde_json::json!("first_step_output"));
let step2_result = result.step_results.get("step2").unwrap();
assert!(step2_result.success);
assert_eq!(step2_result.result, serde_json::json!("first_step_output"));
}
#[tokio::test]
async fn test_execute_workflow_circular_dependency() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let step1 = WorkflowStep {
id: "step1".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input_value".to_string(),
},
depends_on: vec!["step2".to_string()],
parameters: serde_json::json!({}),
};
let step2 = WorkflowStep {
id: "step2".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input_value".to_string(),
},
depends_on: vec!["step1".to_string()],
parameters: serde_json::json!({}),
};
let workflow = Workflow {
id: "circular-dependency-workflow".to_string(),
name: "Circular Dependency Test".to_string(),
description: "Workflow with circular dependencies".to_string(),
steps: vec![step1, step2],
inputs: HashMap::new(),
outputs: vec![],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Circular dependency"));
}
#[tokio::test]
async fn test_execute_workflow_missing_step_dependency() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let step1 = WorkflowStep {
id: "step1".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input_value".to_string(),
},
depends_on: vec!["nonexistent_step".to_string()],
parameters: serde_json::json!({}),
};
let workflow = Workflow {
id: "missing-dependency-workflow".to_string(),
name: "Missing Dependency Test".to_string(),
description: "Workflow with missing dependency".to_string(),
steps: vec![step1],
inputs: HashMap::new(),
outputs: vec![],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[test]
fn test_substitute_variables() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let mut context = HashMap::new();
context.insert("name".to_string(), serde_json::json!("World"));
context.insert("number".to_string(), serde_json::json!(42));
context.insert("flag".to_string(), serde_json::json!(true));
let template = "Hello {{name}}! The number is {{number}} and flag is {{flag}}.";
let result = engine.substitute_variables(template, &context).unwrap();
assert_eq!(result, "Hello World! The number is 42 and flag is true.");
}
#[test]
fn test_substitute_variables_in_json_string() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let mut context = HashMap::new();
context.insert("value".to_string(), serde_json::json!("test_value"));
let json = serde_json::json!("Hello {{value}}!");
let result = engine.substitute_variables_in_json(&json, &context).unwrap();
assert_eq!(result, serde_json::json!("Hello test_value!"));
}
#[test]
fn test_substitute_variables_in_json_object() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let mut context = HashMap::new();
context.insert("name".to_string(), serde_json::json!("test"));
context.insert("value".to_string(), serde_json::json!(42));
let json = serde_json::json!({
"greeting": "Hello {{name}}!",
"data": {
"number": "{{value}}"
}
});
let result = engine.substitute_variables_in_json(&json, &context).unwrap();
let expected = serde_json::json!({
"greeting": "Hello test!",
"data": {
"number": "42"
}
});
assert_eq!(result, expected);
}
#[test]
fn test_substitute_variables_in_json_array() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let mut context = HashMap::new();
context.insert("item1".to_string(), serde_json::json!("first"));
context.insert("item2".to_string(), serde_json::json!("second"));
let json = serde_json::json!(["{{item1}}", "{{item2}}", "static"]);
let result = engine.substitute_variables_in_json(&json, &context).unwrap();
assert_eq!(result, serde_json::json!(["first", "second", "static"]));
}
#[test]
fn test_execute_data_transform_extract_from_context() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let mut context = HashMap::new();
context.insert("test_data".to_string(), serde_json::json!("extracted"));
let step_results = HashMap::new();
let result = engine.execute_data_transform("extract", "test_data", &context, &step_results).unwrap();
assert_eq!(result, serde_json::json!("extracted"));
}
#[test]
fn test_execute_data_transform_extract_from_step() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let context = HashMap::new();
let mut step_results = HashMap::new();
step_results.insert("previous_step".to_string(), StepResult {
step_id: "previous_step".to_string(),
success: true,
result: serde_json::json!("step_output"),
error: None,
execution_time_ms: 100,
});
let result = engine.execute_data_transform("extract", "step_previous_step", &context, &step_results).unwrap();
assert_eq!(result, serde_json::json!("step_output"));
}
#[test]
fn test_execute_data_transform_extract_missing_variable() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let context = HashMap::new();
let step_results = HashMap::new();
let result = engine.execute_data_transform("extract", "nonexistent", &context, &step_results);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[test]
fn test_execute_data_transform_filter_operation() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let context = HashMap::new();
let step_results = HashMap::new();
let result = engine.execute_data_transform("filter", "test_expression", &context, &step_results).unwrap();
assert_eq!(result, serde_json::json!({ "filtered": true, "expression": "test_expression" }));
}
#[test]
fn test_execute_data_transform_unsupported_operation() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let context = HashMap::new();
let step_results = HashMap::new();
let result = engine.execute_data_transform("unsupported", "expression", &context, &step_results);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Unsupported data transform operation"));
}
#[test]
fn test_evaluate_condition_string_equality_true() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let mut context = HashMap::new();
context.insert("status".to_string(), serde_json::json!("active"));
let result = engine.evaluate_condition("status == \"active\"", &context).unwrap();
assert!(result);
}
#[test]
fn test_evaluate_condition_string_equality_false() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let mut context = HashMap::new();
context.insert("status".to_string(), serde_json::json!("inactive"));
let result = engine.evaluate_condition("status == \"active\"", &context).unwrap();
assert!(!result);
}
#[test]
fn test_evaluate_condition_bool_equality() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let mut context = HashMap::new();
context.insert("flag".to_string(), serde_json::json!(true));
let result = engine.evaluate_condition("flag == \"true\"", &context).unwrap();
assert!(result);
}
#[test]
fn test_evaluate_condition_missing_variable() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let context = HashMap::new();
let result = engine.evaluate_condition("nonexistent == \"value\"", &context).unwrap();
assert!(!result);
}
#[test]
fn test_evaluate_condition_invalid_format() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let context = HashMap::new();
let result = engine.evaluate_condition("invalid condition format", &context).unwrap();
assert!(result); }
#[test]
fn test_evaluate_condition_malformed_equality() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let context = HashMap::new();
let result = engine.evaluate_condition("a == b == c", &context).unwrap();
assert!(!result); }
#[tokio::test]
async fn test_call_llm() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let result = engine.call_llm("test prompt", "test-model", 100, 0.5).await.unwrap();
assert!(result.contains("LLM response to: test prompt"));
assert!(result.contains("model: test-model"));
assert!(result.contains("max_tokens: 100"));
assert!(result.contains("temp: 0.5"));
}
#[test]
fn test_calculate_execution_order_simple() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let steps = vec![
WorkflowStep {
id: "step1".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input".to_string(),
},
depends_on: vec![],
parameters: serde_json::json!({}),
},
WorkflowStep {
id: "step2".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input".to_string(),
},
depends_on: vec!["step1".to_string()],
parameters: serde_json::json!({}),
},
];
let order = engine.calculate_execution_order(&steps).unwrap();
assert_eq!(order, vec!["step1".to_string(), "step2".to_string()]);
}
#[test]
fn test_calculate_execution_order_circular_dependency() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let steps = vec![
WorkflowStep {
id: "step1".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input".to_string(),
},
depends_on: vec!["step2".to_string()],
parameters: serde_json::json!({}),
},
WorkflowStep {
id: "step2".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input".to_string(),
},
depends_on: vec!["step1".to_string()],
parameters: serde_json::json!({}),
},
];
let result = engine.calculate_execution_order(&steps);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Circular dependency"));
}
#[test]
fn test_visit_step_missing_dependency() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let steps = vec![
WorkflowStep {
id: "step1".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input".to_string(),
},
depends_on: vec!["nonexistent".to_string()],
parameters: serde_json::json!({}),
},
];
let result = engine.calculate_execution_order(&steps);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[tokio::test]
async fn test_execute_workflow_with_data_transform_missing_variable() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let step = WorkflowStep {
id: "transform_step_missing_var".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "nonexistent_input".to_string(), },
depends_on: vec![],
parameters: serde_json::json!({}),
};
let workflow = Workflow {
id: "data-transform-missing-var-workflow".to_string(),
name: "Data Transform Missing Variable Test".to_string(),
description: "Workflow with data transform step missing variable".to_string(),
steps: vec![step],
inputs: HashMap::new(),
outputs: vec!["transform_step_missing_var".to_string()],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(!result.success); assert!(result.error.is_some());
let step_result = result.step_results.get("transform_step_missing_var").unwrap();
assert!(!step_result.success);
assert!(step_result.error.is_some());
}
#[tokio::test]
async fn test_execute_workflow_with_conditional_step_with_false_branch() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let if_true_step = WorkflowStep {
id: "true_branch".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "true_value".to_string(),
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let if_false_step = WorkflowStep {
id: "false_branch".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "false_value".to_string(),
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let step = WorkflowStep {
id: "conditional_step".to_string(),
step_type: WorkflowStepType::Conditional {
condition: "test_condition == \"should_be_false\"".to_string(), if_true: Box::new(if_true_step),
if_false: Some(Box::new(if_false_step)),
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let mut inputs = HashMap::new();
inputs.insert("test_condition".to_string(), serde_json::json!("trigger_false")); inputs.insert("false_value".to_string(), serde_json::json!("false_branch_executed"));
let workflow = Workflow {
id: "conditional-false-branch-workflow".to_string(),
name: "Conditional False Branch Test".to_string(),
description: "Workflow with conditional step (false branch executed)".to_string(),
steps: vec![step],
inputs,
outputs: vec!["conditional_step".to_string()],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(result.success);
let step_result = result.step_results.get("conditional_step").unwrap();
assert!(step_result.success);
assert_eq!(step_result.result, serde_json::json!("false_branch_executed"));
}
#[test]
fn test_substitute_variables_in_json_non_string_values() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let context = HashMap::new();
let json = serde_json::json!({
"number": 42,
"boolean": true,
"null": null
});
let result = engine.substitute_variables_in_json(&json, &context).unwrap();
assert_eq!(result, json);
}
#[tokio::test]
async fn test_execute_workflow_with_tool_step_variable_substitution() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let step = WorkflowStep {
id: "tool_step_with_vars".to_string(),
step_type: WorkflowStepType::ToolCall {
tool_name: "calculator".to_string(),
arguments: serde_json::json!({
"expression": "{{num1}} + {{num2}}"
}),
},
depends_on: vec![],
parameters: serde_json::json!({}),
};
let mut inputs = HashMap::new();
inputs.insert("num1".to_string(), serde_json::json!("5"));
inputs.insert("num2".to_string(), serde_json::json!("7"));
let workflow = Workflow {
id: "tool-var-substitution-workflow".to_string(),
name: "Tool Variable Substitution Test".to_string(),
description: "Workflow with tool step using variable substitution".to_string(),
steps: vec![step],
inputs,
outputs: vec!["tool_step_with_vars".to_string()],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(result.success);
let step_result = result.step_results.get("tool_step_with_vars").unwrap();
assert!(step_result.success);
assert_eq!(step_result.result, serde_json::json!(12.0));
}
#[tokio::test]
async fn test_execute_workflow_with_llm_step_variable_substitution() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let step = WorkflowStep {
id: "llm_step_with_vars".to_string(),
step_type: WorkflowStepType::LLMGeneration {
prompt: "Process this data: {{input_data}}".to_string(),
model: None, max_tokens: None, temperature: None, },
depends_on: vec![],
parameters: serde_json::json!({}),
};
let mut inputs = HashMap::new();
inputs.insert("input_data".to_string(), serde_json::json!("test data"));
let workflow = Workflow {
id: "llm-var-substitution-workflow".to_string(),
name: "LLM Variable Substitution Test".to_string(),
description: "Workflow with LLM step using variable substitution".to_string(),
steps: vec![step],
inputs,
outputs: vec!["llm_step_with_vars".to_string()],
};
let request = WorkflowRequest {
workflow,
context: HashMap::new(),
};
let result = engine.execute_workflow(request).await.unwrap();
assert!(result.success);
let step_result = result.step_results.get("llm_step_with_vars").unwrap();
assert!(step_result.success);
let result_text = step_result.result.get("text").unwrap().as_str().unwrap();
assert!(result_text.contains("Process this data: test data"));
assert!(result_text.contains("model: default"));
assert!(result_text.contains("max_tokens: 512"));
assert!(result_text.contains("temp: 0.7"));
}
#[test]
fn test_execute_data_transform_extract_from_missing_step() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let context = HashMap::new();
let step_results = HashMap::new();
let result = engine.execute_data_transform("extract", "step_nonexistent_step", &context, &step_results);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[test]
fn test_calculate_execution_order_complex_dependencies() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let steps = vec![
WorkflowStep {
id: "step_c".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input".to_string(),
},
depends_on: vec!["step_a".to_string(), "step_b".to_string()],
parameters: serde_json::json!({}),
},
WorkflowStep {
id: "step_a".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input".to_string(),
},
depends_on: vec![],
parameters: serde_json::json!({}),
},
WorkflowStep {
id: "step_b".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input".to_string(),
},
depends_on: vec!["step_a".to_string()],
parameters: serde_json::json!({}),
},
];
let order = engine.calculate_execution_order(&steps).unwrap();
let a_pos = order.iter().position(|x| x == "step_a").unwrap();
let b_pos = order.iter().position(|x| x == "step_b").unwrap();
let c_pos = order.iter().position(|x| x == "step_c").unwrap();
assert!(a_pos < b_pos);
assert!(b_pos < c_pos);
assert!(a_pos < c_pos);
}
#[test]
fn test_visit_step_already_visited() {
let tool_registry = create_test_tool_registry();
let engine = WorkflowEngine::new(tool_registry);
let steps = vec![
WorkflowStep {
id: "step1".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input".to_string(),
},
depends_on: vec![],
parameters: serde_json::json!({}),
},
WorkflowStep {
id: "step2".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input".to_string(),
},
depends_on: vec!["step1".to_string()],
parameters: serde_json::json!({}),
},
WorkflowStep {
id: "step3".to_string(),
step_type: WorkflowStepType::DataTransform {
operation: "extract".to_string(),
expression: "input".to_string(),
},
depends_on: vec!["step1".to_string()], parameters: serde_json::json!({}),
},
];
let order = engine.calculate_execution_order(&steps).unwrap();
assert_eq!(order.len(), 3);
let step1_count = order.iter().filter(|&x| x == "step1").count();
assert_eq!(step1_count, 1);
let step1_pos = order.iter().position(|x| x == "step1").unwrap();
assert_eq!(step1_pos, 0);
}
}