rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    RoleServer, ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::RequestContext,
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::config::{Authorization, AuthorizationMode};
18use crate::error::Error;
19use crate::tool::{Tool, ToolCollection, ToolMetadata};
20use tracing::{debug, info, info_span, warn};
21
22#[derive(Clone, Builder)]
23pub struct Server {
24    pub openapi_spec: serde_json::Value,
25    #[builder(default)]
26    pub tool_collection: ToolCollection,
27    pub base_url: Url,
28    pub default_headers: Option<HeaderMap>,
29    pub tag_filter: Option<Vec<String>>,
30    pub method_filter: Option<Vec<reqwest::Method>>,
31    #[builder(default)]
32    pub authorization_mode: AuthorizationMode,
33    pub name: Option<String>,
34    pub version: Option<String>,
35    pub title: Option<String>,
36    pub instructions: Option<String>,
37}
38
39impl Server {
40    /// Create a new Server instance with required parameters
41    pub fn new(
42        openapi_spec: serde_json::Value,
43        base_url: Url,
44        default_headers: Option<HeaderMap>,
45        tag_filter: Option<Vec<String>>,
46        method_filter: Option<Vec<reqwest::Method>>,
47    ) -> Self {
48        Self {
49            openapi_spec,
50            tool_collection: ToolCollection::new(),
51            base_url,
52            default_headers,
53            tag_filter,
54            method_filter,
55            authorization_mode: AuthorizationMode::default(),
56            name: None,
57            version: None,
58            title: None,
59            instructions: None,
60        }
61    }
62
63    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
64    ///
65    /// # Errors
66    ///
67    /// Returns an error if the spec cannot be parsed or tools cannot be generated
68    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
69        let span = info_span!("tool_registration");
70        let _enter = span.enter();
71
72        // Parse the OpenAPI specification
73        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
74
75        // Generate OpenApiTool instances directly
76        let tools = spec.to_openapi_tools(
77            self.tag_filter.as_deref(),
78            self.method_filter.as_deref(),
79            Some(self.base_url.clone()),
80            self.default_headers.clone(),
81        )?;
82
83        self.tool_collection = ToolCollection::from_tools(tools);
84
85        info!(
86            tool_count = self.tool_collection.len(),
87            "Loaded tools from OpenAPI spec"
88        );
89
90        Ok(())
91    }
92
93    /// Get the number of loaded tools
94    #[must_use]
95    pub fn tool_count(&self) -> usize {
96        self.tool_collection.len()
97    }
98
99    /// Get all tool names
100    #[must_use]
101    pub fn get_tool_names(&self) -> Vec<String> {
102        self.tool_collection.get_tool_names()
103    }
104
105    /// Check if a specific tool exists
106    #[must_use]
107    pub fn has_tool(&self, name: &str) -> bool {
108        self.tool_collection.has_tool(name)
109    }
110
111    /// Get a tool by name
112    #[must_use]
113    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
114        self.tool_collection.get_tool(name)
115    }
116
117    /// Get tool metadata by name
118    #[must_use]
119    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
120        self.get_tool(name).map(|tool| &tool.metadata)
121    }
122
123    /// Set the authorization mode for the server
124    pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
125        self.authorization_mode = mode;
126    }
127
128    /// Get the current authorization mode
129    pub fn authorization_mode(&self) -> AuthorizationMode {
130        self.authorization_mode
131    }
132
133    /// Get basic tool statistics
134    #[must_use]
135    pub fn get_tool_stats(&self) -> String {
136        self.tool_collection.get_stats()
137    }
138
139    /// Simple validation - check that tools are loaded
140    ///
141    /// # Errors
142    ///
143    /// Returns an error if no tools are loaded
144    pub fn validate_registry(&self) -> Result<(), Error> {
145        if self.tool_collection.is_empty() {
146            return Err(Error::McpError("No tools loaded".to_string()));
147        }
148        Ok(())
149    }
150
151    /// Extract title from OpenAPI spec info section
152    fn extract_openapi_title(&self) -> Option<String> {
153        self.openapi_spec
154            .get("info")?
155            .get("title")?
156            .as_str()
157            .map(|s| s.to_string())
158    }
159
160    /// Extract version from OpenAPI spec info section
161    fn extract_openapi_version(&self) -> Option<String> {
162        self.openapi_spec
163            .get("info")?
164            .get("version")?
165            .as_str()
166            .map(|s| s.to_string())
167    }
168
169    /// Extract description from OpenAPI spec info section
170    fn extract_openapi_description(&self) -> Option<String> {
171        self.openapi_spec
172            .get("info")?
173            .get("description")?
174            .as_str()
175            .map(|s| s.to_string())
176    }
177
178    /// Extract display title from OpenAPI spec info section
179    /// First checks for x-display-title extension, then derives from title
180    fn extract_openapi_display_title(&self) -> Option<String> {
181        // First check for x-display-title extension
182        if let Some(display_title) = self
183            .openapi_spec
184            .get("info")
185            .and_then(|info| info.get("x-display-title"))
186            .and_then(|t| t.as_str())
187        {
188            return Some(display_title.to_string());
189        }
190
191        // Fallback: enhance the title with "Server" suffix if not already present
192        self.extract_openapi_title().map(|title| {
193            if title.to_lowercase().contains("server") {
194                title
195            } else {
196                format!("{} Server", title)
197            }
198        })
199    }
200}
201
202impl ServerHandler for Server {
203    fn get_info(&self) -> InitializeResult {
204        // 3-level fallback for server name: custom -> OpenAPI spec -> default
205        let server_name = self
206            .name
207            .clone()
208            .or_else(|| self.extract_openapi_title())
209            .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
210
211        // 3-level fallback for server version: custom -> OpenAPI spec -> crate version
212        let server_version = self
213            .version
214            .clone()
215            .or_else(|| self.extract_openapi_version())
216            .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
217
218        // 3-level fallback for title: custom -> OpenAPI-derived -> None
219        let server_title = self
220            .title
221            .clone()
222            .or_else(|| self.extract_openapi_display_title());
223
224        // 3-level fallback for instructions: custom -> OpenAPI spec -> default
225        let instructions = self
226            .instructions
227            .clone()
228            .or_else(|| self.extract_openapi_description())
229            .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
230
231        InitializeResult {
232            protocol_version: ProtocolVersion::V_2024_11_05,
233            server_info: Implementation {
234                name: server_name,
235                version: server_version,
236                title: server_title,
237                icons: None,
238                website_url: None,
239            },
240            capabilities: ServerCapabilities {
241                tools: Some(ToolsCapability {
242                    list_changed: Some(false),
243                }),
244                ..Default::default()
245            },
246            instructions,
247        }
248    }
249
250    async fn list_tools(
251        &self,
252        _request: Option<PaginatedRequestParam>,
253        _context: RequestContext<RoleServer>,
254    ) -> Result<ListToolsResult, ErrorData> {
255        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
256        let _enter = span.enter();
257
258        debug!("Processing MCP list_tools request");
259
260        // Delegate to tool collection for MCP tool conversion
261        let tools = self.tool_collection.to_mcp_tools();
262
263        info!(
264            returned_tools = tools.len(),
265            "MCP list_tools request completed successfully"
266        );
267
268        Ok(ListToolsResult {
269            tools,
270            next_cursor: None,
271        })
272    }
273
274    async fn call_tool(
275        &self,
276        request: CallToolRequestParam,
277        context: RequestContext<RoleServer>,
278    ) -> Result<CallToolResult, ErrorData> {
279        let span = info_span!(
280            "call_tool",
281            tool_name = %request.name
282        );
283        let _enter = span.enter();
284
285        debug!(
286            tool_name = %request.name,
287            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
288            "Processing MCP call_tool request"
289        );
290
291        let arguments = request.arguments.unwrap_or_default();
292        let arguments_value = Value::Object(arguments);
293
294        // Extract authorization header from context extensions
295        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
296
297        if auth_header.is_some() {
298            debug!("Authorization header is present");
299        }
300
301        // Create Authorization enum from mode and header
302        let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
303
304        // Delegate all tool validation and execution to the tool collection
305        match self
306            .tool_collection
307            .call_tool(&request.name, &arguments_value, authorization)
308            .await
309        {
310            Ok(result) => {
311                info!(
312                    tool_name = %request.name,
313                    success = true,
314                    "MCP call_tool request completed successfully"
315                );
316                Ok(result)
317            }
318            Err(e) => {
319                warn!(
320                    tool_name = %request.name,
321                    success = false,
322                    error = %e,
323                    "MCP call_tool request failed"
324                );
325                // Convert ToolCallError to ErrorData and return as error
326                Err(e.into())
327            }
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use crate::error::ToolCallValidationError;
336    use crate::{ToolCallError, ToolMetadata};
337    use serde_json::json;
338
339    #[test]
340    fn test_tool_not_found_error_with_suggestions() {
341        // Create test tool metadata
342        let tool1_metadata = ToolMetadata {
343            name: "getPetById".to_string(),
344            title: Some("Get Pet by ID".to_string()),
345            description: "Find pet by ID".to_string(),
346            parameters: json!({
347                "type": "object",
348                "properties": {
349                    "petId": {
350                        "type": "integer"
351                    }
352                },
353                "required": ["petId"]
354            }),
355            output_schema: None,
356            method: "GET".to_string(),
357            path: "/pet/{petId}".to_string(),
358            security: None,
359        };
360
361        let tool2_metadata = ToolMetadata {
362            name: "getPetsByStatus".to_string(),
363            title: Some("Find Pets by Status".to_string()),
364            description: "Find pets by status".to_string(),
365            parameters: json!({
366                "type": "object",
367                "properties": {
368                    "status": {
369                        "type": "array",
370                        "items": {
371                            "type": "string"
372                        }
373                    }
374                },
375                "required": ["status"]
376            }),
377            output_schema: None,
378            method: "GET".to_string(),
379            path: "/pet/findByStatus".to_string(),
380            security: None,
381        };
382
383        // Create OpenApiTool instances
384        let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
385        let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
386
387        // Create server with tools
388        let mut server = Server::new(
389            serde_json::Value::Null,
390            url::Url::parse("http://example.com").unwrap(),
391            None,
392            None,
393            None,
394        );
395        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
396
397        // Test: Create ToolNotFound error with a typo
398        let tool_names = server.get_tool_names();
399        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
400
401        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
402            "getPetByID".to_string(),
403            &tool_name_refs,
404        ));
405        let error_data: ErrorData = error.into();
406        let error_json = serde_json::to_value(&error_data).unwrap();
407
408        // Snapshot the error to verify suggestions
409        insta::assert_json_snapshot!(error_json);
410    }
411
412    #[test]
413    fn test_tool_not_found_error_no_suggestions() {
414        // Create test tool metadata
415        let tool_metadata = ToolMetadata {
416            name: "getPetById".to_string(),
417            title: Some("Get Pet by ID".to_string()),
418            description: "Find pet by ID".to_string(),
419            parameters: json!({
420                "type": "object",
421                "properties": {
422                    "petId": {
423                        "type": "integer"
424                    }
425                },
426                "required": ["petId"]
427            }),
428            output_schema: None,
429            method: "GET".to_string(),
430            path: "/pet/{petId}".to_string(),
431            security: None,
432        };
433
434        // Create OpenApiTool instance
435        let tool = Tool::new(tool_metadata, None, None).unwrap();
436
437        // Create server with tool
438        let mut server = Server::new(
439            serde_json::Value::Null,
440            url::Url::parse("http://example.com").unwrap(),
441            None,
442            None,
443            None,
444        );
445        server.tool_collection = ToolCollection::from_tools(vec![tool]);
446
447        // Test: Create ToolNotFound error with unrelated name
448        let tool_names = server.get_tool_names();
449        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
450
451        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
452            "completelyUnrelatedToolName".to_string(),
453            &tool_name_refs,
454        ));
455        let error_data: ErrorData = error.into();
456        let error_json = serde_json::to_value(&error_data).unwrap();
457
458        // Snapshot the error to verify no suggestions
459        insta::assert_json_snapshot!(error_json);
460    }
461
462    #[test]
463    fn test_validation_error_converted_to_error_data() {
464        // Test that validation errors are properly converted to ErrorData
465        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
466            violations: vec![crate::error::ValidationError::invalid_parameter(
467                "page".to_string(),
468                &["page_number".to_string(), "page_size".to_string()],
469            )],
470        });
471
472        let error_data: ErrorData = error.into();
473        let error_json = serde_json::to_value(&error_data).unwrap();
474
475        // Verify the basic structure
476        assert_eq!(error_json["code"], -32602); // Invalid params error code
477
478        // Snapshot the full error to verify the new error message format
479        insta::assert_json_snapshot!(error_json);
480    }
481
482    #[test]
483    fn test_extract_openapi_info_with_full_spec() {
484        let openapi_spec = json!({
485            "openapi": "3.0.0",
486            "info": {
487                "title": "Pet Store API",
488                "version": "2.1.0",
489                "description": "A sample API for managing pets"
490            },
491            "paths": {}
492        });
493
494        let server = Server::new(
495            openapi_spec,
496            url::Url::parse("http://example.com").unwrap(),
497            None,
498            None,
499            None,
500        );
501
502        assert_eq!(
503            server.extract_openapi_title(),
504            Some("Pet Store API".to_string())
505        );
506        assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
507        assert_eq!(
508            server.extract_openapi_description(),
509            Some("A sample API for managing pets".to_string())
510        );
511    }
512
513    #[test]
514    fn test_extract_openapi_info_with_minimal_spec() {
515        let openapi_spec = json!({
516            "openapi": "3.0.0",
517            "info": {
518                "title": "My API",
519                "version": "1.0.0"
520            },
521            "paths": {}
522        });
523
524        let server = Server::new(
525            openapi_spec,
526            url::Url::parse("http://example.com").unwrap(),
527            None,
528            None,
529            None,
530        );
531
532        assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
533        assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
534        assert_eq!(server.extract_openapi_description(), None);
535    }
536
537    #[test]
538    fn test_extract_openapi_info_with_invalid_spec() {
539        let openapi_spec = json!({
540            "invalid": "spec"
541        });
542
543        let server = Server::new(
544            openapi_spec,
545            url::Url::parse("http://example.com").unwrap(),
546            None,
547            None,
548            None,
549        );
550
551        assert_eq!(server.extract_openapi_title(), None);
552        assert_eq!(server.extract_openapi_version(), None);
553        assert_eq!(server.extract_openapi_description(), None);
554    }
555
556    #[test]
557    fn test_get_info_fallback_hierarchy_custom_metadata() {
558        let server = Server::new(
559            serde_json::Value::Null,
560            url::Url::parse("http://example.com").unwrap(),
561            None,
562            None,
563            None,
564        );
565
566        // Set custom metadata directly
567        let mut server = server;
568        server.name = Some("Custom Server".to_string());
569        server.version = Some("3.0.0".to_string());
570        server.instructions = Some("Custom instructions".to_string());
571
572        let result = server.get_info();
573
574        assert_eq!(result.server_info.name, "Custom Server");
575        assert_eq!(result.server_info.version, "3.0.0");
576        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
577    }
578
579    #[test]
580    fn test_get_info_fallback_hierarchy_openapi_spec() {
581        let openapi_spec = json!({
582            "openapi": "3.0.0",
583            "info": {
584                "title": "OpenAPI Server",
585                "version": "1.5.0",
586                "description": "Server from OpenAPI spec"
587            },
588            "paths": {}
589        });
590
591        let server = Server::new(
592            openapi_spec,
593            url::Url::parse("http://example.com").unwrap(),
594            None,
595            None,
596            None,
597        );
598
599        let result = server.get_info();
600
601        assert_eq!(result.server_info.name, "OpenAPI Server");
602        assert_eq!(result.server_info.version, "1.5.0");
603        assert_eq!(
604            result.instructions,
605            Some("Server from OpenAPI spec".to_string())
606        );
607    }
608
609    #[test]
610    fn test_get_info_fallback_hierarchy_defaults() {
611        let server = Server::new(
612            serde_json::Value::Null,
613            url::Url::parse("http://example.com").unwrap(),
614            None,
615            None,
616            None,
617        );
618
619        let result = server.get_info();
620
621        assert_eq!(result.server_info.name, "OpenAPI MCP Server");
622        assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
623        assert_eq!(
624            result.instructions,
625            Some("Exposes OpenAPI endpoints as MCP tools".to_string())
626        );
627    }
628
629    #[test]
630    fn test_get_info_fallback_hierarchy_mixed() {
631        let openapi_spec = json!({
632            "openapi": "3.0.0",
633            "info": {
634                "title": "OpenAPI Server",
635                "version": "2.5.0",
636                "description": "Server from OpenAPI spec"
637            },
638            "paths": {}
639        });
640
641        let mut server = Server::new(
642            openapi_spec,
643            url::Url::parse("http://example.com").unwrap(),
644            None,
645            None,
646            None,
647        );
648
649        // Set custom name and instructions, leave version to fallback to OpenAPI
650        server.name = Some("Custom Server".to_string());
651        server.instructions = Some("Custom instructions".to_string());
652
653        let result = server.get_info();
654
655        // Custom name takes precedence
656        assert_eq!(result.server_info.name, "Custom Server");
657        // OpenAPI version is used
658        assert_eq!(result.server_info.version, "2.5.0");
659        // Custom instructions take precedence
660        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
661    }
662}