rmcp_openapi/
server.rs

1use bon::Builder;
2use rmcp::{
3    RoleServer, ServerHandler,
4    model::{
5        CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6        ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7        ToolsCapability,
8    },
9    service::RequestContext,
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::config::{Authorization, AuthorizationMode};
18use crate::error::Error;
19use crate::tool::{Tool, ToolCollection, ToolMetadata};
20use tracing::{debug, info, info_span, warn};
21
22#[derive(Clone, Builder)]
23pub struct Server {
24    pub openapi_spec: serde_json::Value,
25    #[builder(default)]
26    pub tool_collection: ToolCollection,
27    pub base_url: Url,
28    pub default_headers: Option<HeaderMap>,
29    pub tag_filter: Option<Vec<String>>,
30    pub method_filter: Option<Vec<reqwest::Method>>,
31    #[builder(default)]
32    pub authorization_mode: AuthorizationMode,
33    pub server_name: Option<String>,
34    pub server_version: Option<String>,
35    pub server_instructions: Option<String>,
36}
37
38impl Server {
39    /// Create a new Server instance with required parameters
40    pub fn new(
41        openapi_spec: serde_json::Value,
42        base_url: Url,
43        default_headers: Option<HeaderMap>,
44        tag_filter: Option<Vec<String>>,
45        method_filter: Option<Vec<reqwest::Method>>,
46    ) -> Self {
47        Self {
48            openapi_spec,
49            tool_collection: ToolCollection::new(),
50            base_url,
51            default_headers,
52            tag_filter,
53            method_filter,
54            authorization_mode: AuthorizationMode::default(),
55            server_name: None,
56            server_version: None,
57            server_instructions: None,
58        }
59    }
60
61    /// Parse the `OpenAPI` specification and convert to OpenApiTool instances
62    ///
63    /// # Errors
64    ///
65    /// Returns an error if the spec cannot be parsed or tools cannot be generated
66    pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
67        let span = info_span!("tool_registration");
68        let _enter = span.enter();
69
70        // Parse the OpenAPI specification
71        let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
72
73        // Generate OpenApiTool instances directly
74        let tools = spec.to_openapi_tools(
75            self.tag_filter.as_deref(),
76            self.method_filter.as_deref(),
77            Some(self.base_url.clone()),
78            self.default_headers.clone(),
79        )?;
80
81        self.tool_collection = ToolCollection::from_tools(tools);
82
83        info!(
84            tool_count = self.tool_collection.len(),
85            "Loaded tools from OpenAPI spec"
86        );
87
88        Ok(())
89    }
90
91    /// Get the number of loaded tools
92    #[must_use]
93    pub fn tool_count(&self) -> usize {
94        self.tool_collection.len()
95    }
96
97    /// Get all tool names
98    #[must_use]
99    pub fn get_tool_names(&self) -> Vec<String> {
100        self.tool_collection.get_tool_names()
101    }
102
103    /// Check if a specific tool exists
104    #[must_use]
105    pub fn has_tool(&self, name: &str) -> bool {
106        self.tool_collection.has_tool(name)
107    }
108
109    /// Get a tool by name
110    #[must_use]
111    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
112        self.tool_collection.get_tool(name)
113    }
114
115    /// Get tool metadata by name
116    #[must_use]
117    pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
118        self.get_tool(name).map(|tool| &tool.metadata)
119    }
120
121    /// Set the authorization mode for the server
122    pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
123        self.authorization_mode = mode;
124    }
125
126    /// Get the current authorization mode
127    pub fn authorization_mode(&self) -> AuthorizationMode {
128        self.authorization_mode
129    }
130
131    /// Get basic tool statistics
132    #[must_use]
133    pub fn get_tool_stats(&self) -> String {
134        self.tool_collection.get_stats()
135    }
136
137    /// Simple validation - check that tools are loaded
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if no tools are loaded
142    pub fn validate_registry(&self) -> Result<(), Error> {
143        if self.tool_collection.is_empty() {
144            return Err(Error::McpError("No tools loaded".to_string()));
145        }
146        Ok(())
147    }
148
149    /// Extract title from OpenAPI spec info section
150    fn extract_openapi_title(&self) -> Option<String> {
151        self.openapi_spec
152            .get("info")?
153            .get("title")?
154            .as_str()
155            .map(|s| s.to_string())
156    }
157
158    /// Extract version from OpenAPI spec info section
159    fn extract_openapi_version(&self) -> Option<String> {
160        self.openapi_spec
161            .get("info")?
162            .get("version")?
163            .as_str()
164            .map(|s| s.to_string())
165    }
166
167    /// Extract description from OpenAPI spec info section
168    fn extract_openapi_description(&self) -> Option<String> {
169        self.openapi_spec
170            .get("info")?
171            .get("description")?
172            .as_str()
173            .map(|s| s.to_string())
174    }
175}
176
177impl ServerHandler for Server {
178    fn get_info(&self) -> InitializeResult {
179        // 3-level fallback for server name: custom -> OpenAPI spec -> default
180        let server_name = self
181            .server_name
182            .clone()
183            .or_else(|| self.extract_openapi_title())
184            .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
185
186        // 3-level fallback for server version: custom -> OpenAPI spec -> crate version
187        let server_version = self
188            .server_version
189            .clone()
190            .or_else(|| self.extract_openapi_version())
191            .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
192
193        // 3-level fallback for instructions: custom -> OpenAPI spec -> default
194        let instructions = self
195            .server_instructions
196            .clone()
197            .or_else(|| self.extract_openapi_description())
198            .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
199
200        InitializeResult {
201            protocol_version: ProtocolVersion::V_2024_11_05,
202            server_info: Implementation {
203                name: server_name,
204                version: server_version,
205            },
206            capabilities: ServerCapabilities {
207                tools: Some(ToolsCapability {
208                    list_changed: Some(false),
209                }),
210                ..Default::default()
211            },
212            instructions,
213        }
214    }
215
216    async fn list_tools(
217        &self,
218        _request: Option<PaginatedRequestParam>,
219        _context: RequestContext<RoleServer>,
220    ) -> Result<ListToolsResult, ErrorData> {
221        let span = info_span!("list_tools", tool_count = self.tool_collection.len());
222        let _enter = span.enter();
223
224        debug!("Processing MCP list_tools request");
225
226        // Delegate to tool collection for MCP tool conversion
227        let tools = self.tool_collection.to_mcp_tools();
228
229        info!(
230            returned_tools = tools.len(),
231            "MCP list_tools request completed successfully"
232        );
233
234        Ok(ListToolsResult {
235            tools,
236            next_cursor: None,
237        })
238    }
239
240    async fn call_tool(
241        &self,
242        request: CallToolRequestParam,
243        context: RequestContext<RoleServer>,
244    ) -> Result<CallToolResult, ErrorData> {
245        let span = info_span!(
246            "call_tool",
247            tool_name = %request.name
248        );
249        let _enter = span.enter();
250
251        debug!(
252            tool_name = %request.name,
253            has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
254            "Processing MCP call_tool request"
255        );
256
257        let arguments = request.arguments.unwrap_or_default();
258        let arguments_value = Value::Object(arguments);
259
260        // Extract authorization header from context extensions
261        let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
262
263        if auth_header.is_some() {
264            debug!("Authorization header is present");
265        }
266
267        // Create Authorization enum from mode and header
268        let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
269
270        // Delegate all tool validation and execution to the tool collection
271        match self
272            .tool_collection
273            .call_tool(&request.name, &arguments_value, authorization)
274            .await
275        {
276            Ok(result) => {
277                info!(
278                    tool_name = %request.name,
279                    success = true,
280                    "MCP call_tool request completed successfully"
281                );
282                Ok(result)
283            }
284            Err(e) => {
285                warn!(
286                    tool_name = %request.name,
287                    success = false,
288                    error = %e,
289                    "MCP call_tool request failed"
290                );
291                // Convert ToolCallError to ErrorData and return as error
292                Err(e.into())
293            }
294        }
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use crate::error::ToolCallValidationError;
302    use crate::{ToolCallError, ToolMetadata};
303    use serde_json::json;
304
305    #[test]
306    fn test_tool_not_found_error_with_suggestions() {
307        // Create test tool metadata
308        let tool1_metadata = ToolMetadata {
309            name: "getPetById".to_string(),
310            title: Some("Get Pet by ID".to_string()),
311            description: "Find pet by ID".to_string(),
312            parameters: json!({
313                "type": "object",
314                "properties": {
315                    "petId": {
316                        "type": "integer"
317                    }
318                },
319                "required": ["petId"]
320            }),
321            output_schema: None,
322            method: "GET".to_string(),
323            path: "/pet/{petId}".to_string(),
324            security: None,
325        };
326
327        let tool2_metadata = ToolMetadata {
328            name: "getPetsByStatus".to_string(),
329            title: Some("Find Pets by Status".to_string()),
330            description: "Find pets by status".to_string(),
331            parameters: json!({
332                "type": "object",
333                "properties": {
334                    "status": {
335                        "type": "array",
336                        "items": {
337                            "type": "string"
338                        }
339                    }
340                },
341                "required": ["status"]
342            }),
343            output_schema: None,
344            method: "GET".to_string(),
345            path: "/pet/findByStatus".to_string(),
346            security: None,
347        };
348
349        // Create OpenApiTool instances
350        let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
351        let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
352
353        // Create server with tools
354        let mut server = Server::new(
355            serde_json::Value::Null,
356            url::Url::parse("http://example.com").unwrap(),
357            None,
358            None,
359            None,
360        );
361        server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
362
363        // Test: Create ToolNotFound error with a typo
364        let tool_names = server.get_tool_names();
365        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
366
367        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
368            "getPetByID".to_string(),
369            &tool_name_refs,
370        ));
371        let error_data: ErrorData = error.into();
372        let error_json = serde_json::to_value(&error_data).unwrap();
373
374        // Snapshot the error to verify suggestions
375        insta::assert_json_snapshot!(error_json);
376    }
377
378    #[test]
379    fn test_tool_not_found_error_no_suggestions() {
380        // Create test tool metadata
381        let tool_metadata = ToolMetadata {
382            name: "getPetById".to_string(),
383            title: Some("Get Pet by ID".to_string()),
384            description: "Find pet by ID".to_string(),
385            parameters: json!({
386                "type": "object",
387                "properties": {
388                    "petId": {
389                        "type": "integer"
390                    }
391                },
392                "required": ["petId"]
393            }),
394            output_schema: None,
395            method: "GET".to_string(),
396            path: "/pet/{petId}".to_string(),
397            security: None,
398        };
399
400        // Create OpenApiTool instance
401        let tool = Tool::new(tool_metadata, None, None).unwrap();
402
403        // Create server with tool
404        let mut server = Server::new(
405            serde_json::Value::Null,
406            url::Url::parse("http://example.com").unwrap(),
407            None,
408            None,
409            None,
410        );
411        server.tool_collection = ToolCollection::from_tools(vec![tool]);
412
413        // Test: Create ToolNotFound error with unrelated name
414        let tool_names = server.get_tool_names();
415        let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
416
417        let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
418            "completelyUnrelatedToolName".to_string(),
419            &tool_name_refs,
420        ));
421        let error_data: ErrorData = error.into();
422        let error_json = serde_json::to_value(&error_data).unwrap();
423
424        // Snapshot the error to verify no suggestions
425        insta::assert_json_snapshot!(error_json);
426    }
427
428    #[test]
429    fn test_validation_error_converted_to_error_data() {
430        // Test that validation errors are properly converted to ErrorData
431        let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
432            violations: vec![crate::error::ValidationError::invalid_parameter(
433                "page".to_string(),
434                &["page_number".to_string(), "page_size".to_string()],
435            )],
436        });
437
438        let error_data: ErrorData = error.into();
439        let error_json = serde_json::to_value(&error_data).unwrap();
440
441        // Verify the basic structure
442        assert_eq!(error_json["code"], -32602); // Invalid params error code
443
444        // Snapshot the full error to verify the new error message format
445        insta::assert_json_snapshot!(error_json);
446    }
447
448    #[test]
449    fn test_extract_openapi_info_with_full_spec() {
450        let openapi_spec = json!({
451            "openapi": "3.0.0",
452            "info": {
453                "title": "Pet Store API",
454                "version": "2.1.0",
455                "description": "A sample API for managing pets"
456            },
457            "paths": {}
458        });
459
460        let server = Server::new(
461            openapi_spec,
462            url::Url::parse("http://example.com").unwrap(),
463            None,
464            None,
465            None,
466        );
467
468        assert_eq!(
469            server.extract_openapi_title(),
470            Some("Pet Store API".to_string())
471        );
472        assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
473        assert_eq!(
474            server.extract_openapi_description(),
475            Some("A sample API for managing pets".to_string())
476        );
477    }
478
479    #[test]
480    fn test_extract_openapi_info_with_minimal_spec() {
481        let openapi_spec = json!({
482            "openapi": "3.0.0",
483            "info": {
484                "title": "My API",
485                "version": "1.0.0"
486            },
487            "paths": {}
488        });
489
490        let server = Server::new(
491            openapi_spec,
492            url::Url::parse("http://example.com").unwrap(),
493            None,
494            None,
495            None,
496        );
497
498        assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
499        assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
500        assert_eq!(server.extract_openapi_description(), None);
501    }
502
503    #[test]
504    fn test_extract_openapi_info_with_invalid_spec() {
505        let openapi_spec = json!({
506            "invalid": "spec"
507        });
508
509        let server = Server::new(
510            openapi_spec,
511            url::Url::parse("http://example.com").unwrap(),
512            None,
513            None,
514            None,
515        );
516
517        assert_eq!(server.extract_openapi_title(), None);
518        assert_eq!(server.extract_openapi_version(), None);
519        assert_eq!(server.extract_openapi_description(), None);
520    }
521
522    #[test]
523    fn test_get_info_fallback_hierarchy_custom_metadata() {
524        let server = Server::new(
525            serde_json::Value::Null,
526            url::Url::parse("http://example.com").unwrap(),
527            None,
528            None,
529            None,
530        );
531
532        // Set custom metadata directly
533        let mut server = server;
534        server.server_name = Some("Custom Server".to_string());
535        server.server_version = Some("3.0.0".to_string());
536        server.server_instructions = Some("Custom instructions".to_string());
537
538        let result = server.get_info();
539
540        assert_eq!(result.server_info.name, "Custom Server");
541        assert_eq!(result.server_info.version, "3.0.0");
542        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
543    }
544
545    #[test]
546    fn test_get_info_fallback_hierarchy_openapi_spec() {
547        let openapi_spec = json!({
548            "openapi": "3.0.0",
549            "info": {
550                "title": "OpenAPI Server",
551                "version": "1.5.0",
552                "description": "Server from OpenAPI spec"
553            },
554            "paths": {}
555        });
556
557        let server = Server::new(
558            openapi_spec,
559            url::Url::parse("http://example.com").unwrap(),
560            None,
561            None,
562            None,
563        );
564
565        let result = server.get_info();
566
567        assert_eq!(result.server_info.name, "OpenAPI Server");
568        assert_eq!(result.server_info.version, "1.5.0");
569        assert_eq!(
570            result.instructions,
571            Some("Server from OpenAPI spec".to_string())
572        );
573    }
574
575    #[test]
576    fn test_get_info_fallback_hierarchy_defaults() {
577        let server = Server::new(
578            serde_json::Value::Null,
579            url::Url::parse("http://example.com").unwrap(),
580            None,
581            None,
582            None,
583        );
584
585        let result = server.get_info();
586
587        assert_eq!(result.server_info.name, "OpenAPI MCP Server");
588        assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
589        assert_eq!(
590            result.instructions,
591            Some("Exposes OpenAPI endpoints as MCP tools".to_string())
592        );
593    }
594
595    #[test]
596    fn test_get_info_fallback_hierarchy_mixed() {
597        let openapi_spec = json!({
598            "openapi": "3.0.0",
599            "info": {
600                "title": "OpenAPI Server",
601                "version": "2.5.0",
602                "description": "Server from OpenAPI spec"
603            },
604            "paths": {}
605        });
606
607        let mut server = Server::new(
608            openapi_spec,
609            url::Url::parse("http://example.com").unwrap(),
610            None,
611            None,
612            None,
613        );
614
615        // Set custom name and instructions, leave version to fallback to OpenAPI
616        server.server_name = Some("Custom Server".to_string());
617        server.server_instructions = Some("Custom instructions".to_string());
618
619        let result = server.get_info();
620
621        // Custom name takes precedence
622        assert_eq!(result.server_info.name, "Custom Server");
623        // OpenAPI version is used
624        assert_eq!(result.server_info.version, "2.5.0");
625        // Custom instructions take precedence
626        assert_eq!(result.instructions, Some("Custom instructions".to_string()));
627    }
628}