rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    RoleServer, ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::RequestContext,
10};
11use serde_json::Value;
12
13use reqwest::header::HeaderMap;
14use url::Url;
15
16use crate::error::Error;
17use crate::tool::{Tool, ToolCollection, ToolMetadata};
18use tracing::{debug, info, info_span, warn};
19
20#[derive(Clone, Builder)]
21pub struct Server {
22    pub openapi_spec: serde_json::Value,
23    #[builder(default)]
24    pub tool_collection: ToolCollection,
25    pub base_url: Option<Url>,
26    pub default_headers: Option<HeaderMap>,
27    pub tag_filter: Option<Vec<String>>,
28    pub method_filter: Option<Vec<reqwest::Method>>,
29}
30
31impl Server {
32    /// Create a new server with basic configuration
33    #[must_use]
34    pub fn new(openapi_spec: serde_json::Value) -> Self {
35        Self::builder().openapi_spec(openapi_spec).build()
36    }
37
38    /// Create a new server with a base URL for API calls
39    ///
40    /// # Errors
41    ///
42    /// Returns an error if the base URL is invalid
43    pub fn with_base_url(openapi_spec: serde_json::Value, base_url: Url) -> Result<Self, Error> {
44        Ok(Self::builder()
45            .openapi_spec(openapi_spec)
46            .base_url(base_url)
47            .build())
48    }
49
50    /// Create a new server with both base URL and default headers
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if the base URL is invalid
55    pub fn with_base_url_and_headers(
56        openapi_spec: serde_json::Value,
57        base_url: Url,
58        default_headers: HeaderMap,
59    ) -> Result<Self, Error> {
60        Ok(Self::builder()
61            .openapi_spec(openapi_spec)
62            .base_url(base_url)
63            .default_headers(default_headers)
64            .build())
65    }
66
67    /// Create a new server with default headers but no base URL
68    #[must_use]
69    pub fn with_default_headers(
70        openapi_spec: serde_json::Value,
71        default_headers: HeaderMap,
72    ) -> Self {
73        Self::builder()
74            .openapi_spec(openapi_spec)
75            .default_headers(default_headers)
76            .build()
77    }
78
79    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
80    ///
81    /// # Errors
82    ///
83    /// Returns an error if the spec cannot be parsed or tools cannot be generated
84    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
85        let span = info_span!("tool_registration");
86        let _enter = span.enter();
87
88        // Parse the OpenAPI specification
89        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
90
91        // Generate OpenApiTool instances directly
92        let tools = spec.to_openapi_tools(
93            self.tag_filter.as_deref(),
94            self.method_filter.as_deref(),
95            self.base_url.clone(),
96            self.default_headers.clone(),
97        )?;
98
99        self.tool_collection = ToolCollection::from_tools(tools);
100
101        info!(
102            tool_count = self.tool_collection.len(),
103            "Loaded tools from OpenAPI spec"
104        );
105
106        Ok(())
107    }
108
109    /// Get the number of loaded tools
110    #[must_use]
111    pub fn tool_count(&self) -> usize {
112        self.tool_collection.len()
113    }
114
115    /// Get all tool names
116    #[must_use]
117    pub fn get_tool_names(&self) -> Vec<String> {
118        self.tool_collection.get_tool_names()
119    }
120
121    /// Check if a specific tool exists
122    #[must_use]
123    pub fn has_tool(&self, name: &str) -> bool {
124        self.tool_collection.has_tool(name)
125    }
126
127    /// Get a tool by name
128    #[must_use]
129    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
130        self.tool_collection.get_tool(name)
131    }
132
133    /// Get tool metadata by name
134    #[must_use]
135    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
136        self.get_tool(name).map(|tool| &tool.metadata)
137    }
138
139    /// Get basic tool statistics
140    #[must_use]
141    pub fn get_tool_stats(&self) -> String {
142        self.tool_collection.get_stats()
143    }
144
145    /// Set tag filter for this server instance
146    #[must_use]
147    pub fn with_tags(mut self, tags: Option<Vec<String>>) -> Self {
148        self.tag_filter = tags;
149        self
150    }
151
152    /// Set method filter for this server instance
153    #[must_use]
154    pub fn with_methods(mut self, methods: Option<Vec<reqwest::Method>>) -> Self {
155        self.method_filter = methods;
156        self
157    }
158
159    /// Simple validation - check that tools are loaded
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if no tools are loaded
164    pub fn validate_registry(&self) -> Result<(), Error> {
165        if self.tool_collection.is_empty() {
166            return Err(Error::McpError("No tools loaded".to_string()));
167        }
168        Ok(())
169    }
170}
171
172impl ServerHandler for Server {
173    fn get_info(&self) -> InitializeResult {
174        InitializeResult {
175            protocol_version: ProtocolVersion::V_2024_11_05,
176            server_info: Implementation {
177                name: "OpenAPI MCP Server".to_string(),
178                version: "0.1.0".to_string(),
179            },
180            capabilities: ServerCapabilities {
181                tools: Some(ToolsCapability {
182                    list_changed: Some(false),
183                }),
184                ..Default::default()
185            },
186            instructions: Some("Exposes OpenAPI endpoints as MCP tools".to_string()),
187        }
188    }
189
190    async fn list_tools(
191        &self,
192        _request: Option<PaginatedRequestParam>,
193        _context: RequestContext<RoleServer>,
194    ) -> Result<ListToolsResult, ErrorData> {
195        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
196        let _enter = span.enter();
197
198        debug!("Processing MCP list_tools request");
199
200        // Delegate to tool collection for MCP tool conversion
201        let tools = self.tool_collection.to_mcp_tools();
202
203        info!(
204            returned_tools = tools.len(),
205            "MCP list_tools request completed successfully"
206        );
207
208        Ok(ListToolsResult {
209            tools,
210            next_cursor: None,
211        })
212    }
213
214    async fn call_tool(
215        &self,
216        request: CallToolRequestParam,
217        _context: RequestContext<RoleServer>,
218    ) -> Result<CallToolResult, ErrorData> {
219        let span = info_span!(
220            "call_tool",
221            tool_name = %request.name
222        );
223        let _enter = span.enter();
224
225        debug!(
226            tool_name = %request.name,
227            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
228            "Processing MCP call_tool request"
229        );
230
231        let arguments = request.arguments.unwrap_or_default();
232        let arguments_value = Value::Object(arguments);
233
234        // Delegate all tool validation and execution to the tool collection
235        match self
236            .tool_collection
237            .call_tool(&request.name, &arguments_value)
238            .await
239        {
240            Ok(result) => {
241                info!(
242                    tool_name = %request.name,
243                    success = true,
244                    "MCP call_tool request completed successfully"
245                );
246                Ok(result)
247            }
248            Err(e) => {
249                warn!(
250                    tool_name = %request.name,
251                    success = false,
252                    error = %e,
253                    "MCP call_tool request failed"
254                );
255                // Convert ToolCallError to ErrorData and return as error
256                Err(e.into())
257            }
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::error::ToolCallValidationError;
266    use crate::{ToolCallError, ToolMetadata};
267    use serde_json::json;
268
269    #[test]
270    fn test_tool_not_found_error_with_suggestions() {
271        // Create test tool metadata
272        let tool1_metadata = ToolMetadata {
273            name: "getPetById".to_string(),
274            title: Some("Get Pet by ID".to_string()),
275            description: "Find pet by ID".to_string(),
276            parameters: json!({
277                "type": "object",
278                "properties": {
279                    "petId": {
280                        "type": "integer"
281                    }
282                },
283                "required": ["petId"]
284            }),
285            output_schema: None,
286            method: "GET".to_string(),
287            path: "/pet/{petId}".to_string(),
288        };
289
290        let tool2_metadata = ToolMetadata {
291            name: "getPetsByStatus".to_string(),
292            title: Some("Find Pets by Status".to_string()),
293            description: "Find pets by status".to_string(),
294            parameters: json!({
295                "type": "object",
296                "properties": {
297                    "status": {
298                        "type": "array",
299                        "items": {
300                            "type": "string"
301                        }
302                    }
303                },
304                "required": ["status"]
305            }),
306            output_schema: None,
307            method: "GET".to_string(),
308            path: "/pet/findByStatus".to_string(),
309        };
310
311        // Create OpenApiTool instances
312        let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
313        let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
314
315        // Create server with tools
316        let mut server = Server::new(serde_json::Value::Null);
317        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
318
319        // Test: Create ToolNotFound error with a typo
320        let tool_names = server.get_tool_names();
321        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
322
323        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
324            "getPetByID".to_string(),
325            &tool_name_refs,
326        ));
327        let error_data: ErrorData = error.into();
328        let error_json = serde_json::to_value(&error_data).unwrap();
329
330        // Snapshot the error to verify suggestions
331        insta::assert_json_snapshot!(error_json);
332    }
333
334    #[test]
335    fn test_tool_not_found_error_no_suggestions() {
336        // Create test tool metadata
337        let tool_metadata = ToolMetadata {
338            name: "getPetById".to_string(),
339            title: Some("Get Pet by ID".to_string()),
340            description: "Find pet by ID".to_string(),
341            parameters: json!({
342                "type": "object",
343                "properties": {
344                    "petId": {
345                        "type": "integer"
346                    }
347                },
348                "required": ["petId"]
349            }),
350            output_schema: None,
351            method: "GET".to_string(),
352            path: "/pet/{petId}".to_string(),
353        };
354
355        // Create OpenApiTool instance
356        let tool = Tool::new(tool_metadata, None, None).unwrap();
357
358        // Create server with tool
359        let mut server = Server::new(serde_json::Value::Null);
360        server.tool_collection = ToolCollection::from_tools(vec![tool]);
361
362        // Test: Create ToolNotFound error with unrelated name
363        let tool_names = server.get_tool_names();
364        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
365
366        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
367            "completelyUnrelatedToolName".to_string(),
368            &tool_name_refs,
369        ));
370        let error_data: ErrorData = error.into();
371        let error_json = serde_json::to_value(&error_data).unwrap();
372
373        // Snapshot the error to verify no suggestions
374        insta::assert_json_snapshot!(error_json);
375    }
376
377    #[test]
378    fn test_validation_error_converted_to_error_data() {
379        // Test that validation errors are properly converted to ErrorData
380        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
381            violations: vec![crate::error::ValidationError::invalid_parameter(
382                "page".to_string(),
383                &["page_number".to_string(), "page_size".to_string()],
384            )],
385        });
386
387        let error_data: ErrorData = error.into();
388        let error_json = serde_json::to_value(&error_data).unwrap();
389
390        // Verify the basic structure
391        assert_eq!(error_json["code"], -32602); // Invalid params error code
392
393        // Snapshot the full error to verify the new error message format
394        insta::assert_json_snapshot!(error_json);
395    }
396}