1use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct HandoffDecision {
33 pub can_handle: bool,
35 pub recommended_agent: Option<String>,
37 pub reason: Option<String>,
39 pub confidence: f32,
41}
42
43impl HandoffDecision {
44 pub fn accept(confidence: f32) -> Self {
46 Self {
47 can_handle: true,
48 recommended_agent: None,
49 reason: None,
50 confidence: confidence.clamp(0.0, 1.0),
51 }
52 }
53
54 pub fn handoff(recommended_agent: &str, reason: &str) -> Self {
56 Self {
57 can_handle: false,
58 recommended_agent: Some(recommended_agent.to_string()),
59 reason: Some(reason.to_string()),
60 confidence: 0.0,
61 }
62 }
63
64 pub fn uncertain(confidence: f32, recommended_agent: Option<&str>) -> Self {
66 Self {
67 can_handle: confidence > 0.5,
68 recommended_agent: recommended_agent.map(|s| s.to_string()),
69 reason: Some("Low confidence in task match".to_string()),
70 confidence: confidence.clamp(0.0, 1.0),
71 }
72 }
73
74 pub fn needs_handoff(&self) -> bool {
76 !self.can_handle && self.recommended_agent.is_some()
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct AgentCapability {
83 pub name: String,
85 pub domain: String,
87 pub keywords: Vec<String>,
89 pub tools: Vec<String>,
91 pub priority: u32,
93 pub description: Option<String>,
95}
96
97impl AgentCapability {
98 pub fn new(
100 name: impl Into<String>,
101 domain: impl Into<String>,
102 keywords: Vec<impl Into<String>>,
103 tools: Vec<impl Into<String>>,
104 ) -> Self {
105 Self {
106 name: name.into(),
107 domain: domain.into(),
108 keywords: keywords.into_iter().map(|s| s.into()).collect(),
109 tools: tools.into_iter().map(|s| s.into()).collect(),
110 priority: 0,
111 description: None,
112 }
113 }
114
115 pub fn with_priority(mut self, priority: u32) -> Self {
117 self.priority = priority;
118 self
119 }
120
121 pub fn with_description(mut self, description: impl Into<String>) -> Self {
123 self.description = Some(description.into());
124 self
125 }
126
127 pub fn match_score(&self, task: &str) -> f32 {
129 let task_lower = task.to_lowercase();
130 let mut score = 0.0f32;
131 let mut matches = 0;
132
133 for keyword in &self.keywords {
135 if task_lower.contains(&keyword.to_lowercase()) {
136 matches += 1;
137 }
138 }
139
140 if !self.keywords.is_empty() {
141 score = matches as f32 / self.keywords.len() as f32;
142 }
143
144 if task_lower.contains(&self.domain.to_lowercase()) {
146 score = (score + 0.3).min(1.0);
147 }
148
149 if self.priority > 0 {
151 score = (score + (self.priority as f32 * 0.01)).min(1.0);
152 }
153
154 score
155 }
156
157 pub fn has_tool(&self, tool: &str) -> bool {
159 self.tools.iter().any(|t| t == tool)
160 }
161
162 pub fn has_any_tool(&self, tools: &[&str]) -> bool {
164 tools.iter().any(|t| self.has_tool(t))
165 }
166}
167
168#[derive(Debug, Clone, Default)]
170pub struct AgentRegistry {
171 agents: HashMap<String, AgentCapability>,
172}
173
174impl AgentRegistry {
175 pub fn new() -> Self {
177 Self::default()
178 }
179
180 pub fn register(&mut self, capability: AgentCapability) {
182 self.agents.insert(capability.name.clone(), capability);
183 }
184
185 pub fn unregister(&mut self, name: &str) -> Option<AgentCapability> {
187 self.agents.remove(name)
188 }
189
190 pub fn get(&self, name: &str) -> Option<&AgentCapability> {
192 self.agents.get(name)
193 }
194
195 pub fn find_best_agent(&self, task: &str, exclude: Option<&str>) -> Option<(String, f32)> {
197 let mut best: Option<(String, f32)> = None;
198
199 for (name, agent) in &self.agents {
200 if let Some(excluded) = exclude {
202 if name == excluded {
203 continue;
204 }
205 }
206
207 let score = agent.match_score(task);
208 if score > 0.0 && (best.is_none() || score > best.as_ref().unwrap().1) {
209 best = Some((name.clone(), score));
210 }
211 }
212
213 best
214 }
215
216 pub fn find_by_tool(&self, tool: &str) -> Vec<&AgentCapability> {
218 self.agents.values().filter(|a| a.has_tool(tool)).collect()
219 }
220
221 pub fn find_by_domain(&self, domain: &str) -> Vec<&AgentCapability> {
223 let domain_lower = domain.to_lowercase();
224 self.agents
225 .values()
226 .filter(|a| a.domain.to_lowercase() == domain_lower)
227 .collect()
228 }
229
230 pub fn all(&self) -> Vec<&AgentCapability> {
232 self.agents.values().collect()
233 }
234
235 pub fn len(&self) -> usize {
237 self.agents.len()
238 }
239
240 pub fn is_empty(&self) -> bool {
242 self.agents.is_empty()
243 }
244}
245
246pub struct HandoffRouter {
248 registry: AgentRegistry,
249 min_confidence: f32,
251 handoff_history: Vec<HandoffRecord>,
253 max_history: usize,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct HandoffRecord {
260 pub from_agent: String,
262 pub to_agent: String,
264 pub task: String,
266 pub score: f32,
268 pub timestamp: i64,
270}
271
272impl HandoffRouter {
273 pub fn new(registry: AgentRegistry) -> Self {
275 Self {
276 registry,
277 min_confidence: 0.3,
278 handoff_history: Vec::new(),
279 max_history: 100,
280 }
281 }
282
283 pub fn with_min_confidence(mut self, threshold: f32) -> Self {
285 self.min_confidence = threshold.clamp(0.0, 1.0);
286 self
287 }
288
289 pub fn route(&self, task: &str) -> Option<(String, f32)> {
291 self.registry.find_best_agent(task, None)
292 }
293
294 pub fn route_excluding(&self, task: &str, exclude: &str) -> Option<(String, f32)> {
296 self.registry.find_best_agent(task, Some(exclude))
297 }
298
299 pub fn evaluate(&self, agent: &str, task: &str) -> HandoffDecision {
301 let current_agent = self.registry.get(agent);
302 let current_score = current_agent.map(|a| a.match_score(task)).unwrap_or(0.0);
303
304 if current_score >= self.min_confidence {
306 return HandoffDecision::accept(current_score);
307 }
308
309 if let Some((best_agent, best_score)) = self.route_excluding(task, agent) {
311 if best_score > current_score {
312 return HandoffDecision::handoff(
313 &best_agent,
314 &format!(
315 "Agent '{}' is better suited for this task (score: {:.2} vs {:.2})",
316 best_agent, best_score, current_score
317 ),
318 );
319 }
320 }
321
322 HandoffDecision::uncertain(current_score, None)
324 }
325
326 pub fn record_handoff(&mut self, from: &str, to: &str, task: &str, score: f32) {
328 let record = HandoffRecord {
329 from_agent: from.to_string(),
330 to_agent: to.to_string(),
331 task: task.to_string(),
332 score,
333 timestamp: chrono::Utc::now().timestamp(),
334 };
335
336 self.handoff_history.push(record);
337
338 if self.handoff_history.len() > self.max_history {
340 self.handoff_history.remove(0);
341 }
342 }
343
344 pub fn stats(&self) -> HandoffStats {
346 let mut stats = HandoffStats::default();
347 let mut agent_counts: HashMap<String, u32> = HashMap::new();
348
349 for record in &self.handoff_history {
350 stats.total_handoffs += 1;
351 stats.total_score += record.score;
352 *agent_counts.entry(record.to_agent.clone()).or_insert(0) += 1;
353 }
354
355 if stats.total_handoffs > 0 {
356 stats.avg_score = stats.total_score / stats.total_handoffs as f32;
357 }
358
359 if let Some((agent, count)) = agent_counts.iter().max_by_key(|(_, c)| *c) {
361 stats.most_common_target = Some(agent.clone());
362 stats.most_common_count = *count;
363 }
364
365 stats
366 }
367
368 pub fn registry(&self) -> &AgentRegistry {
370 &self.registry
371 }
372
373 pub fn registry_mut(&mut self) -> &mut AgentRegistry {
375 &mut self.registry
376 }
377}
378
379#[derive(Debug, Clone, Default, Serialize, Deserialize)]
381pub struct HandoffStats {
382 pub total_handoffs: u32,
383 pub avg_score: f32,
384 pub total_score: f32,
385 pub most_common_target: Option<String>,
386 pub most_common_count: u32,
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 fn create_test_registry() -> AgentRegistry {
394 let mut registry = AgentRegistry::new();
395
396 registry.register(AgentCapability::new(
397 "code-expert",
398 "programming",
399 vec!["rust", "python", "code", "debug", "compile"],
400 vec!["execute_code", "analyze_code"],
401 ));
402
403 registry.register(AgentCapability::new(
404 "data-analyst",
405 "data",
406 vec!["data", "analysis", "statistics", "chart", "graph"],
407 vec!["query_database", "create_chart"],
408 ));
409
410 registry.register(AgentCapability::new(
411 "writer",
412 "writing",
413 vec!["write", "edit", "document", "essay", "article"],
414 vec!["search_web", "create_document"],
415 ));
416
417 registry
418 }
419
420 #[test]
421 fn test_capability_match_score() {
422 let agent = AgentCapability::new(
423 "code-expert",
424 "programming",
425 vec!["rust", "python", "code"],
426 Vec::<String>::new(),
427 );
428
429 assert!(agent.match_score("help me write rust code") > 0.3);
430 assert!(agent.match_score("python debugging") > 0.0);
431 assert_eq!(agent.match_score("cook me dinner"), 0.0);
432 }
433
434 #[test]
435 fn test_registry_find_best() {
436 let registry = create_test_registry();
437
438 let (agent, score) = registry
439 .find_best_agent("help me debug this rust code", None)
440 .unwrap();
441 assert_eq!(agent, "code-expert");
442 assert!(score > 0.3);
443
444 let (agent, _) = registry
445 .find_best_agent("analyze the sales data", None)
446 .unwrap();
447 assert_eq!(agent, "data-analyst");
448
449 let (agent, _) = registry
450 .find_best_agent("write an article about AI", None)
451 .unwrap();
452 assert_eq!(agent, "writer");
453 }
454
455 #[test]
456 fn test_registry_exclude() {
457 let registry = create_test_registry();
458
459 let result = registry.find_best_agent("rust programming", Some("code-expert"));
461 assert!(result.is_none() || result.as_ref().unwrap().0 != "code-expert");
462 }
463
464 #[test]
465 fn test_handoff_decision() {
466 let decision = HandoffDecision::accept(0.8);
467 assert!(decision.can_handle);
468 assert!(!decision.needs_handoff());
469
470 let decision = HandoffDecision::handoff("other-agent", "better match");
471 assert!(!decision.can_handle);
472 assert!(decision.needs_handoff());
473 assert_eq!(decision.recommended_agent, Some("other-agent".to_string()));
474 }
475
476 #[test]
477 fn test_router_evaluate() {
478 let registry = create_test_registry();
479 let router = HandoffRouter::new(registry).with_min_confidence(0.3);
480
481 let decision = router.evaluate("code-expert", "debug this rust code");
483 assert!(decision.can_handle);
484 assert!(decision.confidence > 0.3);
485
486 let decision = router.evaluate("writer", "debug this rust code");
488 if decision.needs_handoff() {
489 assert_eq!(decision.recommended_agent, Some("code-expert".to_string()));
490 }
491 }
492
493 #[test]
494 fn test_router_route() {
495 let registry = create_test_registry();
496 let router = HandoffRouter::new(registry);
497
498 let (agent, score) = router.route("analyze the data and create a chart").unwrap();
499 assert_eq!(agent, "data-analyst");
500 assert!(score > 0.0);
501 }
502
503 #[test]
504 fn test_find_by_tool() {
505 let registry = create_test_registry();
506
507 let agents = registry.find_by_tool("execute_code");
508 assert_eq!(agents.len(), 1);
509 assert_eq!(agents[0].name, "code-expert");
510
511 let agents = registry.find_by_tool("search_web");
512 assert_eq!(agents.len(), 1);
513 assert_eq!(agents[0].name, "writer");
514 }
515
516 #[test]
517 fn test_handoff_stats() {
518 let registry = create_test_registry();
519 let mut router = HandoffRouter::new(registry);
520
521 router.record_handoff("writer", "code-expert", "debug code", 0.8);
522 router.record_handoff("data-analyst", "code-expert", "fix bug", 0.9);
523 router.record_handoff("writer", "data-analyst", "analyze data", 0.7);
524
525 let stats = router.stats();
526 assert_eq!(stats.total_handoffs, 3);
527 assert!(stats.avg_score > 0.7);
528 assert_eq!(stats.most_common_target, Some("code-expert".to_string()));
529 assert_eq!(stats.most_common_count, 2);
530 }
531}