rmcp_openapi/
server.rs

1use rmcp::{
2    RoleServer, ServerHandler,
3    model::{
4        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
5        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities, Tool,
6        ToolsCapability,
7    },
8    service::RequestContext,
9};
10use serde_json::Value;
11
12use reqwest::header::HeaderMap;
13use url::Url;
14
15use crate::error::{OpenApiError, ToolCallValidationError};
16use crate::openapi::OpenApiSpecLocation;
17use crate::tool::OpenApiTool;
18
19#[derive(Clone)]
20pub struct OpenApiServer {
21    pub spec_location: OpenApiSpecLocation,
22    pub tools: Vec<OpenApiTool>,
23    pub base_url: Option<Url>,
24    pub default_headers: Option<HeaderMap>,
25    pub tag_filter: Option<Vec<String>>,
26    pub method_filter: Option<Vec<reqwest::Method>>,
27}
28
29impl OpenApiServer {
30    #[must_use]
31    pub fn new(spec_location: OpenApiSpecLocation) -> Self {
32        Self {
33            spec_location,
34            tools: Vec::new(),
35            base_url: None,
36            default_headers: None,
37            tag_filter: None,
38            method_filter: None,
39        }
40    }
41
42    /// Create a new server with a base URL for API calls
43    ///
44    /// # Errors
45    ///
46    /// Returns an error if the base URL is invalid
47    pub fn with_base_url(
48        spec_location: OpenApiSpecLocation,
49        base_url: Url,
50    ) -> Result<Self, OpenApiError> {
51        Ok(Self {
52            spec_location,
53            tools: Vec::new(),
54            base_url: Some(base_url),
55            default_headers: None,
56            tag_filter: None,
57            method_filter: None,
58        })
59    }
60
61    /// Create a new server with both base URL and default headers
62    ///
63    /// # Errors
64    ///
65    /// Returns an error if the base URL is invalid
66    pub fn with_base_url_and_headers(
67        spec_location: OpenApiSpecLocation,
68        base_url: Url,
69        default_headers: HeaderMap,
70    ) -> Result<Self, OpenApiError> {
71        Ok(Self {
72            spec_location,
73            tools: Vec::new(),
74            base_url: Some(base_url),
75            default_headers: Some(default_headers),
76            tag_filter: None,
77            method_filter: None,
78        })
79    }
80
81    /// Create a new server with default headers but no base URL
82    #[must_use]
83    pub fn with_default_headers(
84        spec_location: OpenApiSpecLocation,
85        default_headers: HeaderMap,
86    ) -> Self {
87        Self {
88            spec_location,
89            tools: Vec::new(),
90            base_url: None,
91            default_headers: Some(default_headers),
92            tag_filter: None,
93            method_filter: None,
94        }
95    }
96
97    /// Load the `OpenAPI` specification and convert to OpenApiTool instances
98    ///
99    /// # Errors
100    ///
101    /// Returns an error if the spec cannot be loaded or tools cannot be generated
102    pub async fn load_openapi_spec(&mut self) -> Result<(), OpenApiError> {
103        // Load the OpenAPI specification
104        let spec = self.spec_location.load_spec().await?;
105
106        // Generate OpenApiTool instances directly
107        let tools = spec.to_openapi_tools(
108            self.tag_filter.as_deref(),
109            self.method_filter.as_deref(),
110            self.base_url.clone(),
111            self.default_headers.clone(),
112        )?;
113
114        self.tools = tools;
115
116        println!("Loaded {} tools from OpenAPI spec", self.tools.len());
117
118        Ok(())
119    }
120
121    /// Get the number of loaded tools
122    #[must_use]
123    pub fn tool_count(&self) -> usize {
124        self.tools.len()
125    }
126
127    /// Get all tool names
128    #[must_use]
129    pub fn get_tool_names(&self) -> Vec<String> {
130        self.tools
131            .iter()
132            .map(|tool| tool.metadata.name.clone())
133            .collect()
134    }
135
136    /// Check if a specific tool exists
137    #[must_use]
138    pub fn has_tool(&self, name: &str) -> bool {
139        self.tools.iter().any(|tool| tool.metadata.name == name)
140    }
141
142    /// Get a tool by name
143    #[must_use]
144    pub fn get_tool(&self, name: &str) -> Option<&crate::tool::OpenApiTool> {
145        self.tools.iter().find(|tool| tool.metadata.name == name)
146    }
147
148    /// Get tool metadata by name
149    #[must_use]
150    pub fn get_tool_metadata(&self, name: &str) -> Option<&crate::ToolMetadata> {
151        self.get_tool(name).map(|tool| &tool.metadata)
152    }
153
154    /// Get basic tool statistics
155    #[must_use]
156    pub fn get_tool_stats(&self) -> String {
157        format!("Total tools: {}", self.tools.len())
158    }
159
160    /// Set tag filter for this server instance
161    #[must_use]
162    pub fn with_tags(mut self, tags: Option<Vec<String>>) -> Self {
163        self.tag_filter = tags;
164        self
165    }
166
167    /// Set method filter for this server instance
168    #[must_use]
169    pub fn with_methods(mut self, methods: Option<Vec<reqwest::Method>>) -> Self {
170        self.method_filter = methods;
171        self
172    }
173
174    /// Simple validation - check that tools are loaded
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if no tools are loaded
179    pub fn validate_registry(&self) -> Result<(), OpenApiError> {
180        if self.tools.is_empty() {
181            return Err(OpenApiError::McpError("No tools loaded".to_string()));
182        }
183        Ok(())
184    }
185}
186
187impl ServerHandler for OpenApiServer {
188    fn get_info(&self) -> InitializeResult {
189        InitializeResult {
190            protocol_version: ProtocolVersion::V_2024_11_05,
191            server_info: Implementation {
192                name: "OpenAPI MCP Server".to_string(),
193                version: "0.1.0".to_string(),
194            },
195            capabilities: ServerCapabilities {
196                tools: Some(ToolsCapability {
197                    list_changed: Some(false),
198                }),
199                ..Default::default()
200            },
201            instructions: Some("Exposes OpenAPI endpoints as MCP tools".to_string()),
202        }
203    }
204
205    async fn list_tools(
206        &self,
207        _request: Option<PaginatedRequestParam>,
208        _context: RequestContext<RoleServer>,
209    ) -> Result<ListToolsResult, ErrorData> {
210        let mut tools = Vec::new();
211
212        // Convert all OpenApiTool instances to MCP Tool format
213        for openapi_tool in &self.tools {
214            let tool = Tool::from(openapi_tool);
215            tools.push(tool);
216        }
217
218        Ok(ListToolsResult {
219            tools,
220            next_cursor: None,
221        })
222    }
223
224    async fn call_tool(
225        &self,
226        request: CallToolRequestParam,
227        _context: RequestContext<RoleServer>,
228    ) -> Result<CallToolResult, ErrorData> {
229        // Find the tool by name
230        if let Some(openapi_tool) = self
231            .tools
232            .iter()
233            .find(|tool| tool.metadata.name == request.name)
234        {
235            let arguments = request.arguments.unwrap_or_default();
236            let arguments_value = Value::Object(arguments.clone());
237
238            // Call the tool directly
239            match openapi_tool.call(&arguments_value).await {
240                Ok(result) => Ok(result),
241                Err(e) => {
242                    // Convert ToolCallError to ErrorData and return as error
243                    Err(e.into())
244                }
245            }
246        } else {
247            // Generate tool name suggestions when tool not found
248            let tool_names: Vec<&str> = self
249                .tools
250                .iter()
251                .map(|tool| tool.metadata.name.as_str())
252                .collect();
253            let suggestions = crate::find_similar_strings(&request.name, &tool_names);
254
255            // Create ToolCallValidationError with suggestions
256            let error = ToolCallValidationError::ToolNotFound {
257                tool_name: request.name.to_string(),
258                suggestions,
259            };
260            Err(error.into())
261        }
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::ToolMetadata;
269    use crate::error::ToolCallError;
270    use serde_json::json;
271
272    #[test]
273    fn test_tool_not_found_error_with_suggestions() {
274        // Create test tool metadata
275        let tool1_metadata = ToolMetadata {
276            name: "getPetById".to_string(),
277            title: Some("Get Pet by ID".to_string()),
278            description: "Find pet by ID".to_string(),
279            parameters: json!({
280                "type": "object",
281                "properties": {
282                    "petId": {
283                        "type": "integer"
284                    }
285                },
286                "required": ["petId"]
287            }),
288            output_schema: None,
289            method: "GET".to_string(),
290            path: "/pet/{petId}".to_string(),
291        };
292
293        let tool2_metadata = ToolMetadata {
294            name: "getPetsByStatus".to_string(),
295            title: Some("Find Pets by Status".to_string()),
296            description: "Find pets by status".to_string(),
297            parameters: json!({
298                "type": "object",
299                "properties": {
300                    "status": {
301                        "type": "array",
302                        "items": {
303                            "type": "string"
304                        }
305                    }
306                },
307                "required": ["status"]
308            }),
309            output_schema: None,
310            method: "GET".to_string(),
311            path: "/pet/findByStatus".to_string(),
312        };
313
314        // Create OpenApiTool instances
315        let tool1 = crate::tool::OpenApiTool::new(tool1_metadata, None, None).unwrap();
316        let tool2 = crate::tool::OpenApiTool::new(tool2_metadata, None, None).unwrap();
317
318        // Create server with tools
319        let mut server = OpenApiServer::new(OpenApiSpecLocation::Url(
320            Url::parse("test://example").unwrap(),
321        ));
322        server.tools = vec![tool1, tool2];
323
324        // Test: Create ToolNotFound error with a typo
325        let tool_names = server.get_tool_names();
326        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
327        let suggestions = crate::find_similar_strings("getPetByID", &tool_name_refs);
328
329        let error = ToolCallError::Validation(ToolCallValidationError::ToolNotFound {
330            tool_name: "getPetByID".to_string(),
331            suggestions,
332        });
333        let error_data: ErrorData = error.into();
334        let error_json = serde_json::to_value(&error_data).unwrap();
335
336        // Snapshot the error to verify suggestions
337        insta::assert_json_snapshot!(error_json);
338    }
339
340    #[test]
341    fn test_tool_not_found_error_no_suggestions() {
342        // Create test tool metadata
343        let tool_metadata = ToolMetadata {
344            name: "getPetById".to_string(),
345            title: Some("Get Pet by ID".to_string()),
346            description: "Find pet by ID".to_string(),
347            parameters: json!({
348                "type": "object",
349                "properties": {
350                    "petId": {
351                        "type": "integer"
352                    }
353                },
354                "required": ["petId"]
355            }),
356            output_schema: None,
357            method: "GET".to_string(),
358            path: "/pet/{petId}".to_string(),
359        };
360
361        // Create OpenApiTool instance
362        let tool = crate::tool::OpenApiTool::new(tool_metadata, None, None).unwrap();
363
364        // Create server with tool
365        let mut server = OpenApiServer::new(OpenApiSpecLocation::Url(
366            Url::parse("test://example").unwrap(),
367        ));
368        server.tools = vec![tool];
369
370        // Test: Create ToolNotFound error with unrelated name
371        let tool_names = server.get_tool_names();
372        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
373        let suggestions =
374            crate::find_similar_strings("completelyUnrelatedToolName", &tool_name_refs);
375
376        let error = ToolCallError::Validation(ToolCallValidationError::ToolNotFound {
377            tool_name: "completelyUnrelatedToolName".to_string(),
378            suggestions,
379        });
380        let error_data: ErrorData = error.into();
381        let error_json = serde_json::to_value(&error_data).unwrap();
382
383        // Snapshot the error to verify no suggestions
384        insta::assert_json_snapshot!(error_json);
385    }
386
387    #[test]
388    fn test_validation_error_converted_to_error_data() {
389        // Test that validation errors are properly converted to ErrorData
390        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
391            violations: vec![crate::error::ValidationError::InvalidParameter {
392                parameter: "page".to_string(),
393                suggestions: vec!["page_number".to_string()],
394                valid_parameters: vec!["page_number".to_string(), "page_size".to_string()],
395            }],
396        });
397
398        let error_data: ErrorData = error.into();
399        let error_json = serde_json::to_value(&error_data).unwrap();
400
401        // Verify the basic structure
402        assert_eq!(error_json["code"], -32602); // Invalid params error code
403
404        // Snapshot the full error to verify the new error message format
405        insta::assert_json_snapshot!(error_json);
406    }
407}