rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    RoleServer, ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::RequestContext,
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::error::Error;
18use crate::tool::{Tool, ToolCollection, ToolMetadata};
19use tracing::{debug, info, info_span, warn};
20
21#[derive(Clone, Builder)]
22pub struct Server {
23    pub openapi_spec: serde_json::Value,
24    #[builder(default)]
25    pub tool_collection: ToolCollection,
26    pub base_url: Url,
27    pub default_headers: Option<HeaderMap>,
28    pub tag_filter: Option<Vec<String>>,
29    pub method_filter: Option<Vec<reqwest::Method>>,
30}
31
32impl Server {
33    /// Create a new Server instance with required parameters
34    pub fn new(
35        openapi_spec: serde_json::Value,
36        base_url: Url,
37        default_headers: Option<HeaderMap>,
38        tag_filter: Option<Vec<String>>,
39        method_filter: Option<Vec<reqwest::Method>>,
40    ) -> Self {
41        Self {
42            openapi_spec,
43            tool_collection: ToolCollection::new(),
44            base_url,
45            default_headers,
46            tag_filter,
47            method_filter,
48        }
49    }
50
51    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
52    ///
53    /// # Errors
54    ///
55    /// Returns an error if the spec cannot be parsed or tools cannot be generated
56    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
57        let span = info_span!("tool_registration");
58        let _enter = span.enter();
59
60        // Parse the OpenAPI specification
61        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
62
63        // Generate OpenApiTool instances directly
64        let tools = spec.to_openapi_tools(
65            self.tag_filter.as_deref(),
66            self.method_filter.as_deref(),
67            Some(self.base_url.clone()),
68            self.default_headers.clone(),
69        )?;
70
71        self.tool_collection = ToolCollection::from_tools(tools);
72
73        info!(
74            tool_count = self.tool_collection.len(),
75            "Loaded tools from OpenAPI spec"
76        );
77
78        Ok(())
79    }
80
81    /// Get the number of loaded tools
82    #[must_use]
83    pub fn tool_count(&self) -> usize {
84        self.tool_collection.len()
85    }
86
87    /// Get all tool names
88    #[must_use]
89    pub fn get_tool_names(&self) -> Vec<String> {
90        self.tool_collection.get_tool_names()
91    }
92
93    /// Check if a specific tool exists
94    #[must_use]
95    pub fn has_tool(&self, name: &str) -> bool {
96        self.tool_collection.has_tool(name)
97    }
98
99    /// Get a tool by name
100    #[must_use]
101    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
102        self.tool_collection.get_tool(name)
103    }
104
105    /// Get tool metadata by name
106    #[must_use]
107    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
108        self.get_tool(name).map(|tool| &tool.metadata)
109    }
110
111    /// Get basic tool statistics
112    #[must_use]
113    pub fn get_tool_stats(&self) -> String {
114        self.tool_collection.get_stats()
115    }
116
117    /// Simple validation - check that tools are loaded
118    ///
119    /// # Errors
120    ///
121    /// Returns an error if no tools are loaded
122    pub fn validate_registry(&self) -> Result<(), Error> {
123        if self.tool_collection.is_empty() {
124            return Err(Error::McpError("No tools loaded".to_string()));
125        }
126        Ok(())
127    }
128}
129
130impl ServerHandler for Server {
131    fn get_info(&self) -> InitializeResult {
132        InitializeResult {
133            protocol_version: ProtocolVersion::V_2024_11_05,
134            server_info: Implementation {
135                name: "OpenAPI MCP Server".to_string(),
136                version: "0.1.0".to_string(),
137            },
138            capabilities: ServerCapabilities {
139                tools: Some(ToolsCapability {
140                    list_changed: Some(false),
141                }),
142                ..Default::default()
143            },
144            instructions: Some("Exposes OpenAPI endpoints as MCP tools".to_string()),
145        }
146    }
147
148    async fn list_tools(
149        &self,
150        _request: Option<PaginatedRequestParam>,
151        _context: RequestContext<RoleServer>,
152    ) -> Result<ListToolsResult, ErrorData> {
153        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
154        let _enter = span.enter();
155
156        debug!("Processing MCP list_tools request");
157
158        // Delegate to tool collection for MCP tool conversion
159        let tools = self.tool_collection.to_mcp_tools();
160
161        info!(
162            returned_tools = tools.len(),
163            "MCP list_tools request completed successfully"
164        );
165
166        Ok(ListToolsResult {
167            tools,
168            next_cursor: None,
169        })
170    }
171
172    async fn call_tool(
173        &self,
174        request: CallToolRequestParam,
175        context: RequestContext<RoleServer>,
176    ) -> Result<CallToolResult, ErrorData> {
177        let span = info_span!(
178            "call_tool",
179            tool_name = %request.name
180        );
181        let _enter = span.enter();
182
183        debug!(
184            tool_name = %request.name,
185            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
186            "Processing MCP call_tool request"
187        );
188
189        let arguments = request.arguments.unwrap_or_default();
190        let arguments_value = Value::Object(arguments);
191
192        // Extract authorization header from context extensions
193        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
194
195        if auth_header.is_some() {
196            debug!("Authorization header is present");
197        }
198
199        // Delegate all tool validation and execution to the tool collection
200        match self
201            .tool_collection
202            .call_tool(&request.name, &arguments_value, auth_header)
203            .await
204        {
205            Ok(result) => {
206                info!(
207                    tool_name = %request.name,
208                    success = true,
209                    "MCP call_tool request completed successfully"
210                );
211                Ok(result)
212            }
213            Err(e) => {
214                warn!(
215                    tool_name = %request.name,
216                    success = false,
217                    error = %e,
218                    "MCP call_tool request failed"
219                );
220                // Convert ToolCallError to ErrorData and return as error
221                Err(e.into())
222            }
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::error::ToolCallValidationError;
231    use crate::{ToolCallError, ToolMetadata};
232    use serde_json::json;
233
234    #[test]
235    fn test_tool_not_found_error_with_suggestions() {
236        // Create test tool metadata
237        let tool1_metadata = ToolMetadata {
238            name: "getPetById".to_string(),
239            title: Some("Get Pet by ID".to_string()),
240            description: "Find pet by ID".to_string(),
241            parameters: json!({
242                "type": "object",
243                "properties": {
244                    "petId": {
245                        "type": "integer"
246                    }
247                },
248                "required": ["petId"]
249            }),
250            output_schema: None,
251            method: "GET".to_string(),
252            path: "/pet/{petId}".to_string(),
253        };
254
255        let tool2_metadata = ToolMetadata {
256            name: "getPetsByStatus".to_string(),
257            title: Some("Find Pets by Status".to_string()),
258            description: "Find pets by status".to_string(),
259            parameters: json!({
260                "type": "object",
261                "properties": {
262                    "status": {
263                        "type": "array",
264                        "items": {
265                            "type": "string"
266                        }
267                    }
268                },
269                "required": ["status"]
270            }),
271            output_schema: None,
272            method: "GET".to_string(),
273            path: "/pet/findByStatus".to_string(),
274        };
275
276        // Create OpenApiTool instances
277        let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
278        let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
279
280        // Create server with tools
281        let mut server = Server::new(
282            serde_json::Value::Null,
283            url::Url::parse("http://example.com").unwrap(),
284            None,
285            None,
286            None,
287        );
288        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
289
290        // Test: Create ToolNotFound error with a typo
291        let tool_names = server.get_tool_names();
292        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
293
294        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
295            "getPetByID".to_string(),
296            &tool_name_refs,
297        ));
298        let error_data: ErrorData = error.into();
299        let error_json = serde_json::to_value(&error_data).unwrap();
300
301        // Snapshot the error to verify suggestions
302        insta::assert_json_snapshot!(error_json);
303    }
304
305    #[test]
306    fn test_tool_not_found_error_no_suggestions() {
307        // Create test tool metadata
308        let tool_metadata = ToolMetadata {
309            name: "getPetById".to_string(),
310            title: Some("Get Pet by ID".to_string()),
311            description: "Find pet by ID".to_string(),
312            parameters: json!({
313                "type": "object",
314                "properties": {
315                    "petId": {
316                        "type": "integer"
317                    }
318                },
319                "required": ["petId"]
320            }),
321            output_schema: None,
322            method: "GET".to_string(),
323            path: "/pet/{petId}".to_string(),
324        };
325
326        // Create OpenApiTool instance
327        let tool = Tool::new(tool_metadata, None, None).unwrap();
328
329        // Create server with tool
330        let mut server = Server::new(
331            serde_json::Value::Null,
332            url::Url::parse("http://example.com").unwrap(),
333            None,
334            None,
335            None,
336        );
337        server.tool_collection = ToolCollection::from_tools(vec![tool]);
338
339        // Test: Create ToolNotFound error with unrelated name
340        let tool_names = server.get_tool_names();
341        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
342
343        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
344            "completelyUnrelatedToolName".to_string(),
345            &tool_name_refs,
346        ));
347        let error_data: ErrorData = error.into();
348        let error_json = serde_json::to_value(&error_data).unwrap();
349
350        // Snapshot the error to verify no suggestions
351        insta::assert_json_snapshot!(error_json);
352    }
353
354    #[test]
355    fn test_validation_error_converted_to_error_data() {
356        // Test that validation errors are properly converted to ErrorData
357        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
358            violations: vec![crate::error::ValidationError::invalid_parameter(
359                "page".to_string(),
360                &["page_number".to_string(), "page_size".to_string()],
361            )],
362        });
363
364        let error_data: ErrorData = error.into();
365        let error_json = serde_json::to_value(&error_data).unwrap();
366
367        // Verify the basic structure
368        assert_eq!(error_json["code"], -32602); // Invalid params error code
369
370        // Snapshot the full error to verify the new error message format
371        insta::assert_json_snapshot!(error_json);
372    }
373}