rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    handler::server::ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13use std::sync::Arc;
14
15use reqwest::header::HeaderMap;
16use url::Url;
17
18use crate::error::Error;
19use crate::filter::ToolFilter;
20use crate::tool::{Tool, ToolCollection, ToolMetadata};
21use crate::transformer::ResponseTransformer;
22use crate::{
23    config::{Authorization, AuthorizationMode},
24    spec::Filters,
25};
26use tracing::{debug, info, info_span, warn};
27
28#[derive(Clone, Builder)]
29pub struct Server {
30    pub openapi_spec: serde_json::Value,
31    #[builder(default)]
32    pub tool_collection: ToolCollection,
33    pub base_url: Url,
34    pub default_headers: Option<HeaderMap>,
35    pub filters: Option<Filters>,
36    #[builder(default)]
37    pub authorization_mode: AuthorizationMode,
38    pub name: Option<String>,
39    pub version: Option<String>,
40    pub title: Option<String>,
41    pub instructions: Option<String>,
42    #[builder(default)]
43    pub skip_tool_descriptions: bool,
44    #[builder(default)]
45    pub skip_parameter_descriptions: bool,
46    /// Global response transformer applied to all tools.
47    ///
48    /// Uses dynamic dispatch (`Arc<dyn>`) because:
49    /// - Transformer runs once per HTTP call (10-1000ms latency)
50    /// - Vtable lookup overhead (~1ns) is unmeasurable
51    /// - Avoids viral generics throughout Server, Tool, ToolCollection
52    pub response_transformer: Option<Arc<dyn ResponseTransformer>>,
53    /// Dynamic tool filter applied to list_tools and call_tool.
54    /// Uses dynamic dispatch (`Arc<dyn>`) for same reasons as response_transformer.
55    pub tool_filter: Option<Arc<dyn ToolFilter>>,
56}
57
58impl Server {
59    /// Create a new Server instance with required parameters
60    pub fn new(
61        openapi_spec: serde_json::Value,
62        base_url: Url,
63        default_headers: Option<HeaderMap>,
64        filters: Option<Filters>,
65        skip_tool_descriptions: bool,
66        skip_parameter_descriptions: bool,
67    ) -> Self {
68        Self {
69            openapi_spec,
70            tool_collection: ToolCollection::new(),
71            base_url,
72            default_headers,
73            filters,
74            authorization_mode: AuthorizationMode::default(),
75            name: None,
76            version: None,
77            title: None,
78            instructions: None,
79            skip_tool_descriptions,
80            skip_parameter_descriptions,
81            response_transformer: None,
82            tool_filter: None,
83        }
84    }
85
86    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
87    ///
88    /// # Errors
89    ///
90    /// Returns an error if the spec cannot be parsed or tools cannot be generated
91    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
92        let span = info_span!("tool_registration");
93        let _enter = span.enter();
94
95        // Parse the OpenAPI specification
96        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
97
98        // Generate OpenApiTool instances directly
99        let tools = spec.to_openapi_tools(
100            self.filters.as_ref(),
101            Some(self.base_url.clone()),
102            self.default_headers.clone(),
103            self.skip_tool_descriptions,
104            self.skip_parameter_descriptions,
105        )?;
106
107        // Apply global transformer to schemas if present
108        let tools = if let Some(ref transformer) = self.response_transformer {
109            tools
110                .into_iter()
111                .map(|mut tool| {
112                    if let Some(schema) = tool.metadata.output_schema.take() {
113                        tool.metadata.output_schema = Some(transformer.transform_schema(schema));
114                    }
115                    tool
116                })
117                .collect()
118        } else {
119            tools
120        };
121
122        self.tool_collection = ToolCollection::from_tools(tools);
123
124        info!(
125            tool_count = self.tool_collection.len(),
126            "Loaded tools from OpenAPI spec"
127        );
128
129        Ok(())
130    }
131
132    /// Set a response transformer for a specific tool, overriding the global one.
133    ///
134    /// The transformer's `transform_schema` method is immediately applied to the tool's
135    /// output schema. The `transform_response` method will be applied to responses
136    /// when the tool is called.
137    ///
138    /// # Errors
139    ///
140    /// Returns an error if the tool is not found
141    pub fn set_tool_transformer(
142        &mut self,
143        tool_name: &str,
144        transformer: Arc<dyn ResponseTransformer>,
145    ) -> Result<(), Error> {
146        self.tool_collection
147            .set_tool_transformer(tool_name, transformer)
148    }
149
150    /// Set the tool filter at runtime.
151    pub fn set_tool_filter(&mut self, filter: Arc<dyn ToolFilter>) {
152        self.tool_filter = Some(filter);
153    }
154
155    /// Get the number of loaded tools
156    #[must_use]
157    pub fn tool_count(&self) -> usize {
158        self.tool_collection.len()
159    }
160
161    /// Get all tool names
162    #[must_use]
163    pub fn get_tool_names(&self) -> Vec<String> {
164        self.tool_collection.get_tool_names()
165    }
166
167    /// Check if a specific tool exists
168    #[must_use]
169    pub fn has_tool(&self, name: &str) -> bool {
170        self.tool_collection.has_tool(name)
171    }
172
173    /// Get a tool by name
174    #[must_use]
175    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
176        self.tool_collection.get_tool(name)
177    }
178
179    /// Get tool metadata by name
180    #[must_use]
181    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
182        self.get_tool(name).map(|tool| &tool.metadata)
183    }
184
185    /// Set the authorization mode for the server
186    pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
187        self.authorization_mode = mode;
188    }
189
190    /// Get the current authorization mode
191    pub fn authorization_mode(&self) -> AuthorizationMode {
192        self.authorization_mode
193    }
194
195    /// Get basic tool statistics
196    #[must_use]
197    pub fn get_tool_stats(&self) -> String {
198        self.tool_collection.get_stats()
199    }
200
201    /// Simple validation - check that tools are loaded
202    ///
203    /// # Errors
204    ///
205    /// Returns an error if no tools are loaded
206    pub fn validate_registry(&self) -> Result<(), Error> {
207        if self.tool_collection.is_empty() {
208            return Err(Error::McpError("No tools loaded".to_string()));
209        }
210        Ok(())
211    }
212
213    /// Extract title from OpenAPI spec info section
214    fn extract_openapi_title(&self) -> Option<String> {
215        self.openapi_spec
216            .get("info")?
217            .get("title")?
218            .as_str()
219            .map(|s| s.to_string())
220    }
221
222    /// Extract version from OpenAPI spec info section
223    fn extract_openapi_version(&self) -> Option<String> {
224        self.openapi_spec
225            .get("info")?
226            .get("version")?
227            .as_str()
228            .map(|s| s.to_string())
229    }
230
231    /// Extract description from OpenAPI spec info section
232    fn extract_openapi_description(&self) -> Option<String> {
233        self.openapi_spec
234            .get("info")?
235            .get("description")?
236            .as_str()
237            .map(|s| s.to_string())
238    }
239
240    /// Extract display title from OpenAPI spec info section
241    /// First checks for x-display-title extension, then derives from title
242    fn extract_openapi_display_title(&self) -> Option<String> {
243        // First check for x-display-title extension
244        if let Some(display_title) = self
245            .openapi_spec
246            .get("info")
247            .and_then(|info| info.get("x-display-title"))
248            .and_then(|t| t.as_str())
249        {
250            return Some(display_title.to_string());
251        }
252
253        // Fallback: enhance the title with "Server" suffix if not already present
254        self.extract_openapi_title().map(|title| {
255            if title.to_lowercase().contains("server") {
256                title
257            } else {
258                format!("{} Server", title)
259            }
260        })
261    }
262}
263
264impl ServerHandler for Server {
265    fn get_info(&self) -> InitializeResult {
266        // 3-level fallback for server name: custom -> OpenAPI spec -> default
267        let server_name = self
268            .name
269            .clone()
270            .or_else(|| self.extract_openapi_title())
271            .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
272
273        // 3-level fallback for server version: custom -> OpenAPI spec -> crate version
274        let server_version = self
275            .version
276            .clone()
277            .or_else(|| self.extract_openapi_version())
278            .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
279
280        // 3-level fallback for title: custom -> OpenAPI-derived -> None
281        let server_title = self
282            .title
283            .clone()
284            .or_else(|| self.extract_openapi_display_title());
285
286        // 3-level fallback for instructions: custom -> OpenAPI spec -> default
287        let instructions = self
288            .instructions
289            .clone()
290            .or_else(|| self.extract_openapi_description())
291            .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
292
293        InitializeResult {
294            protocol_version: ProtocolVersion::V_2024_11_05,
295            server_info: Implementation {
296                name: server_name,
297                version: server_version,
298                title: server_title,
299                icons: None,
300                website_url: None,
301            },
302            capabilities: ServerCapabilities {
303                tools: Some(ToolsCapability {
304                    list_changed: Some(false),
305                }),
306                ..Default::default()
307            },
308            instructions,
309        }
310    }
311
312    async fn list_tools(
313        &self,
314        _request: Option<PaginatedRequestParam>,
315        context: RequestContext<RoleServer>,
316    ) -> Result<ListToolsResult, ErrorData> {
317        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
318        let _enter = span.enter();
319
320        debug!("Processing MCP list_tools request");
321
322        // Delegate to tool collection for MCP tool conversion
323        let mut tools = self.tool_collection.to_mcp_tools();
324
325        // Apply dynamic filter if configured
326        if let Some(filter) = &self.tool_filter {
327            let mut filtered = Vec::with_capacity(tools.len());
328            for mcp_tool in tools {
329                if let Some(tool) = self.tool_collection.get_tool(&mcp_tool.name)
330                    && filter.allow(tool, &context).await
331                {
332                    filtered.push(mcp_tool);
333                }
334            }
335            tools = filtered;
336        }
337
338        info!(
339            returned_tools = tools.len(),
340            "MCP list_tools request completed successfully"
341        );
342
343        Ok(ListToolsResult {
344            meta: None,
345            tools,
346            next_cursor: None,
347        })
348    }
349
350    async fn call_tool(
351        &self,
352        request: CallToolRequestParam,
353        context: RequestContext<RoleServer>,
354    ) -> Result<CallToolResult, ErrorData> {
355        use crate::error::{ToolCallError, ToolCallValidationError};
356
357        let span = info_span!(
358            "call_tool",
359            tool_name = %request.name
360        );
361        let _enter = span.enter();
362
363        debug!(
364            tool_name = %request.name,
365            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
366            "Processing MCP call_tool request"
367        );
368
369        // Filter all tools once upfront (for both access check and suggestions)
370        let allowed_tools: Vec<&Tool> = match &self.tool_filter {
371            None => self.tool_collection.iter().collect(),
372            Some(filter) => {
373                let mut allowed = Vec::new();
374                for tool in self.tool_collection.iter() {
375                    if filter.allow(tool, &context).await {
376                        allowed.push(tool);
377                    }
378                }
379                allowed
380            }
381        };
382
383        // Check if requested tool is in filtered list
384        let tool = allowed_tools
385            .iter()
386            .find(|t| t.metadata.name == request.name);
387
388        let tool = match tool {
389            Some(t) => *t,
390            None => {
391                let available_names: Vec<&str> = allowed_tools
392                    .iter()
393                    .map(|t| t.metadata.name.as_str())
394                    .collect();
395
396                // Uses Jaro distance for suggestions internally
397                let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
398                    request.name.to_string(),
399                    &available_names,
400                ));
401
402                warn!(
403                    tool_name = %request.name,
404                    success = false,
405                    error = %error,
406                    "MCP call_tool request failed - tool not found or filtered"
407                );
408
409                return Err(error.into());
410            }
411        };
412
413        let arguments = request.arguments.unwrap_or_default();
414        let arguments_value = Value::Object(arguments);
415
416        // Extract authorization header from context extensions
417        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
418
419        if auth_header.is_some() {
420            debug!("Authorization header is present");
421        }
422
423        // Create Authorization enum from mode and header
424        let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
425
426        // Get the server-level transformer as a reference for the tool call
427        let server_transformer = self
428            .response_transformer
429            .as_ref()
430            .map(|t| t.as_ref() as &dyn ResponseTransformer);
431
432        // Execute the tool directly (we already have the validated tool reference)
433        match tool
434            .call(&arguments_value, authorization, server_transformer)
435            .await
436        {
437            Ok(result) => {
438                info!(
439                    tool_name = %request.name,
440                    success = true,
441                    "MCP call_tool request completed successfully"
442                );
443                Ok(result)
444            }
445            Err(e) => {
446                warn!(
447                    tool_name = %request.name,
448                    success = false,
449                    error = %e,
450                    "MCP call_tool request failed"
451                );
452                // Convert ToolCallError to ErrorData and return as error
453                Err(e.into())
454            }
455        }
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462    use crate::error::ToolCallValidationError;
463    use crate::{HttpClient, ToolCallError, ToolMetadata};
464    use serde_json::json;
465
466    #[test]
467    fn test_tool_not_found_error_with_suggestions() {
468        // Create test tool metadata
469        let tool1_metadata = ToolMetadata {
470            name: "getPetById".to_string(),
471            title: Some("Get Pet by ID".to_string()),
472            description: Some("Find pet by ID".to_string()),
473            parameters: json!({
474                "type": "object",
475                "properties": {
476                    "petId": {
477                        "type": "integer"
478                    }
479                },
480                "required": ["petId"]
481            }),
482            output_schema: None,
483            method: "GET".to_string(),
484            path: "/pet/{petId}".to_string(),
485            security: None,
486            parameter_mappings: std::collections::HashMap::new(),
487        };
488
489        let tool2_metadata = ToolMetadata {
490            name: "getPetsByStatus".to_string(),
491            title: Some("Find Pets by Status".to_string()),
492            description: Some("Find pets by status".to_string()),
493            parameters: json!({
494                "type": "object",
495                "properties": {
496                    "status": {
497                        "type": "array",
498                        "items": {
499                            "type": "string"
500                        }
501                    }
502                },
503                "required": ["status"]
504            }),
505            output_schema: None,
506            method: "GET".to_string(),
507            path: "/pet/findByStatus".to_string(),
508            security: None,
509            parameter_mappings: std::collections::HashMap::new(),
510        };
511
512        // Create OpenApiTool instances
513        let http_client = HttpClient::new();
514        let tool1 = Tool::new(tool1_metadata, http_client.clone()).unwrap();
515        let tool2 = Tool::new(tool2_metadata, http_client.clone()).unwrap();
516
517        // Create server with tools
518        let mut server = Server::new(
519            serde_json::Value::Null,
520            url::Url::parse("http://example.com").unwrap(),
521            None,
522            None,
523            false,
524            false,
525        );
526        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
527
528        // Test: Create ToolNotFound error with a typo
529        let tool_names = server.get_tool_names();
530        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
531
532        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
533            "getPetByID".to_string(),
534            &tool_name_refs,
535        ));
536        let error_data: ErrorData = error.into();
537        let error_json = serde_json::to_value(&error_data).unwrap();
538
539        // Snapshot the error to verify suggestions
540        insta::assert_json_snapshot!(error_json);
541    }
542
543    #[test]
544    fn test_tool_not_found_error_no_suggestions() {
545        // Create test tool metadata
546        let tool_metadata = ToolMetadata {
547            name: "getPetById".to_string(),
548            title: Some("Get Pet by ID".to_string()),
549            description: Some("Find pet by ID".to_string()),
550            parameters: json!({
551                "type": "object",
552                "properties": {
553                    "petId": {
554                        "type": "integer"
555                    }
556                },
557                "required": ["petId"]
558            }),
559            output_schema: None,
560            method: "GET".to_string(),
561            path: "/pet/{petId}".to_string(),
562            security: None,
563            parameter_mappings: std::collections::HashMap::new(),
564        };
565
566        // Create OpenApiTool instance
567        let tool = Tool::new(tool_metadata, HttpClient::new()).unwrap();
568
569        // Create server with tool
570        let mut server = Server::new(
571            serde_json::Value::Null,
572            url::Url::parse("http://example.com").unwrap(),
573            None,
574            None,
575            false,
576            false,
577        );
578        server.tool_collection = ToolCollection::from_tools(vec![tool]);
579
580        // Test: Create ToolNotFound error with unrelated name
581        let tool_names = server.get_tool_names();
582        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
583
584        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
585            "completelyUnrelatedToolName".to_string(),
586            &tool_name_refs,
587        ));
588        let error_data: ErrorData = error.into();
589        let error_json = serde_json::to_value(&error_data).unwrap();
590
591        // Snapshot the error to verify no suggestions
592        insta::assert_json_snapshot!(error_json);
593    }
594
595    #[test]
596    fn test_validation_error_converted_to_error_data() {
597        // Test that validation errors are properly converted to ErrorData
598        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
599            violations: vec![crate::error::ValidationError::invalid_parameter(
600                "page".to_string(),
601                &["page_number".to_string(), "page_size".to_string()],
602            )],
603        });
604
605        let error_data: ErrorData = error.into();
606        let error_json = serde_json::to_value(&error_data).unwrap();
607
608        // Verify the basic structure
609        assert_eq!(error_json["code"], -32602); // Invalid params error code
610
611        // Snapshot the full error to verify the new error message format
612        insta::assert_json_snapshot!(error_json);
613    }
614
615    #[test]
616    fn test_extract_openapi_info_with_full_spec() {
617        let openapi_spec = json!({
618            "openapi": "3.0.0",
619            "info": {
620                "title": "Pet Store API",
621                "version": "2.1.0",
622                "description": "A sample API for managing pets"
623            },
624            "paths": {}
625        });
626
627        let server = Server::new(
628            openapi_spec,
629            url::Url::parse("http://example.com").unwrap(),
630            None,
631            None,
632            false,
633            false,
634        );
635
636        assert_eq!(
637            server.extract_openapi_title(),
638            Some("Pet Store API".to_string())
639        );
640        assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
641        assert_eq!(
642            server.extract_openapi_description(),
643            Some("A sample API for managing pets".to_string())
644        );
645    }
646
647    #[test]
648    fn test_extract_openapi_info_with_minimal_spec() {
649        let openapi_spec = json!({
650            "openapi": "3.0.0",
651            "info": {
652                "title": "My API",
653                "version": "1.0.0"
654            },
655            "paths": {}
656        });
657
658        let server = Server::new(
659            openapi_spec,
660            url::Url::parse("http://example.com").unwrap(),
661            None,
662            None,
663            false,
664            false,
665        );
666
667        assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
668        assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
669        assert_eq!(server.extract_openapi_description(), None);
670    }
671
672    #[test]
673    fn test_extract_openapi_info_with_invalid_spec() {
674        let openapi_spec = json!({
675            "invalid": "spec"
676        });
677
678        let server = Server::new(
679            openapi_spec,
680            url::Url::parse("http://example.com").unwrap(),
681            None,
682            None,
683            false,
684            false,
685        );
686
687        assert_eq!(server.extract_openapi_title(), None);
688        assert_eq!(server.extract_openapi_version(), None);
689        assert_eq!(server.extract_openapi_description(), None);
690    }
691
692    #[test]
693    fn test_get_info_fallback_hierarchy_custom_metadata() {
694        let server = Server::new(
695            serde_json::Value::Null,
696            url::Url::parse("http://example.com").unwrap(),
697            None,
698            None,
699            false,
700            false,
701        );
702
703        // Set custom metadata directly
704        let mut server = server;
705        server.name = Some("Custom Server".to_string());
706        server.version = Some("3.0.0".to_string());
707        server.instructions = Some("Custom instructions".to_string());
708
709        let result = server.get_info();
710
711        assert_eq!(result.server_info.name, "Custom Server");
712        assert_eq!(result.server_info.version, "3.0.0");
713        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
714    }
715
716    #[test]
717    fn test_get_info_fallback_hierarchy_openapi_spec() {
718        let openapi_spec = json!({
719            "openapi": "3.0.0",
720            "info": {
721                "title": "OpenAPI Server",
722                "version": "1.5.0",
723                "description": "Server from OpenAPI spec"
724            },
725            "paths": {}
726        });
727
728        let server = Server::new(
729            openapi_spec,
730            url::Url::parse("http://example.com").unwrap(),
731            None,
732            None,
733            false,
734            false,
735        );
736
737        let result = server.get_info();
738
739        assert_eq!(result.server_info.name, "OpenAPI Server");
740        assert_eq!(result.server_info.version, "1.5.0");
741        assert_eq!(
742            result.instructions,
743            Some("Server from OpenAPI spec".to_string())
744        );
745    }
746
747    #[test]
748    fn test_get_info_fallback_hierarchy_defaults() {
749        let server = Server::new(
750            serde_json::Value::Null,
751            url::Url::parse("http://example.com").unwrap(),
752            None,
753            None,
754            false,
755            false,
756        );
757
758        let result = server.get_info();
759
760        assert_eq!(result.server_info.name, "OpenAPI MCP Server");
761        assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
762        assert_eq!(
763            result.instructions,
764            Some("Exposes OpenAPI endpoints as MCP tools".to_string())
765        );
766    }
767
768    #[test]
769    fn test_get_info_fallback_hierarchy_mixed() {
770        let openapi_spec = json!({
771            "openapi": "3.0.0",
772            "info": {
773                "title": "OpenAPI Server",
774                "version": "2.5.0",
775                "description": "Server from OpenAPI spec"
776            },
777            "paths": {}
778        });
779
780        let mut server = Server::new(
781            openapi_spec,
782            url::Url::parse("http://example.com").unwrap(),
783            None,
784            None,
785            false,
786            false,
787        );
788
789        // Set custom name and instructions, leave version to fallback to OpenAPI
790        server.name = Some("Custom Server".to_string());
791        server.instructions = Some("Custom instructions".to_string());
792
793        let result = server.get_info();
794
795        // Custom name takes precedence
796        assert_eq!(result.server_info.name, "Custom Server");
797        // OpenAPI version is used
798        assert_eq!(result.server_info.version, "2.5.0");
799        // Custom instructions take precedence
800        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
801    }
802}