rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    handler::server::ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::error::Error;
18use crate::tool::{Tool, ToolCollection, ToolMetadata};
19use crate::{
20    config::{Authorization, AuthorizationMode},
21    spec::Filters,
22};
23use tracing::{debug, info, info_span, warn};
24
25#[derive(Clone, Builder)]
26pub struct Server {
27    pub openapi_spec: serde_json::Value,
28    #[builder(default)]
29    pub tool_collection: ToolCollection,
30    pub base_url: Url,
31    pub default_headers: Option<HeaderMap>,
32    pub filters: Option<Filters>,
33    #[builder(default)]
34    pub authorization_mode: AuthorizationMode,
35    pub name: Option<String>,
36    pub version: Option<String>,
37    pub title: Option<String>,
38    pub instructions: Option<String>,
39    #[builder(default)]
40    pub skip_tool_descriptions: bool,
41    #[builder(default)]
42    pub skip_parameter_descriptions: bool,
43}
44
45impl Server {
46    /// Create a new Server instance with required parameters
47    pub fn new(
48        openapi_spec: serde_json::Value,
49        base_url: Url,
50        default_headers: Option<HeaderMap>,
51        filters: Option<Filters>,
52        skip_tool_descriptions: bool,
53        skip_parameter_descriptions: bool,
54    ) -> Self {
55        Self {
56            openapi_spec,
57            tool_collection: ToolCollection::new(),
58            base_url,
59            default_headers,
60            filters,
61            authorization_mode: AuthorizationMode::default(),
62            name: None,
63            version: None,
64            title: None,
65            instructions: None,
66            skip_tool_descriptions,
67            skip_parameter_descriptions,
68        }
69    }
70
71    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the spec cannot be parsed or tools cannot be generated
76    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
77        let span = info_span!("tool_registration");
78        let _enter = span.enter();
79
80        // Parse the OpenAPI specification
81        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
82
83        // Generate OpenApiTool instances directly
84        let tools = spec.to_openapi_tools(
85            self.filters.as_ref(),
86            Some(self.base_url.clone()),
87            self.default_headers.clone(),
88            self.skip_tool_descriptions,
89            self.skip_parameter_descriptions,
90        )?;
91
92        self.tool_collection = ToolCollection::from_tools(tools);
93
94        info!(
95            tool_count = self.tool_collection.len(),
96            "Loaded tools from OpenAPI spec"
97        );
98
99        Ok(())
100    }
101
102    /// Get the number of loaded tools
103    #[must_use]
104    pub fn tool_count(&self) -> usize {
105        self.tool_collection.len()
106    }
107
108    /// Get all tool names
109    #[must_use]
110    pub fn get_tool_names(&self) -> Vec<String> {
111        self.tool_collection.get_tool_names()
112    }
113
114    /// Check if a specific tool exists
115    #[must_use]
116    pub fn has_tool(&self, name: &str) -> bool {
117        self.tool_collection.has_tool(name)
118    }
119
120    /// Get a tool by name
121    #[must_use]
122    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
123        self.tool_collection.get_tool(name)
124    }
125
126    /// Get tool metadata by name
127    #[must_use]
128    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
129        self.get_tool(name).map(|tool| &tool.metadata)
130    }
131
132    /// Set the authorization mode for the server
133    pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
134        self.authorization_mode = mode;
135    }
136
137    /// Get the current authorization mode
138    pub fn authorization_mode(&self) -> AuthorizationMode {
139        self.authorization_mode
140    }
141
142    /// Get basic tool statistics
143    #[must_use]
144    pub fn get_tool_stats(&self) -> String {
145        self.tool_collection.get_stats()
146    }
147
148    /// Simple validation - check that tools are loaded
149    ///
150    /// # Errors
151    ///
152    /// Returns an error if no tools are loaded
153    pub fn validate_registry(&self) -> Result<(), Error> {
154        if self.tool_collection.is_empty() {
155            return Err(Error::McpError("No tools loaded".to_string()));
156        }
157        Ok(())
158    }
159
160    /// Extract title from OpenAPI spec info section
161    fn extract_openapi_title(&self) -> Option<String> {
162        self.openapi_spec
163            .get("info")?
164            .get("title")?
165            .as_str()
166            .map(|s| s.to_string())
167    }
168
169    /// Extract version from OpenAPI spec info section
170    fn extract_openapi_version(&self) -> Option<String> {
171        self.openapi_spec
172            .get("info")?
173            .get("version")?
174            .as_str()
175            .map(|s| s.to_string())
176    }
177
178    /// Extract description from OpenAPI spec info section
179    fn extract_openapi_description(&self) -> Option<String> {
180        self.openapi_spec
181            .get("info")?
182            .get("description")?
183            .as_str()
184            .map(|s| s.to_string())
185    }
186
187    /// Extract display title from OpenAPI spec info section
188    /// First checks for x-display-title extension, then derives from title
189    fn extract_openapi_display_title(&self) -> Option<String> {
190        // First check for x-display-title extension
191        if let Some(display_title) = self
192            .openapi_spec
193            .get("info")
194            .and_then(|info| info.get("x-display-title"))
195            .and_then(|t| t.as_str())
196        {
197            return Some(display_title.to_string());
198        }
199
200        // Fallback: enhance the title with "Server" suffix if not already present
201        self.extract_openapi_title().map(|title| {
202            if title.to_lowercase().contains("server") {
203                title
204            } else {
205                format!("{} Server", title)
206            }
207        })
208    }
209}
210
211impl ServerHandler for Server {
212    fn get_info(&self) -> InitializeResult {
213        // 3-level fallback for server name: custom -> OpenAPI spec -> default
214        let server_name = self
215            .name
216            .clone()
217            .or_else(|| self.extract_openapi_title())
218            .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
219
220        // 3-level fallback for server version: custom -> OpenAPI spec -> crate version
221        let server_version = self
222            .version
223            .clone()
224            .or_else(|| self.extract_openapi_version())
225            .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
226
227        // 3-level fallback for title: custom -> OpenAPI-derived -> None
228        let server_title = self
229            .title
230            .clone()
231            .or_else(|| self.extract_openapi_display_title());
232
233        // 3-level fallback for instructions: custom -> OpenAPI spec -> default
234        let instructions = self
235            .instructions
236            .clone()
237            .or_else(|| self.extract_openapi_description())
238            .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
239
240        InitializeResult {
241            protocol_version: ProtocolVersion::V_2024_11_05,
242            server_info: Implementation {
243                name: server_name,
244                version: server_version,
245                title: server_title,
246                icons: None,
247                website_url: None,
248            },
249            capabilities: ServerCapabilities {
250                tools: Some(ToolsCapability {
251                    list_changed: Some(false),
252                }),
253                ..Default::default()
254            },
255            instructions,
256        }
257    }
258
259    async fn list_tools(
260        &self,
261        _request: Option<PaginatedRequestParam>,
262        _context: RequestContext<RoleServer>,
263    ) -> Result<ListToolsResult, ErrorData> {
264        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
265        let _enter = span.enter();
266
267        debug!("Processing MCP list_tools request");
268
269        // Delegate to tool collection for MCP tool conversion
270        let tools = self.tool_collection.to_mcp_tools();
271
272        info!(
273            returned_tools = tools.len(),
274            "MCP list_tools request completed successfully"
275        );
276
277        Ok(ListToolsResult {
278            tools,
279            next_cursor: None,
280        })
281    }
282
283    async fn call_tool(
284        &self,
285        request: CallToolRequestParam,
286        context: RequestContext<RoleServer>,
287    ) -> Result<CallToolResult, ErrorData> {
288        let span = info_span!(
289            "call_tool",
290            tool_name = %request.name
291        );
292        let _enter = span.enter();
293
294        debug!(
295            tool_name = %request.name,
296            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
297            "Processing MCP call_tool request"
298        );
299
300        let arguments = request.arguments.unwrap_or_default();
301        let arguments_value = Value::Object(arguments);
302
303        // Extract authorization header from context extensions
304        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
305
306        if auth_header.is_some() {
307            debug!("Authorization header is present");
308        }
309
310        // Create Authorization enum from mode and header
311        let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
312
313        // Delegate all tool validation and execution to the tool collection
314        match self
315            .tool_collection
316            .call_tool(&request.name, &arguments_value, authorization)
317            .await
318        {
319            Ok(result) => {
320                info!(
321                    tool_name = %request.name,
322                    success = true,
323                    "MCP call_tool request completed successfully"
324                );
325                Ok(result)
326            }
327            Err(e) => {
328                warn!(
329                    tool_name = %request.name,
330                    success = false,
331                    error = %e,
332                    "MCP call_tool request failed"
333                );
334                // Convert ToolCallError to ErrorData and return as error
335                Err(e.into())
336            }
337        }
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use crate::error::ToolCallValidationError;
345    use crate::{ToolCallError, ToolMetadata};
346    use serde_json::json;
347
348    #[test]
349    fn test_tool_not_found_error_with_suggestions() {
350        // Create test tool metadata
351        let tool1_metadata = ToolMetadata {
352            name: "getPetById".to_string(),
353            title: Some("Get Pet by ID".to_string()),
354            description: Some("Find pet by ID".to_string()),
355            parameters: json!({
356                "type": "object",
357                "properties": {
358                    "petId": {
359                        "type": "integer"
360                    }
361                },
362                "required": ["petId"]
363            }),
364            output_schema: None,
365            method: "GET".to_string(),
366            path: "/pet/{petId}".to_string(),
367            security: None,
368            parameter_mappings: std::collections::HashMap::new(),
369        };
370
371        let tool2_metadata = ToolMetadata {
372            name: "getPetsByStatus".to_string(),
373            title: Some("Find Pets by Status".to_string()),
374            description: Some("Find pets by status".to_string()),
375            parameters: json!({
376                "type": "object",
377                "properties": {
378                    "status": {
379                        "type": "array",
380                        "items": {
381                            "type": "string"
382                        }
383                    }
384                },
385                "required": ["status"]
386            }),
387            output_schema: None,
388            method: "GET".to_string(),
389            path: "/pet/findByStatus".to_string(),
390            security: None,
391            parameter_mappings: std::collections::HashMap::new(),
392        };
393
394        // Create OpenApiTool instances
395        let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
396        let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
397
398        // Create server with tools
399        let mut server = Server::new(
400            serde_json::Value::Null,
401            url::Url::parse("http://example.com").unwrap(),
402            None,
403            None,
404            false,
405            false,
406        );
407        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
408
409        // Test: Create ToolNotFound error with a typo
410        let tool_names = server.get_tool_names();
411        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
412
413        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
414            "getPetByID".to_string(),
415            &tool_name_refs,
416        ));
417        let error_data: ErrorData = error.into();
418        let error_json = serde_json::to_value(&error_data).unwrap();
419
420        // Snapshot the error to verify suggestions
421        insta::assert_json_snapshot!(error_json);
422    }
423
424    #[test]
425    fn test_tool_not_found_error_no_suggestions() {
426        // Create test tool metadata
427        let tool_metadata = ToolMetadata {
428            name: "getPetById".to_string(),
429            title: Some("Get Pet by ID".to_string()),
430            description: Some("Find pet by ID".to_string()),
431            parameters: json!({
432                "type": "object",
433                "properties": {
434                    "petId": {
435                        "type": "integer"
436                    }
437                },
438                "required": ["petId"]
439            }),
440            output_schema: None,
441            method: "GET".to_string(),
442            path: "/pet/{petId}".to_string(),
443            security: None,
444            parameter_mappings: std::collections::HashMap::new(),
445        };
446
447        // Create OpenApiTool instance
448        let tool = Tool::new(tool_metadata, None, None).unwrap();
449
450        // Create server with tool
451        let mut server = Server::new(
452            serde_json::Value::Null,
453            url::Url::parse("http://example.com").unwrap(),
454            None,
455            None,
456            false,
457            false,
458        );
459        server.tool_collection = ToolCollection::from_tools(vec![tool]);
460
461        // Test: Create ToolNotFound error with unrelated name
462        let tool_names = server.get_tool_names();
463        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
464
465        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
466            "completelyUnrelatedToolName".to_string(),
467            &tool_name_refs,
468        ));
469        let error_data: ErrorData = error.into();
470        let error_json = serde_json::to_value(&error_data).unwrap();
471
472        // Snapshot the error to verify no suggestions
473        insta::assert_json_snapshot!(error_json);
474    }
475
476    #[test]
477    fn test_validation_error_converted_to_error_data() {
478        // Test that validation errors are properly converted to ErrorData
479        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
480            violations: vec![crate::error::ValidationError::invalid_parameter(
481                "page".to_string(),
482                &["page_number".to_string(), "page_size".to_string()],
483            )],
484        });
485
486        let error_data: ErrorData = error.into();
487        let error_json = serde_json::to_value(&error_data).unwrap();
488
489        // Verify the basic structure
490        assert_eq!(error_json["code"], -32602); // Invalid params error code
491
492        // Snapshot the full error to verify the new error message format
493        insta::assert_json_snapshot!(error_json);
494    }
495
496    #[test]
497    fn test_extract_openapi_info_with_full_spec() {
498        let openapi_spec = json!({
499            "openapi": "3.0.0",
500            "info": {
501                "title": "Pet Store API",
502                "version": "2.1.0",
503                "description": "A sample API for managing pets"
504            },
505            "paths": {}
506        });
507
508        let server = Server::new(
509            openapi_spec,
510            url::Url::parse("http://example.com").unwrap(),
511            None,
512            None,
513            false,
514            false,
515        );
516
517        assert_eq!(
518            server.extract_openapi_title(),
519            Some("Pet Store API".to_string())
520        );
521        assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
522        assert_eq!(
523            server.extract_openapi_description(),
524            Some("A sample API for managing pets".to_string())
525        );
526    }
527
528    #[test]
529    fn test_extract_openapi_info_with_minimal_spec() {
530        let openapi_spec = json!({
531            "openapi": "3.0.0",
532            "info": {
533                "title": "My API",
534                "version": "1.0.0"
535            },
536            "paths": {}
537        });
538
539        let server = Server::new(
540            openapi_spec,
541            url::Url::parse("http://example.com").unwrap(),
542            None,
543            None,
544            false,
545            false,
546        );
547
548        assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
549        assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
550        assert_eq!(server.extract_openapi_description(), None);
551    }
552
553    #[test]
554    fn test_extract_openapi_info_with_invalid_spec() {
555        let openapi_spec = json!({
556            "invalid": "spec"
557        });
558
559        let server = Server::new(
560            openapi_spec,
561            url::Url::parse("http://example.com").unwrap(),
562            None,
563            None,
564            false,
565            false,
566        );
567
568        assert_eq!(server.extract_openapi_title(), None);
569        assert_eq!(server.extract_openapi_version(), None);
570        assert_eq!(server.extract_openapi_description(), None);
571    }
572
573    #[test]
574    fn test_get_info_fallback_hierarchy_custom_metadata() {
575        let server = Server::new(
576            serde_json::Value::Null,
577            url::Url::parse("http://example.com").unwrap(),
578            None,
579            None,
580            false,
581            false,
582        );
583
584        // Set custom metadata directly
585        let mut server = server;
586        server.name = Some("Custom Server".to_string());
587        server.version = Some("3.0.0".to_string());
588        server.instructions = Some("Custom instructions".to_string());
589
590        let result = server.get_info();
591
592        assert_eq!(result.server_info.name, "Custom Server");
593        assert_eq!(result.server_info.version, "3.0.0");
594        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
595    }
596
597    #[test]
598    fn test_get_info_fallback_hierarchy_openapi_spec() {
599        let openapi_spec = json!({
600            "openapi": "3.0.0",
601            "info": {
602                "title": "OpenAPI Server",
603                "version": "1.5.0",
604                "description": "Server from OpenAPI spec"
605            },
606            "paths": {}
607        });
608
609        let server = Server::new(
610            openapi_spec,
611            url::Url::parse("http://example.com").unwrap(),
612            None,
613            None,
614            false,
615            false,
616        );
617
618        let result = server.get_info();
619
620        assert_eq!(result.server_info.name, "OpenAPI Server");
621        assert_eq!(result.server_info.version, "1.5.0");
622        assert_eq!(
623            result.instructions,
624            Some("Server from OpenAPI spec".to_string())
625        );
626    }
627
628    #[test]
629    fn test_get_info_fallback_hierarchy_defaults() {
630        let server = Server::new(
631            serde_json::Value::Null,
632            url::Url::parse("http://example.com").unwrap(),
633            None,
634            None,
635            false,
636            false,
637        );
638
639        let result = server.get_info();
640
641        assert_eq!(result.server_info.name, "OpenAPI MCP Server");
642        assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
643        assert_eq!(
644            result.instructions,
645            Some("Exposes OpenAPI endpoints as MCP tools".to_string())
646        );
647    }
648
649    #[test]
650    fn test_get_info_fallback_hierarchy_mixed() {
651        let openapi_spec = json!({
652            "openapi": "3.0.0",
653            "info": {
654                "title": "OpenAPI Server",
655                "version": "2.5.0",
656                "description": "Server from OpenAPI spec"
657            },
658            "paths": {}
659        });
660
661        let mut server = Server::new(
662            openapi_spec,
663            url::Url::parse("http://example.com").unwrap(),
664            None,
665            None,
666            false,
667            false,
668        );
669
670        // Set custom name and instructions, leave version to fallback to OpenAPI
671        server.name = Some("Custom Server".to_string());
672        server.instructions = Some("Custom instructions".to_string());
673
674        let result = server.get_info();
675
676        // Custom name takes precedence
677        assert_eq!(result.server_info.name, "Custom Server");
678        // OpenAPI version is used
679        assert_eq!(result.server_info.version, "2.5.0");
680        // Custom instructions take precedence
681        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
682    }
683}