Skip to main content

rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    handler::server::ServerHandler,
4    model::{
5        CallToolRequestParams, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParams, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13use std::sync::Arc;
14
15use reqwest::header::HeaderMap;
16use url::Url;
17
18use crate::error::Error;
19use crate::filter::ToolFilter;
20use crate::tool::{Tool, ToolCollection, ToolMetadata};
21use crate::transformer::ResponseTransformer;
22use crate::{
23    config::{Authorization, AuthorizationMode},
24    spec::Filters,
25};
26use tracing::{debug, info, info_span, warn};
27
28#[derive(Clone, Builder)]
29pub struct Server {
30    pub openapi_spec: serde_json::Value,
31    #[builder(default)]
32    pub tool_collection: ToolCollection,
33    pub base_url: Url,
34    pub default_headers: Option<HeaderMap>,
35    pub filters: Option<Filters>,
36    #[builder(default)]
37    pub authorization_mode: AuthorizationMode,
38    pub name: Option<String>,
39    pub version: Option<String>,
40    pub title: Option<String>,
41    pub instructions: Option<String>,
42    #[builder(default)]
43    pub skip_tool_descriptions: bool,
44    #[builder(default)]
45    pub skip_parameter_descriptions: bool,
46    /// Global response transformer applied to all tools.
47    ///
48    /// Uses dynamic dispatch (`Arc<dyn>`) because:
49    /// - Transformer runs once per HTTP call (10-1000ms latency)
50    /// - Vtable lookup overhead (~1ns) is unmeasurable
51    /// - Avoids viral generics throughout Server, Tool, ToolCollection
52    pub response_transformer: Option<Arc<dyn ResponseTransformer>>,
53    /// Dynamic tool filter applied to list_tools and call_tool.
54    /// Uses dynamic dispatch (`Arc<dyn>`) for same reasons as response_transformer.
55    pub tool_filter: Option<Arc<dyn ToolFilter>>,
56}
57
58impl Server {
59    /// Create a new Server instance with required parameters
60    pub fn new(
61        openapi_spec: serde_json::Value,
62        base_url: Url,
63        default_headers: Option<HeaderMap>,
64        filters: Option<Filters>,
65        skip_tool_descriptions: bool,
66        skip_parameter_descriptions: bool,
67    ) -> Self {
68        Self {
69            openapi_spec,
70            tool_collection: ToolCollection::new(),
71            base_url,
72            default_headers,
73            filters,
74            authorization_mode: AuthorizationMode::default(),
75            name: None,
76            version: None,
77            title: None,
78            instructions: None,
79            skip_tool_descriptions,
80            skip_parameter_descriptions,
81            response_transformer: None,
82            tool_filter: None,
83        }
84    }
85
86    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
87    ///
88    /// # Errors
89    ///
90    /// Returns an error if the spec cannot be parsed or tools cannot be generated
91    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
92        let span = info_span!("tool_registration");
93        let _enter = span.enter();
94
95        // Parse the OpenAPI specification
96        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
97
98        // Generate OpenApiTool instances directly
99        let tools = spec.to_openapi_tools(
100            self.filters.as_ref(),
101            Some(self.base_url.clone()),
102            self.default_headers.clone(),
103            self.skip_tool_descriptions,
104            self.skip_parameter_descriptions,
105        )?;
106
107        // Apply global transformer to schemas if present
108        let tools = if let Some(ref transformer) = self.response_transformer {
109            tools
110                .into_iter()
111                .map(|mut tool| {
112                    if let Some(schema) = tool.metadata.output_schema.take() {
113                        tool.metadata.output_schema = Some(transformer.transform_schema(schema));
114                    }
115                    tool
116                })
117                .collect()
118        } else {
119            tools
120        };
121
122        self.tool_collection = ToolCollection::from_tools(tools);
123
124        info!(
125            tool_count = self.tool_collection.len(),
126            "Loaded tools from OpenAPI spec"
127        );
128
129        Ok(())
130    }
131
132    /// Set a response transformer for a specific tool, overriding the global one.
133    ///
134    /// The transformer's `transform_schema` method is immediately applied to the tool's
135    /// output schema. The `transform_response` method will be applied to responses
136    /// when the tool is called.
137    ///
138    /// # Errors
139    ///
140    /// Returns an error if the tool is not found
141    pub fn set_tool_transformer(
142        &mut self,
143        tool_name: &str,
144        transformer: Arc<dyn ResponseTransformer>,
145    ) -> Result<(), Error> {
146        self.tool_collection
147            .set_tool_transformer(tool_name, transformer)
148    }
149
150    /// Set the tool filter at runtime.
151    pub fn set_tool_filter(&mut self, filter: Arc<dyn ToolFilter>) {
152        self.tool_filter = Some(filter);
153    }
154
155    /// Get the number of loaded tools
156    #[must_use]
157    pub fn tool_count(&self) -> usize {
158        self.tool_collection.len()
159    }
160
161    /// Get all tool names
162    #[must_use]
163    pub fn get_tool_names(&self) -> Vec<String> {
164        self.tool_collection.get_tool_names()
165    }
166
167    /// Check if a specific tool exists
168    #[must_use]
169    pub fn has_tool(&self, name: &str) -> bool {
170        self.tool_collection.has_tool(name)
171    }
172
173    /// Get a tool by name
174    #[must_use]
175    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
176        self.tool_collection.get_tool(name)
177    }
178
179    /// Get tool metadata by name
180    #[must_use]
181    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
182        self.get_tool(name).map(|tool| &tool.metadata)
183    }
184
185    /// Set the authorization mode for the server
186    pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
187        self.authorization_mode = mode;
188    }
189
190    /// Get the current authorization mode
191    pub fn authorization_mode(&self) -> AuthorizationMode {
192        self.authorization_mode
193    }
194
195    /// Get basic tool statistics
196    #[must_use]
197    pub fn get_tool_stats(&self) -> String {
198        self.tool_collection.get_stats()
199    }
200
201    /// Simple validation - check that tools are loaded
202    ///
203    /// # Errors
204    ///
205    /// Returns an error if no tools are loaded
206    pub fn validate_registry(&self) -> Result<(), Error> {
207        if self.tool_collection.is_empty() {
208            return Err(Error::McpError("No tools loaded".to_string()));
209        }
210        Ok(())
211    }
212
213    /// Extract title from OpenAPI spec info section
214    fn extract_openapi_title(&self) -> Option<String> {
215        self.openapi_spec
216            .get("info")?
217            .get("title")?
218            .as_str()
219            .map(|s| s.to_string())
220    }
221
222    /// Extract version from OpenAPI spec info section
223    fn extract_openapi_version(&self) -> Option<String> {
224        self.openapi_spec
225            .get("info")?
226            .get("version")?
227            .as_str()
228            .map(|s| s.to_string())
229    }
230
231    /// Extract description from OpenAPI spec info section
232    fn extract_openapi_description(&self) -> Option<String> {
233        self.openapi_spec
234            .get("info")?
235            .get("description")?
236            .as_str()
237            .map(|s| s.to_string())
238    }
239
240    /// Extract display title from OpenAPI spec info section
241    /// First checks for x-display-title extension, then derives from title
242    fn extract_openapi_display_title(&self) -> Option<String> {
243        // First check for x-display-title extension
244        if let Some(display_title) = self
245            .openapi_spec
246            .get("info")
247            .and_then(|info| info.get("x-display-title"))
248            .and_then(|t| t.as_str())
249        {
250            return Some(display_title.to_string());
251        }
252
253        // Fallback: enhance the title with "Server" suffix if not already present
254        self.extract_openapi_title().map(|title| {
255            if title.to_lowercase().contains("server") {
256                title
257            } else {
258                format!("{} Server", title)
259            }
260        })
261    }
262}
263
264impl ServerHandler for Server {
265    fn get_info(&self) -> InitializeResult {
266        // 3-level fallback for server name: custom -> OpenAPI spec -> default
267        let server_name = self
268            .name
269            .clone()
270            .or_else(|| self.extract_openapi_title())
271            .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
272
273        // 3-level fallback for server version: custom -> OpenAPI spec -> crate version
274        let server_version = self
275            .version
276            .clone()
277            .or_else(|| self.extract_openapi_version())
278            .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
279
280        // 3-level fallback for title: custom -> OpenAPI-derived -> None
281        let server_title = self
282            .title
283            .clone()
284            .or_else(|| self.extract_openapi_display_title());
285
286        // 3-level fallback for instructions: custom -> OpenAPI spec -> default
287        let instructions = self
288            .instructions
289            .clone()
290            .or_else(|| self.extract_openapi_description())
291            .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
292
293        let mut server_info = Implementation::new(server_name, server_version);
294        server_info.title = server_title;
295        server_info.description = self.extract_openapi_description();
296
297        let mut capabilities = ServerCapabilities::default();
298        capabilities.tools = Some(ToolsCapability {
299            list_changed: Some(false),
300        });
301
302        let mut result = InitializeResult::new(capabilities)
303            .with_protocol_version(ProtocolVersion::V_2024_11_05)
304            .with_server_info(server_info);
305        result.instructions = instructions;
306        result
307    }
308
309    async fn list_tools(
310        &self,
311        _request: Option<PaginatedRequestParams>,
312        context: RequestContext<RoleServer>,
313    ) -> Result<ListToolsResult, ErrorData> {
314        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
315        let _enter = span.enter();
316
317        debug!("Processing MCP list_tools request");
318
319        // Delegate to tool collection for MCP tool conversion
320        let mut tools = self.tool_collection.to_mcp_tools();
321
322        // Apply dynamic filter if configured
323        if let Some(filter) = &self.tool_filter {
324            let mut filtered = Vec::with_capacity(tools.len());
325            for mcp_tool in tools {
326                if let Some(tool) = self.tool_collection.get_tool(&mcp_tool.name)
327                    && filter.allow(tool, &context).await
328                {
329                    filtered.push(mcp_tool);
330                }
331            }
332            tools = filtered;
333        }
334
335        info!(
336            returned_tools = tools.len(),
337            "MCP list_tools request completed successfully"
338        );
339
340        Ok(ListToolsResult {
341            meta: None,
342            tools,
343            next_cursor: None,
344        })
345    }
346
347    async fn call_tool(
348        &self,
349        request: CallToolRequestParams,
350        context: RequestContext<RoleServer>,
351    ) -> Result<CallToolResult, ErrorData> {
352        use crate::error::{ToolCallError, ToolCallValidationError};
353
354        let span = info_span!(
355            "call_tool",
356            tool_name = %request.name
357        );
358        let _enter = span.enter();
359
360        debug!(
361            tool_name = %request.name,
362            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
363            "Processing MCP call_tool request"
364        );
365
366        // Filter all tools once upfront (for both access check and suggestions)
367        let allowed_tools: Vec<&Tool> = match &self.tool_filter {
368            None => self.tool_collection.iter().collect(),
369            Some(filter) => {
370                let mut allowed = Vec::new();
371                for tool in self.tool_collection.iter() {
372                    if filter.allow(tool, &context).await {
373                        allowed.push(tool);
374                    }
375                }
376                allowed
377            }
378        };
379
380        // Check if requested tool is in filtered list
381        let tool = allowed_tools
382            .iter()
383            .find(|t| t.metadata.name == request.name);
384
385        let tool = match tool {
386            Some(t) => *t,
387            None => {
388                let available_names: Vec<&str> = allowed_tools
389                    .iter()
390                    .map(|t| t.metadata.name.as_str())
391                    .collect();
392
393                // Uses Jaro distance for suggestions internally
394                let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
395                    request.name.to_string(),
396                    &available_names,
397                ));
398
399                warn!(
400                    tool_name = %request.name,
401                    success = false,
402                    error = %error,
403                    "MCP call_tool request failed - tool not found or filtered"
404                );
405
406                return Err(error.into());
407            }
408        };
409
410        let arguments = request.arguments.unwrap_or_default();
411        let arguments_value = Value::Object(arguments);
412
413        // Extract authorization header from context extensions
414        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
415
416        if auth_header.is_some() {
417            debug!("Authorization header is present");
418        }
419
420        // Create Authorization enum from mode and header
421        let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
422
423        // Get the server-level transformer as a reference for the tool call
424        let server_transformer = self
425            .response_transformer
426            .as_ref()
427            .map(|t| t.as_ref() as &dyn ResponseTransformer);
428
429        // Execute the tool directly (we already have the validated tool reference)
430        match tool
431            .call(&arguments_value, authorization, server_transformer)
432            .await
433        {
434            Ok(result) => {
435                info!(
436                    tool_name = %request.name,
437                    success = true,
438                    "MCP call_tool request completed successfully"
439                );
440                Ok(result)
441            }
442            Err(e) => {
443                warn!(
444                    tool_name = %request.name,
445                    success = false,
446                    error = %e,
447                    "MCP call_tool request failed"
448                );
449                // Convert ToolCallError to ErrorData and return as error
450                Err(e.into())
451            }
452        }
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use crate::error::ToolCallValidationError;
460    use crate::{HttpClient, ToolCallError, ToolMetadata};
461    use serde_json::json;
462
463    #[test]
464    fn test_tool_not_found_error_with_suggestions() {
465        // Create test tool metadata
466        let tool1_metadata = ToolMetadata {
467            name: "getPetById".to_string(),
468            title: Some("Get Pet by ID".to_string()),
469            description: Some("Find pet by ID".to_string()),
470            parameters: json!({
471                "type": "object",
472                "properties": {
473                    "petId": {
474                        "type": "integer"
475                    }
476                },
477                "required": ["petId"]
478            }),
479            output_schema: None,
480            method: "GET".to_string(),
481            path: "/pet/{petId}".to_string(),
482            security: None,
483            parameter_mappings: std::collections::HashMap::new(),
484        };
485
486        let tool2_metadata = ToolMetadata {
487            name: "getPetsByStatus".to_string(),
488            title: Some("Find Pets by Status".to_string()),
489            description: Some("Find pets by status".to_string()),
490            parameters: json!({
491                "type": "object",
492                "properties": {
493                    "status": {
494                        "type": "array",
495                        "items": {
496                            "type": "string"
497                        }
498                    }
499                },
500                "required": ["status"]
501            }),
502            output_schema: None,
503            method: "GET".to_string(),
504            path: "/pet/findByStatus".to_string(),
505            security: None,
506            parameter_mappings: std::collections::HashMap::new(),
507        };
508
509        // Create OpenApiTool instances
510        let http_client = HttpClient::new();
511        let tool1 = Tool::new(tool1_metadata, http_client.clone()).unwrap();
512        let tool2 = Tool::new(tool2_metadata, http_client.clone()).unwrap();
513
514        // Create server with tools
515        let mut server = Server::new(
516            serde_json::Value::Null,
517            url::Url::parse("http://example.com").unwrap(),
518            None,
519            None,
520            false,
521            false,
522        );
523        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
524
525        // Test: Create ToolNotFound error with a typo
526        let tool_names = server.get_tool_names();
527        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
528
529        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
530            "getPetByID".to_string(),
531            &tool_name_refs,
532        ));
533        let error_data: ErrorData = error.into();
534        let error_json = serde_json::to_value(&error_data).unwrap();
535
536        // Snapshot the error to verify suggestions
537        insta::assert_json_snapshot!(error_json);
538    }
539
540    #[test]
541    fn test_tool_not_found_error_no_suggestions() {
542        // Create test tool metadata
543        let tool_metadata = ToolMetadata {
544            name: "getPetById".to_string(),
545            title: Some("Get Pet by ID".to_string()),
546            description: Some("Find pet by ID".to_string()),
547            parameters: json!({
548                "type": "object",
549                "properties": {
550                    "petId": {
551                        "type": "integer"
552                    }
553                },
554                "required": ["petId"]
555            }),
556            output_schema: None,
557            method: "GET".to_string(),
558            path: "/pet/{petId}".to_string(),
559            security: None,
560            parameter_mappings: std::collections::HashMap::new(),
561        };
562
563        // Create OpenApiTool instance
564        let tool = Tool::new(tool_metadata, HttpClient::new()).unwrap();
565
566        // Create server with tool
567        let mut server = Server::new(
568            serde_json::Value::Null,
569            url::Url::parse("http://example.com").unwrap(),
570            None,
571            None,
572            false,
573            false,
574        );
575        server.tool_collection = ToolCollection::from_tools(vec![tool]);
576
577        // Test: Create ToolNotFound error with unrelated name
578        let tool_names = server.get_tool_names();
579        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
580
581        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
582            "completelyUnrelatedToolName".to_string(),
583            &tool_name_refs,
584        ));
585        let error_data: ErrorData = error.into();
586        let error_json = serde_json::to_value(&error_data).unwrap();
587
588        // Snapshot the error to verify no suggestions
589        insta::assert_json_snapshot!(error_json);
590    }
591
592    #[test]
593    fn test_validation_error_converted_to_error_data() {
594        // Test that validation errors are properly converted to ErrorData
595        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
596            violations: vec![crate::error::ValidationError::invalid_parameter(
597                "page".to_string(),
598                &["page_number".to_string(), "page_size".to_string()],
599            )],
600        });
601
602        let error_data: ErrorData = error.into();
603        let error_json = serde_json::to_value(&error_data).unwrap();
604
605        // Verify the basic structure
606        assert_eq!(error_json["code"], -32602); // Invalid params error code
607
608        // Snapshot the full error to verify the new error message format
609        insta::assert_json_snapshot!(error_json);
610    }
611
612    #[test]
613    fn test_extract_openapi_info_with_full_spec() {
614        let openapi_spec = json!({
615            "openapi": "3.0.0",
616            "info": {
617                "title": "Pet Store API",
618                "version": "2.1.0",
619                "description": "A sample API for managing pets"
620            },
621            "paths": {}
622        });
623
624        let server = Server::new(
625            openapi_spec,
626            url::Url::parse("http://example.com").unwrap(),
627            None,
628            None,
629            false,
630            false,
631        );
632
633        assert_eq!(
634            server.extract_openapi_title(),
635            Some("Pet Store API".to_string())
636        );
637        assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
638        assert_eq!(
639            server.extract_openapi_description(),
640            Some("A sample API for managing pets".to_string())
641        );
642    }
643
644    #[test]
645    fn test_extract_openapi_info_with_minimal_spec() {
646        let openapi_spec = json!({
647            "openapi": "3.0.0",
648            "info": {
649                "title": "My API",
650                "version": "1.0.0"
651            },
652            "paths": {}
653        });
654
655        let server = Server::new(
656            openapi_spec,
657            url::Url::parse("http://example.com").unwrap(),
658            None,
659            None,
660            false,
661            false,
662        );
663
664        assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
665        assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
666        assert_eq!(server.extract_openapi_description(), None);
667    }
668
669    #[test]
670    fn test_extract_openapi_info_with_invalid_spec() {
671        let openapi_spec = json!({
672            "invalid": "spec"
673        });
674
675        let server = Server::new(
676            openapi_spec,
677            url::Url::parse("http://example.com").unwrap(),
678            None,
679            None,
680            false,
681            false,
682        );
683
684        assert_eq!(server.extract_openapi_title(), None);
685        assert_eq!(server.extract_openapi_version(), None);
686        assert_eq!(server.extract_openapi_description(), None);
687    }
688
689    #[test]
690    fn test_get_info_fallback_hierarchy_custom_metadata() {
691        let server = Server::new(
692            serde_json::Value::Null,
693            url::Url::parse("http://example.com").unwrap(),
694            None,
695            None,
696            false,
697            false,
698        );
699
700        // Set custom metadata directly
701        let mut server = server;
702        server.name = Some("Custom Server".to_string());
703        server.version = Some("3.0.0".to_string());
704        server.instructions = Some("Custom instructions".to_string());
705
706        let result = server.get_info();
707
708        assert_eq!(result.server_info.name, "Custom Server");
709        assert_eq!(result.server_info.version, "3.0.0");
710        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
711    }
712
713    #[test]
714    fn test_get_info_fallback_hierarchy_openapi_spec() {
715        let openapi_spec = json!({
716            "openapi": "3.0.0",
717            "info": {
718                "title": "OpenAPI Server",
719                "version": "1.5.0",
720                "description": "Server from OpenAPI spec"
721            },
722            "paths": {}
723        });
724
725        let server = Server::new(
726            openapi_spec,
727            url::Url::parse("http://example.com").unwrap(),
728            None,
729            None,
730            false,
731            false,
732        );
733
734        let result = server.get_info();
735
736        assert_eq!(result.server_info.name, "OpenAPI Server");
737        assert_eq!(result.server_info.version, "1.5.0");
738        assert_eq!(
739            result.instructions,
740            Some("Server from OpenAPI spec".to_string())
741        );
742    }
743
744    #[test]
745    fn test_get_info_fallback_hierarchy_defaults() {
746        let server = Server::new(
747            serde_json::Value::Null,
748            url::Url::parse("http://example.com").unwrap(),
749            None,
750            None,
751            false,
752            false,
753        );
754
755        let result = server.get_info();
756
757        assert_eq!(result.server_info.name, "OpenAPI MCP Server");
758        assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
759        assert_eq!(
760            result.instructions,
761            Some("Exposes OpenAPI endpoints as MCP tools".to_string())
762        );
763    }
764
765    #[test]
766    fn test_get_info_fallback_hierarchy_mixed() {
767        let openapi_spec = json!({
768            "openapi": "3.0.0",
769            "info": {
770                "title": "OpenAPI Server",
771                "version": "2.5.0",
772                "description": "Server from OpenAPI spec"
773            },
774            "paths": {}
775        });
776
777        let mut server = Server::new(
778            openapi_spec,
779            url::Url::parse("http://example.com").unwrap(),
780            None,
781            None,
782            false,
783            false,
784        );
785
786        // Set custom name and instructions, leave version to fallback to OpenAPI
787        server.name = Some("Custom Server".to_string());
788        server.instructions = Some("Custom instructions".to_string());
789
790        let result = server.get_info();
791
792        // Custom name takes precedence
793        assert_eq!(result.server_info.name, "Custom Server");
794        // OpenAPI version is used
795        assert_eq!(result.server_info.version, "2.5.0");
796        // Custom instructions take precedence
797        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
798    }
799}