1use crate::agents::Agent;
7use crate::api::handlers::user_agents::resolve_agent;
8use crate::types::{AgentContext, AgentType, AppError, Result};
9use crate::utils::toml_config::{AgentConfig, WorkflowConfig};
10use crate::AppState;
11use chrono::Utc;
12use serde::{Deserialize, Serialize};
13use utoipa::ToSchema;
14
15#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
17pub struct WorkflowOutput {
18 pub final_response: String,
20 pub steps_executed: usize,
22 pub agents_used: Vec<String>,
24 pub reasoning_path: Vec<WorkflowStep>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
30pub struct WorkflowStep {
31 pub agent_name: String,
33 pub input: String,
35 pub output: String,
37 pub timestamp: i64,
39 pub duration_ms: u64,
41}
42
43const VALID_AGENTS: &[&str] = &[
45 "product",
46 "invoice",
47 "sales",
48 "finance",
49 "hr",
50 "orchestrator",
51 "research",
52 "router",
53];
54
55pub struct WorkflowEngine {
57 state: AppState,
59}
60
61impl WorkflowEngine {
62 pub fn new(state: AppState) -> Self {
64 Self { state }
65 }
66
67 fn parse_routing_decision(output: &str) -> Option<String> {
75 let trimmed = output.trim().to_lowercase();
76
77 if VALID_AGENTS.contains(&trimmed.as_str()) {
79 return Some(trimmed);
80 }
81
82 for word in trimmed.split(|c: char| c.is_whitespace() || c == ':' || c == ',' || c == '.') {
85 let word = word.trim();
86 if VALID_AGENTS.contains(&word) {
87 return Some(word.to_string());
88 }
89 }
90
91 for agent in VALID_AGENTS {
93 if trimmed.contains(agent) {
94 return Some(agent.to_string());
95 }
96 }
97
98 None
99 }
100
101 pub async fn execute_workflow(
113 &self,
114 workflow_name: &str,
115 user_input: &str,
116 context: &AgentContext,
117 ) -> Result<WorkflowOutput> {
118 let config = self.state.config_manager.config();
120 let workflow = config.get_workflow(workflow_name).ok_or_else(|| {
121 AppError::Configuration(format!(
122 "Workflow '{}' not found in configuration",
123 workflow_name
124 ))
125 })?;
126
127 let mut steps = Vec::new();
128 let mut agents_used = Vec::new();
129 let current_input = user_input.to_string();
130 let mut current_agent_name = workflow.entry_agent.clone();
131 let mut depth = 0;
132
133 while depth < workflow.max_depth {
135 let step_start = std::time::Instant::now();
136 let timestamp = Utc::now().timestamp();
137
138 let (user_agent, _source) =
140 match resolve_agent(&self.state, &context.user_id, ¤t_agent_name).await {
141 Ok(res) => res,
142 Err(e) => {
143 if let Some(ref fallback) = workflow.fallback_agent {
145 tracing::warn!(
146 "Failed to resolve agent '{}', using fallback '{}'",
147 current_agent_name,
148 fallback
149 );
150 current_agent_name = fallback.clone();
151 resolve_agent(&self.state, &context.user_id, fallback).await?
152 } else {
153 return Err(e);
154 }
155 }
156 };
157
158 let agent_config = AgentConfig {
160 model: user_agent.model.clone(),
161 system_prompt: user_agent.system_prompt.clone(),
162 tools: user_agent.tools_vec(),
163 max_tool_iterations: user_agent.max_tool_iterations as usize,
164 parallel_tools: user_agent.parallel_tools,
165 extra: std::collections::HashMap::new(),
166 };
167
168 let agent = self
170 .state
171 .agent_registry
172 .create_agent_from_config(¤t_agent_name, &agent_config)
173 .await?;
174
175 let output = agent.execute(¤t_input, context).await?;
177 let duration_ms = step_start.elapsed().as_millis() as u64;
178
179 steps.push(WorkflowStep {
181 agent_name: current_agent_name.clone(),
182 input: current_input.clone(),
183 output: output.clone(),
184 timestamp,
185 duration_ms,
186 });
187
188 if !agents_used.contains(¤t_agent_name) {
189 agents_used.push(current_agent_name.clone());
190 }
191
192 if agent.agent_type() == AgentType::Router {
194 let next_agent = Self::parse_routing_decision(&output);
197
198 if let Some(ref agent_name) = next_agent {
199 if resolve_agent(&self.state, &context.user_id, agent_name)
201 .await
202 .is_ok()
203 {
204 current_agent_name = agent_name.clone();
205 depth += 1;
207 continue;
208 }
209 }
210
211 if let Some(ref fallback) = workflow.fallback_agent {
213 tracing::warn!(
215 "Routed agent '{:?}' not found or invalid, using fallback '{}'",
216 next_agent,
217 fallback
218 );
219 current_agent_name = fallback.clone();
220 depth += 1;
221 continue;
222 } else {
223 break;
225 }
226 }
227
228 break;
230 }
231
232 let final_response = steps
234 .last()
235 .map(|s| s.output.clone())
236 .unwrap_or_else(|| "No response generated".to_string());
237
238 Ok(WorkflowOutput {
239 final_response,
240 steps_executed: steps.len(),
241 agents_used,
242 reasoning_path: steps,
243 })
244 }
245
246 pub fn available_workflows(&self) -> Vec<String> {
248 self.state
249 .config_manager
250 .config()
251 .workflows
252 .keys()
253 .cloned()
254 .collect()
255 }
256
257 pub fn has_workflow(&self, name: &str) -> bool {
259 self.state
260 .config_manager
261 .config()
262 .workflows
263 .contains_key(name)
264 }
265
266 pub fn get_workflow_config(&self, name: &str) -> Option<WorkflowConfig> {
268 self.state
269 .config_manager
270 .config()
271 .get_workflow(name)
272 .cloned()
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::llm::ProviderRegistry;
280 use crate::tools::registry::ToolRegistry;
281 use crate::utils::toml_config::{
282 AgentConfig, AresConfig, AuthConfig, DatabaseConfig, ModelConfig, ProviderConfig,
283 RagConfig, ServerConfig,
284 };
285 use crate::{AgentRegistry, AresConfigManager, DynamicConfigManager};
286 use std::collections::HashMap;
287 use std::sync::Arc;
288
289 fn create_test_config() -> AresConfig {
290 let mut providers = HashMap::new();
291 providers.insert(
292 "ollama-local".to_string(),
293 ProviderConfig::Ollama {
294 base_url: "http://localhost:11434".to_string(),
295 default_model: "ministral-3:3b".to_string(),
296 },
297 );
298
299 let mut models = HashMap::new();
300 models.insert(
301 "default".to_string(),
302 ModelConfig {
303 provider: "ollama-local".to_string(),
304 model: "ministral-3:3b".to_string(),
305 temperature: 0.7,
306 max_tokens: 512,
307 top_p: None,
308 frequency_penalty: None,
309 presence_penalty: None,
310 },
311 );
312
313 let mut agents = HashMap::new();
314 agents.insert(
315 "router".to_string(),
316 AgentConfig {
317 model: "default".to_string(),
318 system_prompt: Some("Route queries to the appropriate agent.".to_string()),
319 tools: vec![],
320 max_tool_iterations: 1,
321 parallel_tools: false,
322 extra: HashMap::new(),
323 },
324 );
325 agents.insert(
326 "orchestrator".to_string(),
327 AgentConfig {
328 model: "default".to_string(),
329 system_prompt: Some("Handle complex queries.".to_string()),
330 tools: vec![],
331 max_tool_iterations: 10,
332 parallel_tools: false,
333 extra: HashMap::new(),
334 },
335 );
336 agents.insert(
337 "product".to_string(),
338 AgentConfig {
339 model: "default".to_string(),
340 system_prompt: Some("Handle product queries.".to_string()),
341 tools: vec![],
342 max_tool_iterations: 5,
343 parallel_tools: false,
344 extra: HashMap::new(),
345 },
346 );
347
348 let mut workflows = HashMap::new();
349 workflows.insert(
350 "default".to_string(),
351 WorkflowConfig {
352 entry_agent: "router".to_string(),
353 fallback_agent: Some("orchestrator".to_string()),
354 max_depth: 3,
355 max_iterations: 5,
356 parallel_subagents: false,
357 },
358 );
359 workflows.insert(
360 "research".to_string(),
361 WorkflowConfig {
362 entry_agent: "orchestrator".to_string(),
363 fallback_agent: None,
364 max_depth: 3,
365 max_iterations: 10,
366 parallel_subagents: true,
367 },
368 );
369
370 AresConfig {
371 server: ServerConfig::default(),
372 auth: AuthConfig::default(),
373 database: DatabaseConfig::default(),
374 config: crate::utils::toml_config::DynamicConfigPaths::default(),
375 providers,
376 models,
377 tools: HashMap::new(),
378 agents,
379 workflows,
380 rag: RagConfig::default(),
381 }
382 }
383
384 #[test]
385 fn test_workflow_engine_creation() {
386 let config = Arc::new(create_test_config());
387 let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
388 let tool_registry = Arc::new(ToolRegistry::new());
389 let agent_registry = Arc::new(AgentRegistry::from_config(
390 &config,
391 provider_registry.clone(),
392 tool_registry.clone(),
393 ));
394
395 let state = AppState {
397 config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
398 dynamic_config: Arc::new(
399 DynamicConfigManager::new(
400 std::path::PathBuf::from("config/agents"),
401 std::path::PathBuf::from("config/models"),
402 std::path::PathBuf::from("config/tools"),
403 std::path::PathBuf::from("config/workflows"),
404 std::path::PathBuf::from("config/mcps"),
405 false,
406 )
407 .unwrap(),
408 ),
409 turso: Arc::new(
410 futures::executor::block_on(crate::db::TursoClient::new_memory()).unwrap(),
411 ),
412 llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
413 provider_registry.clone(),
414 "default",
415 )),
416 provider_registry,
417 agent_registry,
418 tool_registry,
419 auth_service: Arc::new(crate::auth::jwt::AuthService::new(
420 "secret".to_string(),
421 900,
422 604800,
423 )),
424 };
425
426 let engine = WorkflowEngine::new(state);
427
428 assert!(engine.has_workflow("default"));
429 assert!(engine.has_workflow("research"));
430 assert!(!engine.has_workflow("nonexistent"));
431 }
432
433 #[test]
434 fn test_available_workflows() {
435 let config = Arc::new(create_test_config());
436 let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
437 let tool_registry = Arc::new(ToolRegistry::new());
438 let agent_registry = Arc::new(AgentRegistry::from_config(
439 &config,
440 provider_registry.clone(),
441 tool_registry.clone(),
442 ));
443
444 let state = AppState {
446 config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
447 dynamic_config: Arc::new(
448 DynamicConfigManager::new(
449 std::path::PathBuf::from("config/agents"),
450 std::path::PathBuf::from("config/models"),
451 std::path::PathBuf::from("config/tools"),
452 std::path::PathBuf::from("config/workflows"),
453 std::path::PathBuf::from("config/mcps"),
454 false,
455 )
456 .unwrap(),
457 ),
458 turso: Arc::new(
459 futures::executor::block_on(crate::db::TursoClient::new_memory()).unwrap(),
460 ),
461 llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
462 provider_registry.clone(),
463 "default",
464 )),
465 provider_registry,
466 agent_registry,
467 tool_registry,
468 auth_service: Arc::new(crate::auth::jwt::AuthService::new(
469 "secret".to_string(),
470 900,
471 604800,
472 )),
473 };
474
475 let engine = WorkflowEngine::new(state);
476 let workflows = engine.available_workflows();
477
478 assert!(workflows.contains(&"default".to_string()));
479 assert!(workflows.contains(&"research".to_string()));
480 }
481
482 #[test]
483 fn test_get_workflow_config() {
484 let config = Arc::new(create_test_config());
485 let provider_registry = Arc::new(ProviderRegistry::from_config(&config));
486 let tool_registry = Arc::new(ToolRegistry::new());
487 let agent_registry = Arc::new(AgentRegistry::from_config(
488 &config,
489 provider_registry.clone(),
490 tool_registry.clone(),
491 ));
492
493 let state = AppState {
495 config_manager: Arc::new(AresConfigManager::from_config((*config).clone())),
496 dynamic_config: Arc::new(
497 DynamicConfigManager::new(
498 std::path::PathBuf::from("config/agents"),
499 std::path::PathBuf::from("config/models"),
500 std::path::PathBuf::from("config/tools"),
501 std::path::PathBuf::from("config/workflows"),
502 std::path::PathBuf::from("config/mcps"),
503 false,
504 )
505 .unwrap(),
506 ),
507 turso: Arc::new(
508 futures::executor::block_on(crate::db::TursoClient::new_memory()).unwrap(),
509 ),
510 llm_factory: Arc::new(crate::ConfigBasedLLMFactory::new(
511 provider_registry.clone(),
512 "default",
513 )),
514 provider_registry,
515 agent_registry,
516 tool_registry,
517 auth_service: Arc::new(crate::auth::jwt::AuthService::new(
518 "secret".to_string(),
519 900,
520 604800,
521 )),
522 };
523
524 let engine = WorkflowEngine::new(state);
525
526 let default_config = engine.get_workflow_config("default").unwrap();
527 assert_eq!(default_config.entry_agent, "router");
528 assert_eq!(
529 default_config.fallback_agent,
530 Some("orchestrator".to_string())
531 );
532 assert_eq!(default_config.max_depth, 3);
533
534 let research_config = engine.get_workflow_config("research").unwrap();
535 assert_eq!(research_config.entry_agent, "orchestrator");
536 assert!(research_config.parallel_subagents);
537 }
538
539 #[test]
540 fn test_workflow_output_serialization() {
541 let output = WorkflowOutput {
542 final_response: "Test response".to_string(),
543 steps_executed: 2,
544 agents_used: vec!["router".to_string(), "product".to_string()],
545 reasoning_path: vec![
546 WorkflowStep {
547 agent_name: "router".to_string(),
548 input: "What products do we have?".to_string(),
549 output: "product".to_string(),
550 timestamp: 1702500000,
551 duration_ms: 150,
552 },
553 WorkflowStep {
554 agent_name: "product".to_string(),
555 input: "What products do we have?".to_string(),
556 output: "Test response".to_string(),
557 timestamp: 1702500001,
558 duration_ms: 500,
559 },
560 ],
561 };
562
563 let json = serde_json::to_string(&output).unwrap();
564 assert!(json.contains("Test response"));
565 assert!(json.contains("router"));
566 assert!(json.contains("product"));
567
568 let deserialized: WorkflowOutput = serde_json::from_str(&json).unwrap();
569 assert_eq!(deserialized.steps_executed, 2);
570 }
571}