Skip to main content

model_context_protocol/
tool.rs

1//! Tool traits and types for MCP servers.
2//!
3//! This module provides the core abstractions for defining and executing MCP tools.
4
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use crate::protocol::{McpToolDefinition, ToolContent};
13
14/// A boxed future for async tool execution.
15pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
16
17/// Result type for tool execution - returns content or an error message.
18pub type ToolCallResult = Result<Vec<ToolContent>, String>;
19
20// =============================================================================
21// Tool Auto-Discovery via Inventory
22// =============================================================================
23
24/// A factory function that creates a tool instance.
25pub type ToolFactory = fn() -> DynTool;
26
27/// Entry for auto-discovered tools registered via `#[mcp_tool]`.
28///
29/// This struct is used internally by the `inventory` crate to collect
30/// all tools defined with the `#[mcp_tool]` attribute at link time.
31pub struct ToolEntry {
32    /// Factory function to create the tool.
33    pub factory: ToolFactory,
34    /// The group this tool belongs to (if any).
35    pub group: Option<&'static str>,
36}
37
38impl ToolEntry {
39    /// Creates a new tool entry.
40    pub const fn new(factory: ToolFactory, group: Option<&'static str>) -> Self {
41        Self { factory, group }
42    }
43}
44
45// Register ToolEntry with inventory for compile-time collection
46inventory::collect!(ToolEntry);
47
48/// Returns all auto-discovered tools.
49///
50/// This collects all tools registered with `#[mcp_tool]` across the crate.
51pub fn all_tools() -> Vec<DynTool> {
52    inventory::iter::<ToolEntry>()
53        .map(|entry| (entry.factory)())
54        .collect()
55}
56
57/// Returns auto-discovered tools filtered by group.
58///
59/// Only returns tools that have the specified group.
60pub fn tools_in_group(group: &str) -> Vec<DynTool> {
61    inventory::iter::<ToolEntry>()
62        .filter(|entry| entry.group == Some(group))
63        .map(|entry| (entry.factory)())
64        .collect()
65}
66
67/// Trait for implementing MCP tools.
68///
69/// # Example
70///
71/// ```ignore
72/// use mcp::{McpTool, ToolCallResult, McpToolDefinition};
73/// use serde_json::Value;
74///
75/// struct CalculatorTool;
76///
77/// impl McpTool for CalculatorTool {
78///     fn definition(&self) -> McpToolDefinition {
79///         McpToolDefinition {
80///             name: "add".to_string(),
81///             description: Some("Add two numbers".to_string()),
82///             group: None,
83///             input_schema: serde_json::json!({
84///                 "type": "object",
85///                 "properties": {
86///                     "a": { "type": "number" },
87///                     "b": { "type": "number" }
88///                 },
89///                 "required": ["a", "b"]
90///             }),
91///         }
92///     }
93///
94///     fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult> {
95///         Box::pin(async move {
96///             let a = args["a"].as_f64().unwrap_or(0.0);
97///             let b = args["b"].as_f64().unwrap_or(0.0);
98///             Ok(vec![ToolContent::text(format!("{}", a + b))])
99///         })
100///     }
101/// }
102/// ```
103pub trait McpTool: Send + Sync {
104    /// Returns the tool definition (name, description, input schema).
105    fn definition(&self) -> McpToolDefinition;
106
107    /// Executes the tool with the given arguments.
108    fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult>;
109}
110
111/// A type-erased tool wrapper.
112pub type DynTool = Arc<dyn McpTool>;
113
114/// Trait for types that can provide multiple tools.
115///
116/// Implement this trait to group related tools together.
117///
118/// # Example
119///
120/// ```ignore
121/// use mcp::{ToolProvider, McpTool};
122///
123/// struct MathTools;
124///
125/// impl ToolProvider for MathTools {
126///     fn tools(&self) -> Vec<Arc<dyn McpTool>> {
127///         vec![
128///             Arc::new(AddTool),
129///             Arc::new(SubtractTool),
130///             Arc::new(MultiplyTool),
131///         ]
132///     }
133/// }
134/// ```
135pub trait ToolProvider: Send + Sync {
136    /// Returns a list of tools provided by this provider.
137    fn tools(&self) -> Vec<DynTool>;
138}
139
140/// A simple function-based tool.
141///
142/// This allows creating tools from closures without implementing the `McpTool` trait.
143pub struct FnTool<F>
144where
145    F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
146{
147    definition: McpToolDefinition,
148    handler: F,
149}
150
151impl<F> FnTool<F>
152where
153    F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
154{
155    /// Creates a new function-based tool.
156    pub fn new(definition: McpToolDefinition, handler: F) -> Self {
157        Self {
158            definition,
159            handler,
160        }
161    }
162}
163
164impl<F> McpTool for FnTool<F>
165where
166    F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
167{
168    fn definition(&self) -> McpToolDefinition {
169        self.definition.clone()
170    }
171
172    fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult> {
173        (self.handler)(args)
174    }
175}
176
177/// Registry for managing tools.
178#[derive(Default)]
179pub struct ToolRegistry {
180    tools: HashMap<String, DynTool>,
181    /// Cached definitions for faster access
182    definitions_cache: parking_lot::RwLock<Option<Vec<McpToolDefinition>>>,
183}
184
185impl ToolRegistry {
186    /// Creates a new empty tool registry.
187    pub fn new() -> Self {
188        Self {
189            tools: HashMap::new(),
190            definitions_cache: parking_lot::RwLock::new(None),
191        }
192    }
193
194    /// Registers a tool.
195    pub fn register(&mut self, tool: DynTool) {
196        let name = tool.definition().name.clone();
197        self.tools.insert(name, tool);
198        // Invalidate cache when tools change
199        *self.definitions_cache.write() = None;
200    }
201
202    /// Registers multiple tools from a provider.
203    pub fn register_provider<P: ToolProvider>(&mut self, provider: P) {
204        for tool in provider.tools() {
205            let name = tool.definition().name.clone();
206            self.tools.insert(name, tool);
207        }
208        // Invalidate cache when tools change
209        *self.definitions_cache.write() = None;
210    }
211
212    /// Gets a tool by name.
213    pub fn get(&self, name: &str) -> Option<&DynTool> {
214        self.tools.get(name)
215    }
216
217    /// Returns all tool definitions (cached).
218    ///
219    /// Uses an Arc-wrapped cache to minimize cloning overhead.
220    /// Returns a clone of the Arc, so iterating is efficient.
221    pub fn definitions(&self) -> Vec<McpToolDefinition> {
222        // Try to return cached definitions
223        {
224            let cache = self.definitions_cache.read();
225            if let Some(ref defs) = *cache {
226                return defs.clone();
227            }
228        }
229
230        // Build and cache definitions
231        let defs: Vec<McpToolDefinition> = self.tools.values().map(|t| t.definition()).collect();
232        *self.definitions_cache.write() = Some(defs.clone());
233        defs
234    }
235
236    /// Returns an iterator over tool definitions without cloning.
237    ///
238    /// More efficient than `definitions()` when you only need to iterate.
239    pub fn definitions_iter(&self) -> impl Iterator<Item = McpToolDefinition> + '_ {
240        self.tools.values().map(|t| t.definition())
241    }
242
243    /// Returns all tool definitions without caching (for cases where fresh data is needed).
244    pub fn definitions_uncached(&self) -> Vec<McpToolDefinition> {
245        self.tools.values().map(|t| t.definition()).collect()
246    }
247
248    /// Invalidates the definitions cache.
249    pub fn invalidate_cache(&self) {
250        *self.definitions_cache.write() = None;
251    }
252
253    /// Returns the number of registered tools.
254    pub fn len(&self) -> usize {
255        self.tools.len()
256    }
257
258    /// Returns true if no tools are registered.
259    pub fn is_empty(&self) -> bool {
260        self.tools.is_empty()
261    }
262
263    /// Calls a tool by name with the given arguments.
264    pub async fn call(&self, name: &str, args: Value) -> ToolCallResult {
265        match self.get(name) {
266            Some(tool) => tool.call(args).await,
267            None => Err(format!("Unknown tool: {}", name)),
268        }
269    }
270}
271
272/// Helper macro for creating tools from async functions.
273///
274/// # Example
275///
276/// ```ignore
277/// use mcp::fn_tool;
278///
279/// let add_tool = fn_tool!(
280///     "add",
281///     "Add two numbers",
282///     {
283///         "type": "object",
284///         "properties": {
285///             "a": { "type": "number" },
286///             "b": { "type": "number" }
287///         }
288///     },
289///     |args| async move {
290///         let a = args["a"].as_f64().unwrap_or(0.0);
291///         let b = args["b"].as_f64().unwrap_or(0.0);
292///         Ok(vec![ToolContent::text(format!("{}", a + b))])
293///     }
294/// );
295/// ```
296#[macro_export]
297macro_rules! fn_tool {
298    ($name:expr, $desc:expr, $schema:tt, $handler:expr) => {{
299        use $crate::protocol::McpToolDefinition;
300        use $crate::tool::FnTool;
301
302        let definition = McpToolDefinition {
303            name: $name.to_string(),
304            description: Some($desc.to_string()),
305            group: None,
306            input_schema: serde_json::json!($schema),
307        };
308
309        FnTool::new(definition, move |args| Box::pin($handler(args)))
310    }};
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::protocol::ToolContent;
317
318    struct TestTool {
319        name: String,
320    }
321
322    impl McpTool for TestTool {
323        fn definition(&self) -> McpToolDefinition {
324            McpToolDefinition::new(&self.name)
325                .with_description("Test tool")
326                .with_schema(serde_json::json!({"type": "object"}))
327        }
328
329        fn call<'a>(&'a self, _args: Value) -> BoxFuture<'a, ToolCallResult> {
330            Box::pin(async move { Ok(vec![ToolContent::text("ok")]) })
331        }
332    }
333
334    #[test]
335    fn test_registry_register_and_get() {
336        let mut registry = ToolRegistry::new();
337        registry.register(Arc::new(TestTool {
338            name: "test".to_string(),
339        }));
340
341        assert_eq!(registry.len(), 1);
342        assert!(registry.get("test").is_some());
343        assert!(registry.get("nonexistent").is_none());
344    }
345
346    #[test]
347    fn test_registry_definitions() {
348        let mut registry = ToolRegistry::new();
349        registry.register(Arc::new(TestTool {
350            name: "tool1".to_string(),
351        }));
352        registry.register(Arc::new(TestTool {
353            name: "tool2".to_string(),
354        }));
355
356        let defs = registry.definitions();
357        assert_eq!(defs.len(), 2);
358    }
359
360    #[tokio::test]
361    async fn test_registry_call() {
362        let mut registry = ToolRegistry::new();
363        registry.register(Arc::new(TestTool {
364            name: "test".to_string(),
365        }));
366
367        let result = registry.call("test", serde_json::json!({})).await;
368        assert!(result.is_ok());
369
370        let result = registry.call("unknown", serde_json::json!({})).await;
371        assert!(result.is_err());
372    }
373}