1use bon::Builder;
2use rmcp::{
3 RoleServer, ServerHandler,
4 model::{
5 CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6 ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7 ToolsCapability,
8 },
9 service::RequestContext,
10};
11use serde_json::Value;
12
13use reqwest::header::HeaderMap;
14use url::Url;
15
16use crate::error::Error;
17use crate::spec::SpecLocation;
18use crate::tool::{Tool, ToolCollection, ToolMetadata};
19use tracing::{debug, info, info_span, warn};
20
21#[derive(Clone, Builder)]
22pub struct Server {
23 pub spec_location: SpecLocation,
24 #[builder(default)]
25 pub tool_collection: ToolCollection,
26 pub base_url: Option<Url>,
27 pub default_headers: Option<HeaderMap>,
28 pub tag_filter: Option<Vec<String>>,
29 pub method_filter: Option<Vec<reqwest::Method>>,
30}
31
32impl Server {
33 #[must_use]
35 pub fn new(spec_location: SpecLocation) -> Self {
36 Self::builder().spec_location(spec_location).build()
37 }
38
39 pub fn with_base_url(spec_location: SpecLocation, base_url: Url) -> Result<Self, Error> {
45 Ok(Self::builder()
46 .spec_location(spec_location)
47 .base_url(base_url)
48 .build())
49 }
50
51 pub fn with_base_url_and_headers(
57 spec_location: SpecLocation,
58 base_url: Url,
59 default_headers: HeaderMap,
60 ) -> Result<Self, Error> {
61 Ok(Self::builder()
62 .spec_location(spec_location)
63 .base_url(base_url)
64 .default_headers(default_headers)
65 .build())
66 }
67
68 #[must_use]
70 pub fn with_default_headers(spec_location: SpecLocation, default_headers: HeaderMap) -> Self {
71 Self::builder()
72 .spec_location(spec_location)
73 .default_headers(default_headers)
74 .build()
75 }
76
77 pub async fn load_openapi_spec(&mut self) -> Result<(), Error> {
83 let span = info_span!(
84 "tool_registration",
85 spec_location = %self.spec_location
86 );
87 let _enter = span.enter();
88
89 let spec = self.spec_location.load_spec().await?;
91
92 let tools = spec.to_openapi_tools(
94 self.tag_filter.as_deref(),
95 self.method_filter.as_deref(),
96 self.base_url.clone(),
97 self.default_headers.clone(),
98 )?;
99
100 self.tool_collection = ToolCollection::from_tools(tools);
101
102 info!(
103 tool_count = self.tool_collection.len(),
104 "Loaded tools from OpenAPI spec"
105 );
106
107 Ok(())
108 }
109
110 #[must_use]
112 pub fn tool_count(&self) -> usize {
113 self.tool_collection.len()
114 }
115
116 #[must_use]
118 pub fn get_tool_names(&self) -> Vec<String> {
119 self.tool_collection.get_tool_names()
120 }
121
122 #[must_use]
124 pub fn has_tool(&self, name: &str) -> bool {
125 self.tool_collection.has_tool(name)
126 }
127
128 #[must_use]
130 pub fn get_tool(&self, name: &str) -> Option<&Tool> {
131 self.tool_collection.get_tool(name)
132 }
133
134 #[must_use]
136 pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
137 self.get_tool(name).map(|tool| &tool.metadata)
138 }
139
140 #[must_use]
142 pub fn get_tool_stats(&self) -> String {
143 self.tool_collection.get_stats()
144 }
145
146 #[must_use]
148 pub fn with_tags(mut self, tags: Option<Vec<String>>) -> Self {
149 self.tag_filter = tags;
150 self
151 }
152
153 #[must_use]
155 pub fn with_methods(mut self, methods: Option<Vec<reqwest::Method>>) -> Self {
156 self.method_filter = methods;
157 self
158 }
159
160 pub fn validate_registry(&self) -> Result<(), Error> {
166 if self.tool_collection.is_empty() {
167 return Err(Error::McpError("No tools loaded".to_string()));
168 }
169 Ok(())
170 }
171}
172
173impl ServerHandler for Server {
174 fn get_info(&self) -> InitializeResult {
175 InitializeResult {
176 protocol_version: ProtocolVersion::V_2024_11_05,
177 server_info: Implementation {
178 name: "OpenAPI MCP Server".to_string(),
179 version: "0.1.0".to_string(),
180 },
181 capabilities: ServerCapabilities {
182 tools: Some(ToolsCapability {
183 list_changed: Some(false),
184 }),
185 ..Default::default()
186 },
187 instructions: Some("Exposes OpenAPI endpoints as MCP tools".to_string()),
188 }
189 }
190
191 async fn list_tools(
192 &self,
193 _request: Option<PaginatedRequestParam>,
194 _context: RequestContext<RoleServer>,
195 ) -> Result<ListToolsResult, ErrorData> {
196 let span = info_span!("list_tools", tool_count = self.tool_collection.len());
197 let _enter = span.enter();
198
199 debug!("Processing MCP list_tools request");
200
201 let tools = self.tool_collection.to_mcp_tools();
203
204 info!(
205 returned_tools = tools.len(),
206 "MCP list_tools request completed successfully"
207 );
208
209 Ok(ListToolsResult {
210 tools,
211 next_cursor: None,
212 })
213 }
214
215 async fn call_tool(
216 &self,
217 request: CallToolRequestParam,
218 _context: RequestContext<RoleServer>,
219 ) -> Result<CallToolResult, ErrorData> {
220 let span = info_span!(
221 "call_tool",
222 tool_name = %request.name
223 );
224 let _enter = span.enter();
225
226 debug!(
227 tool_name = %request.name,
228 has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
229 "Processing MCP call_tool request"
230 );
231
232 let arguments = request.arguments.unwrap_or_default();
233 let arguments_value = Value::Object(arguments);
234
235 match self
237 .tool_collection
238 .call_tool(&request.name, &arguments_value)
239 .await
240 {
241 Ok(result) => {
242 info!(
243 tool_name = %request.name,
244 success = true,
245 "MCP call_tool request completed successfully"
246 );
247 Ok(result)
248 }
249 Err(e) => {
250 warn!(
251 tool_name = %request.name,
252 success = false,
253 error = %e,
254 "MCP call_tool request failed"
255 );
256 Err(e.into())
258 }
259 }
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::error::ToolCallValidationError;
267 use crate::{ToolCallError, ToolMetadata};
268 use serde_json::json;
269
270 #[test]
271 fn test_tool_not_found_error_with_suggestions() {
272 let tool1_metadata = ToolMetadata {
274 name: "getPetById".to_string(),
275 title: Some("Get Pet by ID".to_string()),
276 description: "Find pet by ID".to_string(),
277 parameters: json!({
278 "type": "object",
279 "properties": {
280 "petId": {
281 "type": "integer"
282 }
283 },
284 "required": ["petId"]
285 }),
286 output_schema: None,
287 method: "GET".to_string(),
288 path: "/pet/{petId}".to_string(),
289 };
290
291 let tool2_metadata = ToolMetadata {
292 name: "getPetsByStatus".to_string(),
293 title: Some("Find Pets by Status".to_string()),
294 description: "Find pets by status".to_string(),
295 parameters: json!({
296 "type": "object",
297 "properties": {
298 "status": {
299 "type": "array",
300 "items": {
301 "type": "string"
302 }
303 }
304 },
305 "required": ["status"]
306 }),
307 output_schema: None,
308 method: "GET".to_string(),
309 path: "/pet/findByStatus".to_string(),
310 };
311
312 let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
314 let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
315
316 let mut server = Server::new(SpecLocation::Url(Url::parse("test://example").unwrap()));
318 server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
319
320 let tool_names = server.get_tool_names();
322 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
323
324 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
325 "getPetByID".to_string(),
326 &tool_name_refs,
327 ));
328 let error_data: ErrorData = error.into();
329 let error_json = serde_json::to_value(&error_data).unwrap();
330
331 insta::assert_json_snapshot!(error_json);
333 }
334
335 #[test]
336 fn test_tool_not_found_error_no_suggestions() {
337 let tool_metadata = ToolMetadata {
339 name: "getPetById".to_string(),
340 title: Some("Get Pet by ID".to_string()),
341 description: "Find pet by ID".to_string(),
342 parameters: json!({
343 "type": "object",
344 "properties": {
345 "petId": {
346 "type": "integer"
347 }
348 },
349 "required": ["petId"]
350 }),
351 output_schema: None,
352 method: "GET".to_string(),
353 path: "/pet/{petId}".to_string(),
354 };
355
356 let tool = Tool::new(tool_metadata, None, None).unwrap();
358
359 let mut server = Server::new(SpecLocation::Url(Url::parse("test://example").unwrap()));
361 server.tool_collection = ToolCollection::from_tools(vec![tool]);
362
363 let tool_names = server.get_tool_names();
365 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
366
367 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
368 "completelyUnrelatedToolName".to_string(),
369 &tool_name_refs,
370 ));
371 let error_data: ErrorData = error.into();
372 let error_json = serde_json::to_value(&error_data).unwrap();
373
374 insta::assert_json_snapshot!(error_json);
376 }
377
378 #[test]
379 fn test_validation_error_converted_to_error_data() {
380 let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
382 violations: vec![crate::error::ValidationError::invalid_parameter(
383 "page".to_string(),
384 &["page_number".to_string(), "page_size".to_string()],
385 )],
386 });
387
388 let error_data: ErrorData = error.into();
389 let error_json = serde_json::to_value(&error_data).unwrap();
390
391 assert_eq!(error_json["code"], -32602); insta::assert_json_snapshot!(error_json);
396 }
397}