rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    handler::server::ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::config::{Authorization, AuthorizationMode};
18use crate::error::Error;
19use crate::tool::{Tool, ToolCollection, ToolMetadata};
20use tracing::{debug, info, info_span, warn};
21
22#[derive(Clone, Builder)]
23pub struct Server {
24    pub openapi_spec: serde_json::Value,
25    #[builder(default)]
26    pub tool_collection: ToolCollection,
27    pub base_url: Url,
28    pub default_headers: Option<HeaderMap>,
29    pub tag_filter: Option<Vec<String>>,
30    pub method_filter: Option<Vec<reqwest::Method>>,
31    #[builder(default)]
32    pub authorization_mode: AuthorizationMode,
33    pub name: Option<String>,
34    pub version: Option<String>,
35    pub title: Option<String>,
36    pub instructions: Option<String>,
37    #[builder(default)]
38    pub skip_tool_descriptions: bool,
39    #[builder(default)]
40    pub skip_parameter_descriptions: bool,
41}
42
43impl Server {
44    /// Create a new Server instance with required parameters
45    pub fn new(
46        openapi_spec: serde_json::Value,
47        base_url: Url,
48        default_headers: Option<HeaderMap>,
49        tag_filter: Option<Vec<String>>,
50        method_filter: Option<Vec<reqwest::Method>>,
51        skip_tool_descriptions: bool,
52        skip_parameter_descriptions: bool,
53    ) -> Self {
54        Self {
55            openapi_spec,
56            tool_collection: ToolCollection::new(),
57            base_url,
58            default_headers,
59            tag_filter,
60            method_filter,
61            authorization_mode: AuthorizationMode::default(),
62            name: None,
63            version: None,
64            title: None,
65            instructions: None,
66            skip_tool_descriptions,
67            skip_parameter_descriptions,
68        }
69    }
70
71    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the spec cannot be parsed or tools cannot be generated
76    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
77        let span = info_span!("tool_registration");
78        let _enter = span.enter();
79
80        // Parse the OpenAPI specification
81        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
82
83        // Generate OpenApiTool instances directly
84        let tools = spec.to_openapi_tools(
85            self.tag_filter.as_deref(),
86            self.method_filter.as_deref(),
87            Some(self.base_url.clone()),
88            self.default_headers.clone(),
89            self.skip_tool_descriptions,
90            self.skip_parameter_descriptions,
91        )?;
92
93        self.tool_collection = ToolCollection::from_tools(tools);
94
95        info!(
96            tool_count = self.tool_collection.len(),
97            "Loaded tools from OpenAPI spec"
98        );
99
100        Ok(())
101    }
102
103    /// Get the number of loaded tools
104    #[must_use]
105    pub fn tool_count(&self) -> usize {
106        self.tool_collection.len()
107    }
108
109    /// Get all tool names
110    #[must_use]
111    pub fn get_tool_names(&self) -> Vec<String> {
112        self.tool_collection.get_tool_names()
113    }
114
115    /// Check if a specific tool exists
116    #[must_use]
117    pub fn has_tool(&self, name: &str) -> bool {
118        self.tool_collection.has_tool(name)
119    }
120
121    /// Get a tool by name
122    #[must_use]
123    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
124        self.tool_collection.get_tool(name)
125    }
126
127    /// Get tool metadata by name
128    #[must_use]
129    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
130        self.get_tool(name).map(|tool| &tool.metadata)
131    }
132
133    /// Set the authorization mode for the server
134    pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
135        self.authorization_mode = mode;
136    }
137
138    /// Get the current authorization mode
139    pub fn authorization_mode(&self) -> AuthorizationMode {
140        self.authorization_mode
141    }
142
143    /// Get basic tool statistics
144    #[must_use]
145    pub fn get_tool_stats(&self) -> String {
146        self.tool_collection.get_stats()
147    }
148
149    /// Simple validation - check that tools are loaded
150    ///
151    /// # Errors
152    ///
153    /// Returns an error if no tools are loaded
154    pub fn validate_registry(&self) -> Result<(), Error> {
155        if self.tool_collection.is_empty() {
156            return Err(Error::McpError("No tools loaded".to_string()));
157        }
158        Ok(())
159    }
160
161    /// Extract title from OpenAPI spec info section
162    fn extract_openapi_title(&self) -> Option<String> {
163        self.openapi_spec
164            .get("info")?
165            .get("title")?
166            .as_str()
167            .map(|s| s.to_string())
168    }
169
170    /// Extract version from OpenAPI spec info section
171    fn extract_openapi_version(&self) -> Option<String> {
172        self.openapi_spec
173            .get("info")?
174            .get("version")?
175            .as_str()
176            .map(|s| s.to_string())
177    }
178
179    /// Extract description from OpenAPI spec info section
180    fn extract_openapi_description(&self) -> Option<String> {
181        self.openapi_spec
182            .get("info")?
183            .get("description")?
184            .as_str()
185            .map(|s| s.to_string())
186    }
187
188    /// Extract display title from OpenAPI spec info section
189    /// First checks for x-display-title extension, then derives from title
190    fn extract_openapi_display_title(&self) -> Option<String> {
191        // First check for x-display-title extension
192        if let Some(display_title) = self
193            .openapi_spec
194            .get("info")
195            .and_then(|info| info.get("x-display-title"))
196            .and_then(|t| t.as_str())
197        {
198            return Some(display_title.to_string());
199        }
200
201        // Fallback: enhance the title with "Server" suffix if not already present
202        self.extract_openapi_title().map(|title| {
203            if title.to_lowercase().contains("server") {
204                title
205            } else {
206                format!("{} Server", title)
207            }
208        })
209    }
210}
211
212impl ServerHandler for Server {
213    fn get_info(&self) -> InitializeResult {
214        // 3-level fallback for server name: custom -> OpenAPI spec -> default
215        let server_name = self
216            .name
217            .clone()
218            .or_else(|| self.extract_openapi_title())
219            .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
220
221        // 3-level fallback for server version: custom -> OpenAPI spec -> crate version
222        let server_version = self
223            .version
224            .clone()
225            .or_else(|| self.extract_openapi_version())
226            .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
227
228        // 3-level fallback for title: custom -> OpenAPI-derived -> None
229        let server_title = self
230            .title
231            .clone()
232            .or_else(|| self.extract_openapi_display_title());
233
234        // 3-level fallback for instructions: custom -> OpenAPI spec -> default
235        let instructions = self
236            .instructions
237            .clone()
238            .or_else(|| self.extract_openapi_description())
239            .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
240
241        InitializeResult {
242            protocol_version: ProtocolVersion::V_2024_11_05,
243            server_info: Implementation {
244                name: server_name,
245                version: server_version,
246                title: server_title,
247                icons: None,
248                website_url: None,
249            },
250            capabilities: ServerCapabilities {
251                tools: Some(ToolsCapability {
252                    list_changed: Some(false),
253                }),
254                ..Default::default()
255            },
256            instructions,
257        }
258    }
259
260    async fn list_tools(
261        &self,
262        _request: Option<PaginatedRequestParam>,
263        _context: RequestContext<RoleServer>,
264    ) -> Result<ListToolsResult, ErrorData> {
265        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
266        let _enter = span.enter();
267
268        debug!("Processing MCP list_tools request");
269
270        // Delegate to tool collection for MCP tool conversion
271        let tools = self.tool_collection.to_mcp_tools();
272
273        info!(
274            returned_tools = tools.len(),
275            "MCP list_tools request completed successfully"
276        );
277
278        Ok(ListToolsResult {
279            tools,
280            next_cursor: None,
281        })
282    }
283
284    async fn call_tool(
285        &self,
286        request: CallToolRequestParam,
287        context: RequestContext<RoleServer>,
288    ) -> Result<CallToolResult, ErrorData> {
289        let span = info_span!(
290            "call_tool",
291            tool_name = %request.name
292        );
293        let _enter = span.enter();
294
295        debug!(
296            tool_name = %request.name,
297            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
298            "Processing MCP call_tool request"
299        );
300
301        let arguments = request.arguments.unwrap_or_default();
302        let arguments_value = Value::Object(arguments);
303
304        // Extract authorization header from context extensions
305        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
306
307        if auth_header.is_some() {
308            debug!("Authorization header is present");
309        }
310
311        // Create Authorization enum from mode and header
312        let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
313
314        // Delegate all tool validation and execution to the tool collection
315        match self
316            .tool_collection
317            .call_tool(&request.name, &arguments_value, authorization)
318            .await
319        {
320            Ok(result) => {
321                info!(
322                    tool_name = %request.name,
323                    success = true,
324                    "MCP call_tool request completed successfully"
325                );
326                Ok(result)
327            }
328            Err(e) => {
329                warn!(
330                    tool_name = %request.name,
331                    success = false,
332                    error = %e,
333                    "MCP call_tool request failed"
334                );
335                // Convert ToolCallError to ErrorData and return as error
336                Err(e.into())
337            }
338        }
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::error::ToolCallValidationError;
346    use crate::{ToolCallError, ToolMetadata};
347    use serde_json::json;
348
349    #[test]
350    fn test_tool_not_found_error_with_suggestions() {
351        // Create test tool metadata
352        let tool1_metadata = ToolMetadata {
353            name: "getPetById".to_string(),
354            title: Some("Get Pet by ID".to_string()),
355            description: Some("Find pet by ID".to_string()),
356            parameters: json!({
357                "type": "object",
358                "properties": {
359                    "petId": {
360                        "type": "integer"
361                    }
362                },
363                "required": ["petId"]
364            }),
365            output_schema: None,
366            method: "GET".to_string(),
367            path: "/pet/{petId}".to_string(),
368            security: None,
369        };
370
371        let tool2_metadata = ToolMetadata {
372            name: "getPetsByStatus".to_string(),
373            title: Some("Find Pets by Status".to_string()),
374            description: Some("Find pets by status".to_string()),
375            parameters: json!({
376                "type": "object",
377                "properties": {
378                    "status": {
379                        "type": "array",
380                        "items": {
381                            "type": "string"
382                        }
383                    }
384                },
385                "required": ["status"]
386            }),
387            output_schema: None,
388            method: "GET".to_string(),
389            path: "/pet/findByStatus".to_string(),
390            security: None,
391        };
392
393        // Create OpenApiTool instances
394        let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
395        let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
396
397        // Create server with tools
398        let mut server = Server::new(
399            serde_json::Value::Null,
400            url::Url::parse("http://example.com").unwrap(),
401            None,
402            None,
403            None,
404            false,
405            false,
406        );
407        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
408
409        // Test: Create ToolNotFound error with a typo
410        let tool_names = server.get_tool_names();
411        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
412
413        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
414            "getPetByID".to_string(),
415            &tool_name_refs,
416        ));
417        let error_data: ErrorData = error.into();
418        let error_json = serde_json::to_value(&error_data).unwrap();
419
420        // Snapshot the error to verify suggestions
421        insta::assert_json_snapshot!(error_json);
422    }
423
424    #[test]
425    fn test_tool_not_found_error_no_suggestions() {
426        // Create test tool metadata
427        let tool_metadata = ToolMetadata {
428            name: "getPetById".to_string(),
429            title: Some("Get Pet by ID".to_string()),
430            description: Some("Find pet by ID".to_string()),
431            parameters: json!({
432                "type": "object",
433                "properties": {
434                    "petId": {
435                        "type": "integer"
436                    }
437                },
438                "required": ["petId"]
439            }),
440            output_schema: None,
441            method: "GET".to_string(),
442            path: "/pet/{petId}".to_string(),
443            security: None,
444        };
445
446        // Create OpenApiTool instance
447        let tool = Tool::new(tool_metadata, None, None).unwrap();
448
449        // Create server with tool
450        let mut server = Server::new(
451            serde_json::Value::Null,
452            url::Url::parse("http://example.com").unwrap(),
453            None,
454            None,
455            None,
456            false,
457            false,
458        );
459        server.tool_collection = ToolCollection::from_tools(vec![tool]);
460
461        // Test: Create ToolNotFound error with unrelated name
462        let tool_names = server.get_tool_names();
463        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
464
465        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
466            "completelyUnrelatedToolName".to_string(),
467            &tool_name_refs,
468        ));
469        let error_data: ErrorData = error.into();
470        let error_json = serde_json::to_value(&error_data).unwrap();
471
472        // Snapshot the error to verify no suggestions
473        insta::assert_json_snapshot!(error_json);
474    }
475
476    #[test]
477    fn test_validation_error_converted_to_error_data() {
478        // Test that validation errors are properly converted to ErrorData
479        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
480            violations: vec![crate::error::ValidationError::invalid_parameter(
481                "page".to_string(),
482                &["page_number".to_string(), "page_size".to_string()],
483            )],
484        });
485
486        let error_data: ErrorData = error.into();
487        let error_json = serde_json::to_value(&error_data).unwrap();
488
489        // Verify the basic structure
490        assert_eq!(error_json["code"], -32602); // Invalid params error code
491
492        // Snapshot the full error to verify the new error message format
493        insta::assert_json_snapshot!(error_json);
494    }
495
496    #[test]
497    fn test_extract_openapi_info_with_full_spec() {
498        let openapi_spec = json!({
499            "openapi": "3.0.0",
500            "info": {
501                "title": "Pet Store API",
502                "version": "2.1.0",
503                "description": "A sample API for managing pets"
504            },
505            "paths": {}
506        });
507
508        let server = Server::new(
509            openapi_spec,
510            url::Url::parse("http://example.com").unwrap(),
511            None,
512            None,
513            None,
514            false,
515            false,
516        );
517
518        assert_eq!(
519            server.extract_openapi_title(),
520            Some("Pet Store API".to_string())
521        );
522        assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
523        assert_eq!(
524            server.extract_openapi_description(),
525            Some("A sample API for managing pets".to_string())
526        );
527    }
528
529    #[test]
530    fn test_extract_openapi_info_with_minimal_spec() {
531        let openapi_spec = json!({
532            "openapi": "3.0.0",
533            "info": {
534                "title": "My API",
535                "version": "1.0.0"
536            },
537            "paths": {}
538        });
539
540        let server = Server::new(
541            openapi_spec,
542            url::Url::parse("http://example.com").unwrap(),
543            None,
544            None,
545            None,
546            false,
547            false,
548        );
549
550        assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
551        assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
552        assert_eq!(server.extract_openapi_description(), None);
553    }
554
555    #[test]
556    fn test_extract_openapi_info_with_invalid_spec() {
557        let openapi_spec = json!({
558            "invalid": "spec"
559        });
560
561        let server = Server::new(
562            openapi_spec,
563            url::Url::parse("http://example.com").unwrap(),
564            None,
565            None,
566            None,
567            false,
568            false,
569        );
570
571        assert_eq!(server.extract_openapi_title(), None);
572        assert_eq!(server.extract_openapi_version(), None);
573        assert_eq!(server.extract_openapi_description(), None);
574    }
575
576    #[test]
577    fn test_get_info_fallback_hierarchy_custom_metadata() {
578        let server = Server::new(
579            serde_json::Value::Null,
580            url::Url::parse("http://example.com").unwrap(),
581            None,
582            None,
583            None,
584            false,
585            false,
586        );
587
588        // Set custom metadata directly
589        let mut server = server;
590        server.name = Some("Custom Server".to_string());
591        server.version = Some("3.0.0".to_string());
592        server.instructions = Some("Custom instructions".to_string());
593
594        let result = server.get_info();
595
596        assert_eq!(result.server_info.name, "Custom Server");
597        assert_eq!(result.server_info.version, "3.0.0");
598        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
599    }
600
601    #[test]
602    fn test_get_info_fallback_hierarchy_openapi_spec() {
603        let openapi_spec = json!({
604            "openapi": "3.0.0",
605            "info": {
606                "title": "OpenAPI Server",
607                "version": "1.5.0",
608                "description": "Server from OpenAPI spec"
609            },
610            "paths": {}
611        });
612
613        let server = Server::new(
614            openapi_spec,
615            url::Url::parse("http://example.com").unwrap(),
616            None,
617            None,
618            None,
619            false,
620            false,
621        );
622
623        let result = server.get_info();
624
625        assert_eq!(result.server_info.name, "OpenAPI Server");
626        assert_eq!(result.server_info.version, "1.5.0");
627        assert_eq!(
628            result.instructions,
629            Some("Server from OpenAPI spec".to_string())
630        );
631    }
632
633    #[test]
634    fn test_get_info_fallback_hierarchy_defaults() {
635        let server = Server::new(
636            serde_json::Value::Null,
637            url::Url::parse("http://example.com").unwrap(),
638            None,
639            None,
640            None,
641            false,
642            false,
643        );
644
645        let result = server.get_info();
646
647        assert_eq!(result.server_info.name, "OpenAPI MCP Server");
648        assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
649        assert_eq!(
650            result.instructions,
651            Some("Exposes OpenAPI endpoints as MCP tools".to_string())
652        );
653    }
654
655    #[test]
656    fn test_get_info_fallback_hierarchy_mixed() {
657        let openapi_spec = json!({
658            "openapi": "3.0.0",
659            "info": {
660                "title": "OpenAPI Server",
661                "version": "2.5.0",
662                "description": "Server from OpenAPI spec"
663            },
664            "paths": {}
665        });
666
667        let mut server = Server::new(
668            openapi_spec,
669            url::Url::parse("http://example.com").unwrap(),
670            None,
671            None,
672            None,
673            false,
674            false,
675        );
676
677        // Set custom name and instructions, leave version to fallback to OpenAPI
678        server.name = Some("Custom Server".to_string());
679        server.instructions = Some("Custom instructions".to_string());
680
681        let result = server.get_info();
682
683        // Custom name takes precedence
684        assert_eq!(result.server_info.name, "Custom Server");
685        // OpenAPI version is used
686        assert_eq!(result.server_info.version, "2.5.0");
687        // Custom instructions take precedence
688        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
689    }
690}