mcpkit_server/capability/
tools.rs1use crate::context::Context;
7use crate::handler::ToolHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::tool::{Tool, ToolOutput};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::Arc;
15
16pub type BoxedToolFn = Box<
18 dyn for<'a> Fn(
19 Value,
20 &'a Context<'a>,
21 ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
22 + Send
23 + Sync,
24>;
25
26pub struct RegisteredTool {
28 pub tool: Tool,
30 pub handler: BoxedToolFn,
32}
33
34pub struct ToolService {
39 tools: HashMap<String, RegisteredTool>,
40}
41
42impl Default for ToolService {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl ToolService {
49 pub fn new() -> Self {
51 Self {
52 tools: HashMap::new(),
53 }
54 }
55
56 pub fn register<F, Fut>(&mut self, tool: Tool, handler: F)
58 where
59 F: Fn(Value, &Context<'_>) -> Fut + Send + Sync + 'static,
60 Fut: Future<Output = Result<ToolOutput, McpError>> + Send + 'static,
61 {
62 let name = tool.name.clone();
63 let boxed: BoxedToolFn = Box::new(move |args, ctx| Box::pin(handler(args, ctx)));
64 self.tools.insert(
65 name,
66 RegisteredTool {
67 tool,
68 handler: boxed,
69 },
70 );
71 }
72
73 pub fn register_arc<H>(&mut self, tool: Tool, handler: Arc<H>)
75 where
76 H: for<'a> Fn(Value, &'a Context<'a>) -> Pin<Box<dyn Future<Output = Result<ToolOutput, McpError>> + Send + 'a>>
77 + Send
78 + Sync
79 + 'static,
80 {
81 let name = tool.name.clone();
82 let boxed: BoxedToolFn = Box::new(move |args, ctx| (handler)(args, ctx));
83 self.tools.insert(
84 name,
85 RegisteredTool {
86 tool,
87 handler: boxed,
88 },
89 );
90 }
91
92 pub fn get(&self, name: &str) -> Option<&RegisteredTool> {
94 self.tools.get(name)
95 }
96
97 pub fn contains(&self, name: &str) -> bool {
99 self.tools.contains_key(name)
100 }
101
102 pub fn list(&self) -> Vec<&Tool> {
104 self.tools.values().map(|r| &r.tool).collect()
105 }
106
107 pub fn len(&self) -> usize {
109 self.tools.len()
110 }
111
112 pub fn is_empty(&self) -> bool {
114 self.tools.is_empty()
115 }
116
117 pub async fn call(
119 &self,
120 name: &str,
121 arguments: Value,
122 ctx: &Context<'_>,
123 ) -> Result<ToolOutput, McpError> {
124 let registered = self.tools.get(name).ok_or_else(|| {
125 McpError::invalid_params("tools/call", format!("Unknown tool: {name}"))
126 })?;
127
128 (registered.handler)(arguments, ctx).await
129 }
130}
131
132impl ToolHandler for ToolService {
133 async fn list_tools(&self, _ctx: &Context<'_>) -> Result<Vec<Tool>, McpError> {
134 Ok(self.list().into_iter().cloned().collect())
135 }
136
137 async fn call_tool(
138 &self,
139 name: &str,
140 arguments: Value,
141 ctx: &Context<'_>,
142 ) -> Result<ToolOutput, McpError> {
143 self.call(name, arguments, ctx).await
144 }
145}
146
147pub struct ToolBuilder {
149 name: String,
150 description: Option<String>,
151 input_schema: Value,
152}
153
154impl ToolBuilder {
155 pub fn new(name: impl Into<String>) -> Self {
157 Self {
158 name: name.into(),
159 description: None,
160 input_schema: serde_json::json!({
161 "type": "object",
162 "properties": {},
163 }),
164 }
165 }
166
167 pub fn description(mut self, desc: impl Into<String>) -> Self {
169 self.description = Some(desc.into());
170 self
171 }
172
173 pub fn input_schema(mut self, schema: Value) -> Self {
175 self.input_schema = schema;
176 self
177 }
178
179 pub fn build(self) -> Tool {
181 Tool {
182 name: self.name,
183 description: self.description,
184 input_schema: self.input_schema,
185 annotations: None,
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::context::{NoOpPeer, Context};
194 use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
195 use mcpkit_core::protocol::RequestId;
196 use mcpkit_core::types::tool::CallToolResult;
197
198 fn make_context() -> (RequestId, ClientCapabilities, ServerCapabilities, NoOpPeer) {
199 (
200 RequestId::Number(1),
201 ClientCapabilities::default(),
202 ServerCapabilities::default(),
203 NoOpPeer,
204 )
205 }
206
207 #[test]
208 fn test_tool_builder() {
209 let tool = ToolBuilder::new("test")
210 .description("A test tool")
211 .input_schema(serde_json::json!({
212 "type": "object",
213 "properties": {
214 "query": { "type": "string" }
215 }
216 }))
217 .build();
218
219 assert_eq!(tool.name, "test");
220 assert_eq!(tool.description.as_deref(), Some("A test tool"));
221 }
222
223 #[tokio::test]
224 async fn test_tool_service() {
225 let mut service = ToolService::new();
226
227 let tool = ToolBuilder::new("echo")
228 .description("Echo back input")
229 .build();
230
231 service.register(tool, |args, _ctx| async move {
232 Ok(ToolOutput::text(args.to_string()))
233 });
234
235 assert!(service.contains("echo"));
236 assert_eq!(service.len(), 1);
237
238 let (req_id, client_caps, server_caps, peer) = make_context();
239 let ctx = Context::new(&req_id, None, &client_caps, &server_caps, &peer);
240
241 let result = service
242 .call("echo", serde_json::json!({"hello": "world"}), &ctx)
243 .await
244 .unwrap();
245
246 let call_result: CallToolResult = result.into();
248 assert!(!call_result.content.is_empty());
249 }
250}