rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    handler::server::ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13use std::sync::Arc;
14
15use reqwest::header::HeaderMap;
16use url::Url;
17
18use crate::error::Error;
19use crate::tool::{Tool, ToolCollection, ToolMetadata};
20use crate::transformer::ResponseTransformer;
21use crate::{
22    config::{Authorization, AuthorizationMode},
23    spec::Filters,
24};
25use tracing::{debug, info, info_span, warn};
26
27#[derive(Clone, Builder)]
28pub struct Server {
29    pub openapi_spec: serde_json::Value,
30    #[builder(default)]
31    pub tool_collection: ToolCollection,
32    pub base_url: Url,
33    pub default_headers: Option<HeaderMap>,
34    pub filters: Option<Filters>,
35    #[builder(default)]
36    pub authorization_mode: AuthorizationMode,
37    pub name: Option<String>,
38    pub version: Option<String>,
39    pub title: Option<String>,
40    pub instructions: Option<String>,
41    #[builder(default)]
42    pub skip_tool_descriptions: bool,
43    #[builder(default)]
44    pub skip_parameter_descriptions: bool,
45    /// Global response transformer applied to all tools.
46    ///
47    /// Uses dynamic dispatch (`Arc<dyn>`) because:
48    /// - Transformer runs once per HTTP call (10-1000ms latency)
49    /// - Vtable lookup overhead (~1ns) is unmeasurable
50    /// - Avoids viral generics throughout Server, Tool, ToolCollection
51    pub response_transformer: Option<Arc<dyn ResponseTransformer>>,
52}
53
54impl Server {
55    /// Create a new Server instance with required parameters
56    pub fn new(
57        openapi_spec: serde_json::Value,
58        base_url: Url,
59        default_headers: Option<HeaderMap>,
60        filters: Option<Filters>,
61        skip_tool_descriptions: bool,
62        skip_parameter_descriptions: bool,
63    ) -> Self {
64        Self {
65            openapi_spec,
66            tool_collection: ToolCollection::new(),
67            base_url,
68            default_headers,
69            filters,
70            authorization_mode: AuthorizationMode::default(),
71            name: None,
72            version: None,
73            title: None,
74            instructions: None,
75            skip_tool_descriptions,
76            skip_parameter_descriptions,
77            response_transformer: None,
78        }
79    }
80
81    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
82    ///
83    /// # Errors
84    ///
85    /// Returns an error if the spec cannot be parsed or tools cannot be generated
86    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
87        let span = info_span!("tool_registration");
88        let _enter = span.enter();
89
90        // Parse the OpenAPI specification
91        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
92
93        // Generate OpenApiTool instances directly
94        let tools = spec.to_openapi_tools(
95            self.filters.as_ref(),
96            Some(self.base_url.clone()),
97            self.default_headers.clone(),
98            self.skip_tool_descriptions,
99            self.skip_parameter_descriptions,
100        )?;
101
102        // Apply global transformer to schemas if present
103        let tools = if let Some(ref transformer) = self.response_transformer {
104            tools
105                .into_iter()
106                .map(|mut tool| {
107                    if let Some(schema) = tool.metadata.output_schema.take() {
108                        tool.metadata.output_schema = Some(transformer.transform_schema(schema));
109                    }
110                    tool
111                })
112                .collect()
113        } else {
114            tools
115        };
116
117        self.tool_collection = ToolCollection::from_tools(tools);
118
119        info!(
120            tool_count = self.tool_collection.len(),
121            "Loaded tools from OpenAPI spec"
122        );
123
124        Ok(())
125    }
126
127    /// Set a response transformer for a specific tool, overriding the global one.
128    ///
129    /// The transformer's `transform_schema` method is immediately applied to the tool's
130    /// output schema. The `transform_response` method will be applied to responses
131    /// when the tool is called.
132    ///
133    /// # Errors
134    ///
135    /// Returns an error if the tool is not found
136    pub fn set_tool_transformer(
137        &mut self,
138        tool_name: &str,
139        transformer: Arc<dyn ResponseTransformer>,
140    ) -> Result<(), Error> {
141        self.tool_collection
142            .set_tool_transformer(tool_name, transformer)
143    }
144
145    /// Get the number of loaded tools
146    #[must_use]
147    pub fn tool_count(&self) -> usize {
148        self.tool_collection.len()
149    }
150
151    /// Get all tool names
152    #[must_use]
153    pub fn get_tool_names(&self) -> Vec<String> {
154        self.tool_collection.get_tool_names()
155    }
156
157    /// Check if a specific tool exists
158    #[must_use]
159    pub fn has_tool(&self, name: &str) -> bool {
160        self.tool_collection.has_tool(name)
161    }
162
163    /// Get a tool by name
164    #[must_use]
165    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
166        self.tool_collection.get_tool(name)
167    }
168
169    /// Get tool metadata by name
170    #[must_use]
171    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
172        self.get_tool(name).map(|tool| &tool.metadata)
173    }
174
175    /// Set the authorization mode for the server
176    pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
177        self.authorization_mode = mode;
178    }
179
180    /// Get the current authorization mode
181    pub fn authorization_mode(&self) -> AuthorizationMode {
182        self.authorization_mode
183    }
184
185    /// Get basic tool statistics
186    #[must_use]
187    pub fn get_tool_stats(&self) -> String {
188        self.tool_collection.get_stats()
189    }
190
191    /// Simple validation - check that tools are loaded
192    ///
193    /// # Errors
194    ///
195    /// Returns an error if no tools are loaded
196    pub fn validate_registry(&self) -> Result<(), Error> {
197        if self.tool_collection.is_empty() {
198            return Err(Error::McpError("No tools loaded".to_string()));
199        }
200        Ok(())
201    }
202
203    /// Extract title from OpenAPI spec info section
204    fn extract_openapi_title(&self) -> Option<String> {
205        self.openapi_spec
206            .get("info")?
207            .get("title")?
208            .as_str()
209            .map(|s| s.to_string())
210    }
211
212    /// Extract version from OpenAPI spec info section
213    fn extract_openapi_version(&self) -> Option<String> {
214        self.openapi_spec
215            .get("info")?
216            .get("version")?
217            .as_str()
218            .map(|s| s.to_string())
219    }
220
221    /// Extract description from OpenAPI spec info section
222    fn extract_openapi_description(&self) -> Option<String> {
223        self.openapi_spec
224            .get("info")?
225            .get("description")?
226            .as_str()
227            .map(|s| s.to_string())
228    }
229
230    /// Extract display title from OpenAPI spec info section
231    /// First checks for x-display-title extension, then derives from title
232    fn extract_openapi_display_title(&self) -> Option<String> {
233        // First check for x-display-title extension
234        if let Some(display_title) = self
235            .openapi_spec
236            .get("info")
237            .and_then(|info| info.get("x-display-title"))
238            .and_then(|t| t.as_str())
239        {
240            return Some(display_title.to_string());
241        }
242
243        // Fallback: enhance the title with "Server" suffix if not already present
244        self.extract_openapi_title().map(|title| {
245            if title.to_lowercase().contains("server") {
246                title
247            } else {
248                format!("{} Server", title)
249            }
250        })
251    }
252}
253
254impl ServerHandler for Server {
255    fn get_info(&self) -> InitializeResult {
256        // 3-level fallback for server name: custom -> OpenAPI spec -> default
257        let server_name = self
258            .name
259            .clone()
260            .or_else(|| self.extract_openapi_title())
261            .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
262
263        // 3-level fallback for server version: custom -> OpenAPI spec -> crate version
264        let server_version = self
265            .version
266            .clone()
267            .or_else(|| self.extract_openapi_version())
268            .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
269
270        // 3-level fallback for title: custom -> OpenAPI-derived -> None
271        let server_title = self
272            .title
273            .clone()
274            .or_else(|| self.extract_openapi_display_title());
275
276        // 3-level fallback for instructions: custom -> OpenAPI spec -> default
277        let instructions = self
278            .instructions
279            .clone()
280            .or_else(|| self.extract_openapi_description())
281            .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
282
283        InitializeResult {
284            protocol_version: ProtocolVersion::V_2024_11_05,
285            server_info: Implementation {
286                name: server_name,
287                version: server_version,
288                title: server_title,
289                icons: None,
290                website_url: None,
291            },
292            capabilities: ServerCapabilities {
293                tools: Some(ToolsCapability {
294                    list_changed: Some(false),
295                }),
296                ..Default::default()
297            },
298            instructions,
299        }
300    }
301
302    async fn list_tools(
303        &self,
304        _request: Option<PaginatedRequestParam>,
305        _context: RequestContext<RoleServer>,
306    ) -> Result<ListToolsResult, ErrorData> {
307        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
308        let _enter = span.enter();
309
310        debug!("Processing MCP list_tools request");
311
312        // Delegate to tool collection for MCP tool conversion
313        let tools = self.tool_collection.to_mcp_tools();
314
315        info!(
316            returned_tools = tools.len(),
317            "MCP list_tools request completed successfully"
318        );
319
320        Ok(ListToolsResult {
321            meta: None,
322            tools,
323            next_cursor: None,
324        })
325    }
326
327    async fn call_tool(
328        &self,
329        request: CallToolRequestParam,
330        context: RequestContext<RoleServer>,
331    ) -> Result<CallToolResult, ErrorData> {
332        let span = info_span!(
333            "call_tool",
334            tool_name = %request.name
335        );
336        let _enter = span.enter();
337
338        debug!(
339            tool_name = %request.name,
340            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
341            "Processing MCP call_tool request"
342        );
343
344        let arguments = request.arguments.unwrap_or_default();
345        let arguments_value = Value::Object(arguments);
346
347        // Extract authorization header from context extensions
348        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
349
350        if auth_header.is_some() {
351            debug!("Authorization header is present");
352        }
353
354        // Create Authorization enum from mode and header
355        let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
356
357        // Get the server-level transformer as a reference for the tool call
358        let server_transformer = self
359            .response_transformer
360            .as_ref()
361            .map(|t| t.as_ref() as &dyn ResponseTransformer);
362
363        // Delegate all tool validation and execution to the tool collection
364        match self
365            .tool_collection
366            .call_tool(
367                &request.name,
368                &arguments_value,
369                authorization,
370                server_transformer,
371            )
372            .await
373        {
374            Ok(result) => {
375                info!(
376                    tool_name = %request.name,
377                    success = true,
378                    "MCP call_tool request completed successfully"
379                );
380                Ok(result)
381            }
382            Err(e) => {
383                warn!(
384                    tool_name = %request.name,
385                    success = false,
386                    error = %e,
387                    "MCP call_tool request failed"
388                );
389                // Convert ToolCallError to ErrorData and return as error
390                Err(e.into())
391            }
392        }
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use crate::error::ToolCallValidationError;
400    use crate::{HttpClient, ToolCallError, ToolMetadata};
401    use serde_json::json;
402
403    #[test]
404    fn test_tool_not_found_error_with_suggestions() {
405        // Create test tool metadata
406        let tool1_metadata = ToolMetadata {
407            name: "getPetById".to_string(),
408            title: Some("Get Pet by ID".to_string()),
409            description: Some("Find pet by ID".to_string()),
410            parameters: json!({
411                "type": "object",
412                "properties": {
413                    "petId": {
414                        "type": "integer"
415                    }
416                },
417                "required": ["petId"]
418            }),
419            output_schema: None,
420            method: "GET".to_string(),
421            path: "/pet/{petId}".to_string(),
422            security: None,
423            parameter_mappings: std::collections::HashMap::new(),
424        };
425
426        let tool2_metadata = ToolMetadata {
427            name: "getPetsByStatus".to_string(),
428            title: Some("Find Pets by Status".to_string()),
429            description: Some("Find pets by status".to_string()),
430            parameters: json!({
431                "type": "object",
432                "properties": {
433                    "status": {
434                        "type": "array",
435                        "items": {
436                            "type": "string"
437                        }
438                    }
439                },
440                "required": ["status"]
441            }),
442            output_schema: None,
443            method: "GET".to_string(),
444            path: "/pet/findByStatus".to_string(),
445            security: None,
446            parameter_mappings: std::collections::HashMap::new(),
447        };
448
449        // Create OpenApiTool instances
450        let http_client = HttpClient::new();
451        let tool1 = Tool::new(tool1_metadata, http_client.clone()).unwrap();
452        let tool2 = Tool::new(tool2_metadata, http_client.clone()).unwrap();
453
454        // Create server with tools
455        let mut server = Server::new(
456            serde_json::Value::Null,
457            url::Url::parse("http://example.com").unwrap(),
458            None,
459            None,
460            false,
461            false,
462        );
463        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
464
465        // Test: Create ToolNotFound error with a typo
466        let tool_names = server.get_tool_names();
467        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
468
469        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
470            "getPetByID".to_string(),
471            &tool_name_refs,
472        ));
473        let error_data: ErrorData = error.into();
474        let error_json = serde_json::to_value(&error_data).unwrap();
475
476        // Snapshot the error to verify suggestions
477        insta::assert_json_snapshot!(error_json);
478    }
479
480    #[test]
481    fn test_tool_not_found_error_no_suggestions() {
482        // Create test tool metadata
483        let tool_metadata = ToolMetadata {
484            name: "getPetById".to_string(),
485            title: Some("Get Pet by ID".to_string()),
486            description: Some("Find pet by ID".to_string()),
487            parameters: json!({
488                "type": "object",
489                "properties": {
490                    "petId": {
491                        "type": "integer"
492                    }
493                },
494                "required": ["petId"]
495            }),
496            output_schema: None,
497            method: "GET".to_string(),
498            path: "/pet/{petId}".to_string(),
499            security: None,
500            parameter_mappings: std::collections::HashMap::new(),
501        };
502
503        // Create OpenApiTool instance
504        let tool = Tool::new(tool_metadata, HttpClient::new()).unwrap();
505
506        // Create server with tool
507        let mut server = Server::new(
508            serde_json::Value::Null,
509            url::Url::parse("http://example.com").unwrap(),
510            None,
511            None,
512            false,
513            false,
514        );
515        server.tool_collection = ToolCollection::from_tools(vec![tool]);
516
517        // Test: Create ToolNotFound error with unrelated name
518        let tool_names = server.get_tool_names();
519        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
520
521        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
522            "completelyUnrelatedToolName".to_string(),
523            &tool_name_refs,
524        ));
525        let error_data: ErrorData = error.into();
526        let error_json = serde_json::to_value(&error_data).unwrap();
527
528        // Snapshot the error to verify no suggestions
529        insta::assert_json_snapshot!(error_json);
530    }
531
532    #[test]
533    fn test_validation_error_converted_to_error_data() {
534        // Test that validation errors are properly converted to ErrorData
535        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
536            violations: vec![crate::error::ValidationError::invalid_parameter(
537                "page".to_string(),
538                &["page_number".to_string(), "page_size".to_string()],
539            )],
540        });
541
542        let error_data: ErrorData = error.into();
543        let error_json = serde_json::to_value(&error_data).unwrap();
544
545        // Verify the basic structure
546        assert_eq!(error_json["code"], -32602); // Invalid params error code
547
548        // Snapshot the full error to verify the new error message format
549        insta::assert_json_snapshot!(error_json);
550    }
551
552    #[test]
553    fn test_extract_openapi_info_with_full_spec() {
554        let openapi_spec = json!({
555            "openapi": "3.0.0",
556            "info": {
557                "title": "Pet Store API",
558                "version": "2.1.0",
559                "description": "A sample API for managing pets"
560            },
561            "paths": {}
562        });
563
564        let server = Server::new(
565            openapi_spec,
566            url::Url::parse("http://example.com").unwrap(),
567            None,
568            None,
569            false,
570            false,
571        );
572
573        assert_eq!(
574            server.extract_openapi_title(),
575            Some("Pet Store API".to_string())
576        );
577        assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
578        assert_eq!(
579            server.extract_openapi_description(),
580            Some("A sample API for managing pets".to_string())
581        );
582    }
583
584    #[test]
585    fn test_extract_openapi_info_with_minimal_spec() {
586        let openapi_spec = json!({
587            "openapi": "3.0.0",
588            "info": {
589                "title": "My API",
590                "version": "1.0.0"
591            },
592            "paths": {}
593        });
594
595        let server = Server::new(
596            openapi_spec,
597            url::Url::parse("http://example.com").unwrap(),
598            None,
599            None,
600            false,
601            false,
602        );
603
604        assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
605        assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
606        assert_eq!(server.extract_openapi_description(), None);
607    }
608
609    #[test]
610    fn test_extract_openapi_info_with_invalid_spec() {
611        let openapi_spec = json!({
612            "invalid": "spec"
613        });
614
615        let server = Server::new(
616            openapi_spec,
617            url::Url::parse("http://example.com").unwrap(),
618            None,
619            None,
620            false,
621            false,
622        );
623
624        assert_eq!(server.extract_openapi_title(), None);
625        assert_eq!(server.extract_openapi_version(), None);
626        assert_eq!(server.extract_openapi_description(), None);
627    }
628
629    #[test]
630    fn test_get_info_fallback_hierarchy_custom_metadata() {
631        let server = Server::new(
632            serde_json::Value::Null,
633            url::Url::parse("http://example.com").unwrap(),
634            None,
635            None,
636            false,
637            false,
638        );
639
640        // Set custom metadata directly
641        let mut server = server;
642        server.name = Some("Custom Server".to_string());
643        server.version = Some("3.0.0".to_string());
644        server.instructions = Some("Custom instructions".to_string());
645
646        let result = server.get_info();
647
648        assert_eq!(result.server_info.name, "Custom Server");
649        assert_eq!(result.server_info.version, "3.0.0");
650        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
651    }
652
653    #[test]
654    fn test_get_info_fallback_hierarchy_openapi_spec() {
655        let openapi_spec = json!({
656            "openapi": "3.0.0",
657            "info": {
658                "title": "OpenAPI Server",
659                "version": "1.5.0",
660                "description": "Server from OpenAPI spec"
661            },
662            "paths": {}
663        });
664
665        let server = Server::new(
666            openapi_spec,
667            url::Url::parse("http://example.com").unwrap(),
668            None,
669            None,
670            false,
671            false,
672        );
673
674        let result = server.get_info();
675
676        assert_eq!(result.server_info.name, "OpenAPI Server");
677        assert_eq!(result.server_info.version, "1.5.0");
678        assert_eq!(
679            result.instructions,
680            Some("Server from OpenAPI spec".to_string())
681        );
682    }
683
684    #[test]
685    fn test_get_info_fallback_hierarchy_defaults() {
686        let server = Server::new(
687            serde_json::Value::Null,
688            url::Url::parse("http://example.com").unwrap(),
689            None,
690            None,
691            false,
692            false,
693        );
694
695        let result = server.get_info();
696
697        assert_eq!(result.server_info.name, "OpenAPI MCP Server");
698        assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
699        assert_eq!(
700            result.instructions,
701            Some("Exposes OpenAPI endpoints as MCP tools".to_string())
702        );
703    }
704
705    #[test]
706    fn test_get_info_fallback_hierarchy_mixed() {
707        let openapi_spec = json!({
708            "openapi": "3.0.0",
709            "info": {
710                "title": "OpenAPI Server",
711                "version": "2.5.0",
712                "description": "Server from OpenAPI spec"
713            },
714            "paths": {}
715        });
716
717        let mut server = Server::new(
718            openapi_spec,
719            url::Url::parse("http://example.com").unwrap(),
720            None,
721            None,
722            false,
723            false,
724        );
725
726        // Set custom name and instructions, leave version to fallback to OpenAPI
727        server.name = Some("Custom Server".to_string());
728        server.instructions = Some("Custom instructions".to_string());
729
730        let result = server.get_info();
731
732        // Custom name takes precedence
733        assert_eq!(result.server_info.name, "Custom Server");
734        // OpenAPI version is used
735        assert_eq!(result.server_info.version, "2.5.0");
736        // Custom instructions take precedence
737        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
738    }
739}