1use std::collections::HashMap;
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::time::{Duration, Instant};
38
39use async_trait::async_trait;
40use parking_lot::RwLock;
41use serde::{Deserialize, Serialize};
42use tracing::{debug, info, warn};
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct AgentToolInput {
47 pub query: String,
49 pub context: HashMap<String, String>,
51 pub history: Vec<String>,
53 pub max_tokens: Option<usize>,
55}
56
57impl AgentToolInput {
58 pub fn new(query: impl Into<String>) -> Self {
59 Self {
60 query: query.into(),
61 context: HashMap::new(),
62 history: Vec::new(),
63 max_tokens: None,
64 }
65 }
66
67 pub fn with_context(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
68 self.context.insert(key.into(), value.into());
69 self
70 }
71
72 pub fn with_history(mut self, history: Vec<String>) -> Self {
73 self.history = history;
74 self
75 }
76
77 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
78 self.max_tokens = Some(max_tokens);
79 self
80 }
81}
82
83impl From<&str> for AgentToolInput {
84 fn from(query: &str) -> Self {
85 Self::new(query)
86 }
87}
88
89impl From<String> for AgentToolInput {
90 fn from(query: String) -> Self {
91 Self::new(query)
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct AgentToolOutput {
98 pub content: String,
100 pub success: bool,
102 pub confidence: f32,
104 pub metadata: HashMap<String, String>,
106 pub duration_ms: u64,
108 pub tools_used: Vec<String>,
110}
111
112impl AgentToolOutput {
113 pub fn success(content: impl Into<String>) -> Self {
114 Self {
115 content: content.into(),
116 success: true,
117 confidence: 1.0,
118 metadata: HashMap::new(),
119 duration_ms: 0,
120 tools_used: Vec::new(),
121 }
122 }
123
124 pub fn failure(content: impl Into<String>) -> Self {
125 Self {
126 content: content.into(),
127 success: false,
128 confidence: 0.0,
129 metadata: HashMap::new(),
130 duration_ms: 0,
131 tools_used: Vec::new(),
132 }
133 }
134
135 pub fn with_confidence(mut self, confidence: f32) -> Self {
136 self.confidence = confidence.clamp(0.0, 1.0);
137 self
138 }
139
140 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
141 self.metadata.insert(key.into(), value.into());
142 self
143 }
144
145 pub fn with_tools_used(mut self, tools: Vec<String>) -> Self {
146 self.tools_used = tools;
147 self
148 }
149}
150
151#[derive(Debug, thiserror::Error)]
153pub enum AgentToolError {
154 #[error("Agent execution failed: {0}")]
155 ExecutionFailed(String),
156
157 #[error("Agent timeout after {0:?}")]
158 Timeout(Duration),
159
160 #[error("Agent not available: {0}")]
161 NotAvailable(String),
162
163 #[error("Invalid input: {0}")]
164 InvalidInput(String),
165
166 #[error("Rate limited: retry after {0:?}")]
167 RateLimited(Duration),
168}
169
170#[async_trait]
172pub trait AgentToolHandler: Send + Sync {
173 async fn execute(&self, input: AgentToolInput) -> Result<AgentToolOutput, AgentToolError>;
175
176 fn name(&self) -> &str;
178
179 fn description(&self) -> &str;
181
182 fn is_available(&self) -> bool {
184 true
185 }
186
187 fn capabilities(&self) -> Vec<String> {
189 Vec::new()
190 }
191}
192
193pub type BoxedHandler = Arc<dyn AgentToolHandler>;
195
196#[derive(Debug, Clone)]
198pub struct AgentToolConfig {
199 pub timeout: Duration,
201 pub max_retries: u32,
203 pub cache_enabled: bool,
205 pub cache_ttl: Duration,
207}
208
209impl Default for AgentToolConfig {
210 fn default() -> Self {
211 Self {
212 timeout: Duration::from_secs(60),
213 max_retries: 2,
214 cache_enabled: false,
215 cache_ttl: Duration::from_secs(300),
216 }
217 }
218}
219
220#[derive(Debug, Clone, Default, Serialize, Deserialize)]
222pub struct AgentToolStats {
223 pub total_calls: u64,
224 pub successful_calls: u64,
225 pub failed_calls: u64,
226 pub timeouts: u64,
227 pub total_duration_ms: u64,
228 pub cache_hits: u64,
229}
230
231impl AgentToolStats {
232 pub fn average_duration_ms(&self) -> f64 {
233 if self.total_calls == 0 {
234 0.0
235 } else {
236 self.total_duration_ms as f64 / self.total_calls as f64
237 }
238 }
239
240 pub fn success_rate(&self) -> f64 {
241 if self.total_calls == 0 {
242 0.0
243 } else {
244 self.successful_calls as f64 / self.total_calls as f64
245 }
246 }
247}
248
249pub struct AgentTool {
251 name: String,
252 description: String,
253 handler: BoxedHandler,
254 config: AgentToolConfig,
255 stats: Arc<RwLock<AgentToolStats>>,
256 cache: Arc<RwLock<HashMap<String, (AgentToolOutput, Instant)>>>,
257}
258
259impl AgentTool {
260 pub fn new<H: AgentToolHandler + 'static>(handler: H) -> Self {
262 Self {
263 name: handler.name().to_string(),
264 description: handler.description().to_string(),
265 handler: Arc::new(handler),
266 config: AgentToolConfig::default(),
267 stats: Arc::new(RwLock::new(AgentToolStats::default())),
268 cache: Arc::new(RwLock::new(HashMap::new())),
269 }
270 }
271
272 pub fn builder(name: impl Into<String>) -> AgentToolBuilder {
274 AgentToolBuilder::new(name)
275 }
276
277 pub fn with_config(mut self, config: AgentToolConfig) -> Self {
278 self.config = config;
279 self
280 }
281
282 pub async fn call(
284 &self,
285 input: impl Into<AgentToolInput>,
286 ) -> Result<AgentToolOutput, AgentToolError> {
287 let input = input.into();
288 let start = Instant::now();
289
290 if self.config.cache_enabled {
292 let cache_key = self.cache_key(&input);
293 if let Some(cached) = self.get_cached(&cache_key) {
294 self.stats.write().cache_hits += 1;
295 return Ok(cached);
296 }
297 }
298
299 let mut last_error = None;
301 for attempt in 0..=self.config.max_retries {
302 if attempt > 0 {
303 debug!(name = %self.name, attempt, "Retrying agent tool");
304 }
305
306 match tokio::time::timeout(self.config.timeout, self.handler.execute(input.clone()))
307 .await
308 {
309 Ok(Ok(mut output)) => {
310 let duration = start.elapsed();
311 output.duration_ms = duration.as_millis() as u64;
312
313 {
315 let mut stats = self.stats.write();
316 stats.total_calls += 1;
317 stats.successful_calls += 1;
318 stats.total_duration_ms += output.duration_ms;
319 }
320
321 if self.config.cache_enabled {
323 let cache_key = self.cache_key(&input);
324 self.set_cached(cache_key, output.clone());
325 }
326
327 return Ok(output);
328 }
329 Ok(Err(e)) => {
330 warn!(name = %self.name, error = %e, attempt, "Agent tool execution failed");
331 last_error = Some(e);
332 }
333 Err(_) => {
334 self.stats.write().timeouts += 1;
335 last_error = Some(AgentToolError::Timeout(self.config.timeout));
336 }
337 }
338 }
339
340 {
342 let mut stats = self.stats.write();
343 stats.total_calls += 1;
344 stats.failed_calls += 1;
345 stats.total_duration_ms += start.elapsed().as_millis() as u64;
346 }
347
348 Err(last_error
349 .unwrap_or_else(|| AgentToolError::ExecutionFailed("Unknown error".to_string())))
350 }
351
352 pub fn name(&self) -> &str {
353 &self.name
354 }
355
356 pub fn description(&self) -> &str {
357 &self.description
358 }
359
360 pub fn stats(&self) -> AgentToolStats {
361 self.stats.read().clone()
362 }
363
364 pub fn is_available(&self) -> bool {
365 self.handler.is_available()
366 }
367
368 pub fn capabilities(&self) -> Vec<String> {
369 self.handler.capabilities()
370 }
371
372 fn cache_key(&self, input: &AgentToolInput) -> String {
373 use std::hash::{Hash, Hasher};
374 let mut hasher = std::collections::hash_map::DefaultHasher::new();
375 input.query.hash(&mut hasher);
376 for (k, v) in &input.context {
377 k.hash(&mut hasher);
378 v.hash(&mut hasher);
379 }
380 format!("{}:{}", self.name, hasher.finish())
381 }
382
383 fn get_cached(&self, key: &str) -> Option<AgentToolOutput> {
384 let cache = self.cache.read();
385 if let Some((output, cached_at)) = cache.get(key) {
386 if cached_at.elapsed() < self.config.cache_ttl {
387 return Some(output.clone());
388 }
389 }
390 None
391 }
392
393 fn set_cached(&self, key: String, output: AgentToolOutput) {
394 let mut cache = self.cache.write();
395 cache.insert(key, (output, Instant::now()));
396
397 let ttl = self.config.cache_ttl;
399 cache.retain(|_, (_, cached_at)| cached_at.elapsed() < ttl);
400 }
401}
402
403pub struct AgentToolBuilder {
405 name: String,
406 description: String,
407 config: AgentToolConfig,
408 capabilities: Vec<String>,
409}
410
411impl AgentToolBuilder {
412 pub fn new(name: impl Into<String>) -> Self {
413 Self {
414 name: name.into(),
415 description: String::new(),
416 config: AgentToolConfig::default(),
417 capabilities: Vec::new(),
418 }
419 }
420
421 pub fn description(mut self, description: impl Into<String>) -> Self {
422 self.description = description.into();
423 self
424 }
425
426 pub fn timeout(mut self, timeout: Duration) -> Self {
427 self.config.timeout = timeout;
428 self
429 }
430
431 pub fn max_retries(mut self, max_retries: u32) -> Self {
432 self.config.max_retries = max_retries;
433 self
434 }
435
436 pub fn cache(mut self, enabled: bool, ttl: Duration) -> Self {
437 self.config.cache_enabled = enabled;
438 self.config.cache_ttl = ttl;
439 self
440 }
441
442 pub fn capability(mut self, capability: impl Into<String>) -> Self {
443 self.capabilities.push(capability.into());
444 self
445 }
446
447 pub fn handler<F, Fut>(self, handler: F) -> AgentTool
449 where
450 F: Fn(AgentToolInput) -> Fut + Send + Sync + 'static,
451 Fut: Future<Output = Result<AgentToolOutput, AgentToolError>> + Send + 'static,
452 {
453 let fn_handler = FnAgentHandler {
454 name: self.name.clone(),
455 description: self.description.clone(),
456 capabilities: self.capabilities.clone(),
457 handler: Arc::new(move |input| Box::pin(handler(input))),
458 };
459
460 AgentTool::new(fn_handler).with_config(self.config)
461 }
462}
463
464struct FnAgentHandler {
466 name: String,
467 description: String,
468 capabilities: Vec<String>,
469 handler: Arc<
470 dyn Fn(
471 AgentToolInput,
472 )
473 -> Pin<Box<dyn Future<Output = Result<AgentToolOutput, AgentToolError>> + Send>>
474 + Send
475 + Sync,
476 >,
477}
478
479#[async_trait]
480impl AgentToolHandler for FnAgentHandler {
481 async fn execute(&self, input: AgentToolInput) -> Result<AgentToolOutput, AgentToolError> {
482 (self.handler)(input).await
483 }
484
485 fn name(&self) -> &str {
486 &self.name
487 }
488
489 fn description(&self) -> &str {
490 &self.description
491 }
492
493 fn capabilities(&self) -> Vec<String> {
494 self.capabilities.clone()
495 }
496}
497
498pub struct AgentToolRegistry {
500 tools: RwLock<HashMap<String, Arc<AgentTool>>>,
501}
502
503impl AgentToolRegistry {
504 pub fn new() -> Self {
505 Self {
506 tools: RwLock::new(HashMap::new()),
507 }
508 }
509
510 pub fn register(&self, tool: AgentTool) {
512 let name = tool.name().to_string();
513 self.tools.write().insert(name.clone(), Arc::new(tool));
514 info!(name = %name, "Agent tool registered");
515 }
516
517 pub fn get(&self, name: &str) -> Option<Arc<AgentTool>> {
519 self.tools.read().get(name).cloned()
520 }
521
522 pub fn list(&self) -> Vec<String> {
524 self.tools.read().keys().cloned().collect()
525 }
526
527 pub fn list_available(&self) -> Vec<String> {
529 self.tools
530 .read()
531 .iter()
532 .filter(|(_, tool)| tool.is_available())
533 .map(|(name, _)| name.clone())
534 .collect()
535 }
536
537 pub fn find_by_capability(&self, capability: &str) -> Vec<Arc<AgentTool>> {
539 self.tools
540 .read()
541 .values()
542 .filter(|tool| tool.capabilities().contains(&capability.to_string()))
543 .cloned()
544 .collect()
545 }
546
547 pub fn remove(&self, name: &str) -> bool {
549 self.tools.write().remove(name).is_some()
550 }
551
552 pub fn all_stats(&self) -> HashMap<String, AgentToolStats> {
554 self.tools
555 .read()
556 .iter()
557 .map(|(name, tool)| (name.clone(), tool.stats()))
558 .collect()
559 }
560}
561
562impl Default for AgentToolRegistry {
563 fn default() -> Self {
564 Self::new()
565 }
566}
567
568pub struct DelegationChain {
570 tools: Vec<Arc<AgentTool>>,
571 strategy: DelegationStrategy,
572}
573
574#[derive(Debug, Clone, Copy)]
575pub enum DelegationStrategy {
576 FirstSuccess,
578 All,
580 Fallback,
582 BestConfidence,
584}
585
586impl DelegationChain {
587 pub fn new(strategy: DelegationStrategy) -> Self {
588 Self {
589 tools: Vec::new(),
590 strategy,
591 }
592 }
593
594 pub fn add(mut self, tool: Arc<AgentTool>) -> Self {
595 self.tools.push(tool);
596 self
597 }
598
599 pub async fn execute(&self, input: AgentToolInput) -> Result<AgentToolOutput, AgentToolError> {
600 match self.strategy {
601 DelegationStrategy::FirstSuccess => self.execute_first_success(input).await,
602 DelegationStrategy::Fallback => self.execute_fallback(input).await,
603 DelegationStrategy::BestConfidence => self.execute_best_confidence(input).await,
604 DelegationStrategy::All => self.execute_all(input).await,
605 }
606 }
607
608 async fn execute_first_success(
609 &self,
610 input: AgentToolInput,
611 ) -> Result<AgentToolOutput, AgentToolError> {
612 for tool in &self.tools {
613 if let Ok(output) = tool.call(input.clone()).await {
614 if output.success {
615 return Ok(output);
616 }
617 }
618 }
619 Err(AgentToolError::ExecutionFailed(
620 "No tool succeeded".to_string(),
621 ))
622 }
623
624 async fn execute_fallback(
625 &self,
626 input: AgentToolInput,
627 ) -> Result<AgentToolOutput, AgentToolError> {
628 let mut last_error = None;
629 for tool in &self.tools {
630 match tool.call(input.clone()).await {
631 Ok(output) => return Ok(output),
632 Err(e) => last_error = Some(e),
633 }
634 }
635 Err(last_error
636 .unwrap_or_else(|| AgentToolError::ExecutionFailed("No tools available".to_string())))
637 }
638
639 async fn execute_best_confidence(
640 &self,
641 input: AgentToolInput,
642 ) -> Result<AgentToolOutput, AgentToolError> {
643 let mut best: Option<AgentToolOutput> = None;
644
645 for tool in &self.tools {
646 if let Ok(output) = tool.call(input.clone()).await {
647 if output.success {
648 match &best {
649 None => best = Some(output),
650 Some(current) if output.confidence > current.confidence => {
651 best = Some(output)
652 }
653 _ => {}
654 }
655 }
656 }
657 }
658
659 best.ok_or_else(|| AgentToolError::ExecutionFailed("No tool succeeded".to_string()))
660 }
661
662 async fn execute_all(&self, input: AgentToolInput) -> Result<AgentToolOutput, AgentToolError> {
663 let mut results = Vec::new();
664 let mut all_success = true;
665 let mut total_confidence = 0.0;
666 let mut all_tools_used = Vec::new();
667
668 for tool in &self.tools {
669 match tool.call(input.clone()).await {
670 Ok(output) => {
671 all_success = all_success && output.success;
672 total_confidence += output.confidence;
673 all_tools_used.extend(output.tools_used.clone());
674 results.push(output.content);
675 }
676 Err(_) => {
677 all_success = false;
678 }
679 }
680 }
681
682 if results.is_empty() {
683 return Err(AgentToolError::ExecutionFailed(
684 "All tools failed".to_string(),
685 ));
686 }
687
688 let avg_confidence = total_confidence / self.tools.len() as f32;
689 let combined_content = results.join("\n---\n");
690
691 Ok(AgentToolOutput {
692 content: combined_content,
693 success: all_success,
694 confidence: avg_confidence,
695 metadata: HashMap::new(),
696 duration_ms: 0,
697 tools_used: all_tools_used,
698 })
699 }
700}
701
702#[cfg(test)]
703mod tests {
704 use super::*;
705
706 #[tokio::test]
707 async fn test_agent_tool_basic() {
708 let tool = AgentTool::builder("test_agent")
709 .description("A test agent")
710 .handler(|input: AgentToolInput| async move {
711 Ok(AgentToolOutput::success(format!(
712 "Processed: {}",
713 input.query
714 )))
715 });
716
717 let result = tool.call("Hello").await.unwrap();
718 assert!(result.success);
719 assert!(result.content.contains("Processed: Hello"));
720 }
721
722 #[tokio::test]
723 async fn test_agent_tool_with_context() {
724 let tool = AgentTool::builder("context_agent")
725 .description("Agent that uses context")
726 .handler(|input: AgentToolInput| async move {
727 let name = input.context.get("name").cloned().unwrap_or_default();
728 Ok(AgentToolOutput::success(format!("Hello, {}!", name)))
729 });
730
731 let input = AgentToolInput::new("greet").with_context("name", "World");
732 let result = tool.call(input).await.unwrap();
733 assert!(result.content.contains("Hello, World!"));
734 }
735
736 #[tokio::test]
737 async fn test_agent_tool_failure() {
738 let tool = AgentTool::builder("failing_agent")
739 .description("Agent that fails")
740 .max_retries(0)
741 .handler(|_: AgentToolInput| async move {
742 Err(AgentToolError::ExecutionFailed(
743 "Intentional failure".to_string(),
744 ))
745 });
746
747 let result = tool.call("test").await;
748 assert!(result.is_err());
749 }
750
751 #[tokio::test]
752 async fn test_agent_tool_stats() {
753 let tool = AgentTool::builder("stats_agent")
754 .description("Agent for testing stats")
755 .handler(|_: AgentToolInput| async move { Ok(AgentToolOutput::success("OK")) });
756
757 tool.call("test1").await.unwrap();
758 tool.call("test2").await.unwrap();
759 tool.call("test3").await.unwrap();
760
761 let stats = tool.stats();
762 assert_eq!(stats.total_calls, 3);
763 assert_eq!(stats.successful_calls, 3);
764 assert_eq!(stats.failed_calls, 0);
765 }
766
767 #[tokio::test]
768 async fn test_agent_tool_registry() {
769 let registry = AgentToolRegistry::new();
770
771 let tool1 = AgentTool::builder("agent1")
772 .description("First agent")
773 .capability("math")
774 .handler(|_: AgentToolInput| async move { Ok(AgentToolOutput::success("1")) });
775
776 let tool2 = AgentTool::builder("agent2")
777 .description("Second agent")
778 .capability("text")
779 .handler(|_: AgentToolInput| async move { Ok(AgentToolOutput::success("2")) });
780
781 registry.register(tool1);
782 registry.register(tool2);
783
784 assert_eq!(registry.list().len(), 2);
785 assert!(registry.get("agent1").is_some());
786 assert!(registry.get("nonexistent").is_none());
787
788 let math_tools = registry.find_by_capability("math");
789 assert_eq!(math_tools.len(), 1);
790 }
791
792 #[tokio::test]
793 async fn test_delegation_chain_fallback() {
794 let failing_tool = Arc::new(
795 AgentTool::builder("failing")
796 .description("Fails")
797 .max_retries(0)
798 .handler(|_: AgentToolInput| async move {
799 Err(AgentToolError::ExecutionFailed("fail".to_string()))
800 }),
801 );
802
803 let success_tool = Arc::new(
804 AgentTool::builder("success")
805 .description("Succeeds")
806 .handler(|_: AgentToolInput| async move { Ok(AgentToolOutput::success("OK")) }),
807 );
808
809 let chain = DelegationChain::new(DelegationStrategy::Fallback)
810 .add(failing_tool)
811 .add(success_tool);
812
813 let result = chain.execute(AgentToolInput::new("test")).await.unwrap();
814 assert!(result.success);
815 assert_eq!(result.content, "OK");
816 }
817
818 #[tokio::test]
819 async fn test_delegation_chain_best_confidence() {
820 let low_conf = Arc::new(
821 AgentTool::builder("low")
822 .description("Low confidence")
823 .handler(|_: AgentToolInput| async move {
824 Ok(AgentToolOutput::success("low").with_confidence(0.3))
825 }),
826 );
827
828 let high_conf = Arc::new(
829 AgentTool::builder("high")
830 .description("High confidence")
831 .handler(|_: AgentToolInput| async move {
832 Ok(AgentToolOutput::success("high").with_confidence(0.9))
833 }),
834 );
835
836 let chain = DelegationChain::new(DelegationStrategy::BestConfidence)
837 .add(low_conf)
838 .add(high_conf);
839
840 let result = chain.execute(AgentToolInput::new("test")).await.unwrap();
841 assert_eq!(result.content, "high");
842 assert_eq!(result.confidence, 0.9);
843 }
844
845 #[tokio::test]
846 async fn test_agent_tool_output_builder() {
847 let output = AgentToolOutput::success("Result")
848 .with_confidence(0.85)
849 .with_metadata("key", "value")
850 .with_tools_used(vec!["tool1".to_string()]);
851
852 assert!(output.success);
853 assert_eq!(output.confidence, 0.85);
854 assert_eq!(output.metadata.get("key").unwrap(), "value");
855 assert_eq!(output.tools_used.len(), 1);
856 }
857
858 #[tokio::test]
859 async fn test_agent_tool_input_builder() {
860 let input = AgentToolInput::new("query")
861 .with_context("key", "value")
862 .with_history(vec!["previous".to_string()])
863 .with_max_tokens(100);
864
865 assert_eq!(input.query, "query");
866 assert_eq!(input.context.get("key").unwrap(), "value");
867 assert_eq!(input.history.len(), 1);
868 assert_eq!(input.max_tokens, Some(100));
869 }
870
871 #[tokio::test]
872 async fn test_agent_tool_capabilities() {
873 let tool = AgentTool::builder("capable")
874 .description("Agent with capabilities")
875 .capability("math")
876 .capability("science")
877 .handler(|_: AgentToolInput| async move { Ok(AgentToolOutput::success("OK")) });
878
879 let caps = tool.capabilities();
880 assert!(caps.contains(&"math".to_string()));
881 assert!(caps.contains(&"science".to_string()));
882 }
883}