ares/workflows/
engine.rs

1//! Workflow Engine
2//!
3//! Executes declarative workflows by orchestrating agent execution based on
4//! TOML configuration.
5
6use crate::agents::Agent;
7use crate::api::handlers::user_agents::resolve_agent;
8use crate::types::{AgentContext, AgentType, AppError, Result};
9use crate::utils::toml_config::{AgentConfig, WorkflowConfig};
10use crate::AppState;
11use chrono::Utc;
12use serde::{Deserialize, Serialize};
13use utoipa::ToSchema;
14
15/// Output from a workflow execution
16#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
17pub struct WorkflowOutput {
18    /// The final response from the workflow
19    pub final_response: String,
20    /// Number of steps executed
21    pub steps_executed: usize,
22    /// List of agent names that were used
23    pub agents_used: Vec<String>,
24    /// Detailed reasoning path showing each step
25    pub reasoning_path: Vec<WorkflowStep>,
26}
27
28/// A single step in the workflow execution
29#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
30pub struct WorkflowStep {
31    /// The agent that executed this step
32    pub agent_name: String,
33    /// The input provided to the agent
34    pub input: String,
35    /// The output from the agent
36    pub output: String,
37    /// Unix timestamp when this step was executed
38    pub timestamp: i64,
39    /// Duration of this step in milliseconds
40    pub duration_ms: u64,
41}
42
43/// Valid agent names for routing
44const VALID_AGENTS: &[&str] = &[
45    "product",
46    "invoice",
47    "sales",
48    "finance",
49    "hr",
50    "orchestrator",
51    "research",
52    "router",
53];
54
55/// Workflow engine that orchestrates agent execution
56pub struct WorkflowEngine {
57    /// Application state for resolving agents
58    state: AppState,
59}
60
61impl WorkflowEngine {
62    /// Create a new workflow engine
63    pub fn new(state: AppState) -> Self {
64        Self { state }
65    }
66
67    /// Parse routing decision from router output
68    ///
69    /// This handles various output formats:
70    /// - Clean output: "product"
71    /// - With whitespace: "  product  "
72    /// - With extra text: "I would route this to product"
73    /// - Agent suffix: "product agent"
74    fn parse_routing_decision(output: &str) -> Option<String> {
75        let trimmed = output.trim().to_lowercase();
76
77        // First, try exact match
78        if VALID_AGENTS.contains(&trimmed.as_str()) {
79            return Some(trimmed);
80        }
81
82        // Try to extract valid agent name from output
83        // Split by common delimiters and check each word
84        for word in trimmed.split(|c: char| c.is_whitespace() || c == ':' || c == ',' || c == '.') {
85            let word = word.trim();
86            if VALID_AGENTS.contains(&word) {
87                return Some(word.to_string());
88            }
89        }
90
91        // Check if any valid agent name is contained in the output
92        for agent in VALID_AGENTS {
93            if trimmed.contains(agent) {
94                return Some(agent.to_string());
95            }
96        }
97
98        None
99    }
100
101    /// Execute a workflow by name
102    ///
103    /// # Arguments
104    ///
105    /// * `workflow_name` - The name of the workflow to execute (e.g., "default", "research")
106    /// * `user_input` - The user's query or input
107    /// * `context` - The agent context with user info and conversation history
108    ///
109    /// # Returns
110    ///
111    /// A `WorkflowOutput` containing the final response and execution details.
112    pub async fn execute_workflow(
113        &self,
114        workflow_name: &str,
115        user_input: &str,
116        context: &AgentContext,
117    ) -> Result<WorkflowOutput> {
118        // Get workflow configuration
119        let config = self.state.config_manager.config();
120        let workflow = config.get_workflow(workflow_name).ok_or_else(|| {
121            AppError::Configuration(format!(
122                "Workflow '{}' not found in configuration",
123                workflow_name
124            ))
125        })?;
126
127        let mut steps = Vec::new();
128        let mut agents_used = Vec::new();
129        let current_input = user_input.to_string();
130        let mut current_agent_name = workflow.entry_agent.clone();
131        let mut depth = 0;
132
133        // Execute workflow with depth limiting
134        while depth < workflow.max_depth {
135            let step_start = std::time::Instant::now();
136            let timestamp = Utc::now().timestamp();
137
138            // Resolve agent using the 3-tier hierarchy
139            let (user_agent, _source) =
140                match resolve_agent(&self.state, &context.user_id, &current_agent_name).await {
141                    Ok(res) => res,
142                    Err(e) => {
143                        // Try fallback agent if available
144                        if let Some(ref fallback) = workflow.fallback_agent {
145                            tracing::warn!(
146                                "Failed to resolve agent '{}', using fallback '{}'",
147                                current_agent_name,
148                                fallback
149                            );
150                            current_agent_name = fallback.clone();
151                            resolve_agent(&self.state, &context.user_id, fallback).await?
152                        } else {
153                            return Err(e);
154                        }
155                    }
156                };
157
158            // Convert UserAgent to AgentConfig
159            let agent_config = AgentConfig {
160                model: user_agent.model.clone(),
161                system_prompt: user_agent.system_prompt.clone(),
162                tools: user_agent.tools_vec(),
163                max_tool_iterations: user_agent.max_tool_iterations as usize,
164                parallel_tools: user_agent.parallel_tools,
165                extra: std::collections::HashMap::new(),
166            };
167
168            // Create the agent
169            let agent = self
170                .state
171                .agent_registry
172                .create_agent_from_config(&current_agent_name, &agent_config)
173                .await?;
174
175            // Execute the agent
176            let output = agent.execute(&current_input, context).await?;
177            let duration_ms = step_start.elapsed().as_millis() as u64;
178
179            // Record this step
180            steps.push(WorkflowStep {
181                agent_name: current_agent_name.clone(),
182                input: current_input.clone(),
183                output: output.clone(),
184                timestamp,
185                duration_ms,
186            });
187
188            if !agents_used.contains(&current_agent_name) {
189                agents_used.push(current_agent_name.clone());
190            }
191
192            // Check if the agent is a router and needs to delegate
193            if agent.agent_type() == AgentType::Router {
194                // Router's output should be an agent name
195                // Use robust parsing to handle various output formats
196                let next_agent = Self::parse_routing_decision(&output);
197
198                if let Some(ref agent_name) = next_agent {
199                    // Validate the routed agent exists (check hierarchy)
200                    if resolve_agent(&self.state, &context.user_id, agent_name)
201                        .await
202                        .is_ok()
203                    {
204                        current_agent_name = agent_name.clone();
205                        // Keep the original user input for the routed agent
206                        depth += 1;
207                        continue;
208                    }
209                }
210
211                // Agent not found or couldn't parse - try fallback
212                if let Some(ref fallback) = workflow.fallback_agent {
213                    // Use fallback if routed agent doesn't exist
214                    tracing::warn!(
215                        "Routed agent '{:?}' not found or invalid, using fallback '{}'",
216                        next_agent,
217                        fallback
218                    );
219                    current_agent_name = fallback.clone();
220                    depth += 1;
221                    continue;
222                } else {
223                    // No fallback, return the router's output as final
224                    break;
225                }
226            }
227
228            // Non-router agent - this is the final response
229            break;
230        }
231
232        // Build the final output
233        let final_response = steps
234            .last()
235            .map(|s| s.output.clone())
236            .unwrap_or_else(|| "No response generated".to_string());
237
238        Ok(WorkflowOutput {
239            final_response,
240            steps_executed: steps.len(),
241            agents_used,
242            reasoning_path: steps,
243        })
244    }
245
246    /// Get available workflow names
247    pub fn available_workflows(&self) -> Vec<String> {
248        self.state
249            .config_manager
250            .config()
251            .workflows
252            .keys()
253            .cloned()
254            .collect()
255    }
256
257    /// Check if a workflow exists
258    pub fn has_workflow(&self, name: &str) -> bool {
259        self.state
260            .config_manager
261            .config()
262            .workflows
263            .contains_key(name)
264    }
265
266    /// Get workflow configuration
267    pub fn get_workflow_config(&self, name: &str) -> Option<WorkflowConfig> {
268        self.state
269            .config_manager
270            .config()
271            .get_workflow(name)
272            .cloned()
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::llm::ProviderRegistry;
280    use crate::tools::registry::ToolRegistry;
281    use crate::utils::toml_config::{
282        AgentConfig, AresConfig, AuthConfig, DatabaseConfig, ModelConfig, ProviderConfig,
283        RagConfig, ServerConfig,
284    };
285    use crate::{AgentRegistry, AresConfigManager, DynamicConfigManager};
286    use std::collections::HashMap;
287    use std::sync::Arc;
288
289    fn create_test_config() -> AresConfig {
290        let mut providers = HashMap::new();
291        providers.insert(
292            "ollama-local".to_string(),
293            ProviderConfig::Ollama {
294                base_url: "http://localhost:11434".to_string(),
295                default_model: "ministral-3:3b".to_string(),
296            },
297        );
298
299        let mut models = HashMap::new();
300        models.insert(
301            "default".to_string(),
302            ModelConfig {
303                provider: "ollama-local".to_string(),
304                model: "ministral-3:3b".to_string(),
305                temperature: 0.7,
306                max_tokens: 512,
307                top_p: None,
308                frequency_penalty: None,
309                presence_penalty: None,
310            },
311        );
312
313        let mut agents = HashMap::new();
314        agents.insert(
315            "router".to_string(),
316            AgentConfig {
317                model: "default".to_string(),
318                system_prompt: Some("Route queries to the appropriate agent.".to_string()),
319                tools: vec![],
320                max_tool_iterations: 1,
321                parallel_tools: false,
322                extra: HashMap::new(),
323            },
324        );
325        agents.insert(
326            "orchestrator".to_string(),
327            AgentConfig {
328                model: "default".to_string(),
329                system_prompt: Some("Handle complex queries.".to_string()),
330                tools: vec![],
331                max_tool_iterations: 10,
332                parallel_tools: false,
333                extra: HashMap::new(),
334            },
335        );
336        agents.insert(
337            "product".to_string(),
338            AgentConfig {
339                model: "default".to_string(),
340                system_prompt: Some("Handle product queries.".to_string()),
341                tools: vec![],
342                max_tool_iterations: 5,
343                parallel_tools: false,
344                extra: HashMap::new(),
345            },
346        );
347
348        let mut workflows = HashMap::new();
349        workflows.insert(
350            "default".to_string(),
351            WorkflowConfig {
352                entry_agent: "router".to_string(),
353                fallback_agent: Some("orchestrator".to_string()),
354                max_depth: 3,
355                max_iterations: 5,
356                parallel_subagents: false,
357            },
358        );
359        workflows.insert(
360            "research".to_string(),
361            WorkflowConfig {
362                entry_agent: "orchestrator".to_string(),
363                fallback_agent: None,
364                max_depth: 3,
365                max_iterations: 10,
366                parallel_subagents: true,
367            },
368        );
369
370        AresConfig {
371            server: ServerConfig::default(),
372            auth: AuthConfig::default(),
373            database: DatabaseConfig::default(),
374            config: crate::utils::toml_config::DynamicConfigPaths::default(),
375            providers,
376            models,
377            tools: HashMap::new(),
378            agents,
379            workflows,
380            rag: RagConfig::default(),
381        }
382    }
383
384    #[test]
385    fn test_workflow_engine_creation() {
386        let config = Arc::new(create_test_config());
387        let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
388        let tool_registry = Arc::new(ToolRegistry::new());
389        let agent_registry = Arc::new(AgentRegistry::from_config(
390            &config,
391            provider_registry.clone(),
392            tool_registry.clone(),
393        ));
394
395        // Create a dummy AppState for testing
396        let state = AppState {
397            config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
398            dynamic_config: Arc::new(
399                DynamicConfigManager::new(
400                    std::path::PathBuf::from("config/agents"),
401                    std::path::PathBuf::from("config/models"),
402                    std::path::PathBuf::from("config/tools"),
403                    std::path::PathBuf::from("config/workflows"),
404                    std::path::PathBuf::from("config/mcps"),
405                    false,
406                )
407                .unwrap(),
408            ),
409            turso: Arc::new(
410                futures::executor::block_on(crate::db::TursoClient::new_memory()).unwrap(),
411            ),
412            llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
413                provider_registry.clone(),
414                "default",
415            )),
416            provider_registry,
417            agent_registry,
418            tool_registry,
419            auth_service: Arc::new(crate::auth::jwt::AuthService::new(
420                "secret".to_string(),
421                900,
422                604800,
423            )),
424        };
425
426        let engine = WorkflowEngine::new(state);
427
428        assert!(engine.has_workflow("default"));
429        assert!(engine.has_workflow("research"));
430        assert!(!engine.has_workflow("nonexistent"));
431    }
432
433    #[test]
434    fn test_available_workflows() {
435        let config = Arc::new(create_test_config());
436        let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
437        let tool_registry = Arc::new(ToolRegistry::new());
438        let agent_registry = Arc::new(AgentRegistry::from_config(
439            &config,
440            provider_registry.clone(),
441            tool_registry.clone(),
442        ));
443
444        // Create a dummy AppState for testing
445        let state = AppState {
446            config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
447            dynamic_config: Arc::new(
448                DynamicConfigManager::new(
449                    std::path::PathBuf::from("config/agents"),
450                    std::path::PathBuf::from("config/models"),
451                    std::path::PathBuf::from("config/tools"),
452                    std::path::PathBuf::from("config/workflows"),
453                    std::path::PathBuf::from("config/mcps"),
454                    false,
455                )
456                .unwrap(),
457            ),
458            turso: Arc::new(
459                futures::executor::block_on(crate::db::TursoClient::new_memory()).unwrap(),
460            ),
461            llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
462                provider_registry.clone(),
463                "default",
464            )),
465            provider_registry,
466            agent_registry,
467            tool_registry,
468            auth_service: Arc::new(crate::auth::jwt::AuthService::new(
469                "secret".to_string(),
470                900,
471                604800,
472            )),
473        };
474
475        let engine = WorkflowEngine::new(state);
476        let workflows = engine.available_workflows();
477
478        assert!(workflows.contains(&"default".to_string()));
479        assert!(workflows.contains(&"research".to_string()));
480    }
481
482    #[test]
483    fn test_get_workflow_config() {
484        let config = Arc::new(create_test_config());
485        let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
486        let tool_registry = Arc::new(ToolRegistry::new());
487        let agent_registry = Arc::new(AgentRegistry::from_config(
488            &config,
489            provider_registry.clone(),
490            tool_registry.clone(),
491        ));
492
493        // Create a dummy AppState for testing
494        let state = AppState {
495            config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
496            dynamic_config: Arc::new(
497                DynamicConfigManager::new(
498                    std::path::PathBuf::from("config/agents"),
499                    std::path::PathBuf::from("config/models"),
500                    std::path::PathBuf::from("config/tools"),
501                    std::path::PathBuf::from("config/workflows"),
502                    std::path::PathBuf::from("config/mcps"),
503                    false,
504                )
505                .unwrap(),
506            ),
507            turso: Arc::new(
508                futures::executor::block_on(crate::db::TursoClient::new_memory()).unwrap(),
509            ),
510            llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
511                provider_registry.clone(),
512                "default",
513            )),
514            provider_registry,
515            agent_registry,
516            tool_registry,
517            auth_service: Arc::new(crate::auth::jwt::AuthService::new(
518                "secret".to_string(),
519                900,
520                604800,
521            )),
522        };
523
524        let engine = WorkflowEngine::new(state);
525
526        let default_config = engine.get_workflow_config("default").unwrap();
527        assert_eq!(default_config.entry_agent, "router");
528        assert_eq!(
529            default_config.fallback_agent,
530            Some("orchestrator".to_string())
531        );
532        assert_eq!(default_config.max_depth, 3);
533
534        let research_config = engine.get_workflow_config("research").unwrap();
535        assert_eq!(research_config.entry_agent, "orchestrator");
536        assert!(research_config.parallel_subagents);
537    }
538
539    #[test]
540    fn test_workflow_output_serialization() {
541        let output = WorkflowOutput {
542            final_response: "Test response".to_string(),
543            steps_executed: 2,
544            agents_used: vec!["router".to_string(), "product".to_string()],
545            reasoning_path: vec![
546                WorkflowStep {
547                    agent_name: "router".to_string(),
548                    input: "What products do we have?".to_string(),
549                    output: "product".to_string(),
550                    timestamp: 1702500000,
551                    duration_ms: 150,
552                },
553                WorkflowStep {
554                    agent_name: "product".to_string(),
555                    input: "What products do we have?".to_string(),
556                    output: "Test response".to_string(),
557                    timestamp: 1702500001,
558                    duration_ms: 500,
559                },
560            ],
561        };
562
563        let json = serde_json::to_string(&output).unwrap();
564        assert!(json.contains("Test response"));
565        assert!(json.contains("router"));
566        assert!(json.contains("product"));
567
568        let deserialized: WorkflowOutput = serde_json::from_str(&json).unwrap();
569        assert_eq!(deserialized.steps_executed, 2);
570    }
571}