use crate::agents::Agent;
use crate::api::handlers::user_agents::resolve_agent;
use crate::types::{AgentContext, AgentType, AppError, Result};
use crate::utils::toml_config::{AgentConfig, WorkflowConfig};
use crate::AppState;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct WorkflowOutput {
pub final_response: String,
pub steps_executed: usize,
pub agents_used: Vec<String>,
pub reasoning_path: Vec<WorkflowStep>,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct WorkflowStep {
pub agent_name: String,
pub input: String,
pub output: String,
pub timestamp: i64,
pub duration_ms: u64,
}
const VALID_AGENTS: &[&str] = &[
"product",
"invoice",
"sales",
"finance",
"hr",
"orchestrator",
"research",
"router",
];
pub struct WorkflowEngine {
state: AppState,
}
impl WorkflowEngine {
pub fn new(state: AppState) -> Self {
Self { state }
}
fn parse_routing_decision(output: &str) -> Option<String> {
let trimmed = output.trim().to_lowercase();
if VALID_AGENTS.contains(&trimmed.as_str()) {
return Some(trimmed);
}
for word in trimmed.split(|c: char| c.is_whitespace() || c == ':' || c == ',' || c == '.') {
let word = word.trim();
if VALID_AGENTS.contains(&word) {
return Some(word.to_string());
}
}
for agent in VALID_AGENTS {
if trimmed.contains(agent) {
return Some(agent.to_string());
}
}
None
}
pub async fn execute_workflow(
&self,
workflow_name: &str,
user_input: &str,
context: &AgentContext,
) -> Result<WorkflowOutput> {
let config = self.state.config_manager.config();
let workflow = config.get_workflow(workflow_name).ok_or_else(|| {
AppError::Configuration(format!(
"Workflow '{}' not found in configuration",
workflow_name
))
})?;
let mut steps = Vec::new();
let mut agents_used = Vec::new();
let current_input = user_input.to_string();
let mut current_agent_name = workflow.entry_agent.clone();
let mut depth = 0;
while depth < workflow.max_depth {
let step_start = std::time::Instant::now();
let timestamp = Utc::now().timestamp();
let (user_agent, _source) = match resolve_agent(
&self.state,
&context.user_id,
current_agent_name.clone(),
)
.await
{
Ok(res) => res,
Err(e) => {
if let Some(ref fallback) = workflow.fallback_agent {
tracing::warn!(
"Failed to resolve agent '{}', using fallback '{}'",
current_agent_name,
fallback
);
current_agent_name = fallback.clone();
resolve_agent(&self.state, &context.user_id, fallback.clone()).await?
} else {
return Err(e);
}
}
};
let agent_config = AgentConfig {
model: user_agent.model.clone(),
system_prompt: user_agent.system_prompt.clone(),
tools: user_agent.tools_vec(),
max_tool_iterations: user_agent.max_tool_iterations as usize,
parallel_tools: user_agent.parallel_tools,
extra: std::collections::HashMap::new(),
};
let agent = self
.state
.agent_registry
.create_agent_from_config(¤t_agent_name, &agent_config)
.await?;
let agent_resp = agent.execute(¤t_input, context).await?;
let output = agent_resp.content;
let duration_ms = step_start.elapsed().as_millis() as u64;
steps.push(WorkflowStep {
agent_name: current_agent_name.clone(),
input: current_input.clone(),
output: output.clone(),
timestamp,
duration_ms,
});
if !agents_used.contains(¤t_agent_name) {
agents_used.push(current_agent_name.clone());
}
if agent.agent_type() == AgentType::Router {
let next_agent = Self::parse_routing_decision(&output);
if let Some(ref agent_name) = next_agent {
if resolve_agent(&self.state, &context.user_id, agent_name.clone())
.await
.is_ok()
{
current_agent_name = agent_name.clone();
depth += 1;
continue;
}
}
if let Some(ref fallback) = workflow.fallback_agent {
tracing::warn!(
"Routed agent '{:?}' not found or invalid, using fallback '{}'",
next_agent,
fallback
);
current_agent_name = fallback.clone();
depth += 1;
continue;
} else {
break;
}
}
break;
}
let final_response = steps
.last()
.map(|s| s.output.clone())
.unwrap_or_else(|| "No response generated".to_string());
Ok(WorkflowOutput {
final_response,
steps_executed: steps.len(),
agents_used,
reasoning_path: steps,
})
}
pub fn available_workflows(&self) -> Vec<String> {
self.state
.config_manager
.config()
.workflows
.keys()
.cloned()
.collect()
}
pub fn has_workflow(&self, name: &str) -> bool {
self.state
.config_manager
.config()
.workflows
.contains_key(name)
}
pub fn get_workflow_config(&self, name: &str) -> Option<WorkflowConfig> {
self.state
.config_manager
.config()
.get_workflow(name)
.cloned()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::ProviderRegistry;
use crate::tools::registry::ToolRegistry;
use crate::utils::toml_config::{
AgentConfig, AresConfig, AuthConfig, DatabaseConfig, ModelConfig, ProviderConfig,
RagConfig, ServerConfig,
};
use crate::{AgentRegistry, AresConfigManager, DynamicConfigManager};
use std::collections::HashMap;
use std::sync::Arc;
fn create_test_config() -> AresConfig {
let mut providers = HashMap::new();
providers.insert(
"ollama-local".to_string(),
ProviderConfig::Ollama {
base_url: "http://localhost:11434".to_string(),
default_model: "ministral-3:3b".to_string(),
},
);
let mut models = HashMap::new();
models.insert(
"default".to_string(),
ModelConfig {
provider: "ollama-local".to_string(),
model: "ministral-3:3b".to_string(),
temperature: 0.7,
max_tokens: 512,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
},
);
let mut agents = HashMap::new();
agents.insert(
"router".to_string(),
AgentConfig {
model: "default".to_string(),
system_prompt: Some("Route queries to the appropriate agent.".to_string()),
tools: vec![],
max_tool_iterations: 1,
parallel_tools: false,
extra: HashMap::new(),
},
);
agents.insert(
"orchestrator".to_string(),
AgentConfig {
model: "default".to_string(),
system_prompt: Some("Handle complex queries.".to_string()),
tools: vec![],
max_tool_iterations: 10,
parallel_tools: false,
extra: HashMap::new(),
},
);
agents.insert(
"product".to_string(),
AgentConfig {
model: "default".to_string(),
system_prompt: Some("Handle product queries.".to_string()),
tools: vec![],
max_tool_iterations: 5,
parallel_tools: false,
extra: HashMap::new(),
},
);
let mut workflows = HashMap::new();
workflows.insert(
"default".to_string(),
WorkflowConfig {
entry_agent: "router".to_string(),
fallback_agent: Some("orchestrator".to_string()),
max_depth: 3,
max_iterations: 5,
parallel_subagents: false,
},
);
workflows.insert(
"research".to_string(),
WorkflowConfig {
entry_agent: "orchestrator".to_string(),
fallback_agent: None,
max_depth: 3,
max_iterations: 10,
parallel_subagents: true,
},
);
AresConfig {
server: ServerConfig::default(),
auth: AuthConfig::default(),
database: DatabaseConfig::default(),
config: crate::utils::toml_config::DynamicConfigPaths::default(),
providers,
models,
tools: HashMap::new(),
agents,
workflows,
rag: RagConfig::default(),
#[cfg(feature = "skills")]
skills: None,
}
}
#[tokio::test]
async fn test_workflow_engine_creation() {
let config = Arc::new(create_test_config());
let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
let tool_registry = Arc::new(ToolRegistry::new());
let agent_registry = Arc::new(AgentRegistry::from_config(
&config,
provider_registry.clone(),
tool_registry.clone(),
));
let state = AppState {
config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
dynamic_config: Arc::new(
DynamicConfigManager::new(
std::path::PathBuf::from("config/agents"),
std::path::PathBuf::from("config/models"),
std::path::PathBuf::from("config/tools"),
std::path::PathBuf::from("config/workflows"),
std::path::PathBuf::from("config/mcps"),
false,
)
.unwrap(),
),
db: Arc::new(crate::db::PostgresClient::new_test()),
tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
crate::db::PostgresClient::new_test(),
))),
llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
provider_registry.clone(),
"default",
)),
provider_registry,
agent_registry,
tool_registry,
auth_service: Arc::new(crate::auth::jwt::AuthService::new(
"secret".to_string(),
900,
604800,
)),
mcp_registry: None,
deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
emergency_stop: Arc::new(std::sync::atomic::AtomicBool::new(false)),
context_provider: Arc::new(crate::agents::NoOpContextProvider),
};
let engine = WorkflowEngine::new(state);
assert!(engine.has_workflow("default"));
assert!(engine.has_workflow("research"));
assert!(!engine.has_workflow("nonexistent"));
}
#[tokio::test]
async fn test_available_workflows() {
let config = Arc::new(create_test_config());
let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
let tool_registry = Arc::new(ToolRegistry::new());
let agent_registry = Arc::new(AgentRegistry::from_config(
&config,
provider_registry.clone(),
tool_registry.clone(),
));
let state = AppState {
config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
dynamic_config: Arc::new(
DynamicConfigManager::new(
std::path::PathBuf::from("config/agents"),
std::path::PathBuf::from("config/models"),
std::path::PathBuf::from("config/tools"),
std::path::PathBuf::from("config/workflows"),
std::path::PathBuf::from("config/mcps"),
false,
)
.unwrap(),
),
db: Arc::new(crate::db::PostgresClient::new_test()),
tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
crate::db::PostgresClient::new_test(),
))),
llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
provider_registry.clone(),
"default",
)),
provider_registry,
agent_registry,
tool_registry,
auth_service: Arc::new(crate::auth::jwt::AuthService::new(
"secret".to_string(),
900,
604800,
)),
mcp_registry: None,
deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
emergency_stop: Arc::new(std::sync::atomic::AtomicBool::new(false)),
context_provider: Arc::new(crate::agents::NoOpContextProvider),
};
let engine = WorkflowEngine::new(state);
let workflows = engine.available_workflows();
assert!(workflows.contains(&"default".to_string()));
assert!(workflows.contains(&"research".to_string()));
}
#[tokio::test]
async fn test_get_workflow_config() {
let config = Arc::new(create_test_config());
let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
let tool_registry = Arc::new(ToolRegistry::new());
let agent_registry = Arc::new(AgentRegistry::from_config(
&config,
provider_registry.clone(),
tool_registry.clone(),
));
let state = AppState {
config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
dynamic_config: Arc::new(
DynamicConfigManager::new(
std::path::PathBuf::from("config/agents"),
std::path::PathBuf::from("config/models"),
std::path::PathBuf::from("config/tools"),
std::path::PathBuf::from("config/workflows"),
std::path::PathBuf::from("config/mcps"),
false,
)
.unwrap(),
),
db: Arc::new(crate::db::PostgresClient::new_test()),
tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
crate::db::PostgresClient::new_test(),
))),
llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
provider_registry.clone(),
"default",
)),
provider_registry,
agent_registry,
tool_registry,
auth_service: Arc::new(crate::auth::jwt::AuthService::new(
"secret".to_string(),
900,
604800,
)),
mcp_registry: None,
deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
emergency_stop: Arc::new(std::sync::atomic::AtomicBool::new(false)),
context_provider: Arc::new(crate::agents::NoOpContextProvider),
};
let engine = WorkflowEngine::new(state);
let default_config = engine.get_workflow_config("default").unwrap();
assert_eq!(default_config.entry_agent, "router");
assert_eq!(
default_config.fallback_agent,
Some("orchestrator".to_string())
);
assert_eq!(default_config.max_depth, 3);
let research_config = engine.get_workflow_config("research").unwrap();
assert_eq!(research_config.entry_agent, "orchestrator");
assert!(research_config.parallel_subagents);
}
#[test]
fn test_workflow_output_serialization() {
let output = WorkflowOutput {
final_response: "Test response".to_string(),
steps_executed: 2,
agents_used: vec!["router".to_string(), "product".to_string()],
reasoning_path: vec![
WorkflowStep {
agent_name: "router".to_string(),
input: "What products do we have?".to_string(),
output: "product".to_string(),
timestamp: 1702500000,
duration_ms: 150,
},
WorkflowStep {
agent_name: "product".to_string(),
input: "What products do we have?".to_string(),
output: "Test response".to_string(),
timestamp: 1702500001,
duration_ms: 500,
},
],
};
let json = serde_json::to_string(&output).unwrap();
assert!(json.contains("Test response"));
assert!(json.contains("router"));
assert!(json.contains("product"));
let deserialized: WorkflowOutput = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.steps_executed, 2);
}
}