model_context_protocol/tool.rs
1//! Tool traits and types for MCP servers.
2//!
3//! This module provides the core abstractions for defining and executing MCP tools.
4
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use crate::protocol::{McpToolDef, ToolContent};
13
14/// A boxed future for async tool execution.
15pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
16
17/// Result type for tool execution - returns content or an error message.
18pub type ToolCallResult = Result<Vec<ToolContent>, String>;
19
20// =============================================================================
21// Tool Auto-Discovery via Inventory
22// =============================================================================
23
24/// A factory function that creates a tool instance.
25pub type ToolFactory = fn() -> DynTool;
26
27/// Entry for auto-discovered tools registered via `#[mcp_tool]`.
28///
29/// This struct is used internally by the `inventory` crate to collect
30/// all tools defined with the `#[mcp_tool]` attribute at link time.
31pub struct ToolEntry {
32 /// Factory function to create the tool.
33 pub factory: ToolFactory,
34 /// The group this tool belongs to (if any).
35 pub group: Option<&'static str>,
36}
37
38impl ToolEntry {
39 /// Creates a new tool entry.
40 pub const fn new(factory: ToolFactory, group: Option<&'static str>) -> Self {
41 Self { factory, group }
42 }
43}
44
45// Register ToolEntry with inventory for compile-time collection
46inventory::collect!(ToolEntry);
47
48/// Returns all auto-discovered tools.
49///
50/// This collects all tools registered with `#[mcp_tool]` across the crate.
51pub fn all_tools() -> Vec<DynTool> {
52 inventory::iter::<ToolEntry>()
53 .map(|entry| (entry.factory)())
54 .collect()
55}
56
57/// Returns auto-discovered tools filtered by group.
58///
59/// Only returns tools that have the specified group.
60pub fn tools_in_group(group: &str) -> Vec<DynTool> {
61 inventory::iter::<ToolEntry>()
62 .filter(|entry| entry.group == Some(group))
63 .map(|entry| (entry.factory)())
64 .collect()
65}
66
67/// Trait for implementing MCP tools.
68///
69/// # Example
70///
71/// ```ignore
72/// use mcp::{McpTool, ToolCallResult, McpToolDef};
73/// use serde_json::Value;
74///
75/// struct CalculatorTool;
76///
77/// impl McpTool for CalculatorTool {
78/// fn definition(&self) -> McpToolDef {
79/// McpToolDef {
80/// name: "add".to_string(),
81/// description: Some("Add two numbers".to_string()),
82/// input_schema: serde_json::json!({
83/// "type": "object",
84/// "properties": {
85/// "a": { "type": "number" },
86/// "b": { "type": "number" }
87/// },
88/// "required": ["a", "b"]
89/// }),
90/// }
91/// }
92///
93/// fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult> {
94/// Box::pin(async move {
95/// let a = args["a"].as_f64().unwrap_or(0.0);
96/// let b = args["b"].as_f64().unwrap_or(0.0);
97/// Ok(vec![ToolContent::text(format!("{}", a + b))])
98/// })
99/// }
100/// }
101/// ```
102pub trait McpTool: Send + Sync {
103 /// Returns the tool definition (name, description, input schema).
104 fn definition(&self) -> McpToolDef;
105
106 /// Executes the tool with the given arguments.
107 fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult>;
108}
109
110/// A type-erased tool wrapper.
111pub type DynTool = Arc<dyn McpTool>;
112
113/// Trait for types that can provide multiple tools.
114///
115/// Implement this trait to group related tools together.
116///
117/// # Example
118///
119/// ```ignore
120/// use mcp::{ToolProvider, McpTool};
121///
122/// struct MathTools;
123///
124/// impl ToolProvider for MathTools {
125/// fn tools(&self) -> Vec<Arc<dyn McpTool>> {
126/// vec![
127/// Arc::new(AddTool),
128/// Arc::new(SubtractTool),
129/// Arc::new(MultiplyTool),
130/// ]
131/// }
132/// }
133/// ```
134pub trait ToolProvider: Send + Sync {
135 /// Returns a list of tools provided by this provider.
136 fn tools(&self) -> Vec<DynTool>;
137}
138
139/// A simple function-based tool.
140///
141/// This allows creating tools from closures without implementing the `McpTool` trait.
142pub struct FnTool<F>
143where
144 F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
145{
146 definition: McpToolDef,
147 handler: F,
148}
149
150impl<F> FnTool<F>
151where
152 F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
153{
154 /// Creates a new function-based tool.
155 pub fn new(definition: McpToolDef, handler: F) -> Self {
156 Self {
157 definition,
158 handler,
159 }
160 }
161}
162
163impl<F> McpTool for FnTool<F>
164where
165 F: Fn(Value) -> BoxFuture<'static, ToolCallResult> + Send + Sync,
166{
167 fn definition(&self) -> McpToolDef {
168 self.definition.clone()
169 }
170
171 fn call<'a>(&'a self, args: Value) -> BoxFuture<'a, ToolCallResult> {
172 (self.handler)(args)
173 }
174}
175
176/// Registry for managing tools.
177#[derive(Default)]
178pub struct ToolRegistry {
179 tools: HashMap<String, DynTool>,
180}
181
182impl ToolRegistry {
183 /// Creates a new empty tool registry.
184 pub fn new() -> Self {
185 Self::default()
186 }
187
188 /// Registers a tool.
189 pub fn register(&mut self, tool: DynTool) {
190 let name = tool.definition().name.clone();
191 self.tools.insert(name, tool);
192 }
193
194 /// Registers multiple tools from a provider.
195 pub fn register_provider<P: ToolProvider>(&mut self, provider: P) {
196 for tool in provider.tools() {
197 self.register(tool);
198 }
199 }
200
201 /// Gets a tool by name.
202 pub fn get(&self, name: &str) -> Option<&DynTool> {
203 self.tools.get(name)
204 }
205
206 /// Returns all tool definitions.
207 pub fn definitions(&self) -> Vec<McpToolDef> {
208 self.tools.values().map(|t| t.definition()).collect()
209 }
210
211 /// Returns the number of registered tools.
212 pub fn len(&self) -> usize {
213 self.tools.len()
214 }
215
216 /// Returns true if no tools are registered.
217 pub fn is_empty(&self) -> bool {
218 self.tools.is_empty()
219 }
220
221 /// Calls a tool by name with the given arguments.
222 pub async fn call(&self, name: &str, args: Value) -> ToolCallResult {
223 match self.get(name) {
224 Some(tool) => tool.call(args).await,
225 None => Err(format!("Unknown tool: {}", name)),
226 }
227 }
228}
229
230/// Helper macro for creating tools from async functions.
231///
232/// # Example
233///
234/// ```ignore
235/// use mcp::fn_tool;
236///
237/// let add_tool = fn_tool!(
238/// "add",
239/// "Add two numbers",
240/// {
241/// "type": "object",
242/// "properties": {
243/// "a": { "type": "number" },
244/// "b": { "type": "number" }
245/// }
246/// },
247/// |args| async move {
248/// let a = args["a"].as_f64().unwrap_or(0.0);
249/// let b = args["b"].as_f64().unwrap_or(0.0);
250/// Ok(vec![ToolContent::text(format!("{}", a + b))])
251/// }
252/// );
253/// ```
254#[macro_export]
255macro_rules! fn_tool {
256 ($name:expr, $desc:expr, $schema:tt, $handler:expr) => {{
257 use $crate::protocol::McpToolDef;
258 use $crate::tool::FnTool;
259
260 let definition = McpToolDef {
261 name: $name.to_string(),
262 description: Some($desc.to_string()),
263 group: None,
264 input_schema: serde_json::json!($schema),
265 };
266
267 FnTool::new(definition, move |args| Box::pin($handler(args)))
268 }};
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::protocol::ToolContent;
275
276 struct TestTool {
277 name: String,
278 }
279
280 impl McpTool for TestTool {
281 fn definition(&self) -> McpToolDef {
282 McpToolDef {
283 name: self.name.clone(),
284 description: Some("Test tool".to_string()),
285 group: None,
286 input_schema: serde_json::json!({"type": "object"}),
287 }
288 }
289
290 fn call<'a>(&'a self, _args: Value) -> BoxFuture<'a, ToolCallResult> {
291 Box::pin(async move { Ok(vec![ToolContent::text("ok")]) })
292 }
293 }
294
295 #[test]
296 fn test_registry_register_and_get() {
297 let mut registry = ToolRegistry::new();
298 registry.register(Arc::new(TestTool {
299 name: "test".to_string(),
300 }));
301
302 assert_eq!(registry.len(), 1);
303 assert!(registry.get("test").is_some());
304 assert!(registry.get("nonexistent").is_none());
305 }
306
307 #[test]
308 fn test_registry_definitions() {
309 let mut registry = ToolRegistry::new();
310 registry.register(Arc::new(TestTool {
311 name: "tool1".to_string(),
312 }));
313 registry.register(Arc::new(TestTool {
314 name: "tool2".to_string(),
315 }));
316
317 let defs = registry.definitions();
318 assert_eq!(defs.len(), 2);
319 }
320
321 #[tokio::test]
322 async fn test_registry_call() {
323 let mut registry = ToolRegistry::new();
324 registry.register(Arc::new(TestTool {
325 name: "test".to_string(),
326 }));
327
328 let result = registry.call("test", serde_json::json!({})).await;
329 assert!(result.is_ok());
330
331 let result = registry.call("unknown", serde_json::json!({})).await;
332 assert!(result.is_err());
333 }
334}