Skip to main content

model_context_protocol/
tool.rs

1//! Tool traits and types for MCP servers.
2//!
3//! This module provides the core abstractions for defining and executing MCP tools.
4
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use crate::protocol::{McpToolDef, ToolContent};
13
14/// A boxed future for async tool execution.
15pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
16
17/// Result type for tool execution - returns content or an error message.
18pub type ToolCallResult = Result<Vec<ToolContent>, String>;
19
20// =============================================================================
21// Tool Auto-Discovery via Inventory
22// =============================================================================
23
24/// A factory function that creates a tool instance.
25pub type ToolFactory = fn() -> DynTool;
26
27/// Entry for auto-discovered tools registered via `#[mcp_tool]`.
28///
29/// This struct is used internally by the `inventory` crate to collect
30/// all tools defined with the `#[mcp_tool]` attribute at link time.
31pub struct ToolEntry {
32    /// Factory function to create the tool.
33    pub factory: ToolFactory,
34    /// The group this tool belongs to (if any).
35    pub group: Option<&'static str>,
36}
37
38impl ToolEntry {
39    /// Creates a new tool entry.
40    pub const fn new(factory: ToolFactory, group: Option<&'static str>) -> Self {
41        Self { factory, group }
42    }
43}
44
45// Register ToolEntry with inventory for compile-time collection
46inventory::collect!(ToolEntry);
47
48/// Returns all auto-discovered tools.
49///
50/// This collects all tools registered with `#[mcp_tool]` across the crate.
51pub fn all_tools() -> Vec<DynTool> {
52    inventory::iter::<ToolEntry>()
53        .map(|entry| (entry.factory)())
54        .collect()
55}
56
57/// Returns auto-discovered tools filtered by group.
58///
59/// Only returns tools that have the specified group.
60pub fn tools_in_group(group: &str) -> Vec<DynTool> {
61    inventory::iter::<ToolEntry>()
62        .filter(|entry| entry.group == Some(group))
63        .map(|entry| (entry.factory)())
64        .collect()
65}
66
67/// Trait for implementing MCP tools.
68///
69/// # Example
70///
71/// ```ignore
72/// use mcp::{McpTool, ToolCallResult, McpToolDef};
73/// use serde_json::Value;
74///
75/// struct CalculatorTool;
76///
77/// impl McpTool for CalculatorTool {
78///     fn definition(&self) -> McpToolDef {
79///         McpToolDef {
80///             name: "add".to_string(),
81///             description: Some("Add two numbers".to_string()),
82///             input_schema: serde_json::json!({
83///                 "type": "object",
84///                 "properties": {
85///                     "a": { "type": "number" },
86///                     "b": { "type": "number" }
87///                 },
88///                 "required": ["a", "b"]
89///             }),
90///         }
91///     }
92///
93///     fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult> {
94///         Box::pin(async move {
95///             let a = args["a"].as_f64().unwrap_or(0.0);
96///             let b = args["b"].as_f64().unwrap_or(0.0);
97///             Ok(vec![ToolContent::text(format!("{}", a + b))])
98///         })
99///     }
100/// }
101/// ```
102pub trait McpTool: Send + Sync {
103    /// Returns the tool definition (name, description, input schema).
104    fn definition(&self) -> McpToolDef;
105
106    /// Executes the tool with the given arguments.
107    fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult>;
108}
109
110/// A type-erased tool wrapper.
111pub type DynTool = Arc<dyn McpTool>;
112
113/// Trait for types that can provide multiple tools.
114///
115/// Implement this trait to group related tools together.
116///
117/// # Example
118///
119/// ```ignore
120/// use mcp::{ToolProvider, McpTool};
121///
122/// struct MathTools;
123///
124/// impl ToolProvider for MathTools {
125///     fn tools(&self) -> Vec<Arc<dyn McpTool>> {
126///         vec![
127///             Arc::new(AddTool),
128///             Arc::new(SubtractTool),
129///             Arc::new(MultiplyTool),
130///         ]
131///     }
132/// }
133/// ```
134pub trait ToolProvider: Send + Sync {
135    /// Returns a list of tools provided by this provider.
136    fn tools(&self) -> Vec<DynTool>;
137}
138
139/// A simple function-based tool.
140///
141/// This allows creating tools from closures without implementing the `McpTool` trait.
142pub struct FnTool<F>
143where
144    F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
145{
146    definition: McpToolDef,
147    handler: F,
148}
149
150impl<F> FnTool<F>
151where
152    F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
153{
154    /// Creates a new function-based tool.
155    pub fn new(definition: McpToolDef, handler: F) -> Self {
156        Self {
157            definition,
158            handler,
159        }
160    }
161}
162
163impl<F> McpTool for FnTool<F>
164where
165    F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
166{
167    fn definition(&self) -> McpToolDef {
168        self.definition.clone()
169    }
170
171    fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult> {
172        (self.handler)(args)
173    }
174}
175
176/// Registry for managing tools.
177#[derive(Default)]
178pub struct ToolRegistry {
179    tools: HashMap<String, DynTool>,
180}
181
182impl ToolRegistry {
183    /// Creates a new empty tool registry.
184    pub fn new() -> Self {
185        Self::default()
186    }
187
188    /// Registers a tool.
189    pub fn register(&mut self, tool: DynTool) {
190        let name = tool.definition().name.clone();
191        self.tools.insert(name, tool);
192    }
193
194    /// Registers multiple tools from a provider.
195    pub fn register_provider<P: ToolProvider>(&mut self, provider: P) {
196        for tool in provider.tools() {
197            self.register(tool);
198        }
199    }
200
201    /// Gets a tool by name.
202    pub fn get(&self, name: &str) -> Option<&DynTool> {
203        self.tools.get(name)
204    }
205
206    /// Returns all tool definitions.
207    pub fn definitions(&self) -> Vec<McpToolDef> {
208        self.tools.values().map(|t| t.definition()).collect()
209    }
210
211    /// Returns the number of registered tools.
212    pub fn len(&self) -> usize {
213        self.tools.len()
214    }
215
216    /// Returns true if no tools are registered.
217    pub fn is_empty(&self) -> bool {
218        self.tools.is_empty()
219    }
220
221    /// Calls a tool by name with the given arguments.
222    pub async fn call(&self, name: &str, args: Value) -> ToolCallResult {
223        match self.get(name) {
224            Some(tool) => tool.call(args).await,
225            None => Err(format!("Unknown tool: {}", name)),
226        }
227    }
228}
229
230/// Helper macro for creating tools from async functions.
231///
232/// # Example
233///
234/// ```ignore
235/// use mcp::fn_tool;
236///
237/// let add_tool = fn_tool!(
238///     "add",
239///     "Add two numbers",
240///     {
241///         "type": "object",
242///         "properties": {
243///             "a": { "type": "number" },
244///             "b": { "type": "number" }
245///         }
246///     },
247///     |args| async move {
248///         let a = args["a"].as_f64().unwrap_or(0.0);
249///         let b = args["b"].as_f64().unwrap_or(0.0);
250///         Ok(vec![ToolContent::text(format!("{}", a + b))])
251///     }
252/// );
253/// ```
254#[macro_export]
255macro_rules! fn_tool {
256    ($name:expr, $desc:expr, $schema:tt, $handler:expr) => {{
257        use $crate::protocol::McpToolDef;
258        use $crate::tool::FnTool;
259
260        let definition = McpToolDef {
261            name: $name.to_string(),
262            description: Some($desc.to_string()),
263            group: None,
264            input_schema: serde_json::json!($schema),
265        };
266
267        FnTool::new(definition, move |args| Box::pin($handler(args)))
268    }};
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use crate::protocol::ToolContent;
275
276    struct TestTool {
277        name: String,
278    }
279
280    impl McpTool for TestTool {
281        fn definition(&self) -> McpToolDef {
282            McpToolDef {
283                name: self.name.clone(),
284                description: Some("Test tool".to_string()),
285                group: None,
286                input_schema: serde_json::json!({"type": "object"}),
287            }
288        }
289
290        fn call<'a>(&'a self, _args: Value) -> BoxFuture<'a, ToolCallResult> {
291            Box::pin(async move { Ok(vec![ToolContent::text("ok")]) })
292        }
293    }
294
295    #[test]
296    fn test_registry_register_and_get() {
297        let mut registry = ToolRegistry::new();
298        registry.register(Arc::new(TestTool {
299            name: "test".to_string(),
300        }));
301
302        assert_eq!(registry.len(), 1);
303        assert!(registry.get("test").is_some());
304        assert!(registry.get("nonexistent").is_none());
305    }
306
307    #[test]
308    fn test_registry_definitions() {
309        let mut registry = ToolRegistry::new();
310        registry.register(Arc::new(TestTool {
311            name: "tool1".to_string(),
312        }));
313        registry.register(Arc::new(TestTool {
314            name: "tool2".to_string(),
315        }));
316
317        let defs = registry.definitions();
318        assert_eq!(defs.len(), 2);
319    }
320
321    #[tokio::test]
322    async fn test_registry_call() {
323        let mut registry = ToolRegistry::new();
324        registry.register(Arc::new(TestTool {
325            name: "test".to_string(),
326        }));
327
328        let result = registry.call("test", serde_json::json!({})).await;
329        assert!(result.is_ok());
330
331        let result = registry.call("unknown", serde_json::json!({})).await;
332        assert!(result.is_err());
333    }
334}