rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    handler::server::ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::error::Error;
18use crate::tool::{Tool, ToolCollection, ToolMetadata};
19use crate::{
20    config::{Authorization, AuthorizationMode},
21    spec::Filters,
22};
23use tracing::{debug, info, info_span, warn};
24
25#[derive(Clone, Builder)]
26pub struct Server {
27    pub openapi_spec: serde_json::Value,
28    #[builder(default)]
29    pub tool_collection: ToolCollection,
30    pub base_url: Url,
31    pub default_headers: Option<HeaderMap>,
32    pub filters: Option<Filters>,
33    #[builder(default)]
34    pub authorization_mode: AuthorizationMode,
35    pub name: Option<String>,
36    pub version: Option<String>,
37    pub title: Option<String>,
38    pub instructions: Option<String>,
39    #[builder(default)]
40    pub skip_tool_descriptions: bool,
41    #[builder(default)]
42    pub skip_parameter_descriptions: bool,
43}
44
45impl Server {
46    /// Create a new Server instance with required parameters
47    pub fn new(
48        openapi_spec: serde_json::Value,
49        base_url: Url,
50        default_headers: Option<HeaderMap>,
51        filters: Option<Filters>,
52        skip_tool_descriptions: bool,
53        skip_parameter_descriptions: bool,
54    ) -> Self {
55        Self {
56            openapi_spec,
57            tool_collection: ToolCollection::new(),
58            base_url,
59            default_headers,
60            filters,
61            authorization_mode: AuthorizationMode::default(),
62            name: None,
63            version: None,
64            title: None,
65            instructions: None,
66            skip_tool_descriptions,
67            skip_parameter_descriptions,
68        }
69    }
70
71    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the spec cannot be parsed or tools cannot be generated
76    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
77        let span = info_span!("tool_registration");
78        let _enter = span.enter();
79
80        // Parse the OpenAPI specification
81        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
82
83        // Generate OpenApiTool instances directly
84        let tools = spec.to_openapi_tools(
85            self.filters.as_ref(),
86            Some(self.base_url.clone()),
87            self.default_headers.clone(),
88            self.skip_tool_descriptions,
89            self.skip_parameter_descriptions,
90        )?;
91
92        self.tool_collection = ToolCollection::from_tools(tools);
93
94        info!(
95            tool_count = self.tool_collection.len(),
96            "Loaded tools from OpenAPI spec"
97        );
98
99        Ok(())
100    }
101
102    /// Get the number of loaded tools
103    #[must_use]
104    pub fn tool_count(&self) -> usize {
105        self.tool_collection.len()
106    }
107
108    /// Get all tool names
109    #[must_use]
110    pub fn get_tool_names(&self) -> Vec<String> {
111        self.tool_collection.get_tool_names()
112    }
113
114    /// Check if a specific tool exists
115    #[must_use]
116    pub fn has_tool(&self, name: &str) -> bool {
117        self.tool_collection.has_tool(name)
118    }
119
120    /// Get a tool by name
121    #[must_use]
122    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
123        self.tool_collection.get_tool(name)
124    }
125
126    /// Get tool metadata by name
127    #[must_use]
128    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
129        self.get_tool(name).map(|tool| &tool.metadata)
130    }
131
132    /// Set the authorization mode for the server
133    pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
134        self.authorization_mode = mode;
135    }
136
137    /// Get the current authorization mode
138    pub fn authorization_mode(&self) -> AuthorizationMode {
139        self.authorization_mode
140    }
141
142    /// Get basic tool statistics
143    #[must_use]
144    pub fn get_tool_stats(&self) -> String {
145        self.tool_collection.get_stats()
146    }
147
148    /// Simple validation - check that tools are loaded
149    ///
150    /// # Errors
151    ///
152    /// Returns an error if no tools are loaded
153    pub fn validate_registry(&self) -> Result<(), Error> {
154        if self.tool_collection.is_empty() {
155            return Err(Error::McpError("No tools loaded".to_string()));
156        }
157        Ok(())
158    }
159
160    /// Extract title from OpenAPI spec info section
161    fn extract_openapi_title(&self) -> Option<String> {
162        self.openapi_spec
163            .get("info")?
164            .get("title")?
165            .as_str()
166            .map(|s| s.to_string())
167    }
168
169    /// Extract version from OpenAPI spec info section
170    fn extract_openapi_version(&self) -> Option<String> {
171        self.openapi_spec
172            .get("info")?
173            .get("version")?
174            .as_str()
175            .map(|s| s.to_string())
176    }
177
178    /// Extract description from OpenAPI spec info section
179    fn extract_openapi_description(&self) -> Option<String> {
180        self.openapi_spec
181            .get("info")?
182            .get("description")?
183            .as_str()
184            .map(|s| s.to_string())
185    }
186
187    /// Extract display title from OpenAPI spec info section
188    /// First checks for x-display-title extension, then derives from title
189    fn extract_openapi_display_title(&self) -> Option<String> {
190        // First check for x-display-title extension
191        if let Some(display_title) = self
192            .openapi_spec
193            .get("info")
194            .and_then(|info| info.get("x-display-title"))
195            .and_then(|t| t.as_str())
196        {
197            return Some(display_title.to_string());
198        }
199
200        // Fallback: enhance the title with "Server" suffix if not already present
201        self.extract_openapi_title().map(|title| {
202            if title.to_lowercase().contains("server") {
203                title
204            } else {
205                format!("{} Server", title)
206            }
207        })
208    }
209}
210
211impl ServerHandler for Server {
212    fn get_info(&self) -> InitializeResult {
213        // 3-level fallback for server name: custom -> OpenAPI spec -> default
214        let server_name = self
215            .name
216            .clone()
217            .or_else(|| self.extract_openapi_title())
218            .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
219
220        // 3-level fallback for server version: custom -> OpenAPI spec -> crate version
221        let server_version = self
222            .version
223            .clone()
224            .or_else(|| self.extract_openapi_version())
225            .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
226
227        // 3-level fallback for title: custom -> OpenAPI-derived -> None
228        let server_title = self
229            .title
230            .clone()
231            .or_else(|| self.extract_openapi_display_title());
232
233        // 3-level fallback for instructions: custom -> OpenAPI spec -> default
234        let instructions = self
235            .instructions
236            .clone()
237            .or_else(|| self.extract_openapi_description())
238            .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
239
240        InitializeResult {
241            protocol_version: ProtocolVersion::V_2024_11_05,
242            server_info: Implementation {
243                name: server_name,
244                version: server_version,
245                title: server_title,
246                icons: None,
247                website_url: None,
248            },
249            capabilities: ServerCapabilities {
250                tools: Some(ToolsCapability {
251                    list_changed: Some(false),
252                }),
253                ..Default::default()
254            },
255            instructions,
256        }
257    }
258
259    async fn list_tools(
260        &self,
261        _request: Option<PaginatedRequestParam>,
262        _context: RequestContext<RoleServer>,
263    ) -> Result<ListToolsResult, ErrorData> {
264        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
265        let _enter = span.enter();
266
267        debug!("Processing MCP list_tools request");
268
269        // Delegate to tool collection for MCP tool conversion
270        let tools = self.tool_collection.to_mcp_tools();
271
272        info!(
273            returned_tools = tools.len(),
274            "MCP list_tools request completed successfully"
275        );
276
277        Ok(ListToolsResult {
278            meta: None,
279            tools,
280            next_cursor: None,
281        })
282    }
283
284    async fn call_tool(
285        &self,
286        request: CallToolRequestParam,
287        context: RequestContext<RoleServer>,
288    ) -> Result<CallToolResult, ErrorData> {
289        let span = info_span!(
290            "call_tool",
291            tool_name = %request.name
292        );
293        let _enter = span.enter();
294
295        debug!(
296            tool_name = %request.name,
297            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
298            "Processing MCP call_tool request"
299        );
300
301        let arguments = request.arguments.unwrap_or_default();
302        let arguments_value = Value::Object(arguments);
303
304        // Extract authorization header from context extensions
305        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
306
307        if auth_header.is_some() {
308            debug!("Authorization header is present");
309        }
310
311        // Create Authorization enum from mode and header
312        let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
313
314        // Delegate all tool validation and execution to the tool collection
315        match self
316            .tool_collection
317            .call_tool(&request.name, &arguments_value, authorization)
318            .await
319        {
320            Ok(result) => {
321                info!(
322                    tool_name = %request.name,
323                    success = true,
324                    "MCP call_tool request completed successfully"
325                );
326                Ok(result)
327            }
328            Err(e) => {
329                warn!(
330                    tool_name = %request.name,
331                    success = false,
332                    error = %e,
333                    "MCP call_tool request failed"
334                );
335                // Convert ToolCallError to ErrorData and return as error
336                Err(e.into())
337            }
338        }
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::error::ToolCallValidationError;
346    use crate::{HttpClient, ToolCallError, ToolMetadata};
347    use serde_json::json;
348
349    #[test]
350    fn test_tool_not_found_error_with_suggestions() {
351        // Create test tool metadata
352        let tool1_metadata = ToolMetadata {
353            name: "getPetById".to_string(),
354            title: Some("Get Pet by ID".to_string()),
355            description: Some("Find pet by ID".to_string()),
356            parameters: json!({
357                "type": "object",
358                "properties": {
359                    "petId": {
360                        "type": "integer"
361                    }
362                },
363                "required": ["petId"]
364            }),
365            output_schema: None,
366            method: "GET".to_string(),
367            path: "/pet/{petId}".to_string(),
368            security: None,
369            parameter_mappings: std::collections::HashMap::new(),
370        };
371
372        let tool2_metadata = ToolMetadata {
373            name: "getPetsByStatus".to_string(),
374            title: Some("Find Pets by Status".to_string()),
375            description: Some("Find pets by status".to_string()),
376            parameters: json!({
377                "type": "object",
378                "properties": {
379                    "status": {
380                        "type": "array",
381                        "items": {
382                            "type": "string"
383                        }
384                    }
385                },
386                "required": ["status"]
387            }),
388            output_schema: None,
389            method: "GET".to_string(),
390            path: "/pet/findByStatus".to_string(),
391            security: None,
392            parameter_mappings: std::collections::HashMap::new(),
393        };
394
395        // Create OpenApiTool instances
396        let http_client = HttpClient::new();
397        let tool1 = Tool::new(tool1_metadata, http_client.clone()).unwrap();
398        let tool2 = Tool::new(tool2_metadata, http_client.clone()).unwrap();
399
400        // Create server with tools
401        let mut server = Server::new(
402            serde_json::Value::Null,
403            url::Url::parse("http://example.com").unwrap(),
404            None,
405            None,
406            false,
407            false,
408        );
409        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
410
411        // Test: Create ToolNotFound error with a typo
412        let tool_names = server.get_tool_names();
413        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
414
415        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
416            "getPetByID".to_string(),
417            &tool_name_refs,
418        ));
419        let error_data: ErrorData = error.into();
420        let error_json = serde_json::to_value(&error_data).unwrap();
421
422        // Snapshot the error to verify suggestions
423        insta::assert_json_snapshot!(error_json);
424    }
425
426    #[test]
427    fn test_tool_not_found_error_no_suggestions() {
428        // Create test tool metadata
429        let tool_metadata = ToolMetadata {
430            name: "getPetById".to_string(),
431            title: Some("Get Pet by ID".to_string()),
432            description: Some("Find pet by ID".to_string()),
433            parameters: json!({
434                "type": "object",
435                "properties": {
436                    "petId": {
437                        "type": "integer"
438                    }
439                },
440                "required": ["petId"]
441            }),
442            output_schema: None,
443            method: "GET".to_string(),
444            path: "/pet/{petId}".to_string(),
445            security: None,
446            parameter_mappings: std::collections::HashMap::new(),
447        };
448
449        // Create OpenApiTool instance
450        let tool = Tool::new(tool_metadata, HttpClient::new()).unwrap();
451
452        // Create server with tool
453        let mut server = Server::new(
454            serde_json::Value::Null,
455            url::Url::parse("http://example.com").unwrap(),
456            None,
457            None,
458            false,
459            false,
460        );
461        server.tool_collection = ToolCollection::from_tools(vec![tool]);
462
463        // Test: Create ToolNotFound error with unrelated name
464        let tool_names = server.get_tool_names();
465        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
466
467        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
468            "completelyUnrelatedToolName".to_string(),
469            &tool_name_refs,
470        ));
471        let error_data: ErrorData = error.into();
472        let error_json = serde_json::to_value(&error_data).unwrap();
473
474        // Snapshot the error to verify no suggestions
475        insta::assert_json_snapshot!(error_json);
476    }
477
478    #[test]
479    fn test_validation_error_converted_to_error_data() {
480        // Test that validation errors are properly converted to ErrorData
481        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
482            violations: vec![crate::error::ValidationError::invalid_parameter(
483                "page".to_string(),
484                &["page_number".to_string(), "page_size".to_string()],
485            )],
486        });
487
488        let error_data: ErrorData = error.into();
489        let error_json = serde_json::to_value(&error_data).unwrap();
490
491        // Verify the basic structure
492        assert_eq!(error_json["code"], -32602); // Invalid params error code
493
494        // Snapshot the full error to verify the new error message format
495        insta::assert_json_snapshot!(error_json);
496    }
497
498    #[test]
499    fn test_extract_openapi_info_with_full_spec() {
500        let openapi_spec = json!({
501            "openapi": "3.0.0",
502            "info": {
503                "title": "Pet Store API",
504                "version": "2.1.0",
505                "description": "A sample API for managing pets"
506            },
507            "paths": {}
508        });
509
510        let server = Server::new(
511            openapi_spec,
512            url::Url::parse("http://example.com").unwrap(),
513            None,
514            None,
515            false,
516            false,
517        );
518
519        assert_eq!(
520            server.extract_openapi_title(),
521            Some("Pet Store API".to_string())
522        );
523        assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
524        assert_eq!(
525            server.extract_openapi_description(),
526            Some("A sample API for managing pets".to_string())
527        );
528    }
529
530    #[test]
531    fn test_extract_openapi_info_with_minimal_spec() {
532        let openapi_spec = json!({
533            "openapi": "3.0.0",
534            "info": {
535                "title": "My API",
536                "version": "1.0.0"
537            },
538            "paths": {}
539        });
540
541        let server = Server::new(
542            openapi_spec,
543            url::Url::parse("http://example.com").unwrap(),
544            None,
545            None,
546            false,
547            false,
548        );
549
550        assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
551        assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
552        assert_eq!(server.extract_openapi_description(), None);
553    }
554
555    #[test]
556    fn test_extract_openapi_info_with_invalid_spec() {
557        let openapi_spec = json!({
558            "invalid": "spec"
559        });
560
561        let server = Server::new(
562            openapi_spec,
563            url::Url::parse("http://example.com").unwrap(),
564            None,
565            None,
566            false,
567            false,
568        );
569
570        assert_eq!(server.extract_openapi_title(), None);
571        assert_eq!(server.extract_openapi_version(), None);
572        assert_eq!(server.extract_openapi_description(), None);
573    }
574
575    #[test]
576    fn test_get_info_fallback_hierarchy_custom_metadata() {
577        let server = Server::new(
578            serde_json::Value::Null,
579            url::Url::parse("http://example.com").unwrap(),
580            None,
581            None,
582            false,
583            false,
584        );
585
586        // Set custom metadata directly
587        let mut server = server;
588        server.name = Some("Custom Server".to_string());
589        server.version = Some("3.0.0".to_string());
590        server.instructions = Some("Custom instructions".to_string());
591
592        let result = server.get_info();
593
594        assert_eq!(result.server_info.name, "Custom Server");
595        assert_eq!(result.server_info.version, "3.0.0");
596        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
597    }
598
599    #[test]
600    fn test_get_info_fallback_hierarchy_openapi_spec() {
601        let openapi_spec = json!({
602            "openapi": "3.0.0",
603            "info": {
604                "title": "OpenAPI Server",
605                "version": "1.5.0",
606                "description": "Server from OpenAPI spec"
607            },
608            "paths": {}
609        });
610
611        let server = Server::new(
612            openapi_spec,
613            url::Url::parse("http://example.com").unwrap(),
614            None,
615            None,
616            false,
617            false,
618        );
619
620        let result = server.get_info();
621
622        assert_eq!(result.server_info.name, "OpenAPI Server");
623        assert_eq!(result.server_info.version, "1.5.0");
624        assert_eq!(
625            result.instructions,
626            Some("Server from OpenAPI spec".to_string())
627        );
628    }
629
630    #[test]
631    fn test_get_info_fallback_hierarchy_defaults() {
632        let server = Server::new(
633            serde_json::Value::Null,
634            url::Url::parse("http://example.com").unwrap(),
635            None,
636            None,
637            false,
638            false,
639        );
640
641        let result = server.get_info();
642
643        assert_eq!(result.server_info.name, "OpenAPI MCP Server");
644        assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
645        assert_eq!(
646            result.instructions,
647            Some("Exposes OpenAPI endpoints as MCP tools".to_string())
648        );
649    }
650
651    #[test]
652    fn test_get_info_fallback_hierarchy_mixed() {
653        let openapi_spec = json!({
654            "openapi": "3.0.0",
655            "info": {
656                "title": "OpenAPI Server",
657                "version": "2.5.0",
658                "description": "Server from OpenAPI spec"
659            },
660            "paths": {}
661        });
662
663        let mut server = Server::new(
664            openapi_spec,
665            url::Url::parse("http://example.com").unwrap(),
666            None,
667            None,
668            false,
669            false,
670        );
671
672        // Set custom name and instructions, leave version to fallback to OpenAPI
673        server.name = Some("Custom Server".to_string());
674        server.instructions = Some("Custom instructions".to_string());
675
676        let result = server.get_info();
677
678        // Custom name takes precedence
679        assert_eq!(result.server_info.name, "Custom Server");
680        // OpenAPI version is used
681        assert_eq!(result.server_info.version, "2.5.0");
682        // Custom instructions take precedence
683        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
684    }
685}