rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    RoleServer, ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::RequestContext,
10};
11use serde_json::Value;
12
13use reqwest::header::HeaderMap;
14use url::Url;
15
16use crate::error::Error;
17use crate::tool::{Tool, ToolCollection, ToolMetadata};
18use tracing::{debug, info, info_span, warn};
19
20#[derive(Clone, Builder)]
21pub struct Server {
22    pub openapi_spec: serde_json::Value,
23    #[builder(default)]
24    pub tool_collection: ToolCollection,
25    pub base_url: Url,
26    pub default_headers: Option<HeaderMap>,
27    pub tag_filter: Option<Vec<String>>,
28    pub method_filter: Option<Vec<reqwest::Method>>,
29}
30
31impl Server {
32    /// Create a new Server instance with required parameters
33    pub fn new(
34        openapi_spec: serde_json::Value,
35        base_url: Url,
36        default_headers: Option<HeaderMap>,
37        tag_filter: Option<Vec<String>>,
38        method_filter: Option<Vec<reqwest::Method>>,
39    ) -> Self {
40        Self {
41            openapi_spec,
42            tool_collection: ToolCollection::new(),
43            base_url,
44            default_headers,
45            tag_filter,
46            method_filter,
47        }
48    }
49
50    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if the spec cannot be parsed or tools cannot be generated
55    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
56        let span = info_span!("tool_registration");
57        let _enter = span.enter();
58
59        // Parse the OpenAPI specification
60        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
61
62        // Generate OpenApiTool instances directly
63        let tools = spec.to_openapi_tools(
64            self.tag_filter.as_deref(),
65            self.method_filter.as_deref(),
66            Some(self.base_url.clone()),
67            self.default_headers.clone(),
68        )?;
69
70        self.tool_collection = ToolCollection::from_tools(tools);
71
72        info!(
73            tool_count = self.tool_collection.len(),
74            "Loaded tools from OpenAPI spec"
75        );
76
77        Ok(())
78    }
79
80    /// Get the number of loaded tools
81    #[must_use]
82    pub fn tool_count(&self) -> usize {
83        self.tool_collection.len()
84    }
85
86    /// Get all tool names
87    #[must_use]
88    pub fn get_tool_names(&self) -> Vec<String> {
89        self.tool_collection.get_tool_names()
90    }
91
92    /// Check if a specific tool exists
93    #[must_use]
94    pub fn has_tool(&self, name: &str) -> bool {
95        self.tool_collection.has_tool(name)
96    }
97
98    /// Get a tool by name
99    #[must_use]
100    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
101        self.tool_collection.get_tool(name)
102    }
103
104    /// Get tool metadata by name
105    #[must_use]
106    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
107        self.get_tool(name).map(|tool| &tool.metadata)
108    }
109
110    /// Get basic tool statistics
111    #[must_use]
112    pub fn get_tool_stats(&self) -> String {
113        self.tool_collection.get_stats()
114    }
115
116    /// Simple validation - check that tools are loaded
117    ///
118    /// # Errors
119    ///
120    /// Returns an error if no tools are loaded
121    pub fn validate_registry(&self) -> Result<(), Error> {
122        if self.tool_collection.is_empty() {
123            return Err(Error::McpError("No tools loaded".to_string()));
124        }
125        Ok(())
126    }
127}
128
129impl ServerHandler for Server {
130    fn get_info(&self) -> InitializeResult {
131        InitializeResult {
132            protocol_version: ProtocolVersion::V_2024_11_05,
133            server_info: Implementation {
134                name: "OpenAPI MCP Server".to_string(),
135                version: "0.1.0".to_string(),
136            },
137            capabilities: ServerCapabilities {
138                tools: Some(ToolsCapability {
139                    list_changed: Some(false),
140                }),
141                ..Default::default()
142            },
143            instructions: Some("Exposes OpenAPI endpoints as MCP tools".to_string()),
144        }
145    }
146
147    async fn list_tools(
148        &self,
149        _request: Option<PaginatedRequestParam>,
150        _context: RequestContext<RoleServer>,
151    ) -> Result<ListToolsResult, ErrorData> {
152        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
153        let _enter = span.enter();
154
155        debug!("Processing MCP list_tools request");
156
157        // Delegate to tool collection for MCP tool conversion
158        let tools = self.tool_collection.to_mcp_tools();
159
160        info!(
161            returned_tools = tools.len(),
162            "MCP list_tools request completed successfully"
163        );
164
165        Ok(ListToolsResult {
166            tools,
167            next_cursor: None,
168        })
169    }
170
171    async fn call_tool(
172        &self,
173        request: CallToolRequestParam,
174        _context: RequestContext<RoleServer>,
175    ) -> Result<CallToolResult, ErrorData> {
176        let span = info_span!(
177            "call_tool",
178            tool_name = %request.name
179        );
180        let _enter = span.enter();
181
182        debug!(
183            tool_name = %request.name,
184            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
185            "Processing MCP call_tool request"
186        );
187
188        let arguments = request.arguments.unwrap_or_default();
189        let arguments_value = Value::Object(arguments);
190
191        // Delegate all tool validation and execution to the tool collection
192        match self
193            .tool_collection
194            .call_tool(&request.name, &arguments_value)
195            .await
196        {
197            Ok(result) => {
198                info!(
199                    tool_name = %request.name,
200                    success = true,
201                    "MCP call_tool request completed successfully"
202                );
203                Ok(result)
204            }
205            Err(e) => {
206                warn!(
207                    tool_name = %request.name,
208                    success = false,
209                    error = %e,
210                    "MCP call_tool request failed"
211                );
212                // Convert ToolCallError to ErrorData and return as error
213                Err(e.into())
214            }
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::error::ToolCallValidationError;
223    use crate::{ToolCallError, ToolMetadata};
224    use serde_json::json;
225
226    #[test]
227    fn test_tool_not_found_error_with_suggestions() {
228        // Create test tool metadata
229        let tool1_metadata = ToolMetadata {
230            name: "getPetById".to_string(),
231            title: Some("Get Pet by ID".to_string()),
232            description: "Find pet by ID".to_string(),
233            parameters: json!({
234                "type": "object",
235                "properties": {
236                    "petId": {
237                        "type": "integer"
238                    }
239                },
240                "required": ["petId"]
241            }),
242            output_schema: None,
243            method: "GET".to_string(),
244            path: "/pet/{petId}".to_string(),
245        };
246
247        let tool2_metadata = ToolMetadata {
248            name: "getPetsByStatus".to_string(),
249            title: Some("Find Pets by Status".to_string()),
250            description: "Find pets by status".to_string(),
251            parameters: json!({
252                "type": "object",
253                "properties": {
254                    "status": {
255                        "type": "array",
256                        "items": {
257                            "type": "string"
258                        }
259                    }
260                },
261                "required": ["status"]
262            }),
263            output_schema: None,
264            method: "GET".to_string(),
265            path: "/pet/findByStatus".to_string(),
266        };
267
268        // Create OpenApiTool instances
269        let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
270        let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
271
272        // Create server with tools
273        let mut server = Server::new(
274            serde_json::Value::Null,
275            url::Url::parse("http://example.com").unwrap(),
276            None,
277            None,
278            None,
279        );
280        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
281
282        // Test: Create ToolNotFound error with a typo
283        let tool_names = server.get_tool_names();
284        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
285
286        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
287            "getPetByID".to_string(),
288            &tool_name_refs,
289        ));
290        let error_data: ErrorData = error.into();
291        let error_json = serde_json::to_value(&error_data).unwrap();
292
293        // Snapshot the error to verify suggestions
294        insta::assert_json_snapshot!(error_json);
295    }
296
297    #[test]
298    fn test_tool_not_found_error_no_suggestions() {
299        // Create test tool metadata
300        let tool_metadata = ToolMetadata {
301            name: "getPetById".to_string(),
302            title: Some("Get Pet by ID".to_string()),
303            description: "Find pet by ID".to_string(),
304            parameters: json!({
305                "type": "object",
306                "properties": {
307                    "petId": {
308                        "type": "integer"
309                    }
310                },
311                "required": ["petId"]
312            }),
313            output_schema: None,
314            method: "GET".to_string(),
315            path: "/pet/{petId}".to_string(),
316        };
317
318        // Create OpenApiTool instance
319        let tool = Tool::new(tool_metadata, None, None).unwrap();
320
321        // Create server with tool
322        let mut server = Server::new(
323            serde_json::Value::Null,
324            url::Url::parse("http://example.com").unwrap(),
325            None,
326            None,
327            None,
328        );
329        server.tool_collection = ToolCollection::from_tools(vec![tool]);
330
331        // Test: Create ToolNotFound error with unrelated name
332        let tool_names = server.get_tool_names();
333        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
334
335        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
336            "completelyUnrelatedToolName".to_string(),
337            &tool_name_refs,
338        ));
339        let error_data: ErrorData = error.into();
340        let error_json = serde_json::to_value(&error_data).unwrap();
341
342        // Snapshot the error to verify no suggestions
343        insta::assert_json_snapshot!(error_json);
344    }
345
346    #[test]
347    fn test_validation_error_converted_to_error_data() {
348        // Test that validation errors are properly converted to ErrorData
349        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
350            violations: vec![crate::error::ValidationError::invalid_parameter(
351                "page".to_string(),
352                &["page_number".to_string(), "page_size".to_string()],
353            )],
354        });
355
356        let error_data: ErrorData = error.into();
357        let error_json = serde_json::to_value(&error_data).unwrap();
358
359        // Verify the basic structure
360        assert_eq!(error_json["code"], -32602); // Invalid params error code
361
362        // Snapshot the full error to verify the new error message format
363        insta::assert_json_snapshot!(error_json);
364    }
365}