model_context_protocol/tool.rs
1//! Tool traits and types for MCP servers.
2//!
3//! This module provides the core abstractions for defining and executing MCP tools.
4
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use crate::protocol::{McpToolDefinition, ToolContent};
13
14/// A boxed future for async tool execution.
15pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
16
17/// Result type for tool execution - returns content or an error message.
18pub type ToolCallResult = Result<Vec<ToolContent>, String>;
19
20// =============================================================================
21// Tool Auto-Discovery via Inventory
22// =============================================================================
23
24/// A factory function that creates a tool instance.
25pub type ToolFactory = fn() -> DynTool;
26
27/// Entry for auto-discovered tools registered via `#[mcp_tool]`.
28///
29/// This struct is used internally by the `inventory` crate to collect
30/// all tools defined with the `#[mcp_tool]` attribute at link time.
31pub struct ToolEntry {
32 /// Factory function to create the tool.
33 pub factory: ToolFactory,
34 /// The group this tool belongs to (if any).
35 pub group: Option<&'static str>,
36}
37
38impl ToolEntry {
39 /// Creates a new tool entry.
40 pub const fn new(factory: ToolFactory, group: Option<&'static str>) -> Self {
41 Self { factory, group }
42 }
43}
44
45// Register ToolEntry with inventory for compile-time collection
46inventory::collect!(ToolEntry);
47
48/// Returns all auto-discovered tools.
49///
50/// This collects all tools registered with `#[mcp_tool]` across the crate.
51pub fn all_tools() -> Vec<DynTool> {
52 inventory::iter::<ToolEntry>()
53 .map(|entry| (entry.factory)())
54 .collect()
55}
56
57/// Returns auto-discovered tools filtered by group.
58///
59/// Only returns tools that have the specified group.
60pub fn tools_in_group(group: &str) -> Vec<DynTool> {
61 inventory::iter::<ToolEntry>()
62 .filter(|entry| entry.group == Some(group))
63 .map(|entry| (entry.factory)())
64 .collect()
65}
66
67/// Trait for implementing MCP tools.
68///
69/// # Example
70///
71/// ```ignore
72/// use mcp::{McpTool, ToolCallResult, McpToolDefinition};
73/// use serde_json::Value;
74///
75/// struct CalculatorTool;
76///
77/// impl McpTool for CalculatorTool {
78/// fn definition(&self) -> McpToolDefinition {
79/// McpToolDefinition {
80/// name: "add".to_string(),
81/// description: Some("Add two numbers".to_string()),
82/// group: None,
83/// input_schema: serde_json::json!({
84/// "type": "object",
85/// "properties": {
86/// "a": { "type": "number" },
87/// "b": { "type": "number" }
88/// },
89/// "required": ["a", "b"]
90/// }),
91/// }
92/// }
93///
94/// fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult> {
95/// Box::pin(async move {
96/// let a = args["a"].as_f64().unwrap_or(0.0);
97/// let b = args["b"].as_f64().unwrap_or(0.0);
98/// Ok(vec![ToolContent::text(format!("{}", a + b))])
99/// })
100/// }
101/// }
102/// ```
103pub trait McpTool: Send + Sync {
104 /// Returns the tool definition (name, description, input schema).
105 fn definition(&self) -> McpToolDefinition;
106
107 /// Executes the tool with the given arguments.
108 fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult>;
109}
110
111/// A type-erased tool wrapper.
112pub type DynTool = Arc<dyn McpTool>;
113
114/// Trait for types that can provide multiple tools.
115///
116/// Implement this trait to group related tools together.
117///
118/// # Example
119///
120/// ```ignore
121/// use mcp::{ToolProvider, McpTool};
122///
123/// struct MathTools;
124///
125/// impl ToolProvider for MathTools {
126/// fn tools(&self) -> Vec<Arc<dyn McpTool>> {
127/// vec![
128/// Arc::new(AddTool),
129/// Arc::new(SubtractTool),
130/// Arc::new(MultiplyTool),
131/// ]
132/// }
133/// }
134/// ```
135pub trait ToolProvider: Send + Sync {
136 /// Returns a list of tools provided by this provider.
137 fn tools(&self) -> Vec<DynTool>;
138}
139
140/// A simple function-based tool.
141///
142/// This allows creating tools from closures without implementing the `McpTool` trait.
143pub struct FnTool<F>
144where
145 F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
146{
147 definition: McpToolDefinition,
148 handler: F,
149}
150
151impl<F> FnTool<F>
152where
153 F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
154{
155 /// Creates a new function-based tool.
156 pub fn new(definition: McpToolDefinition, handler: F) -> Self {
157 Self {
158 definition,
159 handler,
160 }
161 }
162}
163
164impl<F> McpTool for FnTool<F>
165where
166 F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
167{
168 fn definition(&self) -> McpToolDefinition {
169 self.definition.clone()
170 }
171
172 fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult> {
173 (self.handler)(args)
174 }
175}
176
177/// Registry for managing tools.
178#[derive(Default)]
179pub struct ToolRegistry {
180 tools: HashMap<String, DynTool>,
181 /// Cached definitions for faster access
182 definitions_cache: parking_lot::RwLock<Option<Vec<McpToolDefinition>>>,
183}
184
185impl ToolRegistry {
186 /// Creates a new empty tool registry.
187 pub fn new() -> Self {
188 Self {
189 tools: HashMap::new(),
190 definitions_cache: parking_lot::RwLock::new(None),
191 }
192 }
193
194 /// Registers a tool.
195 pub fn register(&mut self, tool: DynTool) {
196 let name = tool.definition().name.clone();
197 self.tools.insert(name, tool);
198 // Invalidate cache when tools change
199 *self.definitions_cache.write() = None;
200 }
201
202 /// Registers multiple tools from a provider.
203 pub fn register_provider<P: ToolProvider>(&mut self, provider: P) {
204 for tool in provider.tools() {
205 let name = tool.definition().name.clone();
206 self.tools.insert(name, tool);
207 }
208 // Invalidate cache when tools change
209 *self.definitions_cache.write() = None;
210 }
211
212 /// Gets a tool by name.
213 pub fn get(&self, name: &str) -> Option<&DynTool> {
214 self.tools.get(name)
215 }
216
217 /// Returns all tool definitions (cached).
218 ///
219 /// Uses an Arc-wrapped cache to minimize cloning overhead.
220 /// Returns a clone of the Arc, so iterating is efficient.
221 pub fn definitions(&self) -> Vec<McpToolDefinition> {
222 // Try to return cached definitions
223 {
224 let cache = self.definitions_cache.read();
225 if let Some(ref defs) = *cache {
226 return defs.clone();
227 }
228 }
229
230 // Build and cache definitions
231 let defs: Vec<McpToolDefinition> = self.tools.values().map(|t| t.definition()).collect();
232 *self.definitions_cache.write() = Some(defs.clone());
233 defs
234 }
235
236 /// Returns an iterator over tool definitions without cloning.
237 ///
238 /// More efficient than `definitions()` when you only need to iterate.
239 pub fn definitions_iter(&self) -> impl Iterator<Item = McpToolDefinition> + '_ {
240 self.tools.values().map(|t| t.definition())
241 }
242
243 /// Returns all tool definitions without caching (for cases where fresh data is needed).
244 pub fn definitions_uncached(&self) -> Vec<McpToolDefinition> {
245 self.tools.values().map(|t| t.definition()).collect()
246 }
247
248 /// Invalidates the definitions cache.
249 pub fn invalidate_cache(&self) {
250 *self.definitions_cache.write() = None;
251 }
252
253 /// Returns the number of registered tools.
254 pub fn len(&self) -> usize {
255 self.tools.len()
256 }
257
258 /// Returns true if no tools are registered.
259 pub fn is_empty(&self) -> bool {
260 self.tools.is_empty()
261 }
262
263 /// Calls a tool by name with the given arguments.
264 pub async fn call(&self, name: &str, args: Value) -> ToolCallResult {
265 match self.get(name) {
266 Some(tool) => tool.call(args).await,
267 None => Err(format!("Unknown tool: {}", name)),
268 }
269 }
270}
271
272/// Helper macro for creating tools from async functions.
273///
274/// # Example
275///
276/// ```ignore
277/// use mcp::fn_tool;
278///
279/// let add_tool = fn_tool!(
280/// "add",
281/// "Add two numbers",
282/// {
283/// "type": "object",
284/// "properties": {
285/// "a": { "type": "number" },
286/// "b": { "type": "number" }
287/// }
288/// },
289/// |args| async move {
290/// let a = args["a"].as_f64().unwrap_or(0.0);
291/// let b = args["b"].as_f64().unwrap_or(0.0);
292/// Ok(vec![ToolContent::text(format!("{}", a + b))])
293/// }
294/// );
295/// ```
296#[macro_export]
297macro_rules! fn_tool {
298 ($name:expr, $desc:expr, $schema:tt, $handler:expr) => {{
299 use $crate::protocol::McpToolDefinition;
300 use $crate::tool::FnTool;
301
302 let definition = McpToolDefinition {
303 name: $name.to_string(),
304 description: Some($desc.to_string()),
305 group: None,
306 input_schema: serde_json::json!($schema),
307 };
308
309 FnTool::new(definition, move |args| Box::pin($handler(args)))
310 }};
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use crate::protocol::ToolContent;
317
318 struct TestTool {
319 name: String,
320 }
321
322 impl McpTool for TestTool {
323 fn definition(&self) -> McpToolDefinition {
324 McpToolDefinition::new(&self.name)
325 .with_description("Test tool")
326 .with_schema(serde_json::json!({"type": "object"}))
327 }
328
329 fn call<'a>(&'a self, _args: Value) -> BoxFuture<'a, ToolCallResult> {
330 Box::pin(async move { Ok(vec![ToolContent::text("ok")]) })
331 }
332 }
333
334 #[test]
335 fn test_registry_register_and_get() {
336 let mut registry = ToolRegistry::new();
337 registry.register(Arc::new(TestTool {
338 name: "test".to_string(),
339 }));
340
341 assert_eq!(registry.len(), 1);
342 assert!(registry.get("test").is_some());
343 assert!(registry.get("nonexistent").is_none());
344 }
345
346 #[test]
347 fn test_registry_definitions() {
348 let mut registry = ToolRegistry::new();
349 registry.register(Arc::new(TestTool {
350 name: "tool1".to_string(),
351 }));
352 registry.register(Arc::new(TestTool {
353 name: "tool2".to_string(),
354 }));
355
356 let defs = registry.definitions();
357 assert_eq!(defs.len(), 2);
358 }
359
360 #[tokio::test]
361 async fn test_registry_call() {
362 let mut registry = ToolRegistry::new();
363 registry.register(Arc::new(TestTool {
364 name: "test".to_string(),
365 }));
366
367 let result = registry.call("test", serde_json::json!({})).await;
368 assert!(result.is_ok());
369
370 let result = registry.call("unknown", serde_json::json!({})).await;
371 assert!(result.is_err());
372 }
373}