Skip to main content

rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    handler::server::ServerHandler,
4    model::{
5        CallToolRequestParams, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParams, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13use std::sync::Arc;
14
15use reqwest::header::HeaderMap;
16use url::Url;
17
18use crate::error::Error;
19use crate::filter::ToolFilter;
20use crate::tool::{Tool, ToolCollection, ToolMetadata};
21use crate::transformer::ResponseTransformer;
22use crate::{
23    config::{Authorization, AuthorizationMode},
24    spec::Filters,
25};
26use tracing::{debug, info, info_span, warn};
27
28#[derive(Clone, Builder)]
29pub struct Server {
30    pub openapi_spec: serde_json::Value,
31    #[builder(default)]
32    pub tool_collection: ToolCollection,
33    pub base_url: Url,
34    pub default_headers: Option<HeaderMap>,
35    pub filters: Option<Filters>,
36    #[builder(default)]
37    pub authorization_mode: AuthorizationMode,
38    pub name: Option<String>,
39    pub version: Option<String>,
40    pub title: Option<String>,
41    pub instructions: Option<String>,
42    #[builder(default)]
43    pub skip_tool_descriptions: bool,
44    #[builder(default)]
45    pub skip_parameter_descriptions: bool,
46    /// Global response transformer applied to all tools.
47    ///
48    /// Uses dynamic dispatch (`Arc<dyn>`) because:
49    /// - Transformer runs once per HTTP call (10-1000ms latency)
50    /// - Vtable lookup overhead (~1ns) is unmeasurable
51    /// - Avoids viral generics throughout Server, Tool, ToolCollection
52    pub response_transformer: Option<Arc<dyn ResponseTransformer>>,
53    /// Dynamic tool filter applied to list_tools and call_tool.
54    /// Uses dynamic dispatch (`Arc<dyn>`) for same reasons as response_transformer.
55    pub tool_filter: Option<Arc<dyn ToolFilter>>,
56}
57
58impl Server {
59    /// Create a new Server instance with required parameters
60    pub fn new(
61        openapi_spec: serde_json::Value,
62        base_url: Url,
63        default_headers: Option<HeaderMap>,
64        filters: Option<Filters>,
65        skip_tool_descriptions: bool,
66        skip_parameter_descriptions: bool,
67    ) -> Self {
68        Self {
69            openapi_spec,
70            tool_collection: ToolCollection::new(),
71            base_url,
72            default_headers,
73            filters,
74            authorization_mode: AuthorizationMode::default(),
75            name: None,
76            version: None,
77            title: None,
78            instructions: None,
79            skip_tool_descriptions,
80            skip_parameter_descriptions,
81            response_transformer: None,
82            tool_filter: None,
83        }
84    }
85
86    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
87    ///
88    /// # Errors
89    ///
90    /// Returns an error if the spec cannot be parsed or tools cannot be generated
91    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
92        let span = info_span!("tool_registration");
93        let _enter = span.enter();
94
95        // Parse the OpenAPI specification
96        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
97
98        // Generate OpenApiTool instances directly
99        let tools = spec.to_openapi_tools(
100            self.filters.as_ref(),
101            Some(self.base_url.clone()),
102            self.default_headers.clone(),
103            self.skip_tool_descriptions,
104            self.skip_parameter_descriptions,
105        )?;
106
107        // Apply global transformer to schemas if present
108        let tools = if let Some(ref transformer) = self.response_transformer {
109            tools
110                .into_iter()
111                .map(|mut tool| {
112                    if let Some(schema) = tool.metadata.output_schema.take() {
113                        tool.metadata.output_schema = Some(transformer.transform_schema(schema));
114                    }
115                    tool
116                })
117                .collect()
118        } else {
119            tools
120        };
121
122        self.tool_collection = ToolCollection::from_tools(tools);
123
124        info!(
125            tool_count = self.tool_collection.len(),
126            "Loaded tools from OpenAPI spec"
127        );
128
129        Ok(())
130    }
131
132    /// Set a response transformer for a specific tool, overriding the global one.
133    ///
134    /// The transformer's `transform_schema` method is immediately applied to the tool's
135    /// output schema. The `transform_response` method will be applied to responses
136    /// when the tool is called.
137    ///
138    /// # Errors
139    ///
140    /// Returns an error if the tool is not found
141    pub fn set_tool_transformer(
142        &mut self,
143        tool_name: &str,
144        transformer: Arc<dyn ResponseTransformer>,
145    ) -> Result<(), Error> {
146        self.tool_collection
147            .set_tool_transformer(tool_name, transformer)
148    }
149
150    /// Set the tool filter at runtime.
151    pub fn set_tool_filter(&mut self, filter: Arc<dyn ToolFilter>) {
152        self.tool_filter = Some(filter);
153    }
154
155    /// Get the number of loaded tools
156    #[must_use]
157    pub fn tool_count(&self) -> usize {
158        self.tool_collection.len()
159    }
160
161    /// Get all tool names
162    #[must_use]
163    pub fn get_tool_names(&self) -> Vec<String> {
164        self.tool_collection.get_tool_names()
165    }
166
167    /// Check if a specific tool exists
168    #[must_use]
169    pub fn has_tool(&self, name: &str) -> bool {
170        self.tool_collection.has_tool(name)
171    }
172
173    /// Get a tool by name
174    #[must_use]
175    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
176        self.tool_collection.get_tool(name)
177    }
178
179    /// Get tool metadata by name
180    #[must_use]
181    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
182        self.get_tool(name).map(|tool| &tool.metadata)
183    }
184
185    /// Set the authorization mode for the server
186    pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
187        self.authorization_mode = mode;
188    }
189
190    /// Get the current authorization mode
191    pub fn authorization_mode(&self) -> AuthorizationMode {
192        self.authorization_mode
193    }
194
195    /// Get basic tool statistics
196    #[must_use]
197    pub fn get_tool_stats(&self) -> String {
198        self.tool_collection.get_stats()
199    }
200
201    /// Simple validation - check that tools are loaded
202    ///
203    /// # Errors
204    ///
205    /// Returns an error if no tools are loaded
206    pub fn validate_registry(&self) -> Result<(), Error> {
207        if self.tool_collection.is_empty() {
208            return Err(Error::McpError("No tools loaded".to_string()));
209        }
210        Ok(())
211    }
212
213    /// Extract title from OpenAPI spec info section
214    fn extract_openapi_title(&self) -> Option<String> {
215        self.openapi_spec
216            .get("info")?
217            .get("title")?
218            .as_str()
219            .map(|s| s.to_string())
220    }
221
222    /// Extract version from OpenAPI spec info section
223    fn extract_openapi_version(&self) -> Option<String> {
224        self.openapi_spec
225            .get("info")?
226            .get("version")?
227            .as_str()
228            .map(|s| s.to_string())
229    }
230
231    /// Extract description from OpenAPI spec info section
232    fn extract_openapi_description(&self) -> Option<String> {
233        self.openapi_spec
234            .get("info")?
235            .get("description")?
236            .as_str()
237            .map(|s| s.to_string())
238    }
239
240    /// Extract display title from OpenAPI spec info section
241    /// First checks for x-display-title extension, then derives from title
242    fn extract_openapi_display_title(&self) -> Option<String> {
243        // First check for x-display-title extension
244        if let Some(display_title) = self
245            .openapi_spec
246            .get("info")
247            .and_then(|info| info.get("x-display-title"))
248            .and_then(|t| t.as_str())
249        {
250            return Some(display_title.to_string());
251        }
252
253        // Fallback: enhance the title with "Server" suffix if not already present
254        self.extract_openapi_title().map(|title| {
255            if title.to_lowercase().contains("server") {
256                title
257            } else {
258                format!("{} Server", title)
259            }
260        })
261    }
262}
263
264impl ServerHandler for Server {
265    fn get_info(&self) -> InitializeResult {
266        // 3-level fallback for server name: custom -> OpenAPI spec -> default
267        let server_name = self
268            .name
269            .clone()
270            .or_else(|| self.extract_openapi_title())
271            .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
272
273        // 3-level fallback for server version: custom -> OpenAPI spec -> crate version
274        let server_version = self
275            .version
276            .clone()
277            .or_else(|| self.extract_openapi_version())
278            .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
279
280        // 3-level fallback for title: custom -> OpenAPI-derived -> None
281        let server_title = self
282            .title
283            .clone()
284            .or_else(|| self.extract_openapi_display_title());
285
286        // 3-level fallback for instructions: custom -> OpenAPI spec -> default
287        let instructions = self
288            .instructions
289            .clone()
290            .or_else(|| self.extract_openapi_description())
291            .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
292
293        InitializeResult {
294            protocol_version: ProtocolVersion::V_2024_11_05,
295            server_info: Implementation {
296                name: server_name,
297                version: server_version,
298                title: server_title,
299                description: self.extract_openapi_description(),
300                icons: None,
301                website_url: None,
302            },
303            capabilities: ServerCapabilities {
304                tools: Some(ToolsCapability {
305                    list_changed: Some(false),
306                }),
307                ..Default::default()
308            },
309            instructions,
310        }
311    }
312
313    async fn list_tools(
314        &self,
315        _request: Option<PaginatedRequestParams>,
316        context: RequestContext<RoleServer>,
317    ) -> Result<ListToolsResult, ErrorData> {
318        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
319        let _enter = span.enter();
320
321        debug!("Processing MCP list_tools request");
322
323        // Delegate to tool collection for MCP tool conversion
324        let mut tools = self.tool_collection.to_mcp_tools();
325
326        // Apply dynamic filter if configured
327        if let Some(filter) = &self.tool_filter {
328            let mut filtered = Vec::with_capacity(tools.len());
329            for mcp_tool in tools {
330                if let Some(tool) = self.tool_collection.get_tool(&mcp_tool.name)
331                    && filter.allow(tool, &context).await
332                {
333                    filtered.push(mcp_tool);
334                }
335            }
336            tools = filtered;
337        }
338
339        info!(
340            returned_tools = tools.len(),
341            "MCP list_tools request completed successfully"
342        );
343
344        Ok(ListToolsResult {
345            meta: None,
346            tools,
347            next_cursor: None,
348        })
349    }
350
351    async fn call_tool(
352        &self,
353        request: CallToolRequestParams,
354        context: RequestContext<RoleServer>,
355    ) -> Result<CallToolResult, ErrorData> {
356        use crate::error::{ToolCallError, ToolCallValidationError};
357
358        let span = info_span!(
359            "call_tool",
360            tool_name = %request.name
361        );
362        let _enter = span.enter();
363
364        debug!(
365            tool_name = %request.name,
366            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
367            "Processing MCP call_tool request"
368        );
369
370        // Filter all tools once upfront (for both access check and suggestions)
371        let allowed_tools: Vec<&Tool> = match &self.tool_filter {
372            None => self.tool_collection.iter().collect(),
373            Some(filter) => {
374                let mut allowed = Vec::new();
375                for tool in self.tool_collection.iter() {
376                    if filter.allow(tool, &context).await {
377                        allowed.push(tool);
378                    }
379                }
380                allowed
381            }
382        };
383
384        // Check if requested tool is in filtered list
385        let tool = allowed_tools
386            .iter()
387            .find(|t| t.metadata.name == request.name);
388
389        let tool = match tool {
390            Some(t) => *t,
391            None => {
392                let available_names: Vec<&str> = allowed_tools
393                    .iter()
394                    .map(|t| t.metadata.name.as_str())
395                    .collect();
396
397                // Uses Jaro distance for suggestions internally
398                let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
399                    request.name.to_string(),
400                    &available_names,
401                ));
402
403                warn!(
404                    tool_name = %request.name,
405                    success = false,
406                    error = %error,
407                    "MCP call_tool request failed - tool not found or filtered"
408                );
409
410                return Err(error.into());
411            }
412        };
413
414        let arguments = request.arguments.unwrap_or_default();
415        let arguments_value = Value::Object(arguments);
416
417        // Extract authorization header from context extensions
418        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
419
420        if auth_header.is_some() {
421            debug!("Authorization header is present");
422        }
423
424        // Create Authorization enum from mode and header
425        let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
426
427        // Get the server-level transformer as a reference for the tool call
428        let server_transformer = self
429            .response_transformer
430            .as_ref()
431            .map(|t| t.as_ref() as &dyn ResponseTransformer);
432
433        // Execute the tool directly (we already have the validated tool reference)
434        match tool
435            .call(&arguments_value, authorization, server_transformer)
436            .await
437        {
438            Ok(result) => {
439                info!(
440                    tool_name = %request.name,
441                    success = true,
442                    "MCP call_tool request completed successfully"
443                );
444                Ok(result)
445            }
446            Err(e) => {
447                warn!(
448                    tool_name = %request.name,
449                    success = false,
450                    error = %e,
451                    "MCP call_tool request failed"
452                );
453                // Convert ToolCallError to ErrorData and return as error
454                Err(e.into())
455            }
456        }
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463    use crate::error::ToolCallValidationError;
464    use crate::{HttpClient, ToolCallError, ToolMetadata};
465    use serde_json::json;
466
467    #[test]
468    fn test_tool_not_found_error_with_suggestions() {
469        // Create test tool metadata
470        let tool1_metadata = ToolMetadata {
471            name: "getPetById".to_string(),
472            title: Some("Get Pet by ID".to_string()),
473            description: Some("Find pet by ID".to_string()),
474            parameters: json!({
475                "type": "object",
476                "properties": {
477                    "petId": {
478                        "type": "integer"
479                    }
480                },
481                "required": ["petId"]
482            }),
483            output_schema: None,
484            method: "GET".to_string(),
485            path: "/pet/{petId}".to_string(),
486            security: None,
487            parameter_mappings: std::collections::HashMap::new(),
488        };
489
490        let tool2_metadata = ToolMetadata {
491            name: "getPetsByStatus".to_string(),
492            title: Some("Find Pets by Status".to_string()),
493            description: Some("Find pets by status".to_string()),
494            parameters: json!({
495                "type": "object",
496                "properties": {
497                    "status": {
498                        "type": "array",
499                        "items": {
500                            "type": "string"
501                        }
502                    }
503                },
504                "required": ["status"]
505            }),
506            output_schema: None,
507            method: "GET".to_string(),
508            path: "/pet/findByStatus".to_string(),
509            security: None,
510            parameter_mappings: std::collections::HashMap::new(),
511        };
512
513        // Create OpenApiTool instances
514        let http_client = HttpClient::new();
515        let tool1 = Tool::new(tool1_metadata, http_client.clone()).unwrap();
516        let tool2 = Tool::new(tool2_metadata, http_client.clone()).unwrap();
517
518        // Create server with tools
519        let mut server = Server::new(
520            serde_json::Value::Null,
521            url::Url::parse("http://example.com").unwrap(),
522            None,
523            None,
524            false,
525            false,
526        );
527        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
528
529        // Test: Create ToolNotFound error with a typo
530        let tool_names = server.get_tool_names();
531        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
532
533        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
534            "getPetByID".to_string(),
535            &tool_name_refs,
536        ));
537        let error_data: ErrorData = error.into();
538        let error_json = serde_json::to_value(&error_data).unwrap();
539
540        // Snapshot the error to verify suggestions
541        insta::assert_json_snapshot!(error_json);
542    }
543
544    #[test]
545    fn test_tool_not_found_error_no_suggestions() {
546        // Create test tool metadata
547        let tool_metadata = ToolMetadata {
548            name: "getPetById".to_string(),
549            title: Some("Get Pet by ID".to_string()),
550            description: Some("Find pet by ID".to_string()),
551            parameters: json!({
552                "type": "object",
553                "properties": {
554                    "petId": {
555                        "type": "integer"
556                    }
557                },
558                "required": ["petId"]
559            }),
560            output_schema: None,
561            method: "GET".to_string(),
562            path: "/pet/{petId}".to_string(),
563            security: None,
564            parameter_mappings: std::collections::HashMap::new(),
565        };
566
567        // Create OpenApiTool instance
568        let tool = Tool::new(tool_metadata, HttpClient::new()).unwrap();
569
570        // Create server with tool
571        let mut server = Server::new(
572            serde_json::Value::Null,
573            url::Url::parse("http://example.com").unwrap(),
574            None,
575            None,
576            false,
577            false,
578        );
579        server.tool_collection = ToolCollection::from_tools(vec![tool]);
580
581        // Test: Create ToolNotFound error with unrelated name
582        let tool_names = server.get_tool_names();
583        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
584
585        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
586            "completelyUnrelatedToolName".to_string(),
587            &tool_name_refs,
588        ));
589        let error_data: ErrorData = error.into();
590        let error_json = serde_json::to_value(&error_data).unwrap();
591
592        // Snapshot the error to verify no suggestions
593        insta::assert_json_snapshot!(error_json);
594    }
595
596    #[test]
597    fn test_validation_error_converted_to_error_data() {
598        // Test that validation errors are properly converted to ErrorData
599        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
600            violations: vec![crate::error::ValidationError::invalid_parameter(
601                "page".to_string(),
602                &["page_number".to_string(), "page_size".to_string()],
603            )],
604        });
605
606        let error_data: ErrorData = error.into();
607        let error_json = serde_json::to_value(&error_data).unwrap();
608
609        // Verify the basic structure
610        assert_eq!(error_json["code"], -32602); // Invalid params error code
611
612        // Snapshot the full error to verify the new error message format
613        insta::assert_json_snapshot!(error_json);
614    }
615
616    #[test]
617    fn test_extract_openapi_info_with_full_spec() {
618        let openapi_spec = json!({
619            "openapi": "3.0.0",
620            "info": {
621                "title": "Pet Store API",
622                "version": "2.1.0",
623                "description": "A sample API for managing pets"
624            },
625            "paths": {}
626        });
627
628        let server = Server::new(
629            openapi_spec,
630            url::Url::parse("http://example.com").unwrap(),
631            None,
632            None,
633            false,
634            false,
635        );
636
637        assert_eq!(
638            server.extract_openapi_title(),
639            Some("Pet Store API".to_string())
640        );
641        assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
642        assert_eq!(
643            server.extract_openapi_description(),
644            Some("A sample API for managing pets".to_string())
645        );
646    }
647
648    #[test]
649    fn test_extract_openapi_info_with_minimal_spec() {
650        let openapi_spec = json!({
651            "openapi": "3.0.0",
652            "info": {
653                "title": "My API",
654                "version": "1.0.0"
655            },
656            "paths": {}
657        });
658
659        let server = Server::new(
660            openapi_spec,
661            url::Url::parse("http://example.com").unwrap(),
662            None,
663            None,
664            false,
665            false,
666        );
667
668        assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
669        assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
670        assert_eq!(server.extract_openapi_description(), None);
671    }
672
673    #[test]
674    fn test_extract_openapi_info_with_invalid_spec() {
675        let openapi_spec = json!({
676            "invalid": "spec"
677        });
678
679        let server = Server::new(
680            openapi_spec,
681            url::Url::parse("http://example.com").unwrap(),
682            None,
683            None,
684            false,
685            false,
686        );
687
688        assert_eq!(server.extract_openapi_title(), None);
689        assert_eq!(server.extract_openapi_version(), None);
690        assert_eq!(server.extract_openapi_description(), None);
691    }
692
693    #[test]
694    fn test_get_info_fallback_hierarchy_custom_metadata() {
695        let server = Server::new(
696            serde_json::Value::Null,
697            url::Url::parse("http://example.com").unwrap(),
698            None,
699            None,
700            false,
701            false,
702        );
703
704        // Set custom metadata directly
705        let mut server = server;
706        server.name = Some("Custom Server".to_string());
707        server.version = Some("3.0.0".to_string());
708        server.instructions = Some("Custom instructions".to_string());
709
710        let result = server.get_info();
711
712        assert_eq!(result.server_info.name, "Custom Server");
713        assert_eq!(result.server_info.version, "3.0.0");
714        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
715    }
716
717    #[test]
718    fn test_get_info_fallback_hierarchy_openapi_spec() {
719        let openapi_spec = json!({
720            "openapi": "3.0.0",
721            "info": {
722                "title": "OpenAPI Server",
723                "version": "1.5.0",
724                "description": "Server from OpenAPI spec"
725            },
726            "paths": {}
727        });
728
729        let server = Server::new(
730            openapi_spec,
731            url::Url::parse("http://example.com").unwrap(),
732            None,
733            None,
734            false,
735            false,
736        );
737
738        let result = server.get_info();
739
740        assert_eq!(result.server_info.name, "OpenAPI Server");
741        assert_eq!(result.server_info.version, "1.5.0");
742        assert_eq!(
743            result.instructions,
744            Some("Server from OpenAPI spec".to_string())
745        );
746    }
747
748    #[test]
749    fn test_get_info_fallback_hierarchy_defaults() {
750        let server = Server::new(
751            serde_json::Value::Null,
752            url::Url::parse("http://example.com").unwrap(),
753            None,
754            None,
755            false,
756            false,
757        );
758
759        let result = server.get_info();
760
761        assert_eq!(result.server_info.name, "OpenAPI MCP Server");
762        assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
763        assert_eq!(
764            result.instructions,
765            Some("Exposes OpenAPI endpoints as MCP tools".to_string())
766        );
767    }
768
769    #[test]
770    fn test_get_info_fallback_hierarchy_mixed() {
771        let openapi_spec = json!({
772            "openapi": "3.0.0",
773            "info": {
774                "title": "OpenAPI Server",
775                "version": "2.5.0",
776                "description": "Server from OpenAPI spec"
777            },
778            "paths": {}
779        });
780
781        let mut server = Server::new(
782            openapi_spec,
783            url::Url::parse("http://example.com").unwrap(),
784            None,
785            None,
786            false,
787            false,
788        );
789
790        // Set custom name and instructions, leave version to fallback to OpenAPI
791        server.name = Some("Custom Server".to_string());
792        server.instructions = Some("Custom instructions".to_string());
793
794        let result = server.get_info();
795
796        // Custom name takes precedence
797        assert_eq!(result.server_info.name, "Custom Server");
798        // OpenAPI version is used
799        assert_eq!(result.server_info.version, "2.5.0");
800        // Custom instructions take precedence
801        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
802    }
803}