Skip to main content

agentcore/tools/
tool.rs

1use std::any::{Any, TypeId};
2use std::collections::{HashMap, HashSet};
3use std::future::Future;
4use std::path::PathBuf;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11use crate::error::Result;
12use crate::provider::types::ContentBlock;
13
14// ---------------------------------------------------------------------------
15// Core types
16// ---------------------------------------------------------------------------
17
18/// Context passed to tool execution.
19pub struct ToolContext {
20    pub working_directory: PathBuf,
21    pub tool_registry: Option<Arc<ToolRegistry>>,
22    extensions: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
23}
24
25impl ToolContext {
26    pub fn new(working_directory: PathBuf) -> Self {
27        Self {
28            working_directory,
29            tool_registry: None,
30            extensions: HashMap::new(),
31        }
32    }
33
34    pub fn with_registry(mut self, registry: Arc<ToolRegistry>) -> Self {
35        self.tool_registry = Some(registry);
36        self
37    }
38
39    /// Store a typed extension value accessible by tools.
40    pub fn set_extension<T: Any + Send + Sync + 'static>(&mut self, value: T) {
41        self.extensions.insert(TypeId::of::<T>(), Arc::new(value));
42    }
43
44    /// Retrieve a typed extension value.
45    pub fn get_extension<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
46        self.extensions
47            .get(&TypeId::of::<T>())
48            .and_then(|arc| arc.downcast_ref::<T>())
49    }
50
51}
52
53impl Clone for ToolContext {
54    fn clone(&self) -> Self {
55        Self {
56            working_directory: self.working_directory.clone(),
57            tool_registry: self.tool_registry.clone(),
58            extensions: self.extensions.clone(),
59        }
60    }
61}
62
63impl std::fmt::Debug for ToolContext {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("ToolContext")
66            .field("working_directory", &self.working_directory)
67            .field("tool_registry", &self.tool_registry)
68            .field("extensions_count", &self.extensions.len())
69            .finish()
70    }
71}
72
73/// Definition sent to the LLM as part of the tools parameter.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolDefinition {
76    pub name: String,
77    pub description: String,
78    pub input_schema: Value,
79}
80
81/// A tool call extracted from an LLM response.
82#[derive(Debug, Clone)]
83pub struct ToolCall {
84    pub id: String,
85    pub name: String,
86    pub input: Value,
87}
88
89/// Result returned by a tool execution.
90#[derive(Debug, Clone)]
91pub struct ToolResult {
92    pub content: String,
93    pub is_error: bool,
94}
95
96impl ToolResult {
97    pub fn success(content: impl Into<String>) -> Self {
98        Self { content: content.into(), is_error: false }
99    }
100
101    pub fn error(content: impl Into<String>) -> Self {
102        Self { content: content.into(), is_error: true }
103    }
104}
105
106/// A search result from ToolRegistry::search().
107#[derive(Debug, Clone)]
108pub struct ToolSearchResult {
109    pub definition: ToolDefinition,
110    pub score: u32,
111}
112
113// ---------------------------------------------------------------------------
114// Tool trait
115// ---------------------------------------------------------------------------
116
117/// The core tool interface. Object-safe via boxed futures.
118pub trait Tool: Send + Sync {
119    fn name(&self) -> &str;
120    fn description(&self) -> &str;
121    fn input_schema(&self) -> Value;
122
123    fn is_read_only(&self) -> bool {
124        false
125    }
126
127    fn should_defer(&self) -> bool {
128        false
129    }
130
131    fn search_hints(&self) -> Vec<String> {
132        Vec::new()
133    }
134
135    fn call<'a>(
136        &'a self,
137        input: Value,
138        ctx: &'a ToolContext,
139    ) -> Pin<Box<dyn Future<Output = Result<ToolResult>> + Send + 'a>>;
140
141    fn definition(&self) -> ToolDefinition {
142        ToolDefinition {
143            name: self.name().to_string(),
144            description: self.description().to_string(),
145            input_schema: self.input_schema(),
146        }
147    }
148}
149
150// ---------------------------------------------------------------------------
151// Toolset trait
152// ---------------------------------------------------------------------------
153
154/// A collection of related tools.
155pub trait Toolset: Send + Sync {
156    fn tools(&self) -> Vec<Box<dyn Tool>>;
157}
158
159// ---------------------------------------------------------------------------
160// ToolRegistry
161// ---------------------------------------------------------------------------
162
163/// Registry of tools available to an agent.
164pub struct ToolRegistry {
165    pub(crate) tools: Vec<Arc<dyn Tool>>,
166}
167
168impl std::fmt::Debug for ToolRegistry {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        let names: Vec<&str> = self.tools.iter().map(|t| t.name()).collect();
171        f.debug_struct("ToolRegistry")
172            .field("tools", &names)
173            .finish()
174    }
175}
176
177impl ToolRegistry {
178    pub fn new() -> Self {
179        Self { tools: Vec::new() }
180    }
181
182    pub fn register(&mut self, tool: impl Tool + 'static) {
183        self.tools.push(Arc::new(tool));
184    }
185
186    pub fn get(&self, name: &str) -> Option<&dyn Tool> {
187        self.tools
188            .iter()
189            .find(|t| t.name() == name)
190            .map(|t| t.as_ref() as &dyn Tool)
191    }
192
193    pub fn definitions(&self) -> Vec<ToolDefinition> {
194        self.tools.iter().map(|t| t.definition()).collect()
195    }
196
197    /// Full definitions for non-deferred + discovered deferred tools.
198    /// Undiscovered deferred tools get name-only definitions.
199    pub fn definitions_filtered(&self, discovered: &HashSet<String>) -> Vec<ToolDefinition> {
200        self.tools
201            .iter()
202            .map(|t| {
203                if t.should_defer() && !discovered.contains(t.name()) {
204                    ToolDefinition {
205                        name: t.name().to_string(),
206                        description: String::new(),
207                        input_schema: serde_json::json!({}),
208                    }
209                } else {
210                    t.definition()
211                }
212            })
213            .collect()
214    }
215
216    /// Search tools by query string. Returns matches sorted by score (highest first).
217    pub fn search(&self, query: &str) -> Vec<ToolSearchResult> {
218        let query_lower = query.to_lowercase();
219        let mut results: Vec<ToolSearchResult> = self
220            .tools
221            .iter()
222            .filter_map(|t| {
223                let mut score = 0u32;
224                let name = t.name().to_lowercase();
225                let desc = t.description().to_lowercase();
226
227                // Exact name match
228                if name == query_lower {
229                    score += 100;
230                } else if name.contains(&query_lower) {
231                    score += 50;
232                }
233
234                // Description match
235                if desc.contains(&query_lower) {
236                    score += 25;
237                }
238
239                // Search hints match
240                for hint in t.search_hints() {
241                    if hint.to_lowercase().contains(&query_lower) {
242                        score += 30;
243                    }
244                }
245
246                if score > 0 {
247                    Some(ToolSearchResult {
248                        definition: t.definition(),
249                        score,
250                    })
251                } else {
252                    None
253                }
254            })
255            .collect();
256
257        results.sort_by(|a, b| b.score.cmp(&a.score));
258        results
259    }
260
261    pub fn has_deferred_tools(&self) -> bool {
262        self.tools.iter().any(|t| t.should_defer())
263    }
264
265    pub fn is_empty(&self) -> bool {
266        self.tools.is_empty()
267    }
268}
269
270impl Clone for ToolRegistry {
271    fn clone(&self) -> Self {
272        Self {
273            tools: self.tools.clone(),
274        }
275    }
276}
277
278// ---------------------------------------------------------------------------
279// ToolBuilder
280// ---------------------------------------------------------------------------
281
282type ToolHandler = Box<
283    dyn Fn(Value, &ToolContext) -> Pin<Box<dyn Future<Output = Result<ToolResult>> + Send + '_>>
284        + Send
285        + Sync,
286>;
287
288struct BuiltTool {
289    name: String,
290    description: String,
291    schema: Value,
292    read_only: bool,
293    defer: bool,
294    hints: Vec<String>,
295    handler: ToolHandler,
296}
297
298impl Tool for BuiltTool {
299    fn name(&self) -> &str {
300        &self.name
301    }
302    fn description(&self) -> &str {
303        &self.description
304    }
305    fn input_schema(&self) -> Value {
306        self.schema.clone()
307    }
308    fn is_read_only(&self) -> bool {
309        self.read_only
310    }
311    fn should_defer(&self) -> bool {
312        self.defer
313    }
314    fn search_hints(&self) -> Vec<String> {
315        self.hints.clone()
316    }
317    fn call<'a>(
318        &'a self,
319        input: Value,
320        ctx: &'a ToolContext,
321    ) -> Pin<Box<dyn Future<Output = Result<ToolResult>> + Send + 'a>> {
322        (self.handler)(input, ctx)
323    }
324}
325
326pub struct ToolBuilder {
327    name: String,
328    description: String,
329    schema: Value,
330    read_only: bool,
331    defer: bool,
332    hints: Vec<String>,
333    handler: Option<ToolHandler>,
334}
335
336impl ToolBuilder {
337    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
338        Self {
339            name: name.into(),
340            description: description.into(),
341            schema: serde_json::json!({"type": "object", "properties": {}}),
342            read_only: false,
343            defer: false,
344            hints: Vec::new(),
345            handler: None,
346        }
347    }
348
349    pub fn schema(mut self, schema: Value) -> Self {
350        self.schema = schema;
351        self
352    }
353
354    pub fn read_only(mut self, read_only: bool) -> Self {
355        self.read_only = read_only;
356        self
357    }
358
359    pub fn should_defer(mut self, defer: bool) -> Self {
360        self.defer = defer;
361        self
362    }
363
364    pub fn search_hints(mut self, hints: Vec<String>) -> Self {
365        self.hints = hints;
366        self
367    }
368
369    pub fn handler<F>(mut self, f: F) -> Self
370    where
371        F: Fn(Value, &ToolContext) -> Pin<Box<dyn Future<Output = Result<ToolResult>> + Send + '_>>
372            + Send
373            + Sync
374            + 'static,
375    {
376        self.handler = Some(Box::new(f));
377        self
378    }
379
380    pub fn build(self) -> impl Tool {
381        let handler = self
382            .handler
383            .expect("ToolBuilder requires a handler before build()");
384        BuiltTool {
385            name: self.name,
386            description: self.description,
387            schema: self.schema,
388            read_only: self.read_only,
389            defer: self.defer,
390            hints: self.hints,
391            handler,
392        }
393    }
394}
395
396// ---------------------------------------------------------------------------
397// execute_tool_calls
398// ---------------------------------------------------------------------------
399
400enum ToolBatch {
401    Concurrent(Vec<ToolCall>),
402    Serial(ToolCall),
403}
404
405fn partition_tool_calls(calls: &[ToolCall], registry: &ToolRegistry) -> Vec<ToolBatch> {
406    let mut batches: Vec<ToolBatch> = Vec::new();
407    let mut concurrent_batch: Vec<ToolCall> = Vec::new();
408
409    for call in calls {
410        let is_read_only = registry
411            .get(&call.name)
412            .map_or(false, |t| t.is_read_only());
413
414        if is_read_only {
415            concurrent_batch.push(call.clone());
416        } else {
417            if !concurrent_batch.is_empty() {
418                batches.push(ToolBatch::Concurrent(std::mem::take(&mut concurrent_batch)));
419            }
420            batches.push(ToolBatch::Serial(call.clone()));
421        }
422    }
423
424    if !concurrent_batch.is_empty() {
425        batches.push(ToolBatch::Concurrent(concurrent_batch));
426    }
427
428    batches
429}
430
431/// Execute tool calls with concurrent read-only batching and serial write execution.
432pub async fn execute_tool_calls(
433    calls: &[ToolCall],
434    registry: &ToolRegistry,
435    ctx: &ToolContext,
436) -> Vec<ContentBlock> {
437    let batches = partition_tool_calls(calls, registry);
438    let mut results: Vec<ContentBlock> = Vec::new();
439    let semaphore = Arc::new(tokio::sync::Semaphore::new(10));
440
441    for batch in batches {
442        match batch {
443            ToolBatch::Concurrent(calls) => {
444                let mut set = tokio::task::JoinSet::new();
445                for call in calls {
446                    let sem = semaphore.clone();
447                    let ctx = ctx.clone();
448                    let tool_arc = registry
449                        .tools
450                        .iter()
451                        .find(|t| t.name() == call.name)
452                        .cloned();
453                    let call_id = call.id.clone();
454                    let call_name = call.name.clone();
455                    let input = call.input.clone();
456
457                    set.spawn(async move {
458                        let _permit = sem.acquire().await.unwrap();
459                        let result = match tool_arc {
460                            Some(t) => match t.call(input, &ctx).await {
461                                Ok(r) => r,
462                                Err(e) => ToolResult::error(format!("Tool error: {e}")),
463                            },
464                            None => ToolResult::error(format!("Unknown tool: {call_name}")),
465                        };
466                        (call_id, result)
467                    });
468                }
469
470                while let Some(join_result) = set.join_next().await {
471                    if let Ok((id, result)) = join_result {
472                        results.push(ContentBlock::ToolResult {
473                            tool_use_id: id,
474                            content: result.content,
475                            is_error: result.is_error,
476                        });
477                    }
478                }
479            }
480            ToolBatch::Serial(call) => {
481                let result = match registry.get(&call.name) {
482                    Some(tool) => match tool.call(call.input.clone(), ctx).await {
483                        Ok(r) => r,
484                        Err(e) => ToolResult::error(format!("Tool error: {e}")),
485                    },
486                    None => ToolResult::error(format!("Unknown tool: {}", call.name)),
487                };
488                results.push(ContentBlock::ToolResult {
489                    tool_use_id: call.id.clone(),
490                    content: result.content,
491                    is_error: result.is_error,
492                });
493            }
494        }
495    }
496
497    results
498}
499
500// ---------------------------------------------------------------------------
501// Tests
502// ---------------------------------------------------------------------------
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507    use crate::testutil::*;
508
509    #[test]
510    fn registry_register_and_get() {
511        let mut registry = ToolRegistry::new();
512        let tool = MockTool::new("read_file", true, "file contents");
513        registry.register(tool);
514
515        assert!(registry.get("read_file").is_some());
516        assert!(registry.get("nonexistent").is_none());
517    }
518
519    #[test]
520    fn registry_definitions() {
521        let mut registry = ToolRegistry::new();
522        registry.register(MockTool::new("read", true, "ok"));
523        registry.register(MockTool::new("write", false, "ok"));
524
525        let defs = registry.definitions();
526        assert_eq!(defs.len(), 2);
527        assert_eq!(defs[0].name, "read");
528        assert_eq!(defs[1].name, "write");
529    }
530
531    #[test]
532    fn registry_is_empty() {
533        let registry = ToolRegistry::new();
534        assert!(registry.is_empty());
535
536        let mut registry = ToolRegistry::new();
537        registry.register(MockTool::new("t", true, "ok"));
538        assert!(!registry.is_empty());
539    }
540
541    #[test]
542    fn registry_definitions_filtered_deferred() {
543        let mut registry = ToolRegistry::new();
544        registry.register(MockTool::new("always_visible", true, "ok"));
545        registry.register(DeferredMockTool::new("deferred_tool"));
546
547        // Without discovery: deferred tool has empty definition
548        let discovered = HashSet::new();
549        let defs = registry.definitions_filtered(&discovered);
550        assert_eq!(defs.len(), 2);
551        let deferred = defs.iter().find(|d| d.name == "deferred_tool").unwrap();
552        assert!(deferred.description.is_empty());
553        assert_eq!(deferred.input_schema, serde_json::json!({}));
554
555        // With discovery: deferred tool has full definition
556        let mut discovered = HashSet::new();
557        discovered.insert("deferred_tool".to_string());
558        let defs = registry.definitions_filtered(&discovered);
559        let deferred = defs.iter().find(|d| d.name == "deferred_tool").unwrap();
560        assert!(!deferred.description.is_empty());
561    }
562
563    #[test]
564    fn registry_has_deferred_tools() {
565        let mut registry = ToolRegistry::new();
566        registry.register(MockTool::new("t", true, "ok"));
567        assert!(!registry.has_deferred_tools());
568
569        registry.register(DeferredMockTool::new("d"));
570        assert!(registry.has_deferred_tools());
571    }
572
573    #[test]
574    fn registry_search_by_name() {
575        let mut registry = ToolRegistry::new();
576        registry.register(MockTool::new("read_file", true, "ok"));
577        registry.register(MockTool::new("write_file", false, "ok"));
578
579        let results = registry.search("read");
580        assert_eq!(results.len(), 1);
581        assert_eq!(results[0].definition.name, "read_file");
582    }
583
584    #[test]
585    fn registry_clone() {
586        let mut registry = ToolRegistry::new();
587        registry.register(MockTool::new("t", true, "ok"));
588        let cloned = registry.clone();
589        assert_eq!(cloned.definitions().len(), 1);
590    }
591
592    #[tokio::test]
593    async fn execute_unknown_tool_returns_error() {
594        let registry = ToolRegistry::new();
595        let ctx = test_tool_context();
596        let calls = vec![ToolCall {
597            id: "c1".into(),
598            name: "nonexistent".into(),
599            input: serde_json::json!({}),
600        }];
601
602        let results = execute_tool_calls(&calls, &registry, &ctx).await;
603        assert_eq!(results.len(), 1);
604        match &results[0] {
605            ContentBlock::ToolResult {
606                is_error, content, ..
607            } => {
608                assert!(is_error);
609                assert!(content.contains("Unknown tool"));
610            }
611            other => panic!("Expected ToolResult, got {other:?}"),
612        }
613    }
614
615    #[tokio::test]
616    async fn execute_read_only_tools_concurrently() {
617        let mut registry = ToolRegistry::new();
618        registry.register(MockTool::new("read1", true, "result1"));
619        registry.register(MockTool::new("read2", true, "result2"));
620        let ctx = test_tool_context();
621
622        let calls = vec![
623            ToolCall {
624                id: "c1".into(),
625                name: "read1".into(),
626                input: serde_json::json!({}),
627            },
628            ToolCall {
629                id: "c2".into(),
630                name: "read2".into(),
631                input: serde_json::json!({}),
632            },
633        ];
634
635        let results = execute_tool_calls(&calls, &registry, &ctx).await;
636        assert_eq!(results.len(), 2);
637    }
638
639    #[tokio::test]
640    async fn execute_serial_tool() {
641        let mut registry = ToolRegistry::new();
642        let tool = MockTool::new("write_file", false, "written");
643        registry.register(tool);
644        let ctx = test_tool_context();
645
646        let calls = vec![ToolCall {
647            id: "c1".into(),
648            name: "write_file".into(),
649            input: serde_json::json!({"path": "/tmp/test"}),
650        }];
651
652        let results = execute_tool_calls(&calls, &registry, &ctx).await;
653        assert_eq!(results.len(), 1);
654        match &results[0] {
655            ContentBlock::ToolResult {
656                content, is_error, ..
657            } => {
658                assert!(!is_error);
659                assert_eq!(content, "written");
660            }
661            other => panic!("Expected ToolResult, got {other:?}"),
662        }
663    }
664
665    #[test]
666    fn tool_builder_basic() {
667        let tool = ToolBuilder::new("echo", "Echoes input")
668            .schema(serde_json::json!({"type": "object", "properties": {"text": {"type": "string"}}}))
669            .read_only(true)
670            .handler(|input, _ctx| {
671                Box::pin(async move {
672                    let text = input["text"].as_str().unwrap_or("").to_string();
673                    Ok(ToolResult::success(text))
674                })
675            })
676            .build();
677
678        assert_eq!(tool.name(), "echo");
679        assert!(tool.is_read_only());
680    }
681
682    #[test]
683    fn tool_builder_defer_and_hints() {
684        let tool = ToolBuilder::new("advanced", "Advanced tool")
685            .should_defer(true)
686            .search_hints(vec!["analyze".into(), "inspect".into()])
687            .handler(|_input, _ctx| {
688                Box::pin(async { Ok(ToolResult::success("ok")) })
689            })
690            .build();
691
692        assert!(tool.should_defer());
693        assert_eq!(tool.search_hints().len(), 2);
694    }
695
696    #[test]
697    #[should_panic(expected = "requires a handler")]
698    fn tool_builder_panics_without_handler() {
699        let _ = ToolBuilder::new("no_handler", "missing").build();
700    }
701
702    #[test]
703    fn tool_context_extensions_set_get() {
704        let mut ctx = test_tool_context();
705        ctx.set_extension(42u32);
706        ctx.set_extension("hello".to_string());
707
708        assert_eq!(ctx.get_extension::<u32>(), Some(&42));
709        assert_eq!(ctx.get_extension::<String>(), Some(&"hello".to_string()));
710    }
711
712    #[test]
713    fn tool_context_extensions_missing() {
714        let ctx = test_tool_context();
715        assert!(ctx.get_extension::<u32>().is_none());
716    }
717}