Skip to main content

ares/workflows/
engine.rs

1//! Workflow Engine
2//!
3//! Executes declarative workflows by orchestrating agent execution based on
4//! TOML configuration.
5
6use crate::agents::Agent;
7use crate::api::handlers::user_agents::resolve_agent;
8use crate::types::{AgentContext, AgentType, AppError, Result};
9use crate::utils::toml_config::{AgentConfig, WorkflowConfig};
10use crate::AppState;
11use chrono::Utc;
12use serde::{Deserialize, Serialize};
13use utoipa::ToSchema;
14
15/// Output from a workflow execution
16#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
17pub struct WorkflowOutput {
18    /// The final response from the workflow
19    pub final_response: String,
20    /// Number of steps executed
21    pub steps_executed: usize,
22    /// List of agent names that were used
23    pub agents_used: Vec<String>,
24    /// Detailed reasoning path showing each step
25    pub reasoning_path: Vec<WorkflowStep>,
26}
27
28/// A single step in the workflow execution
29#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
30pub struct WorkflowStep {
31    /// The agent that executed this step
32    pub agent_name: String,
33    /// The input provided to the agent
34    pub input: String,
35    /// The output from the agent
36    pub output: String,
37    /// Unix timestamp when this step was executed
38    pub timestamp: i64,
39    /// Duration of this step in milliseconds
40    pub duration_ms: u64,
41}
42
43/// Valid agent names for routing
44const VALID_AGENTS: &[&str] = &[
45    "product",
46    "invoice",
47    "sales",
48    "finance",
49    "hr",
50    "orchestrator",
51    "research",
52    "router",
53];
54
55/// Workflow engine that orchestrates agent execution
56pub struct WorkflowEngine {
57    /// Application state for resolving agents
58    state: AppState,
59}
60
61impl WorkflowEngine {
62    /// Create a new workflow engine
63    pub fn new(state: AppState) -> Self {
64        Self { state }
65    }
66
67    /// Parse routing decision from router output
68    ///
69    /// This handles various output formats:
70    /// - Clean output: "product"
71    /// - With whitespace: "  product  "
72    /// - With extra text: "I would route this to product"
73    /// - Agent suffix: "product agent"
74    fn parse_routing_decision(output: &str) -> Option<String> {
75        let trimmed = output.trim().to_lowercase();
76
77        // First, try exact match
78        if VALID_AGENTS.contains(&trimmed.as_str()) {
79            return Some(trimmed);
80        }
81
82        // Try to extract valid agent name from output
83        // Split by common delimiters and check each word
84        for word in trimmed.split(|c: char| c.is_whitespace() || c == ':' || c == ',' || c == '.') {
85            let word = word.trim();
86            if VALID_AGENTS.contains(&word) {
87                return Some(word.to_string());
88            }
89        }
90
91        // Check if any valid agent name is contained in the output
92        for agent in VALID_AGENTS {
93            if trimmed.contains(agent) {
94                return Some(agent.to_string());
95            }
96        }
97
98        None
99    }
100
101    /// Execute a workflow by name
102    ///
103    /// # Arguments
104    ///
105    /// * `workflow_name` - The name of the workflow to execute (e.g., "default", "research")
106    /// * `user_input` - The user's query or input
107    /// * `context` - The agent context with user info and conversation history
108    ///
109    /// # Returns
110    ///
111    /// A `WorkflowOutput` containing the final response and execution details.
112    pub async fn execute_workflow(
113        &self,
114        workflow_name: &str,
115        user_input: &str,
116        context: &AgentContext,
117    ) -> Result<WorkflowOutput> {
118        // Get workflow configuration
119        let config = self.state.config_manager.config();
120        let workflow = config.get_workflow(workflow_name).ok_or_else(|| {
121            AppError::Configuration(format!(
122                "Workflow '{}' not found in configuration",
123                workflow_name
124            ))
125        })?;
126
127        let mut steps = Vec::new();
128        let mut agents_used = Vec::new();
129        let current_input = user_input.to_string();
130        let mut current_agent_name = workflow.entry_agent.clone();
131        let mut depth = 0;
132
133        // Execute workflow with depth limiting
134        while depth < workflow.max_depth {
135            let step_start = std::time::Instant::now();
136            let timestamp = Utc::now().timestamp();
137
138            // Resolve agent using the 3-tier hierarchy
139            let (user_agent, _source) =
140                match resolve_agent(&self.state, &context.user_id, current_agent_name.clone()).await {
141                    Ok(res) => res,
142                    Err(e) => {
143                        // Try fallback agent if available
144                        if let Some(ref fallback) = workflow.fallback_agent {
145                            tracing::warn!(
146                                "Failed to resolve agent '{}', using fallback '{}'",
147                                current_agent_name,
148                                fallback
149                            );
150                            current_agent_name = fallback.clone();
151                            resolve_agent(&self.state, &context.user_id, fallback.clone()).await?
152                        } else {
153                            return Err(e);
154                        }
155                    }
156                };
157
158            // Convert UserAgent to AgentConfig
159            let agent_config = AgentConfig {
160                model: user_agent.model.clone(),
161                system_prompt: user_agent.system_prompt.clone(),
162                tools: user_agent.tools_vec(),
163                max_tool_iterations: user_agent.max_tool_iterations as usize,
164                parallel_tools: user_agent.parallel_tools,
165                extra: std::collections::HashMap::new(),
166            };
167
168            // Create the agent
169            let agent = self
170                .state
171                .agent_registry
172                .create_agent_from_config(&current_agent_name, &agent_config)
173                .await?;
174
175            // Execute the agent
176            let output = agent.execute(&current_input, context).await?;
177            let duration_ms = step_start.elapsed().as_millis() as u64;
178
179            // Record this step
180            steps.push(WorkflowStep {
181                agent_name: current_agent_name.clone(),
182                input: current_input.clone(),
183                output: output.clone(),
184                timestamp,
185                duration_ms,
186            });
187
188            if !agents_used.contains(&current_agent_name) {
189                agents_used.push(current_agent_name.clone());
190            }
191
192            // Check if the agent is a router and needs to delegate
193            if agent.agent_type() == AgentType::Router {
194                // Router's output should be an agent name
195                // Use robust parsing to handle various output formats
196                let next_agent = Self::parse_routing_decision(&output);
197
198                if let Some(ref agent_name) = next_agent {
199                    // Validate the routed agent exists (check hierarchy)
200                    if resolve_agent(&self.state, &context.user_id, agent_name.clone())
201                        .await
202                        .is_ok()
203                    {
204                        current_agent_name = agent_name.clone();
205                        // Keep the original user input for the routed agent
206                        depth += 1;
207                        continue;
208                    }
209                }
210
211                // Agent not found or couldn't parse - try fallback
212                if let Some(ref fallback) = workflow.fallback_agent {
213                    // Use fallback if routed agent doesn't exist
214                    tracing::warn!(
215                        "Routed agent '{:?}' not found or invalid, using fallback '{}'",
216                        next_agent,
217                        fallback
218                    );
219                    current_agent_name = fallback.clone();
220                    depth += 1;
221                    continue;
222                } else {
223                    // No fallback, return the router's output as final
224                    break;
225                }
226            }
227
228            // Non-router agent - this is the final response
229            break;
230        }
231
232        // Build the final output
233        let final_response = steps
234            .last()
235            .map(|s| s.output.clone())
236            .unwrap_or_else(|| "No response generated".to_string());
237
238        Ok(WorkflowOutput {
239            final_response,
240            steps_executed: steps.len(),
241            agents_used,
242            reasoning_path: steps,
243        })
244    }
245
246    /// Get available workflow names
247    pub fn available_workflows(&self) -> Vec<String> {
248        self.state
249            .config_manager
250            .config()
251            .workflows
252            .keys()
253            .cloned()
254            .collect()
255    }
256
257    /// Check if a workflow exists
258    pub fn has_workflow(&self, name: &str) -> bool {
259        self.state
260            .config_manager
261            .config()
262            .workflows
263            .contains_key(name)
264    }
265
266    /// Get workflow configuration
267    pub fn get_workflow_config(&self, name: &str) -> Option<WorkflowConfig> {
268        self.state
269            .config_manager
270            .config()
271            .get_workflow(name)
272            .cloned()
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::llm::ProviderRegistry;
280    use crate::tools::registry::ToolRegistry;
281    use crate::utils::toml_config::{
282        AgentConfig, AresConfig, AuthConfig, DatabaseConfig, ModelConfig, ProviderConfig,
283        RagConfig, ServerConfig,
284    };
285    use crate::{AgentRegistry, AresConfigManager, DynamicConfigManager};
286    use std::collections::HashMap;
287    use std::sync::Arc;
288
289    fn create_test_config() -> AresConfig {
290        let mut providers = HashMap::new();
291        providers.insert(
292            "ollama-local".to_string(),
293            ProviderConfig::Ollama {
294                base_url: "http://localhost:11434".to_string(),
295                default_model: "ministral-3:3b".to_string(),
296            },
297        );
298
299        let mut models = HashMap::new();
300        models.insert(
301            "default".to_string(),
302            ModelConfig {
303                provider: "ollama-local".to_string(),
304                model: "ministral-3:3b".to_string(),
305                temperature: 0.7,
306                max_tokens: 512,
307                top_p: None,
308                frequency_penalty: None,
309                presence_penalty: None,
310            },
311        );
312
313        let mut agents = HashMap::new();
314        agents.insert(
315            "router".to_string(),
316            AgentConfig {
317                model: "default".to_string(),
318                system_prompt: Some("Route queries to the appropriate agent.".to_string()),
319                tools: vec![],
320                max_tool_iterations: 1,
321                parallel_tools: false,
322                extra: HashMap::new(),
323            },
324        );
325        agents.insert(
326            "orchestrator".to_string(),
327            AgentConfig {
328                model: "default".to_string(),
329                system_prompt: Some("Handle complex queries.".to_string()),
330                tools: vec![],
331                max_tool_iterations: 10,
332                parallel_tools: false,
333                extra: HashMap::new(),
334            },
335        );
336        agents.insert(
337            "product".to_string(),
338            AgentConfig {
339                model: "default".to_string(),
340                system_prompt: Some("Handle product queries.".to_string()),
341                tools: vec![],
342                max_tool_iterations: 5,
343                parallel_tools: false,
344                extra: HashMap::new(),
345            },
346        );
347
348        let mut workflows = HashMap::new();
349        workflows.insert(
350            "default".to_string(),
351            WorkflowConfig {
352                entry_agent: "router".to_string(),
353                fallback_agent: Some("orchestrator".to_string()),
354                max_depth: 3,
355                max_iterations: 5,
356                parallel_subagents: false,
357            },
358        );
359        workflows.insert(
360            "research".to_string(),
361            WorkflowConfig {
362                entry_agent: "orchestrator".to_string(),
363                fallback_agent: None,
364                max_depth: 3,
365                max_iterations: 10,
366                parallel_subagents: true,
367            },
368        );
369
370        AresConfig {
371            server: ServerConfig::default(),
372            auth: AuthConfig::default(),
373            database: DatabaseConfig::default(),
374            config: crate::utils::toml_config::DynamicConfigPaths::default(),
375            providers,
376            models,
377            tools: HashMap::new(),
378            agents,
379            workflows,
380            rag: RagConfig::default(),
381        }
382    }
383
384    #[test]
385    fn test_workflow_engine_creation() {
386        let config = Arc::new(create_test_config());
387        let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
388        let tool_registry = Arc::new(ToolRegistry::new());
389        let agent_registry = Arc::new(AgentRegistry::from_config(
390            &config,
391            provider_registry.clone(),
392            tool_registry.clone(),
393        ));
394
395        // Create a dummy AppState for testing
396        let state = AppState {
397            config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
398            dynamic_config: Arc::new(
399                DynamicConfigManager::new(
400                    std::path::PathBuf::from("config/agents"),
401                    std::path::PathBuf::from("config/models"),
402                    std::path::PathBuf::from("config/tools"),
403                    std::path::PathBuf::from("config/workflows"),
404                    std::path::PathBuf::from("config/mcps"),
405                    false,
406                )
407                .unwrap(),
408            ),
409            db: Arc::new(
410                futures::executor::block_on(crate::db::PostgresClient::new_memory()).unwrap(),
411            ),
412            tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
413                futures::executor::block_on(crate::db::PostgresClient::new_memory()).unwrap(),
414            ))),
415            llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
416                provider_registry.clone(),
417                "default",
418            )),
419            provider_registry,
420            agent_registry,
421            tool_registry,
422            auth_service: Arc::new(crate::auth::jwt::AuthService::new(
423                "secret".to_string(),
424                900,
425                604800,
426            )),
427            mcp_registry: None,
428            deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
429        };
430
431        let engine = WorkflowEngine::new(state);
432
433        assert!(engine.has_workflow("default"));
434        assert!(engine.has_workflow("research"));
435        assert!(!engine.has_workflow("nonexistent"));
436    }
437
438    #[test]
439    fn test_available_workflows() {
440        let config = Arc::new(create_test_config());
441        let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
442        let tool_registry = Arc::new(ToolRegistry::new());
443        let agent_registry = Arc::new(AgentRegistry::from_config(
444            &config,
445            provider_registry.clone(),
446            tool_registry.clone(),
447        ));
448
449        // Create a dummy AppState for testing
450        let state = AppState {
451            config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
452            dynamic_config: Arc::new(
453                DynamicConfigManager::new(
454                    std::path::PathBuf::from("config/agents"),
455                    std::path::PathBuf::from("config/models"),
456                    std::path::PathBuf::from("config/tools"),
457                    std::path::PathBuf::from("config/workflows"),
458                    std::path::PathBuf::from("config/mcps"),
459                    false,
460                )
461                .unwrap(),
462            ),
463            db: Arc::new(
464                futures::executor::block_on(crate::db::PostgresClient::new_memory()).unwrap(),
465            ),
466            tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
467                futures::executor::block_on(crate::db::PostgresClient::new_memory()).unwrap(),
468            ))),
469            llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
470                provider_registry.clone(),
471                "default",
472            )),
473            provider_registry,
474            agent_registry,
475            tool_registry,
476            auth_service: Arc::new(crate::auth::jwt::AuthService::new(
477                "secret".to_string(),
478                900,
479                604800,
480            )),
481            mcp_registry: None,
482            deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
483        };
484
485        let engine = WorkflowEngine::new(state);
486        let workflows = engine.available_workflows();
487
488        assert!(workflows.contains(&"default".to_string()));
489        assert!(workflows.contains(&"research".to_string()));
490    }
491
492    #[test]
493    fn test_get_workflow_config() {
494        let config = Arc::new(create_test_config());
495        let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
496        let tool_registry = Arc::new(ToolRegistry::new());
497        let agent_registry = Arc::new(AgentRegistry::from_config(
498            &config,
499            provider_registry.clone(),
500            tool_registry.clone(),
501        ));
502
503        // Create a dummy AppState for testing
504        let state = AppState {
505            config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
506            dynamic_config: Arc::new(
507                DynamicConfigManager::new(
508                    std::path::PathBuf::from("config/agents"),
509                    std::path::PathBuf::from("config/models"),
510                    std::path::PathBuf::from("config/tools"),
511                    std::path::PathBuf::from("config/workflows"),
512                    std::path::PathBuf::from("config/mcps"),
513                    false,
514                )
515                .unwrap(),
516            ),
517            db: Arc::new(
518                futures::executor::block_on(crate::db::PostgresClient::new_memory()).unwrap(),
519            ),
520            tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
521                futures::executor::block_on(crate::db::PostgresClient::new_memory()).unwrap(),
522            ))),
523            llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
524                provider_registry.clone(),
525                "default",
526            )),
527            provider_registry,
528            agent_registry,
529            tool_registry,
530            auth_service: Arc::new(crate::auth::jwt::AuthService::new(
531                "secret".to_string(),
532                900,
533                604800,
534            )),
535            mcp_registry: None,
536            deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
537        };
538
539        let engine = WorkflowEngine::new(state);
540
541        let default_config = engine.get_workflow_config("default").unwrap();
542        assert_eq!(default_config.entry_agent, "router");
543        assert_eq!(
544            default_config.fallback_agent,
545            Some("orchestrator".to_string())
546        );
547        assert_eq!(default_config.max_depth, 3);
548
549        let research_config = engine.get_workflow_config("research").unwrap();
550        assert_eq!(research_config.entry_agent, "orchestrator");
551        assert!(research_config.parallel_subagents);
552    }
553
554    #[test]
555    fn test_workflow_output_serialization() {
556        let output = WorkflowOutput {
557            final_response: "Test response".to_string(),
558            steps_executed: 2,
559            agents_used: vec!["router".to_string(), "product".to_string()],
560            reasoning_path: vec![
561                WorkflowStep {
562                    agent_name: "router".to_string(),
563                    input: "What products do we have?".to_string(),
564                    output: "product".to_string(),
565                    timestamp: 1702500000,
566                    duration_ms: 150,
567                },
568                WorkflowStep {
569                    agent_name: "product".to_string(),
570                    input: "What products do we have?".to_string(),
571                    output: "Test response".to_string(),
572                    timestamp: 1702500001,
573                    duration_ms: 500,
574                },
575            ],
576        };
577
578        let json = serde_json::to_string(&output).unwrap();
579        assert!(json.contains("Test response"));
580        assert!(json.contains("router"));
581        assert!(json.contains("product"));
582
583        let deserialized: WorkflowOutput = serde_json::from_str(&json).unwrap();
584        assert_eq!(deserialized.steps_executed, 2);
585    }
586}