rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    RoleServer, ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::RequestContext,
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::config::{Authorization, AuthorizationMode};
18use crate::error::Error;
19use crate::tool::{Tool, ToolCollection, ToolMetadata};
20use tracing::{debug, info, info_span, warn};
21
22#[derive(Clone, Builder)]
23pub struct Server {
24    pub openapi_spec: serde_json::Value,
25    #[builder(default)]
26    pub tool_collection: ToolCollection,
27    pub base_url: Url,
28    pub default_headers: Option<HeaderMap>,
29    pub tag_filter: Option<Vec<String>>,
30    pub method_filter: Option<Vec<reqwest::Method>>,
31    #[builder(default)]
32    pub authorization_mode: AuthorizationMode,
33}
34
35impl Server {
36    /// Create a new Server instance with required parameters
37    pub fn new(
38        openapi_spec: serde_json::Value,
39        base_url: Url,
40        default_headers: Option<HeaderMap>,
41        tag_filter: Option<Vec<String>>,
42        method_filter: Option<Vec<reqwest::Method>>,
43    ) -> Self {
44        Self {
45            openapi_spec,
46            tool_collection: ToolCollection::new(),
47            base_url,
48            default_headers,
49            tag_filter,
50            method_filter,
51            authorization_mode: AuthorizationMode::default(),
52        }
53    }
54
55    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if the spec cannot be parsed or tools cannot be generated
60    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
61        let span = info_span!("tool_registration");
62        let _enter = span.enter();
63
64        // Parse the OpenAPI specification
65        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
66
67        // Generate OpenApiTool instances directly
68        let tools = spec.to_openapi_tools(
69            self.tag_filter.as_deref(),
70            self.method_filter.as_deref(),
71            Some(self.base_url.clone()),
72            self.default_headers.clone(),
73        )?;
74
75        self.tool_collection = ToolCollection::from_tools(tools);
76
77        info!(
78            tool_count = self.tool_collection.len(),
79            "Loaded tools from OpenAPI spec"
80        );
81
82        Ok(())
83    }
84
85    /// Get the number of loaded tools
86    #[must_use]
87    pub fn tool_count(&self) -> usize {
88        self.tool_collection.len()
89    }
90
91    /// Get all tool names
92    #[must_use]
93    pub fn get_tool_names(&self) -> Vec<String> {
94        self.tool_collection.get_tool_names()
95    }
96
97    /// Check if a specific tool exists
98    #[must_use]
99    pub fn has_tool(&self, name: &str) -> bool {
100        self.tool_collection.has_tool(name)
101    }
102
103    /// Get a tool by name
104    #[must_use]
105    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
106        self.tool_collection.get_tool(name)
107    }
108
109    /// Get tool metadata by name
110    #[must_use]
111    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
112        self.get_tool(name).map(|tool| &tool.metadata)
113    }
114
115    /// Set the authorization mode for the server
116    pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
117        self.authorization_mode = mode;
118    }
119
120    /// Get the current authorization mode
121    pub fn authorization_mode(&self) -> AuthorizationMode {
122        self.authorization_mode
123    }
124
125    /// Get basic tool statistics
126    #[must_use]
127    pub fn get_tool_stats(&self) -> String {
128        self.tool_collection.get_stats()
129    }
130
131    /// Simple validation - check that tools are loaded
132    ///
133    /// # Errors
134    ///
135    /// Returns an error if no tools are loaded
136    pub fn validate_registry(&self) -> Result<(), Error> {
137        if self.tool_collection.is_empty() {
138            return Err(Error::McpError("No tools loaded".to_string()));
139        }
140        Ok(())
141    }
142}
143
144impl ServerHandler for Server {
145    fn get_info(&self) -> InitializeResult {
146        InitializeResult {
147            protocol_version: ProtocolVersion::V_2024_11_05,
148            server_info: Implementation {
149                name: "OpenAPI MCP Server".to_string(),
150                version: "0.1.0".to_string(),
151            },
152            capabilities: ServerCapabilities {
153                tools: Some(ToolsCapability {
154                    list_changed: Some(false),
155                }),
156                ..Default::default()
157            },
158            instructions: Some("Exposes OpenAPI endpoints as MCP tools".to_string()),
159        }
160    }
161
162    async fn list_tools(
163        &self,
164        _request: Option<PaginatedRequestParam>,
165        _context: RequestContext<RoleServer>,
166    ) -> Result<ListToolsResult, ErrorData> {
167        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
168        let _enter = span.enter();
169
170        debug!("Processing MCP list_tools request");
171
172        // Delegate to tool collection for MCP tool conversion
173        let tools = self.tool_collection.to_mcp_tools();
174
175        info!(
176            returned_tools = tools.len(),
177            "MCP list_tools request completed successfully"
178        );
179
180        Ok(ListToolsResult {
181            tools,
182            next_cursor: None,
183        })
184    }
185
186    async fn call_tool(
187        &self,
188        request: CallToolRequestParam,
189        context: RequestContext<RoleServer>,
190    ) -> Result<CallToolResult, ErrorData> {
191        let span = info_span!(
192            "call_tool",
193            tool_name = %request.name
194        );
195        let _enter = span.enter();
196
197        debug!(
198            tool_name = %request.name,
199            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
200            "Processing MCP call_tool request"
201        );
202
203        let arguments = request.arguments.unwrap_or_default();
204        let arguments_value = Value::Object(arguments);
205
206        // Extract authorization header from context extensions
207        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
208
209        if auth_header.is_some() {
210            debug!("Authorization header is present");
211        }
212
213        // Create Authorization enum from mode and header
214        let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
215
216        // Delegate all tool validation and execution to the tool collection
217        match self
218            .tool_collection
219            .call_tool(&request.name, &arguments_value, authorization)
220            .await
221        {
222            Ok(result) => {
223                info!(
224                    tool_name = %request.name,
225                    success = true,
226                    "MCP call_tool request completed successfully"
227                );
228                Ok(result)
229            }
230            Err(e) => {
231                warn!(
232                    tool_name = %request.name,
233                    success = false,
234                    error = %e,
235                    "MCP call_tool request failed"
236                );
237                // Convert ToolCallError to ErrorData and return as error
238                Err(e.into())
239            }
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::error::ToolCallValidationError;
248    use crate::{ToolCallError, ToolMetadata};
249    use serde_json::json;
250
251    #[test]
252    fn test_tool_not_found_error_with_suggestions() {
253        // Create test tool metadata
254        let tool1_metadata = ToolMetadata {
255            name: "getPetById".to_string(),
256            title: Some("Get Pet by ID".to_string()),
257            description: "Find pet by ID".to_string(),
258            parameters: json!({
259                "type": "object",
260                "properties": {
261                    "petId": {
262                        "type": "integer"
263                    }
264                },
265                "required": ["petId"]
266            }),
267            output_schema: None,
268            method: "GET".to_string(),
269            path: "/pet/{petId}".to_string(),
270            security: None,
271        };
272
273        let tool2_metadata = ToolMetadata {
274            name: "getPetsByStatus".to_string(),
275            title: Some("Find Pets by Status".to_string()),
276            description: "Find pets by status".to_string(),
277            parameters: json!({
278                "type": "object",
279                "properties": {
280                    "status": {
281                        "type": "array",
282                        "items": {
283                            "type": "string"
284                        }
285                    }
286                },
287                "required": ["status"]
288            }),
289            output_schema: None,
290            method: "GET".to_string(),
291            path: "/pet/findByStatus".to_string(),
292            security: None,
293        };
294
295        // Create OpenApiTool instances
296        let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
297        let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
298
299        // Create server with tools
300        let mut server = Server::new(
301            serde_json::Value::Null,
302            url::Url::parse("http://example.com").unwrap(),
303            None,
304            None,
305            None,
306        );
307        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
308
309        // Test: Create ToolNotFound error with a typo
310        let tool_names = server.get_tool_names();
311        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
312
313        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
314            "getPetByID".to_string(),
315            &tool_name_refs,
316        ));
317        let error_data: ErrorData = error.into();
318        let error_json = serde_json::to_value(&error_data).unwrap();
319
320        // Snapshot the error to verify suggestions
321        insta::assert_json_snapshot!(error_json);
322    }
323
324    #[test]
325    fn test_tool_not_found_error_no_suggestions() {
326        // Create test tool metadata
327        let tool_metadata = ToolMetadata {
328            name: "getPetById".to_string(),
329            title: Some("Get Pet by ID".to_string()),
330            description: "Find pet by ID".to_string(),
331            parameters: json!({
332                "type": "object",
333                "properties": {
334                    "petId": {
335                        "type": "integer"
336                    }
337                },
338                "required": ["petId"]
339            }),
340            output_schema: None,
341            method: "GET".to_string(),
342            path: "/pet/{petId}".to_string(),
343            security: None,
344        };
345
346        // Create OpenApiTool instance
347        let tool = Tool::new(tool_metadata, None, None).unwrap();
348
349        // Create server with tool
350        let mut server = Server::new(
351            serde_json::Value::Null,
352            url::Url::parse("http://example.com").unwrap(),
353            None,
354            None,
355            None,
356        );
357        server.tool_collection = ToolCollection::from_tools(vec![tool]);
358
359        // Test: Create ToolNotFound error with unrelated name
360        let tool_names = server.get_tool_names();
361        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
362
363        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
364            "completelyUnrelatedToolName".to_string(),
365            &tool_name_refs,
366        ));
367        let error_data: ErrorData = error.into();
368        let error_json = serde_json::to_value(&error_data).unwrap();
369
370        // Snapshot the error to verify no suggestions
371        insta::assert_json_snapshot!(error_json);
372    }
373
374    #[test]
375    fn test_validation_error_converted_to_error_data() {
376        // Test that validation errors are properly converted to ErrorData
377        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
378            violations: vec![crate::error::ValidationError::invalid_parameter(
379                "page".to_string(),
380                &["page_number".to_string(), "page_size".to_string()],
381            )],
382        });
383
384        let error_data: ErrorData = error.into();
385        let error_json = serde_json::to_value(&error_data).unwrap();
386
387        // Verify the basic structure
388        assert_eq!(error_json["code"], -32602); // Invalid params error code
389
390        // Snapshot the full error to verify the new error message format
391        insta::assert_json_snapshot!(error_json);
392    }
393}