rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    RoleServer, ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::RequestContext,
10};
11use serde_json::Value;
12
13use reqwest::header::HeaderMap;
14use url::Url;
15
16use crate::error::Error;
17use crate::spec::SpecLocation;
18use crate::tool::{Tool, ToolCollection, ToolMetadata};
19use tracing::{debug, info, info_span, warn};
20
21#[derive(Clone, Builder)]
22pub struct Server {
23    pub spec_location: SpecLocation,
24    #[builder(default)]
25    pub tool_collection: ToolCollection,
26    pub base_url: Option<Url>,
27    pub default_headers: Option<HeaderMap>,
28    pub tag_filter: Option<Vec<String>>,
29    pub method_filter: Option<Vec<reqwest::Method>>,
30}
31
32impl Server {
33    /// Create a new server with basic configuration (backwards compatibility)
34    #[must_use]
35    pub fn new(spec_location: SpecLocation) -> Self {
36        Self::builder().spec_location(spec_location).build()
37    }
38
39    /// Create a new server with a base URL for API calls (backwards compatibility)
40    ///
41    /// # Errors
42    ///
43    /// Returns an error if the base URL is invalid
44    pub fn with_base_url(spec_location: SpecLocation, base_url: Url) -> Result<Self, Error> {
45        Ok(Self::builder()
46            .spec_location(spec_location)
47            .base_url(base_url)
48            .build())
49    }
50
51    /// Create a new server with both base URL and default headers (backwards compatibility)
52    ///
53    /// # Errors
54    ///
55    /// Returns an error if the base URL is invalid
56    pub fn with_base_url_and_headers(
57        spec_location: SpecLocation,
58        base_url: Url,
59        default_headers: HeaderMap,
60    ) -> Result<Self, Error> {
61        Ok(Self::builder()
62            .spec_location(spec_location)
63            .base_url(base_url)
64            .default_headers(default_headers)
65            .build())
66    }
67
68    /// Create a new server with default headers but no base URL (backwards compatibility)
69    #[must_use]
70    pub fn with_default_headers(spec_location: SpecLocation, default_headers: HeaderMap) -> Self {
71        Self::builder()
72            .spec_location(spec_location)
73            .default_headers(default_headers)
74            .build()
75    }
76
77    /// Load the `OpenAPI` specification and convert to OpenApiTool instances
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if the spec cannot be loaded or tools cannot be generated
82    pub async fn load_openapi_spec(&mut self) -> Result<(), Error> {
83        let span = info_span!(
84            "tool_registration",
85            spec_location = %self.spec_location
86        );
87        let _enter = span.enter();
88
89        // Load the OpenAPI specification
90        let spec = self.spec_location.load_spec().await?;
91
92        // Generate OpenApiTool instances directly
93        let tools = spec.to_openapi_tools(
94            self.tag_filter.as_deref(),
95            self.method_filter.as_deref(),
96            self.base_url.clone(),
97            self.default_headers.clone(),
98        )?;
99
100        self.tool_collection = ToolCollection::from_tools(tools);
101
102        info!(
103            tool_count = self.tool_collection.len(),
104            "Loaded tools from OpenAPI spec"
105        );
106
107        Ok(())
108    }
109
110    /// Get the number of loaded tools
111    #[must_use]
112    pub fn tool_count(&self) -> usize {
113        self.tool_collection.len()
114    }
115
116    /// Get all tool names
117    #[must_use]
118    pub fn get_tool_names(&self) -> Vec<String> {
119        self.tool_collection.get_tool_names()
120    }
121
122    /// Check if a specific tool exists
123    #[must_use]
124    pub fn has_tool(&self, name: &str) -> bool {
125        self.tool_collection.has_tool(name)
126    }
127
128    /// Get a tool by name
129    #[must_use]
130    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
131        self.tool_collection.get_tool(name)
132    }
133
134    /// Get tool metadata by name
135    #[must_use]
136    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
137        self.get_tool(name).map(|tool| &tool.metadata)
138    }
139
140    /// Get basic tool statistics
141    #[must_use]
142    pub fn get_tool_stats(&self) -> String {
143        self.tool_collection.get_stats()
144    }
145
146    /// Set tag filter for this server instance
147    #[must_use]
148    pub fn with_tags(mut self, tags: Option<Vec<String>>) -> Self {
149        self.tag_filter = tags;
150        self
151    }
152
153    /// Set method filter for this server instance
154    #[must_use]
155    pub fn with_methods(mut self, methods: Option<Vec<reqwest::Method>>) -> Self {
156        self.method_filter = methods;
157        self
158    }
159
160    /// Simple validation - check that tools are loaded
161    ///
162    /// # Errors
163    ///
164    /// Returns an error if no tools are loaded
165    pub fn validate_registry(&self) -> Result<(), Error> {
166        if self.tool_collection.is_empty() {
167            return Err(Error::McpError("No tools loaded".to_string()));
168        }
169        Ok(())
170    }
171}
172
173impl ServerHandler for Server {
174    fn get_info(&self) -> InitializeResult {
175        InitializeResult {
176            protocol_version: ProtocolVersion::V_2024_11_05,
177            server_info: Implementation {
178                name: "OpenAPI MCP Server".to_string(),
179                version: "0.1.0".to_string(),
180            },
181            capabilities: ServerCapabilities {
182                tools: Some(ToolsCapability {
183                    list_changed: Some(false),
184                }),
185                ..Default::default()
186            },
187            instructions: Some("Exposes OpenAPI endpoints as MCP tools".to_string()),
188        }
189    }
190
191    async fn list_tools(
192        &self,
193        _request: Option<PaginatedRequestParam>,
194        _context: RequestContext<RoleServer>,
195    ) -> Result<ListToolsResult, ErrorData> {
196        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
197        let _enter = span.enter();
198
199        debug!("Processing MCP list_tools request");
200
201        // Delegate to tool collection for MCP tool conversion
202        let tools = self.tool_collection.to_mcp_tools();
203
204        info!(
205            returned_tools = tools.len(),
206            "MCP list_tools request completed successfully"
207        );
208
209        Ok(ListToolsResult {
210            tools,
211            next_cursor: None,
212        })
213    }
214
215    async fn call_tool(
216        &self,
217        request: CallToolRequestParam,
218        _context: RequestContext<RoleServer>,
219    ) -> Result<CallToolResult, ErrorData> {
220        let span = info_span!(
221            "call_tool",
222            tool_name = %request.name
223        );
224        let _enter = span.enter();
225
226        debug!(
227            tool_name = %request.name,
228            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
229            "Processing MCP call_tool request"
230        );
231
232        let arguments = request.arguments.unwrap_or_default();
233        let arguments_value = Value::Object(arguments);
234
235        // Delegate all tool validation and execution to the tool collection
236        match self
237            .tool_collection
238            .call_tool(&request.name, &arguments_value)
239            .await
240        {
241            Ok(result) => {
242                info!(
243                    tool_name = %request.name,
244                    success = true,
245                    "MCP call_tool request completed successfully"
246                );
247                Ok(result)
248            }
249            Err(e) => {
250                warn!(
251                    tool_name = %request.name,
252                    success = false,
253                    error = %e,
254                    "MCP call_tool request failed"
255                );
256                // Convert ToolCallError to ErrorData and return as error
257                Err(e.into())
258            }
259        }
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::error::ToolCallValidationError;
267    use crate::{ToolCallError, ToolMetadata};
268    use serde_json::json;
269
270    #[test]
271    fn test_tool_not_found_error_with_suggestions() {
272        // Create test tool metadata
273        let tool1_metadata = ToolMetadata {
274            name: "getPetById".to_string(),
275            title: Some("Get Pet by ID".to_string()),
276            description: "Find pet by ID".to_string(),
277            parameters: json!({
278                "type": "object",
279                "properties": {
280                    "petId": {
281                        "type": "integer"
282                    }
283                },
284                "required": ["petId"]
285            }),
286            output_schema: None,
287            method: "GET".to_string(),
288            path: "/pet/{petId}".to_string(),
289        };
290
291        let tool2_metadata = ToolMetadata {
292            name: "getPetsByStatus".to_string(),
293            title: Some("Find Pets by Status".to_string()),
294            description: "Find pets by status".to_string(),
295            parameters: json!({
296                "type": "object",
297                "properties": {
298                    "status": {
299                        "type": "array",
300                        "items": {
301                            "type": "string"
302                        }
303                    }
304                },
305                "required": ["status"]
306            }),
307            output_schema: None,
308            method: "GET".to_string(),
309            path: "/pet/findByStatus".to_string(),
310        };
311
312        // Create OpenApiTool instances
313        let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
314        let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
315
316        // Create server with tools
317        let mut server = Server::new(SpecLocation::Url(Url::parse("test://example").unwrap()));
318        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
319
320        // Test: Create ToolNotFound error with a typo
321        let tool_names = server.get_tool_names();
322        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
323
324        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
325            "getPetByID".to_string(),
326            &tool_name_refs,
327        ));
328        let error_data: ErrorData = error.into();
329        let error_json = serde_json::to_value(&error_data).unwrap();
330
331        // Snapshot the error to verify suggestions
332        insta::assert_json_snapshot!(error_json);
333    }
334
335    #[test]
336    fn test_tool_not_found_error_no_suggestions() {
337        // Create test tool metadata
338        let tool_metadata = ToolMetadata {
339            name: "getPetById".to_string(),
340            title: Some("Get Pet by ID".to_string()),
341            description: "Find pet by ID".to_string(),
342            parameters: json!({
343                "type": "object",
344                "properties": {
345                    "petId": {
346                        "type": "integer"
347                    }
348                },
349                "required": ["petId"]
350            }),
351            output_schema: None,
352            method: "GET".to_string(),
353            path: "/pet/{petId}".to_string(),
354        };
355
356        // Create OpenApiTool instance
357        let tool = Tool::new(tool_metadata, None, None).unwrap();
358
359        // Create server with tool
360        let mut server = Server::new(SpecLocation::Url(Url::parse("test://example").unwrap()));
361        server.tool_collection = ToolCollection::from_tools(vec![tool]);
362
363        // Test: Create ToolNotFound error with unrelated name
364        let tool_names = server.get_tool_names();
365        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
366
367        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
368            "completelyUnrelatedToolName".to_string(),
369            &tool_name_refs,
370        ));
371        let error_data: ErrorData = error.into();
372        let error_json = serde_json::to_value(&error_data).unwrap();
373
374        // Snapshot the error to verify no suggestions
375        insta::assert_json_snapshot!(error_json);
376    }
377
378    #[test]
379    fn test_validation_error_converted_to_error_data() {
380        // Test that validation errors are properly converted to ErrorData
381        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
382            violations: vec![crate::error::ValidationError::invalid_parameter(
383                "page".to_string(),
384                &["page_number".to_string(), "page_size".to_string()],
385            )],
386        });
387
388        let error_data: ErrorData = error.into();
389        let error_json = serde_json::to_value(&error_data).unwrap();
390
391        // Verify the basic structure
392        assert_eq!(error_json["code"], -32602); // Invalid params error code
393
394        // Snapshot the full error to verify the new error message format
395        insta::assert_json_snapshot!(error_json);
396    }
397}