Skip to main content

ares/workflows/
engine.rs

1//! Workflow Engine
2//!
3//! Executes declarative workflows by orchestrating agent execution based on
4//! TOML configuration.
5
6use crate::agents::Agent;
7use crate::api::handlers::user_agents::resolve_agent;
8use crate::types::{AgentContext, AgentType, AppError, Result};
9use crate::utils::toml_config::{AgentConfig, WorkflowConfig};
10use crate::AppState;
11use chrono::Utc;
12use serde::{Deserialize, Serialize};
13use utoipa::ToSchema;
14
15/// Output from a workflow execution
16#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
17pub struct WorkflowOutput {
18    /// The final response from the workflow
19    pub final_response: String,
20    /// Number of steps executed
21    pub steps_executed: usize,
22    /// List of agent names that were used
23    pub agents_used: Vec<String>,
24    /// Detailed reasoning path showing each step
25    pub reasoning_path: Vec<WorkflowStep>,
26}
27
28/// A single step in the workflow execution
29#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
30pub struct WorkflowStep {
31    /// The agent that executed this step
32    pub agent_name: String,
33    /// The input provided to the agent
34    pub input: String,
35    /// The output from the agent
36    pub output: String,
37    /// Unix timestamp when this step was executed
38    pub timestamp: i64,
39    /// Duration of this step in milliseconds
40    pub duration_ms: u64,
41}
42
43/// Valid agent names for routing
44const VALID_AGENTS: &[&str] = &[
45    "product",
46    "invoice",
47    "sales",
48    "finance",
49    "hr",
50    "orchestrator",
51    "research",
52    "router",
53];
54
55/// Workflow engine that orchestrates agent execution
56pub struct WorkflowEngine {
57    /// Application state for resolving agents
58    state: AppState,
59}
60
61impl WorkflowEngine {
62    /// Create a new workflow engine
63    pub fn new(state: AppState) -> Self {
64        Self { state }
65    }
66
67    /// Parse routing decision from router output
68    ///
69    /// This handles various output formats:
70    /// - Clean output: "product"
71    /// - With whitespace: "  product  "
72    /// - With extra text: "I would route this to product"
73    /// - Agent suffix: "product agent"
74    fn parse_routing_decision(output: &str) -> Option<String> {
75        let trimmed = output.trim().to_lowercase();
76
77        // First, try exact match
78        if VALID_AGENTS.contains(&trimmed.as_str()) {
79            return Some(trimmed);
80        }
81
82        // Try to extract valid agent name from output
83        // Split by common delimiters and check each word
84        for word in trimmed.split(|c: char| c.is_whitespace() || c == ':' || c == ',' || c == '.') {
85            let word = word.trim();
86            if VALID_AGENTS.contains(&word) {
87                return Some(word.to_string());
88            }
89        }
90
91        // Check if any valid agent name is contained in the output
92        for agent in VALID_AGENTS {
93            if trimmed.contains(agent) {
94                return Some(agent.to_string());
95            }
96        }
97
98        None
99    }
100
101    /// Execute a workflow by name
102    ///
103    /// # Arguments
104    ///
105    /// * `workflow_name` - The name of the workflow to execute (e.g., "default", "research")
106    /// * `user_input` - The user's query or input
107    /// * `context` - The agent context with user info and conversation history
108    ///
109    /// # Returns
110    ///
111    /// A `WorkflowOutput` containing the final response and execution details.
112    pub async fn execute_workflow(
113        &self,
114        workflow_name: &str,
115        user_input: &str,
116        context: &AgentContext,
117    ) -> Result<WorkflowOutput> {
118        // Get workflow configuration
119        let config = self.state.config_manager.config();
120        let workflow = config.get_workflow(workflow_name).ok_or_else(|| {
121            AppError::Configuration(format!(
122                "Workflow '{}' not found in configuration",
123                workflow_name
124            ))
125        })?;
126
127        let mut steps = Vec::new();
128        let mut agents_used = Vec::new();
129        let current_input = user_input.to_string();
130        let mut current_agent_name = workflow.entry_agent.clone();
131        let mut depth = 0;
132
133        // Execute workflow with depth limiting
134        while depth < workflow.max_depth {
135            let step_start = std::time::Instant::now();
136            let timestamp = Utc::now().timestamp();
137
138            // Resolve agent using the 3-tier hierarchy
139            let (user_agent, _source) = match resolve_agent(
140                &self.state,
141                &context.user_id,
142                current_agent_name.clone(),
143            )
144            .await
145            {
146                Ok(res) => res,
147                Err(e) => {
148                    // Try fallback agent if available
149                    if let Some(ref fallback) = workflow.fallback_agent {
150                        tracing::warn!(
151                            "Failed to resolve agent '{}', using fallback '{}'",
152                            current_agent_name,
153                            fallback
154                        );
155                        current_agent_name = fallback.clone();
156                        resolve_agent(&self.state, &context.user_id, fallback.clone()).await?
157                    } else {
158                        return Err(e);
159                    }
160                }
161            };
162
163            // Convert UserAgent to AgentConfig
164            let agent_config = AgentConfig {
165                model: user_agent.model.clone(),
166                system_prompt: user_agent.system_prompt.clone(),
167                tools: user_agent.tools_vec(),
168                max_tool_iterations: user_agent.max_tool_iterations as usize,
169                parallel_tools: user_agent.parallel_tools,
170                extra: std::collections::HashMap::new(),
171            };
172
173            // Create the agent
174            let agent = self
175                .state
176                .agent_registry
177                .create_agent_from_config(&current_agent_name, &agent_config)
178                .await?;
179
180            // Execute the agent
181            let agent_resp = agent.execute(&current_input, context).await?;
182            let output = agent_resp.content;
183            let duration_ms = step_start.elapsed().as_millis() as u64;
184
185            // Record this step
186            steps.push(WorkflowStep {
187                agent_name: current_agent_name.clone(),
188                input: current_input.clone(),
189                output: output.clone(),
190                timestamp,
191                duration_ms,
192            });
193
194            if !agents_used.contains(&current_agent_name) {
195                agents_used.push(current_agent_name.clone());
196            }
197
198            // Check if the agent is a router and needs to delegate
199            if agent.agent_type() == AgentType::Router {
200                // Router's output should be an agent name
201                // Use robust parsing to handle various output formats
202                let next_agent = Self::parse_routing_decision(&output);
203
204                if let Some(ref agent_name) = next_agent {
205                    // Validate the routed agent exists (check hierarchy)
206                    if resolve_agent(&self.state, &context.user_id, agent_name.clone())
207                        .await
208                        .is_ok()
209                    {
210                        current_agent_name = agent_name.clone();
211                        // Keep the original user input for the routed agent
212                        depth += 1;
213                        continue;
214                    }
215                }
216
217                // Agent not found or couldn't parse - try fallback
218                if let Some(ref fallback) = workflow.fallback_agent {
219                    // Use fallback if routed agent doesn't exist
220                    tracing::warn!(
221                        "Routed agent '{:?}' not found or invalid, using fallback '{}'",
222                        next_agent,
223                        fallback
224                    );
225                    current_agent_name = fallback.clone();
226                    depth += 1;
227                    continue;
228                } else {
229                    // No fallback, return the router's output as final
230                    break;
231                }
232            }
233
234            // Non-router agent - this is the final response
235            break;
236        }
237
238        // Build the final output
239        let final_response = steps
240            .last()
241            .map(|s| s.output.clone())
242            .unwrap_or_else(|| "No response generated".to_string());
243
244        Ok(WorkflowOutput {
245            final_response,
246            steps_executed: steps.len(),
247            agents_used,
248            reasoning_path: steps,
249        })
250    }
251
252    /// Get available workflow names
253    pub fn available_workflows(&self) -> Vec<String> {
254        self.state
255            .config_manager
256            .config()
257            .workflows
258            .keys()
259            .cloned()
260            .collect()
261    }
262
263    /// Check if a workflow exists
264    pub fn has_workflow(&self, name: &str) -> bool {
265        self.state
266            .config_manager
267            .config()
268            .workflows
269            .contains_key(name)
270    }
271
272    /// Get workflow configuration
273    pub fn get_workflow_config(&self, name: &str) -> Option<WorkflowConfig> {
274        self.state
275            .config_manager
276            .config()
277            .get_workflow(name)
278            .cloned()
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::llm::ProviderRegistry;
286    use crate::tools::registry::ToolRegistry;
287    use crate::utils::toml_config::{
288        AgentConfig, AresConfig, AuthConfig, DatabaseConfig, ModelConfig, ProviderConfig,
289        RagConfig, ServerConfig,
290    };
291    use crate::{AgentRegistry, AresConfigManager, DynamicConfigManager};
292    use std::collections::HashMap;
293    use std::sync::Arc;
294
295    fn create_test_config() -> AresConfig {
296        let mut providers = HashMap::new();
297        providers.insert(
298            "ollama-local".to_string(),
299            ProviderConfig::Ollama {
300                base_url: "http://localhost:11434".to_string(),
301                default_model: "ministral-3:3b".to_string(),
302            },
303        );
304
305        let mut models = HashMap::new();
306        models.insert(
307            "default".to_string(),
308            ModelConfig {
309                provider: "ollama-local".to_string(),
310                model: "ministral-3:3b".to_string(),
311                temperature: 0.7,
312                max_tokens: 512,
313                top_p: None,
314                frequency_penalty: None,
315                presence_penalty: None,
316            },
317        );
318
319        let mut agents = HashMap::new();
320        agents.insert(
321            "router".to_string(),
322            AgentConfig {
323                model: "default".to_string(),
324                system_prompt: Some("Route queries to the appropriate agent.".to_string()),
325                tools: vec![],
326                max_tool_iterations: 1,
327                parallel_tools: false,
328                extra: HashMap::new(),
329            },
330        );
331        agents.insert(
332            "orchestrator".to_string(),
333            AgentConfig {
334                model: "default".to_string(),
335                system_prompt: Some("Handle complex queries.".to_string()),
336                tools: vec![],
337                max_tool_iterations: 10,
338                parallel_tools: false,
339                extra: HashMap::new(),
340            },
341        );
342        agents.insert(
343            "product".to_string(),
344            AgentConfig {
345                model: "default".to_string(),
346                system_prompt: Some("Handle product queries.".to_string()),
347                tools: vec![],
348                max_tool_iterations: 5,
349                parallel_tools: false,
350                extra: HashMap::new(),
351            },
352        );
353
354        let mut workflows = HashMap::new();
355        workflows.insert(
356            "default".to_string(),
357            WorkflowConfig {
358                entry_agent: "router".to_string(),
359                fallback_agent: Some("orchestrator".to_string()),
360                max_depth: 3,
361                max_iterations: 5,
362                parallel_subagents: false,
363            },
364        );
365        workflows.insert(
366            "research".to_string(),
367            WorkflowConfig {
368                entry_agent: "orchestrator".to_string(),
369                fallback_agent: None,
370                max_depth: 3,
371                max_iterations: 10,
372                parallel_subagents: true,
373            },
374        );
375
376        AresConfig {
377            server: ServerConfig::default(),
378            auth: AuthConfig::default(),
379            database: DatabaseConfig::default(),
380            config: crate::utils::toml_config::DynamicConfigPaths::default(),
381            providers,
382            models,
383            tools: HashMap::new(),
384            agents,
385            workflows,
386            rag: RagConfig::default(),
387            #[cfg(feature = "skills")]
388            skills: None,
389        }
390    }
391
392    #[tokio::test]
393    async fn test_workflow_engine_creation() {
394        let config = Arc::new(create_test_config());
395        let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
396        let tool_registry = Arc::new(ToolRegistry::new());
397        let agent_registry = Arc::new(AgentRegistry::from_config(
398            &config,
399            provider_registry.clone(),
400            tool_registry.clone(),
401        ));
402
403        // Create a dummy AppState for testing
404        let state = AppState {
405            config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
406            dynamic_config: Arc::new(
407                DynamicConfigManager::new(
408                    std::path::PathBuf::from("config/agents"),
409                    std::path::PathBuf::from("config/models"),
410                    std::path::PathBuf::from("config/tools"),
411                    std::path::PathBuf::from("config/workflows"),
412                    std::path::PathBuf::from("config/mcps"),
413                    false,
414                )
415                .unwrap(),
416            ),
417            db: Arc::new(crate::db::PostgresClient::new_test()),
418            tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
419                crate::db::PostgresClient::new_test(),
420            ))),
421            llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
422                provider_registry.clone(),
423                "default",
424            )),
425            provider_registry,
426            agent_registry,
427            tool_registry,
428            auth_service: Arc::new(crate::auth::jwt::AuthService::new(
429                "secret".to_string(),
430                900,
431                604800,
432            )),
433            mcp_registry: None,
434            deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
435            emergency_stop: Arc::new(std::sync::atomic::AtomicBool::new(false)),
436            context_provider: Arc::new(crate::agents::NoOpContextProvider),
437        };
438
439        let engine = WorkflowEngine::new(state);
440
441        assert!(engine.has_workflow("default"));
442        assert!(engine.has_workflow("research"));
443        assert!(!engine.has_workflow("nonexistent"));
444    }
445
446    #[tokio::test]
447    async fn test_available_workflows() {
448        let config = Arc::new(create_test_config());
449        let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
450        let tool_registry = Arc::new(ToolRegistry::new());
451        let agent_registry = Arc::new(AgentRegistry::from_config(
452            &config,
453            provider_registry.clone(),
454            tool_registry.clone(),
455        ));
456
457        // Create a dummy AppState for testing
458        let state = AppState {
459            config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
460            dynamic_config: Arc::new(
461                DynamicConfigManager::new(
462                    std::path::PathBuf::from("config/agents"),
463                    std::path::PathBuf::from("config/models"),
464                    std::path::PathBuf::from("config/tools"),
465                    std::path::PathBuf::from("config/workflows"),
466                    std::path::PathBuf::from("config/mcps"),
467                    false,
468                )
469                .unwrap(),
470            ),
471            db: Arc::new(crate::db::PostgresClient::new_test()),
472            tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
473                crate::db::PostgresClient::new_test(),
474            ))),
475            llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
476                provider_registry.clone(),
477                "default",
478            )),
479            provider_registry,
480            agent_registry,
481            tool_registry,
482            auth_service: Arc::new(crate::auth::jwt::AuthService::new(
483                "secret".to_string(),
484                900,
485                604800,
486            )),
487            mcp_registry: None,
488            deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
489            emergency_stop: Arc::new(std::sync::atomic::AtomicBool::new(false)),
490            context_provider: Arc::new(crate::agents::NoOpContextProvider),
491        };
492
493        let engine = WorkflowEngine::new(state);
494        let workflows = engine.available_workflows();
495
496        assert!(workflows.contains(&"default".to_string()));
497        assert!(workflows.contains(&"research".to_string()));
498    }
499
500    #[tokio::test]
501    async fn test_get_workflow_config() {
502        let config = Arc::new(create_test_config());
503        let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
504        let tool_registry = Arc::new(ToolRegistry::new());
505        let agent_registry = Arc::new(AgentRegistry::from_config(
506            &config,
507            provider_registry.clone(),
508            tool_registry.clone(),
509        ));
510
511        // Create a dummy AppState for testing
512        let state = AppState {
513            config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
514            dynamic_config: Arc::new(
515                DynamicConfigManager::new(
516                    std::path::PathBuf::from("config/agents"),
517                    std::path::PathBuf::from("config/models"),
518                    std::path::PathBuf::from("config/tools"),
519                    std::path::PathBuf::from("config/workflows"),
520                    std::path::PathBuf::from("config/mcps"),
521                    false,
522                )
523                .unwrap(),
524            ),
525            db: Arc::new(crate::db::PostgresClient::new_test()),
526            tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
527                crate::db::PostgresClient::new_test(),
528            ))),
529            llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
530                provider_registry.clone(),
531                "default",
532            )),
533            provider_registry,
534            agent_registry,
535            tool_registry,
536            auth_service: Arc::new(crate::auth::jwt::AuthService::new(
537                "secret".to_string(),
538                900,
539                604800,
540            )),
541            mcp_registry: None,
542            deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
543            emergency_stop: Arc::new(std::sync::atomic::AtomicBool::new(false)),
544            context_provider: Arc::new(crate::agents::NoOpContextProvider),
545        };
546
547        let engine = WorkflowEngine::new(state);
548
549        let default_config = engine.get_workflow_config("default").unwrap();
550        assert_eq!(default_config.entry_agent, "router");
551        assert_eq!(
552            default_config.fallback_agent,
553            Some("orchestrator".to_string())
554        );
555        assert_eq!(default_config.max_depth, 3);
556
557        let research_config = engine.get_workflow_config("research").unwrap();
558        assert_eq!(research_config.entry_agent, "orchestrator");
559        assert!(research_config.parallel_subagents);
560    }
561
562    #[test]
563    fn test_workflow_output_serialization() {
564        let output = WorkflowOutput {
565            final_response: "Test response".to_string(),
566            steps_executed: 2,
567            agents_used: vec!["router".to_string(), "product".to_string()],
568            reasoning_path: vec![
569                WorkflowStep {
570                    agent_name: "router".to_string(),
571                    input: "What products do we have?".to_string(),
572                    output: "product".to_string(),
573                    timestamp: 1702500000,
574                    duration_ms: 150,
575                },
576                WorkflowStep {
577                    agent_name: "product".to_string(),
578                    input: "What products do we have?".to_string(),
579                    output: "Test response".to_string(),
580                    timestamp: 1702500001,
581                    duration_ms: 500,
582                },
583            ],
584        };
585
586        let json = serde_json::to_string(&output).unwrap();
587        assert!(json.contains("Test response"));
588        assert!(json.contains("router"));
589        assert!(json.contains("product"));
590
591        let deserialized: WorkflowOutput = serde_json::from_str(&json).unwrap();
592        assert_eq!(deserialized.steps_executed, 2);
593    }
594}