1use std::any::{Any, TypeId};
2use std::collections::{HashMap, HashSet};
3use std::future::Future;
4use std::path::PathBuf;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11use crate::error::Result;
12use crate::provider::types::ContentBlock;
13
14pub struct ToolContext {
20 pub working_directory: PathBuf,
21 pub tool_registry: Option<Arc<ToolRegistry>>,
22 extensions: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
23}
24
25impl ToolContext {
26 pub fn new(working_directory: PathBuf) -> Self {
27 Self {
28 working_directory,
29 tool_registry: None,
30 extensions: HashMap::new(),
31 }
32 }
33
34 pub fn with_registry(mut self, registry: Arc<ToolRegistry>) -> Self {
35 self.tool_registry = Some(registry);
36 self
37 }
38
39 pub fn set_extension<T: Any + Send + Sync + 'static>(&mut self, value: T) {
41 self.extensions.insert(TypeId::of::<T>(), Arc::new(value));
42 }
43
44 pub fn get_extension<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
46 self.extensions
47 .get(&TypeId::of::<T>())
48 .and_then(|arc| arc.downcast_ref::<T>())
49 }
50
51}
52
53impl Clone for ToolContext {
54 fn clone(&self) -> Self {
55 Self {
56 working_directory: self.working_directory.clone(),
57 tool_registry: self.tool_registry.clone(),
58 extensions: self.extensions.clone(),
59 }
60 }
61}
62
63impl std::fmt::Debug for ToolContext {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.debug_struct("ToolContext")
66 .field("working_directory", &self.working_directory)
67 .field("tool_registry", &self.tool_registry)
68 .field("extensions_count", &self.extensions.len())
69 .finish()
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolDefinition {
76 pub name: String,
77 pub description: String,
78 pub input_schema: Value,
79}
80
81#[derive(Debug, Clone)]
83pub struct ToolCall {
84 pub id: String,
85 pub name: String,
86 pub input: Value,
87}
88
89#[derive(Debug, Clone)]
91pub struct ToolResult {
92 pub content: String,
93 pub is_error: bool,
94}
95
96impl ToolResult {
97 pub fn success(content: impl Into<String>) -> Self {
98 Self { content: content.into(), is_error: false }
99 }
100
101 pub fn error(content: impl Into<String>) -> Self {
102 Self { content: content.into(), is_error: true }
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct ToolSearchResult {
109 pub definition: ToolDefinition,
110 pub score: u32,
111}
112
113pub trait Tool: Send + Sync {
119 fn name(&self) -> &str;
120 fn description(&self) -> &str;
121 fn input_schema(&self) -> Value;
122
123 fn is_read_only(&self) -> bool {
124 false
125 }
126
127 fn should_defer(&self) -> bool {
128 false
129 }
130
131 fn search_hints(&self) -> Vec<String> {
132 Vec::new()
133 }
134
135 fn call<'a>(
136 &'a self,
137 input: Value,
138 ctx: &'a ToolContext,
139 ) -> Pin<Box<dyn Future<Output = Result<ToolResult>> + Send + 'a>>;
140
141 fn definition(&self) -> ToolDefinition {
142 ToolDefinition {
143 name: self.name().to_string(),
144 description: self.description().to_string(),
145 input_schema: self.input_schema(),
146 }
147 }
148}
149
150pub trait Toolset: Send + Sync {
156 fn tools(&self) -> Vec<Box<dyn Tool>>;
157}
158
159pub struct ToolRegistry {
165 pub(crate) tools: Vec<Arc<dyn Tool>>,
166}
167
168impl std::fmt::Debug for ToolRegistry {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 let names: Vec<&str> = self.tools.iter().map(|t| t.name()).collect();
171 f.debug_struct("ToolRegistry")
172 .field("tools", &names)
173 .finish()
174 }
175}
176
177impl ToolRegistry {
178 pub fn new() -> Self {
179 Self { tools: Vec::new() }
180 }
181
182 pub fn register(&mut self, tool: impl Tool + 'static) {
183 self.tools.push(Arc::new(tool));
184 }
185
186 pub fn get(&self, name: &str) -> Option<&dyn Tool> {
187 self.tools
188 .iter()
189 .find(|t| t.name() == name)
190 .map(|t| t.as_ref() as &dyn Tool)
191 }
192
193 pub fn definitions(&self) -> Vec<ToolDefinition> {
194 self.tools.iter().map(|t| t.definition()).collect()
195 }
196
197 pub fn definitions_filtered(&self, discovered: &HashSet<String>) -> Vec<ToolDefinition> {
200 self.tools
201 .iter()
202 .map(|t| {
203 if t.should_defer() && !discovered.contains(t.name()) {
204 ToolDefinition {
205 name: t.name().to_string(),
206 description: String::new(),
207 input_schema: serde_json::json!({}),
208 }
209 } else {
210 t.definition()
211 }
212 })
213 .collect()
214 }
215
216 pub fn search(&self, query: &str) -> Vec<ToolSearchResult> {
218 let query_lower = query.to_lowercase();
219 let mut results: Vec<ToolSearchResult> = self
220 .tools
221 .iter()
222 .filter_map(|t| {
223 let mut score = 0u32;
224 let name = t.name().to_lowercase();
225 let desc = t.description().to_lowercase();
226
227 if name == query_lower {
229 score += 100;
230 } else if name.contains(&query_lower) {
231 score += 50;
232 }
233
234 if desc.contains(&query_lower) {
236 score += 25;
237 }
238
239 for hint in t.search_hints() {
241 if hint.to_lowercase().contains(&query_lower) {
242 score += 30;
243 }
244 }
245
246 if score > 0 {
247 Some(ToolSearchResult {
248 definition: t.definition(),
249 score,
250 })
251 } else {
252 None
253 }
254 })
255 .collect();
256
257 results.sort_by(|a, b| b.score.cmp(&a.score));
258 results
259 }
260
261 pub fn has_deferred_tools(&self) -> bool {
262 self.tools.iter().any(|t| t.should_defer())
263 }
264
265 pub fn is_empty(&self) -> bool {
266 self.tools.is_empty()
267 }
268}
269
270impl Clone for ToolRegistry {
271 fn clone(&self) -> Self {
272 Self {
273 tools: self.tools.clone(),
274 }
275 }
276}
277
278type ToolHandler = Box<
283 dyn Fn(Value, &ToolContext) -> Pin<Box<dyn Future<Output = Result<ToolResult>> + Send + '_>>
284 + Send
285 + Sync,
286>;
287
288struct BuiltTool {
289 name: String,
290 description: String,
291 schema: Value,
292 read_only: bool,
293 defer: bool,
294 hints: Vec<String>,
295 handler: ToolHandler,
296}
297
298impl Tool for BuiltTool {
299 fn name(&self) -> &str {
300 &self.name
301 }
302 fn description(&self) -> &str {
303 &self.description
304 }
305 fn input_schema(&self) -> Value {
306 self.schema.clone()
307 }
308 fn is_read_only(&self) -> bool {
309 self.read_only
310 }
311 fn should_defer(&self) -> bool {
312 self.defer
313 }
314 fn search_hints(&self) -> Vec<String> {
315 self.hints.clone()
316 }
317 fn call<'a>(
318 &'a self,
319 input: Value,
320 ctx: &'a ToolContext,
321 ) -> Pin<Box<dyn Future<Output = Result<ToolResult>> + Send + 'a>> {
322 (self.handler)(input, ctx)
323 }
324}
325
326pub struct ToolBuilder {
327 name: String,
328 description: String,
329 schema: Value,
330 read_only: bool,
331 defer: bool,
332 hints: Vec<String>,
333 handler: Option<ToolHandler>,
334}
335
336impl ToolBuilder {
337 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
338 Self {
339 name: name.into(),
340 description: description.into(),
341 schema: serde_json::json!({"type": "object", "properties": {}}),
342 read_only: false,
343 defer: false,
344 hints: Vec::new(),
345 handler: None,
346 }
347 }
348
349 pub fn schema(mut self, schema: Value) -> Self {
350 self.schema = schema;
351 self
352 }
353
354 pub fn read_only(mut self, read_only: bool) -> Self {
355 self.read_only = read_only;
356 self
357 }
358
359 pub fn should_defer(mut self, defer: bool) -> Self {
360 self.defer = defer;
361 self
362 }
363
364 pub fn search_hints(mut self, hints: Vec<String>) -> Self {
365 self.hints = hints;
366 self
367 }
368
369 pub fn handler<F>(mut self, f: F) -> Self
370 where
371 F: Fn(Value, &ToolContext) -> Pin<Box<dyn Future<Output = Result<ToolResult>> + Send + '_>>
372 + Send
373 + Sync
374 + 'static,
375 {
376 self.handler = Some(Box::new(f));
377 self
378 }
379
380 pub fn build(self) -> impl Tool {
381 let handler = self
382 .handler
383 .expect("ToolBuilder requires a handler before build()");
384 BuiltTool {
385 name: self.name,
386 description: self.description,
387 schema: self.schema,
388 read_only: self.read_only,
389 defer: self.defer,
390 hints: self.hints,
391 handler,
392 }
393 }
394}
395
396enum ToolBatch {
401 Concurrent(Vec<ToolCall>),
402 Serial(ToolCall),
403}
404
405fn partition_tool_calls(calls: &[ToolCall], registry: &ToolRegistry) -> Vec<ToolBatch> {
406 let mut batches: Vec<ToolBatch> = Vec::new();
407 let mut concurrent_batch: Vec<ToolCall> = Vec::new();
408
409 for call in calls {
410 let is_read_only = registry
411 .get(&call.name)
412 .map_or(false, |t| t.is_read_only());
413
414 if is_read_only {
415 concurrent_batch.push(call.clone());
416 } else {
417 if !concurrent_batch.is_empty() {
418 batches.push(ToolBatch::Concurrent(std::mem::take(&mut concurrent_batch)));
419 }
420 batches.push(ToolBatch::Serial(call.clone()));
421 }
422 }
423
424 if !concurrent_batch.is_empty() {
425 batches.push(ToolBatch::Concurrent(concurrent_batch));
426 }
427
428 batches
429}
430
431pub async fn execute_tool_calls(
433 calls: &[ToolCall],
434 registry: &ToolRegistry,
435 ctx: &ToolContext,
436) -> Vec<ContentBlock> {
437 let batches = partition_tool_calls(calls, registry);
438 let mut results: Vec<ContentBlock> = Vec::new();
439 let semaphore = Arc::new(tokio::sync::Semaphore::new(10));
440
441 for batch in batches {
442 match batch {
443 ToolBatch::Concurrent(calls) => {
444 let mut set = tokio::task::JoinSet::new();
445 for call in calls {
446 let sem = semaphore.clone();
447 let ctx = ctx.clone();
448 let tool_arc = registry
449 .tools
450 .iter()
451 .find(|t| t.name() == call.name)
452 .cloned();
453 let call_id = call.id.clone();
454 let call_name = call.name.clone();
455 let input = call.input.clone();
456
457 set.spawn(async move {
458 let _permit = sem.acquire().await.unwrap();
459 let result = match tool_arc {
460 Some(t) => match t.call(input, &ctx).await {
461 Ok(r) => r,
462 Err(e) => ToolResult::error(format!("Tool error: {e}")),
463 },
464 None => ToolResult::error(format!("Unknown tool: {call_name}")),
465 };
466 (call_id, result)
467 });
468 }
469
470 while let Some(join_result) = set.join_next().await {
471 if let Ok((id, result)) = join_result {
472 results.push(ContentBlock::ToolResult {
473 tool_use_id: id,
474 content: result.content,
475 is_error: result.is_error,
476 });
477 }
478 }
479 }
480 ToolBatch::Serial(call) => {
481 let result = match registry.get(&call.name) {
482 Some(tool) => match tool.call(call.input.clone(), ctx).await {
483 Ok(r) => r,
484 Err(e) => ToolResult::error(format!("Tool error: {e}")),
485 },
486 None => ToolResult::error(format!("Unknown tool: {}", call.name)),
487 };
488 results.push(ContentBlock::ToolResult {
489 tool_use_id: call.id.clone(),
490 content: result.content,
491 is_error: result.is_error,
492 });
493 }
494 }
495 }
496
497 results
498}
499
500#[cfg(test)]
505mod tests {
506 use super::*;
507 use crate::testutil::*;
508
509 #[test]
510 fn registry_register_and_get() {
511 let mut registry = ToolRegistry::new();
512 let tool = MockTool::new("read_file", true, "file contents");
513 registry.register(tool);
514
515 assert!(registry.get("read_file").is_some());
516 assert!(registry.get("nonexistent").is_none());
517 }
518
519 #[test]
520 fn registry_definitions() {
521 let mut registry = ToolRegistry::new();
522 registry.register(MockTool::new("read", true, "ok"));
523 registry.register(MockTool::new("write", false, "ok"));
524
525 let defs = registry.definitions();
526 assert_eq!(defs.len(), 2);
527 assert_eq!(defs[0].name, "read");
528 assert_eq!(defs[1].name, "write");
529 }
530
531 #[test]
532 fn registry_is_empty() {
533 let registry = ToolRegistry::new();
534 assert!(registry.is_empty());
535
536 let mut registry = ToolRegistry::new();
537 registry.register(MockTool::new("t", true, "ok"));
538 assert!(!registry.is_empty());
539 }
540
541 #[test]
542 fn registry_definitions_filtered_deferred() {
543 let mut registry = ToolRegistry::new();
544 registry.register(MockTool::new("always_visible", true, "ok"));
545 registry.register(DeferredMockTool::new("deferred_tool"));
546
547 let discovered = HashSet::new();
549 let defs = registry.definitions_filtered(&discovered);
550 assert_eq!(defs.len(), 2);
551 let deferred = defs.iter().find(|d| d.name == "deferred_tool").unwrap();
552 assert!(deferred.description.is_empty());
553 assert_eq!(deferred.input_schema, serde_json::json!({}));
554
555 let mut discovered = HashSet::new();
557 discovered.insert("deferred_tool".to_string());
558 let defs = registry.definitions_filtered(&discovered);
559 let deferred = defs.iter().find(|d| d.name == "deferred_tool").unwrap();
560 assert!(!deferred.description.is_empty());
561 }
562
563 #[test]
564 fn registry_has_deferred_tools() {
565 let mut registry = ToolRegistry::new();
566 registry.register(MockTool::new("t", true, "ok"));
567 assert!(!registry.has_deferred_tools());
568
569 registry.register(DeferredMockTool::new("d"));
570 assert!(registry.has_deferred_tools());
571 }
572
573 #[test]
574 fn registry_search_by_name() {
575 let mut registry = ToolRegistry::new();
576 registry.register(MockTool::new("read_file", true, "ok"));
577 registry.register(MockTool::new("write_file", false, "ok"));
578
579 let results = registry.search("read");
580 assert_eq!(results.len(), 1);
581 assert_eq!(results[0].definition.name, "read_file");
582 }
583
584 #[test]
585 fn registry_clone() {
586 let mut registry = ToolRegistry::new();
587 registry.register(MockTool::new("t", true, "ok"));
588 let cloned = registry.clone();
589 assert_eq!(cloned.definitions().len(), 1);
590 }
591
592 #[tokio::test]
593 async fn execute_unknown_tool_returns_error() {
594 let registry = ToolRegistry::new();
595 let ctx = test_tool_context();
596 let calls = vec![ToolCall {
597 id: "c1".into(),
598 name: "nonexistent".into(),
599 input: serde_json::json!({}),
600 }];
601
602 let results = execute_tool_calls(&calls, ®istry, &ctx).await;
603 assert_eq!(results.len(), 1);
604 match &results[0] {
605 ContentBlock::ToolResult {
606 is_error, content, ..
607 } => {
608 assert!(is_error);
609 assert!(content.contains("Unknown tool"));
610 }
611 other => panic!("Expected ToolResult, got {other:?}"),
612 }
613 }
614
615 #[tokio::test]
616 async fn execute_read_only_tools_concurrently() {
617 let mut registry = ToolRegistry::new();
618 registry.register(MockTool::new("read1", true, "result1"));
619 registry.register(MockTool::new("read2", true, "result2"));
620 let ctx = test_tool_context();
621
622 let calls = vec![
623 ToolCall {
624 id: "c1".into(),
625 name: "read1".into(),
626 input: serde_json::json!({}),
627 },
628 ToolCall {
629 id: "c2".into(),
630 name: "read2".into(),
631 input: serde_json::json!({}),
632 },
633 ];
634
635 let results = execute_tool_calls(&calls, ®istry, &ctx).await;
636 assert_eq!(results.len(), 2);
637 }
638
639 #[tokio::test]
640 async fn execute_serial_tool() {
641 let mut registry = ToolRegistry::new();
642 let tool = MockTool::new("write_file", false, "written");
643 registry.register(tool);
644 let ctx = test_tool_context();
645
646 let calls = vec![ToolCall {
647 id: "c1".into(),
648 name: "write_file".into(),
649 input: serde_json::json!({"path": "/tmp/test"}),
650 }];
651
652 let results = execute_tool_calls(&calls, ®istry, &ctx).await;
653 assert_eq!(results.len(), 1);
654 match &results[0] {
655 ContentBlock::ToolResult {
656 content, is_error, ..
657 } => {
658 assert!(!is_error);
659 assert_eq!(content, "written");
660 }
661 other => panic!("Expected ToolResult, got {other:?}"),
662 }
663 }
664
665 #[test]
666 fn tool_builder_basic() {
667 let tool = ToolBuilder::new("echo", "Echoes input")
668 .schema(serde_json::json!({"type": "object", "properties": {"text": {"type": "string"}}}))
669 .read_only(true)
670 .handler(|input, _ctx| {
671 Box::pin(async move {
672 let text = input["text"].as_str().unwrap_or("").to_string();
673 Ok(ToolResult::success(text))
674 })
675 })
676 .build();
677
678 assert_eq!(tool.name(), "echo");
679 assert!(tool.is_read_only());
680 }
681
682 #[test]
683 fn tool_builder_defer_and_hints() {
684 let tool = ToolBuilder::new("advanced", "Advanced tool")
685 .should_defer(true)
686 .search_hints(vec!["analyze".into(), "inspect".into()])
687 .handler(|_input, _ctx| {
688 Box::pin(async { Ok(ToolResult::success("ok")) })
689 })
690 .build();
691
692 assert!(tool.should_defer());
693 assert_eq!(tool.search_hints().len(), 2);
694 }
695
696 #[test]
697 #[should_panic(expected = "requires a handler")]
698 fn tool_builder_panics_without_handler() {
699 let _ = ToolBuilder::new("no_handler", "missing").build();
700 }
701
702 #[test]
703 fn tool_context_extensions_set_get() {
704 let mut ctx = test_tool_context();
705 ctx.set_extension(42u32);
706 ctx.set_extension("hello".to_string());
707
708 assert_eq!(ctx.get_extension::<u32>(), Some(&42));
709 assert_eq!(ctx.get_extension::<String>(), Some(&"hello".to_string()));
710 }
711
712 #[test]
713 fn tool_context_extensions_missing() {
714 let ctx = test_tool_context();
715 assert!(ctx.get_extension::<u32>().is_none());
716 }
717}