1use crate::agents::Agent;
7use crate::api::handlers::user_agents::resolve_agent;
8use crate::types::{AgentContext, AgentType, AppError, Result};
9use crate::utils::toml_config::{AgentConfig, WorkflowConfig};
10use crate::AppState;
11use chrono::Utc;
12use serde::{Deserialize, Serialize};
13use utoipa::ToSchema;
14
15#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
17pub struct WorkflowOutput {
18 pub final_response: String,
20 pub steps_executed: usize,
22 pub agents_used: Vec<String>,
24 pub reasoning_path: Vec<WorkflowStep>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
30pub struct WorkflowStep {
31 pub agent_name: String,
33 pub input: String,
35 pub output: String,
37 pub timestamp: i64,
39 pub duration_ms: u64,
41}
42
43const VALID_AGENTS: &[&str] = &[
45 "product",
46 "invoice",
47 "sales",
48 "finance",
49 "hr",
50 "orchestrator",
51 "research",
52 "router",
53];
54
55pub struct WorkflowEngine {
57 state: AppState,
59}
60
61impl WorkflowEngine {
62 pub fn new(state: AppState) -> Self {
64 Self { state }
65 }
66
67 fn parse_routing_decision(output: &str) -> Option<String> {
75 let trimmed = output.trim().to_lowercase();
76
77 if VALID_AGENTS.contains(&trimmed.as_str()) {
79 return Some(trimmed);
80 }
81
82 for word in trimmed.split(|c: char| c.is_whitespace() || c == ':' || c == ',' || c == '.') {
85 let word = word.trim();
86 if VALID_AGENTS.contains(&word) {
87 return Some(word.to_string());
88 }
89 }
90
91 for agent in VALID_AGENTS {
93 if trimmed.contains(agent) {
94 return Some(agent.to_string());
95 }
96 }
97
98 None
99 }
100
101 pub async fn execute_workflow(
113 &self,
114 workflow_name: &str,
115 user_input: &str,
116 context: &AgentContext,
117 ) -> Result<WorkflowOutput> {
118 let config = self.state.config_manager.config();
120 let workflow = config.get_workflow(workflow_name).ok_or_else(|| {
121 AppError::Configuration(format!(
122 "Workflow '{}' not found in configuration",
123 workflow_name
124 ))
125 })?;
126
127 let mut steps = Vec::new();
128 let mut agents_used = Vec::new();
129 let current_input = user_input.to_string();
130 let mut current_agent_name = workflow.entry_agent.clone();
131 let mut depth = 0;
132
133 while depth < workflow.max_depth {
135 let step_start = std::time::Instant::now();
136 let timestamp = Utc::now().timestamp();
137
138 let (user_agent, _source) = match resolve_agent(
140 &self.state,
141 &context.user_id,
142 current_agent_name.clone(),
143 )
144 .await
145 {
146 Ok(res) => res,
147 Err(e) => {
148 if let Some(ref fallback) = workflow.fallback_agent {
150 tracing::warn!(
151 "Failed to resolve agent '{}', using fallback '{}'",
152 current_agent_name,
153 fallback
154 );
155 current_agent_name = fallback.clone();
156 resolve_agent(&self.state, &context.user_id, fallback.clone()).await?
157 } else {
158 return Err(e);
159 }
160 }
161 };
162
163 let agent_config = AgentConfig {
165 model: user_agent.model.clone(),
166 system_prompt: user_agent.system_prompt.clone(),
167 tools: user_agent.tools_vec(),
168 max_tool_iterations: user_agent.max_tool_iterations as usize,
169 parallel_tools: user_agent.parallel_tools,
170 extra: std::collections::HashMap::new(),
171 };
172
173 let agent = self
175 .state
176 .agent_registry
177 .create_agent_from_config(¤t_agent_name, &agent_config)
178 .await?;
179
180 let agent_resp = agent.execute(¤t_input, context).await?;
182 let output = agent_resp.content;
183 let duration_ms = step_start.elapsed().as_millis() as u64;
184
185 steps.push(WorkflowStep {
187 agent_name: current_agent_name.clone(),
188 input: current_input.clone(),
189 output: output.clone(),
190 timestamp,
191 duration_ms,
192 });
193
194 if !agents_used.contains(¤t_agent_name) {
195 agents_used.push(current_agent_name.clone());
196 }
197
198 if agent.agent_type() == AgentType::Router {
200 let next_agent = Self::parse_routing_decision(&output);
203
204 if let Some(ref agent_name) = next_agent {
205 if resolve_agent(&self.state, &context.user_id, agent_name.clone())
207 .await
208 .is_ok()
209 {
210 current_agent_name = agent_name.clone();
211 depth += 1;
213 continue;
214 }
215 }
216
217 if let Some(ref fallback) = workflow.fallback_agent {
219 tracing::warn!(
221 "Routed agent '{:?}' not found or invalid, using fallback '{}'",
222 next_agent,
223 fallback
224 );
225 current_agent_name = fallback.clone();
226 depth += 1;
227 continue;
228 } else {
229 break;
231 }
232 }
233
234 break;
236 }
237
238 let final_response = steps
240 .last()
241 .map(|s| s.output.clone())
242 .unwrap_or_else(|| "No response generated".to_string());
243
244 Ok(WorkflowOutput {
245 final_response,
246 steps_executed: steps.len(),
247 agents_used,
248 reasoning_path: steps,
249 })
250 }
251
252 pub fn available_workflows(&self) -> Vec<String> {
254 self.state
255 .config_manager
256 .config()
257 .workflows
258 .keys()
259 .cloned()
260 .collect()
261 }
262
263 pub fn has_workflow(&self, name: &str) -> bool {
265 self.state
266 .config_manager
267 .config()
268 .workflows
269 .contains_key(name)
270 }
271
272 pub fn get_workflow_config(&self, name: &str) -> Option<WorkflowConfig> {
274 self.state
275 .config_manager
276 .config()
277 .get_workflow(name)
278 .cloned()
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::llm::ProviderRegistry;
286 use crate::tools::registry::ToolRegistry;
287 use crate::utils::toml_config::{
288 AgentConfig, AresConfig, AuthConfig, DatabaseConfig, ModelConfig, ProviderConfig,
289 RagConfig, ServerConfig,
290 };
291 use crate::{AgentRegistry, AresConfigManager, DynamicConfigManager};
292 use std::collections::HashMap;
293 use std::sync::Arc;
294
295 fn create_test_config() -> AresConfig {
296 let mut providers = HashMap::new();
297 providers.insert(
298 "ollama-local".to_string(),
299 ProviderConfig::Ollama {
300 base_url: "http://localhost:11434".to_string(),
301 default_model: "ministral-3:3b".to_string(),
302 },
303 );
304
305 let mut models = HashMap::new();
306 models.insert(
307 "default".to_string(),
308 ModelConfig {
309 provider: "ollama-local".to_string(),
310 model: "ministral-3:3b".to_string(),
311 temperature: 0.7,
312 max_tokens: 512,
313 top_p: None,
314 frequency_penalty: None,
315 presence_penalty: None,
316 },
317 );
318
319 let mut agents = HashMap::new();
320 agents.insert(
321 "router".to_string(),
322 AgentConfig {
323 model: "default".to_string(),
324 system_prompt: Some("Route queries to the appropriate agent.".to_string()),
325 tools: vec![],
326 max_tool_iterations: 1,
327 parallel_tools: false,
328 extra: HashMap::new(),
329 },
330 );
331 agents.insert(
332 "orchestrator".to_string(),
333 AgentConfig {
334 model: "default".to_string(),
335 system_prompt: Some("Handle complex queries.".to_string()),
336 tools: vec![],
337 max_tool_iterations: 10,
338 parallel_tools: false,
339 extra: HashMap::new(),
340 },
341 );
342 agents.insert(
343 "product".to_string(),
344 AgentConfig {
345 model: "default".to_string(),
346 system_prompt: Some("Handle product queries.".to_string()),
347 tools: vec![],
348 max_tool_iterations: 5,
349 parallel_tools: false,
350 extra: HashMap::new(),
351 },
352 );
353
354 let mut workflows = HashMap::new();
355 workflows.insert(
356 "default".to_string(),
357 WorkflowConfig {
358 entry_agent: "router".to_string(),
359 fallback_agent: Some("orchestrator".to_string()),
360 max_depth: 3,
361 max_iterations: 5,
362 parallel_subagents: false,
363 },
364 );
365 workflows.insert(
366 "research".to_string(),
367 WorkflowConfig {
368 entry_agent: "orchestrator".to_string(),
369 fallback_agent: None,
370 max_depth: 3,
371 max_iterations: 10,
372 parallel_subagents: true,
373 },
374 );
375
376 AresConfig {
377 server: ServerConfig::default(),
378 auth: AuthConfig::default(),
379 database: DatabaseConfig::default(),
380 config: crate::utils::toml_config::DynamicConfigPaths::default(),
381 providers,
382 models,
383 tools: HashMap::new(),
384 agents,
385 workflows,
386 rag: RagConfig::default(),
387 #[cfg(feature = "skills")]
388 skills: None,
389 }
390 }
391
392 #[tokio::test]
393 async fn test_workflow_engine_creation() {
394 let config = Arc::new(create_test_config());
395 let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
396 let tool_registry = Arc::new(ToolRegistry::new());
397 let agent_registry = Arc::new(AgentRegistry::from_config(
398 &config,
399 provider_registry.clone(),
400 tool_registry.clone(),
401 ));
402
403 let state = AppState {
405 config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
406 dynamic_config: Arc::new(
407 DynamicConfigManager::new(
408 std::path::PathBuf::from("config/agents"),
409 std::path::PathBuf::from("config/models"),
410 std::path::PathBuf::from("config/tools"),
411 std::path::PathBuf::from("config/workflows"),
412 std::path::PathBuf::from("config/mcps"),
413 false,
414 )
415 .unwrap(),
416 ),
417 db: Arc::new(crate::db::PostgresClient::new_test()),
418 tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
419 crate::db::PostgresClient::new_test(),
420 ))),
421 llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
422 provider_registry.clone(),
423 "default",
424 )),
425 provider_registry,
426 agent_registry,
427 tool_registry,
428 auth_service: Arc::new(crate::auth::jwt::AuthService::new(
429 "secret".to_string(),
430 900,
431 604800,
432 )),
433 mcp_registry: None,
434 deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
435 emergency_stop: Arc::new(std::sync::atomic::AtomicBool::new(false)),
436 context_provider: Arc::new(crate::agents::NoOpContextProvider),
437 };
438
439 let engine = WorkflowEngine::new(state);
440
441 assert!(engine.has_workflow("default"));
442 assert!(engine.has_workflow("research"));
443 assert!(!engine.has_workflow("nonexistent"));
444 }
445
446 #[tokio::test]
447 async fn test_available_workflows() {
448 let config = Arc::new(create_test_config());
449 let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
450 let tool_registry = Arc::new(ToolRegistry::new());
451 let agent_registry = Arc::new(AgentRegistry::from_config(
452 &config,
453 provider_registry.clone(),
454 tool_registry.clone(),
455 ));
456
457 let state = AppState {
459 config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
460 dynamic_config: Arc::new(
461 DynamicConfigManager::new(
462 std::path::PathBuf::from("config/agents"),
463 std::path::PathBuf::from("config/models"),
464 std::path::PathBuf::from("config/tools"),
465 std::path::PathBuf::from("config/workflows"),
466 std::path::PathBuf::from("config/mcps"),
467 false,
468 )
469 .unwrap(),
470 ),
471 db: Arc::new(crate::db::PostgresClient::new_test()),
472 tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
473 crate::db::PostgresClient::new_test(),
474 ))),
475 llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
476 provider_registry.clone(),
477 "default",
478 )),
479 provider_registry,
480 agent_registry,
481 tool_registry,
482 auth_service: Arc::new(crate::auth::jwt::AuthService::new(
483 "secret".to_string(),
484 900,
485 604800,
486 )),
487 mcp_registry: None,
488 deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
489 emergency_stop: Arc::new(std::sync::atomic::AtomicBool::new(false)),
490 context_provider: Arc::new(crate::agents::NoOpContextProvider),
491 };
492
493 let engine = WorkflowEngine::new(state);
494 let workflows = engine.available_workflows();
495
496 assert!(workflows.contains(&"default".to_string()));
497 assert!(workflows.contains(&"research".to_string()));
498 }
499
500 #[tokio::test]
501 async fn test_get_workflow_config() {
502 let config = Arc::new(create_test_config());
503 let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
504 let tool_registry = Arc::new(ToolRegistry::new());
505 let agent_registry = Arc::new(AgentRegistry::from_config(
506 &config,
507 provider_registry.clone(),
508 tool_registry.clone(),
509 ));
510
511 let state = AppState {
513 config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
514 dynamic_config: Arc::new(
515 DynamicConfigManager::new(
516 std::path::PathBuf::from("config/agents"),
517 std::path::PathBuf::from("config/models"),
518 std::path::PathBuf::from("config/tools"),
519 std::path::PathBuf::from("config/workflows"),
520 std::path::PathBuf::from("config/mcps"),
521 false,
522 )
523 .unwrap(),
524 ),
525 db: Arc::new(crate::db::PostgresClient::new_test()),
526 tenant_db: Arc::new(crate::db::TenantDb::new(Arc::new(
527 crate::db::PostgresClient::new_test(),
528 ))),
529 llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
530 provider_registry.clone(),
531 "default",
532 )),
533 provider_registry,
534 agent_registry,
535 tool_registry,
536 auth_service: Arc::new(crate::auth::jwt::AuthService::new(
537 "secret".to_string(),
538 900,
539 604800,
540 )),
541 mcp_registry: None,
542 deploy_registry: crate::api::handlers::deploy::new_deploy_registry(),
543 emergency_stop: Arc::new(std::sync::atomic::AtomicBool::new(false)),
544 context_provider: Arc::new(crate::agents::NoOpContextProvider),
545 };
546
547 let engine = WorkflowEngine::new(state);
548
549 let default_config = engine.get_workflow_config("default").unwrap();
550 assert_eq!(default_config.entry_agent, "router");
551 assert_eq!(
552 default_config.fallback_agent,
553 Some("orchestrator".to_string())
554 );
555 assert_eq!(default_config.max_depth, 3);
556
557 let research_config = engine.get_workflow_config("research").unwrap();
558 assert_eq!(research_config.entry_agent, "orchestrator");
559 assert!(research_config.parallel_subagents);
560 }
561
562 #[test]
563 fn test_workflow_output_serialization() {
564 let output = WorkflowOutput {
565 final_response: "Test response".to_string(),
566 steps_executed: 2,
567 agents_used: vec!["router".to_string(), "product".to_string()],
568 reasoning_path: vec![
569 WorkflowStep {
570 agent_name: "router".to_string(),
571 input: "What products do we have?".to_string(),
572 output: "product".to_string(),
573 timestamp: 1702500000,
574 duration_ms: 150,
575 },
576 WorkflowStep {
577 agent_name: "product".to_string(),
578 input: "What products do we have?".to_string(),
579 output: "Test response".to_string(),
580 timestamp: 1702500001,
581 duration_ms: 500,
582 },
583 ],
584 };
585
586 let json = serde_json::to_string(&output).unwrap();
587 assert!(json.contains("Test response"));
588 assert!(json.contains("router"));
589 assert!(json.contains("product"));
590
591 let deserialized: WorkflowOutput = serde_json::from_str(&json).unwrap();
592 assert_eq!(deserialized.steps_executed, 2);
593 }
594}