rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    handler::server::ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::config::{Authorization, AuthorizationMode};
18use crate::error::Error;
19use crate::tool::{Tool, ToolCollection, ToolMetadata};
20use tracing::{debug, info, info_span, warn};
21
22#[derive(Clone, Builder)]
23pub struct Server {
24    pub openapi_spec: serde_json::Value,
25    #[builder(default)]
26    pub tool_collection: ToolCollection,
27    pub base_url: Url,
28    pub default_headers: Option<HeaderMap>,
29    pub tag_filter: Option<Vec<String>>,
30    pub method_filter: Option<Vec<reqwest::Method>>,
31    #[builder(default)]
32    pub authorization_mode: AuthorizationMode,
33    pub name: Option<String>,
34    pub version: Option<String>,
35    pub title: Option<String>,
36    pub instructions: Option<String>,
37    #[builder(default)]
38    pub skip_tool_descriptions: bool,
39    #[builder(default)]
40    pub skip_parameter_descriptions: bool,
41}
42
43impl Server {
44    /// Create a new Server instance with required parameters
45    pub fn new(
46        openapi_spec: serde_json::Value,
47        base_url: Url,
48        default_headers: Option<HeaderMap>,
49        tag_filter: Option<Vec<String>>,
50        method_filter: Option<Vec<reqwest::Method>>,
51        skip_tool_descriptions: bool,
52        skip_parameter_descriptions: bool,
53    ) -> Self {
54        Self {
55            openapi_spec,
56            tool_collection: ToolCollection::new(),
57            base_url,
58            default_headers,
59            tag_filter,
60            method_filter,
61            authorization_mode: AuthorizationMode::default(),
62            name: None,
63            version: None,
64            title: None,
65            instructions: None,
66            skip_tool_descriptions,
67            skip_parameter_descriptions,
68        }
69    }
70
71    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the spec cannot be parsed or tools cannot be generated
76    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
77        let span = info_span!("tool_registration");
78        let _enter = span.enter();
79
80        // Parse the OpenAPI specification
81        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
82
83        // Generate OpenApiTool instances directly
84        let tools = spec.to_openapi_tools(
85            self.tag_filter.as_deref(),
86            self.method_filter.as_deref(),
87            Some(self.base_url.clone()),
88            self.default_headers.clone(),
89            self.skip_tool_descriptions,
90            self.skip_parameter_descriptions,
91        )?;
92
93        self.tool_collection = ToolCollection::from_tools(tools);
94
95        info!(
96            tool_count = self.tool_collection.len(),
97            "Loaded tools from OpenAPI spec"
98        );
99
100        Ok(())
101    }
102
103    /// Get the number of loaded tools
104    #[must_use]
105    pub fn tool_count(&self) -> usize {
106        self.tool_collection.len()
107    }
108
109    /// Get all tool names
110    #[must_use]
111    pub fn get_tool_names(&self) -> Vec<String> {
112        self.tool_collection.get_tool_names()
113    }
114
115    /// Check if a specific tool exists
116    #[must_use]
117    pub fn has_tool(&self, name: &str) -> bool {
118        self.tool_collection.has_tool(name)
119    }
120
121    /// Get a tool by name
122    #[must_use]
123    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
124        self.tool_collection.get_tool(name)
125    }
126
127    /// Get tool metadata by name
128    #[must_use]
129    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
130        self.get_tool(name).map(|tool| &tool.metadata)
131    }
132
133    /// Set the authorization mode for the server
134    pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
135        self.authorization_mode = mode;
136    }
137
138    /// Get the current authorization mode
139    pub fn authorization_mode(&self) -> AuthorizationMode {
140        self.authorization_mode
141    }
142
143    /// Get basic tool statistics
144    #[must_use]
145    pub fn get_tool_stats(&self) -> String {
146        self.tool_collection.get_stats()
147    }
148
149    /// Simple validation - check that tools are loaded
150    ///
151    /// # Errors
152    ///
153    /// Returns an error if no tools are loaded
154    pub fn validate_registry(&self) -> Result<(), Error> {
155        if self.tool_collection.is_empty() {
156            return Err(Error::McpError("No tools loaded".to_string()));
157        }
158        Ok(())
159    }
160
161    /// Extract title from OpenAPI spec info section
162    fn extract_openapi_title(&self) -> Option<String> {
163        self.openapi_spec
164            .get("info")?
165            .get("title")?
166            .as_str()
167            .map(|s| s.to_string())
168    }
169
170    /// Extract version from OpenAPI spec info section
171    fn extract_openapi_version(&self) -> Option<String> {
172        self.openapi_spec
173            .get("info")?
174            .get("version")?
175            .as_str()
176            .map(|s| s.to_string())
177    }
178
179    /// Extract description from OpenAPI spec info section
180    fn extract_openapi_description(&self) -> Option<String> {
181        self.openapi_spec
182            .get("info")?
183            .get("description")?
184            .as_str()
185            .map(|s| s.to_string())
186    }
187
188    /// Extract display title from OpenAPI spec info section
189    /// First checks for x-display-title extension, then derives from title
190    fn extract_openapi_display_title(&self) -> Option<String> {
191        // First check for x-display-title extension
192        if let Some(display_title) = self
193            .openapi_spec
194            .get("info")
195            .and_then(|info| info.get("x-display-title"))
196            .and_then(|t| t.as_str())
197        {
198            return Some(display_title.to_string());
199        }
200
201        // Fallback: enhance the title with "Server" suffix if not already present
202        self.extract_openapi_title().map(|title| {
203            if title.to_lowercase().contains("server") {
204                title
205            } else {
206                format!("{} Server", title)
207            }
208        })
209    }
210}
211
212impl ServerHandler for Server {
213    fn get_info(&self) -> InitializeResult {
214        // 3-level fallback for server name: custom -> OpenAPI spec -> default
215        let server_name = self
216            .name
217            .clone()
218            .or_else(|| self.extract_openapi_title())
219            .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
220
221        // 3-level fallback for server version: custom -> OpenAPI spec -> crate version
222        let server_version = self
223            .version
224            .clone()
225            .or_else(|| self.extract_openapi_version())
226            .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
227
228        // 3-level fallback for title: custom -> OpenAPI-derived -> None
229        let server_title = self
230            .title
231            .clone()
232            .or_else(|| self.extract_openapi_display_title());
233
234        // 3-level fallback for instructions: custom -> OpenAPI spec -> default
235        let instructions = self
236            .instructions
237            .clone()
238            .or_else(|| self.extract_openapi_description())
239            .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
240
241        InitializeResult {
242            protocol_version: ProtocolVersion::V_2024_11_05,
243            server_info: Implementation {
244                name: server_name,
245                version: server_version,
246                title: server_title,
247                icons: None,
248                website_url: None,
249            },
250            capabilities: ServerCapabilities {
251                tools: Some(ToolsCapability {
252                    list_changed: Some(false),
253                }),
254                ..Default::default()
255            },
256            instructions,
257        }
258    }
259
260    async fn list_tools(
261        &self,
262        _request: Option<PaginatedRequestParam>,
263        _context: RequestContext<RoleServer>,
264    ) -> Result<ListToolsResult, ErrorData> {
265        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
266        let _enter = span.enter();
267
268        debug!("Processing MCP list_tools request");
269
270        // Delegate to tool collection for MCP tool conversion
271        let tools = self.tool_collection.to_mcp_tools();
272
273        info!(
274            returned_tools = tools.len(),
275            "MCP list_tools request completed successfully"
276        );
277
278        Ok(ListToolsResult {
279            tools,
280            next_cursor: None,
281        })
282    }
283
284    async fn call_tool(
285        &self,
286        request: CallToolRequestParam,
287        context: RequestContext<RoleServer>,
288    ) -> Result<CallToolResult, ErrorData> {
289        let span = info_span!(
290            "call_tool",
291            tool_name = %request.name
292        );
293        let _enter = span.enter();
294
295        debug!(
296            tool_name = %request.name,
297            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
298            "Processing MCP call_tool request"
299        );
300
301        let arguments = request.arguments.unwrap_or_default();
302        let arguments_value = Value::Object(arguments);
303
304        // Extract authorization header from context extensions
305        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
306
307        if auth_header.is_some() {
308            debug!("Authorization header is present");
309        }
310
311        // Create Authorization enum from mode and header
312        let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
313
314        // Delegate all tool validation and execution to the tool collection
315        match self
316            .tool_collection
317            .call_tool(&request.name, &arguments_value, authorization)
318            .await
319        {
320            Ok(result) => {
321                info!(
322                    tool_name = %request.name,
323                    success = true,
324                    "MCP call_tool request completed successfully"
325                );
326                Ok(result)
327            }
328            Err(e) => {
329                warn!(
330                    tool_name = %request.name,
331                    success = false,
332                    error = %e,
333                    "MCP call_tool request failed"
334                );
335                // Convert ToolCallError to ErrorData and return as error
336                Err(e.into())
337            }
338        }
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::error::ToolCallValidationError;
346    use crate::{ToolCallError, ToolMetadata};
347    use serde_json::json;
348
349    #[test]
350    fn test_tool_not_found_error_with_suggestions() {
351        // Create test tool metadata
352        let tool1_metadata = ToolMetadata {
353            name: "getPetById".to_string(),
354            title: Some("Get Pet by ID".to_string()),
355            description: Some("Find pet by ID".to_string()),
356            parameters: json!({
357                "type": "object",
358                "properties": {
359                    "petId": {
360                        "type": "integer"
361                    }
362                },
363                "required": ["petId"]
364            }),
365            output_schema: None,
366            method: "GET".to_string(),
367            path: "/pet/{petId}".to_string(),
368            security: None,
369            parameter_mappings: std::collections::HashMap::new(),
370        };
371
372        let tool2_metadata = ToolMetadata {
373            name: "getPetsByStatus".to_string(),
374            title: Some("Find Pets by Status".to_string()),
375            description: Some("Find pets by status".to_string()),
376            parameters: json!({
377                "type": "object",
378                "properties": {
379                    "status": {
380                        "type": "array",
381                        "items": {
382                            "type": "string"
383                        }
384                    }
385                },
386                "required": ["status"]
387            }),
388            output_schema: None,
389            method: "GET".to_string(),
390            path: "/pet/findByStatus".to_string(),
391            security: None,
392            parameter_mappings: std::collections::HashMap::new(),
393        };
394
395        // Create OpenApiTool instances
396        let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
397        let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
398
399        // Create server with tools
400        let mut server = Server::new(
401            serde_json::Value::Null,
402            url::Url::parse("http://example.com").unwrap(),
403            None,
404            None,
405            None,
406            false,
407            false,
408        );
409        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
410
411        // Test: Create ToolNotFound error with a typo
412        let tool_names = server.get_tool_names();
413        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
414
415        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
416            "getPetByID".to_string(),
417            &tool_name_refs,
418        ));
419        let error_data: ErrorData = error.into();
420        let error_json = serde_json::to_value(&error_data).unwrap();
421
422        // Snapshot the error to verify suggestions
423        insta::assert_json_snapshot!(error_json);
424    }
425
426    #[test]
427    fn test_tool_not_found_error_no_suggestions() {
428        // Create test tool metadata
429        let tool_metadata = ToolMetadata {
430            name: "getPetById".to_string(),
431            title: Some("Get Pet by ID".to_string()),
432            description: Some("Find pet by ID".to_string()),
433            parameters: json!({
434                "type": "object",
435                "properties": {
436                    "petId": {
437                        "type": "integer"
438                    }
439                },
440                "required": ["petId"]
441            }),
442            output_schema: None,
443            method: "GET".to_string(),
444            path: "/pet/{petId}".to_string(),
445            security: None,
446            parameter_mappings: std::collections::HashMap::new(),
447        };
448
449        // Create OpenApiTool instance
450        let tool = Tool::new(tool_metadata, None, None).unwrap();
451
452        // Create server with tool
453        let mut server = Server::new(
454            serde_json::Value::Null,
455            url::Url::parse("http://example.com").unwrap(),
456            None,
457            None,
458            None,
459            false,
460            false,
461        );
462        server.tool_collection = ToolCollection::from_tools(vec![tool]);
463
464        // Test: Create ToolNotFound error with unrelated name
465        let tool_names = server.get_tool_names();
466        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
467
468        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
469            "completelyUnrelatedToolName".to_string(),
470            &tool_name_refs,
471        ));
472        let error_data: ErrorData = error.into();
473        let error_json = serde_json::to_value(&error_data).unwrap();
474
475        // Snapshot the error to verify no suggestions
476        insta::assert_json_snapshot!(error_json);
477    }
478
479    #[test]
480    fn test_validation_error_converted_to_error_data() {
481        // Test that validation errors are properly converted to ErrorData
482        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
483            violations: vec![crate::error::ValidationError::invalid_parameter(
484                "page".to_string(),
485                &["page_number".to_string(), "page_size".to_string()],
486            )],
487        });
488
489        let error_data: ErrorData = error.into();
490        let error_json = serde_json::to_value(&error_data).unwrap();
491
492        // Verify the basic structure
493        assert_eq!(error_json["code"], -32602); // Invalid params error code
494
495        // Snapshot the full error to verify the new error message format
496        insta::assert_json_snapshot!(error_json);
497    }
498
499    #[test]
500    fn test_extract_openapi_info_with_full_spec() {
501        let openapi_spec = json!({
502            "openapi": "3.0.0",
503            "info": {
504                "title": "Pet Store API",
505                "version": "2.1.0",
506                "description": "A sample API for managing pets"
507            },
508            "paths": {}
509        });
510
511        let server = Server::new(
512            openapi_spec,
513            url::Url::parse("http://example.com").unwrap(),
514            None,
515            None,
516            None,
517            false,
518            false,
519        );
520
521        assert_eq!(
522            server.extract_openapi_title(),
523            Some("Pet Store API".to_string())
524        );
525        assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
526        assert_eq!(
527            server.extract_openapi_description(),
528            Some("A sample API for managing pets".to_string())
529        );
530    }
531
532    #[test]
533    fn test_extract_openapi_info_with_minimal_spec() {
534        let openapi_spec = json!({
535            "openapi": "3.0.0",
536            "info": {
537                "title": "My API",
538                "version": "1.0.0"
539            },
540            "paths": {}
541        });
542
543        let server = Server::new(
544            openapi_spec,
545            url::Url::parse("http://example.com").unwrap(),
546            None,
547            None,
548            None,
549            false,
550            false,
551        );
552
553        assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
554        assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
555        assert_eq!(server.extract_openapi_description(), None);
556    }
557
558    #[test]
559    fn test_extract_openapi_info_with_invalid_spec() {
560        let openapi_spec = json!({
561            "invalid": "spec"
562        });
563
564        let server = Server::new(
565            openapi_spec,
566            url::Url::parse("http://example.com").unwrap(),
567            None,
568            None,
569            None,
570            false,
571            false,
572        );
573
574        assert_eq!(server.extract_openapi_title(), None);
575        assert_eq!(server.extract_openapi_version(), None);
576        assert_eq!(server.extract_openapi_description(), None);
577    }
578
579    #[test]
580    fn test_get_info_fallback_hierarchy_custom_metadata() {
581        let server = Server::new(
582            serde_json::Value::Null,
583            url::Url::parse("http://example.com").unwrap(),
584            None,
585            None,
586            None,
587            false,
588            false,
589        );
590
591        // Set custom metadata directly
592        let mut server = server;
593        server.name = Some("Custom Server".to_string());
594        server.version = Some("3.0.0".to_string());
595        server.instructions = Some("Custom instructions".to_string());
596
597        let result = server.get_info();
598
599        assert_eq!(result.server_info.name, "Custom Server");
600        assert_eq!(result.server_info.version, "3.0.0");
601        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
602    }
603
604    #[test]
605    fn test_get_info_fallback_hierarchy_openapi_spec() {
606        let openapi_spec = json!({
607            "openapi": "3.0.0",
608            "info": {
609                "title": "OpenAPI Server",
610                "version": "1.5.0",
611                "description": "Server from OpenAPI spec"
612            },
613            "paths": {}
614        });
615
616        let server = Server::new(
617            openapi_spec,
618            url::Url::parse("http://example.com").unwrap(),
619            None,
620            None,
621            None,
622            false,
623            false,
624        );
625
626        let result = server.get_info();
627
628        assert_eq!(result.server_info.name, "OpenAPI Server");
629        assert_eq!(result.server_info.version, "1.5.0");
630        assert_eq!(
631            result.instructions,
632            Some("Server from OpenAPI spec".to_string())
633        );
634    }
635
636    #[test]
637    fn test_get_info_fallback_hierarchy_defaults() {
638        let server = Server::new(
639            serde_json::Value::Null,
640            url::Url::parse("http://example.com").unwrap(),
641            None,
642            None,
643            None,
644            false,
645            false,
646        );
647
648        let result = server.get_info();
649
650        assert_eq!(result.server_info.name, "OpenAPI MCP Server");
651        assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
652        assert_eq!(
653            result.instructions,
654            Some("Exposes OpenAPI endpoints as MCP tools".to_string())
655        );
656    }
657
658    #[test]
659    fn test_get_info_fallback_hierarchy_mixed() {
660        let openapi_spec = json!({
661            "openapi": "3.0.0",
662            "info": {
663                "title": "OpenAPI Server",
664                "version": "2.5.0",
665                "description": "Server from OpenAPI spec"
666            },
667            "paths": {}
668        });
669
670        let mut server = Server::new(
671            openapi_spec,
672            url::Url::parse("http://example.com").unwrap(),
673            None,
674            None,
675            None,
676            false,
677            false,
678        );
679
680        // Set custom name and instructions, leave version to fallback to OpenAPI
681        server.name = Some("Custom Server".to_string());
682        server.instructions = Some("Custom instructions".to_string());
683
684        let result = server.get_info();
685
686        // Custom name takes precedence
687        assert_eq!(result.server_info.name, "Custom Server");
688        // OpenAPI version is used
689        assert_eq!(result.server_info.version, "2.5.0");
690        // Custom instructions take precedence
691        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
692    }
693}