1use bon::Builder;
2use rmcp::{
3 handler::server::ServerHandler,
4 model::{
5 CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6 ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7 ToolsCapability,
8 },
9 service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::config::{Authorization, AuthorizationMode};
18use crate::error::Error;
19use crate::tool::{Tool, ToolCollection, ToolMetadata};
20use tracing::{debug, info, info_span, warn};
21
22#[derive(Clone, Builder)]
23pub struct Server {
24 pub openapi_spec: serde_json::Value,
25 #[builder(default)]
26 pub tool_collection: ToolCollection,
27 pub base_url: Url,
28 pub default_headers: Option<HeaderMap>,
29 pub tag_filter: Option<Vec<String>>,
30 pub method_filter: Option<Vec<reqwest::Method>>,
31 #[builder(default)]
32 pub authorization_mode: AuthorizationMode,
33 pub name: Option<String>,
34 pub version: Option<String>,
35 pub title: Option<String>,
36 pub instructions: Option<String>,
37 #[builder(default)]
38 pub skip_tool_descriptions: bool,
39 #[builder(default)]
40 pub skip_parameter_descriptions: bool,
41}
42
43impl Server {
44 pub fn new(
46 openapi_spec: serde_json::Value,
47 base_url: Url,
48 default_headers: Option<HeaderMap>,
49 tag_filter: Option<Vec<String>>,
50 method_filter: Option<Vec<reqwest::Method>>,
51 skip_tool_descriptions: bool,
52 skip_parameter_descriptions: bool,
53 ) -> Self {
54 Self {
55 openapi_spec,
56 tool_collection: ToolCollection::new(),
57 base_url,
58 default_headers,
59 tag_filter,
60 method_filter,
61 authorization_mode: AuthorizationMode::default(),
62 name: None,
63 version: None,
64 title: None,
65 instructions: None,
66 skip_tool_descriptions,
67 skip_parameter_descriptions,
68 }
69 }
70
71 pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
77 let span = info_span!("tool_registration");
78 let _enter = span.enter();
79
80 let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
82
83 let tools = spec.to_openapi_tools(
85 self.tag_filter.as_deref(),
86 self.method_filter.as_deref(),
87 Some(self.base_url.clone()),
88 self.default_headers.clone(),
89 self.skip_tool_descriptions,
90 self.skip_parameter_descriptions,
91 )?;
92
93 self.tool_collection = ToolCollection::from_tools(tools);
94
95 info!(
96 tool_count = self.tool_collection.len(),
97 "Loaded tools from OpenAPI spec"
98 );
99
100 Ok(())
101 }
102
103 #[must_use]
105 pub fn tool_count(&self) -> usize {
106 self.tool_collection.len()
107 }
108
109 #[must_use]
111 pub fn get_tool_names(&self) -> Vec<String> {
112 self.tool_collection.get_tool_names()
113 }
114
115 #[must_use]
117 pub fn has_tool(&self, name: &str) -> bool {
118 self.tool_collection.has_tool(name)
119 }
120
121 #[must_use]
123 pub fn get_tool(&self, name: &str) -> Option<&Tool> {
124 self.tool_collection.get_tool(name)
125 }
126
127 #[must_use]
129 pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
130 self.get_tool(name).map(|tool| &tool.metadata)
131 }
132
133 pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
135 self.authorization_mode = mode;
136 }
137
138 pub fn authorization_mode(&self) -> AuthorizationMode {
140 self.authorization_mode
141 }
142
143 #[must_use]
145 pub fn get_tool_stats(&self) -> String {
146 self.tool_collection.get_stats()
147 }
148
149 pub fn validate_registry(&self) -> Result<(), Error> {
155 if self.tool_collection.is_empty() {
156 return Err(Error::McpError("No tools loaded".to_string()));
157 }
158 Ok(())
159 }
160
161 fn extract_openapi_title(&self) -> Option<String> {
163 self.openapi_spec
164 .get("info")?
165 .get("title")?
166 .as_str()
167 .map(|s| s.to_string())
168 }
169
170 fn extract_openapi_version(&self) -> Option<String> {
172 self.openapi_spec
173 .get("info")?
174 .get("version")?
175 .as_str()
176 .map(|s| s.to_string())
177 }
178
179 fn extract_openapi_description(&self) -> Option<String> {
181 self.openapi_spec
182 .get("info")?
183 .get("description")?
184 .as_str()
185 .map(|s| s.to_string())
186 }
187
188 fn extract_openapi_display_title(&self) -> Option<String> {
191 if let Some(display_title) = self
193 .openapi_spec
194 .get("info")
195 .and_then(|info| info.get("x-display-title"))
196 .and_then(|t| t.as_str())
197 {
198 return Some(display_title.to_string());
199 }
200
201 self.extract_openapi_title().map(|title| {
203 if title.to_lowercase().contains("server") {
204 title
205 } else {
206 format!("{} Server", title)
207 }
208 })
209 }
210}
211
212impl ServerHandler for Server {
213 fn get_info(&self) -> InitializeResult {
214 let server_name = self
216 .name
217 .clone()
218 .or_else(|| self.extract_openapi_title())
219 .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
220
221 let server_version = self
223 .version
224 .clone()
225 .or_else(|| self.extract_openapi_version())
226 .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
227
228 let server_title = self
230 .title
231 .clone()
232 .or_else(|| self.extract_openapi_display_title());
233
234 let instructions = self
236 .instructions
237 .clone()
238 .or_else(|| self.extract_openapi_description())
239 .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
240
241 InitializeResult {
242 protocol_version: ProtocolVersion::V_2024_11_05,
243 server_info: Implementation {
244 name: server_name,
245 version: server_version,
246 title: server_title,
247 icons: None,
248 website_url: None,
249 },
250 capabilities: ServerCapabilities {
251 tools: Some(ToolsCapability {
252 list_changed: Some(false),
253 }),
254 ..Default::default()
255 },
256 instructions,
257 }
258 }
259
260 async fn list_tools(
261 &self,
262 _request: Option<PaginatedRequestParam>,
263 _context: RequestContext<RoleServer>,
264 ) -> Result<ListToolsResult, ErrorData> {
265 let span = info_span!("list_tools", tool_count = self.tool_collection.len());
266 let _enter = span.enter();
267
268 debug!("Processing MCP list_tools request");
269
270 let tools = self.tool_collection.to_mcp_tools();
272
273 info!(
274 returned_tools = tools.len(),
275 "MCP list_tools request completed successfully"
276 );
277
278 Ok(ListToolsResult {
279 tools,
280 next_cursor: None,
281 })
282 }
283
284 async fn call_tool(
285 &self,
286 request: CallToolRequestParam,
287 context: RequestContext<RoleServer>,
288 ) -> Result<CallToolResult, ErrorData> {
289 let span = info_span!(
290 "call_tool",
291 tool_name = %request.name
292 );
293 let _enter = span.enter();
294
295 debug!(
296 tool_name = %request.name,
297 has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
298 "Processing MCP call_tool request"
299 );
300
301 let arguments = request.arguments.unwrap_or_default();
302 let arguments_value = Value::Object(arguments);
303
304 let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
306
307 if auth_header.is_some() {
308 debug!("Authorization header is present");
309 }
310
311 let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
313
314 match self
316 .tool_collection
317 .call_tool(&request.name, &arguments_value, authorization)
318 .await
319 {
320 Ok(result) => {
321 info!(
322 tool_name = %request.name,
323 success = true,
324 "MCP call_tool request completed successfully"
325 );
326 Ok(result)
327 }
328 Err(e) => {
329 warn!(
330 tool_name = %request.name,
331 success = false,
332 error = %e,
333 "MCP call_tool request failed"
334 );
335 Err(e.into())
337 }
338 }
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::error::ToolCallValidationError;
346 use crate::{ToolCallError, ToolMetadata};
347 use serde_json::json;
348
349 #[test]
350 fn test_tool_not_found_error_with_suggestions() {
351 let tool1_metadata = ToolMetadata {
353 name: "getPetById".to_string(),
354 title: Some("Get Pet by ID".to_string()),
355 description: Some("Find pet by ID".to_string()),
356 parameters: json!({
357 "type": "object",
358 "properties": {
359 "petId": {
360 "type": "integer"
361 }
362 },
363 "required": ["petId"]
364 }),
365 output_schema: None,
366 method: "GET".to_string(),
367 path: "/pet/{petId}".to_string(),
368 security: None,
369 parameter_mappings: std::collections::HashMap::new(),
370 };
371
372 let tool2_metadata = ToolMetadata {
373 name: "getPetsByStatus".to_string(),
374 title: Some("Find Pets by Status".to_string()),
375 description: Some("Find pets by status".to_string()),
376 parameters: json!({
377 "type": "object",
378 "properties": {
379 "status": {
380 "type": "array",
381 "items": {
382 "type": "string"
383 }
384 }
385 },
386 "required": ["status"]
387 }),
388 output_schema: None,
389 method: "GET".to_string(),
390 path: "/pet/findByStatus".to_string(),
391 security: None,
392 parameter_mappings: std::collections::HashMap::new(),
393 };
394
395 let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
397 let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
398
399 let mut server = Server::new(
401 serde_json::Value::Null,
402 url::Url::parse("http://example.com").unwrap(),
403 None,
404 None,
405 None,
406 false,
407 false,
408 );
409 server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
410
411 let tool_names = server.get_tool_names();
413 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
414
415 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
416 "getPetByID".to_string(),
417 &tool_name_refs,
418 ));
419 let error_data: ErrorData = error.into();
420 let error_json = serde_json::to_value(&error_data).unwrap();
421
422 insta::assert_json_snapshot!(error_json);
424 }
425
426 #[test]
427 fn test_tool_not_found_error_no_suggestions() {
428 let tool_metadata = ToolMetadata {
430 name: "getPetById".to_string(),
431 title: Some("Get Pet by ID".to_string()),
432 description: Some("Find pet by ID".to_string()),
433 parameters: json!({
434 "type": "object",
435 "properties": {
436 "petId": {
437 "type": "integer"
438 }
439 },
440 "required": ["petId"]
441 }),
442 output_schema: None,
443 method: "GET".to_string(),
444 path: "/pet/{petId}".to_string(),
445 security: None,
446 parameter_mappings: std::collections::HashMap::new(),
447 };
448
449 let tool = Tool::new(tool_metadata, None, None).unwrap();
451
452 let mut server = Server::new(
454 serde_json::Value::Null,
455 url::Url::parse("http://example.com").unwrap(),
456 None,
457 None,
458 None,
459 false,
460 false,
461 );
462 server.tool_collection = ToolCollection::from_tools(vec![tool]);
463
464 let tool_names = server.get_tool_names();
466 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
467
468 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
469 "completelyUnrelatedToolName".to_string(),
470 &tool_name_refs,
471 ));
472 let error_data: ErrorData = error.into();
473 let error_json = serde_json::to_value(&error_data).unwrap();
474
475 insta::assert_json_snapshot!(error_json);
477 }
478
479 #[test]
480 fn test_validation_error_converted_to_error_data() {
481 let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
483 violations: vec![crate::error::ValidationError::invalid_parameter(
484 "page".to_string(),
485 &["page_number".to_string(), "page_size".to_string()],
486 )],
487 });
488
489 let error_data: ErrorData = error.into();
490 let error_json = serde_json::to_value(&error_data).unwrap();
491
492 assert_eq!(error_json["code"], -32602); insta::assert_json_snapshot!(error_json);
497 }
498
499 #[test]
500 fn test_extract_openapi_info_with_full_spec() {
501 let openapi_spec = json!({
502 "openapi": "3.0.0",
503 "info": {
504 "title": "Pet Store API",
505 "version": "2.1.0",
506 "description": "A sample API for managing pets"
507 },
508 "paths": {}
509 });
510
511 let server = Server::new(
512 openapi_spec,
513 url::Url::parse("http://example.com").unwrap(),
514 None,
515 None,
516 None,
517 false,
518 false,
519 );
520
521 assert_eq!(
522 server.extract_openapi_title(),
523 Some("Pet Store API".to_string())
524 );
525 assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
526 assert_eq!(
527 server.extract_openapi_description(),
528 Some("A sample API for managing pets".to_string())
529 );
530 }
531
532 #[test]
533 fn test_extract_openapi_info_with_minimal_spec() {
534 let openapi_spec = json!({
535 "openapi": "3.0.0",
536 "info": {
537 "title": "My API",
538 "version": "1.0.0"
539 },
540 "paths": {}
541 });
542
543 let server = Server::new(
544 openapi_spec,
545 url::Url::parse("http://example.com").unwrap(),
546 None,
547 None,
548 None,
549 false,
550 false,
551 );
552
553 assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
554 assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
555 assert_eq!(server.extract_openapi_description(), None);
556 }
557
558 #[test]
559 fn test_extract_openapi_info_with_invalid_spec() {
560 let openapi_spec = json!({
561 "invalid": "spec"
562 });
563
564 let server = Server::new(
565 openapi_spec,
566 url::Url::parse("http://example.com").unwrap(),
567 None,
568 None,
569 None,
570 false,
571 false,
572 );
573
574 assert_eq!(server.extract_openapi_title(), None);
575 assert_eq!(server.extract_openapi_version(), None);
576 assert_eq!(server.extract_openapi_description(), None);
577 }
578
579 #[test]
580 fn test_get_info_fallback_hierarchy_custom_metadata() {
581 let server = Server::new(
582 serde_json::Value::Null,
583 url::Url::parse("http://example.com").unwrap(),
584 None,
585 None,
586 None,
587 false,
588 false,
589 );
590
591 let mut server = server;
593 server.name = Some("Custom Server".to_string());
594 server.version = Some("3.0.0".to_string());
595 server.instructions = Some("Custom instructions".to_string());
596
597 let result = server.get_info();
598
599 assert_eq!(result.server_info.name, "Custom Server");
600 assert_eq!(result.server_info.version, "3.0.0");
601 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
602 }
603
604 #[test]
605 fn test_get_info_fallback_hierarchy_openapi_spec() {
606 let openapi_spec = json!({
607 "openapi": "3.0.0",
608 "info": {
609 "title": "OpenAPI Server",
610 "version": "1.5.0",
611 "description": "Server from OpenAPI spec"
612 },
613 "paths": {}
614 });
615
616 let server = Server::new(
617 openapi_spec,
618 url::Url::parse("http://example.com").unwrap(),
619 None,
620 None,
621 None,
622 false,
623 false,
624 );
625
626 let result = server.get_info();
627
628 assert_eq!(result.server_info.name, "OpenAPI Server");
629 assert_eq!(result.server_info.version, "1.5.0");
630 assert_eq!(
631 result.instructions,
632 Some("Server from OpenAPI spec".to_string())
633 );
634 }
635
636 #[test]
637 fn test_get_info_fallback_hierarchy_defaults() {
638 let server = Server::new(
639 serde_json::Value::Null,
640 url::Url::parse("http://example.com").unwrap(),
641 None,
642 None,
643 None,
644 false,
645 false,
646 );
647
648 let result = server.get_info();
649
650 assert_eq!(result.server_info.name, "OpenAPI MCP Server");
651 assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
652 assert_eq!(
653 result.instructions,
654 Some("Exposes OpenAPI endpoints as MCP tools".to_string())
655 );
656 }
657
658 #[test]
659 fn test_get_info_fallback_hierarchy_mixed() {
660 let openapi_spec = json!({
661 "openapi": "3.0.0",
662 "info": {
663 "title": "OpenAPI Server",
664 "version": "2.5.0",
665 "description": "Server from OpenAPI spec"
666 },
667 "paths": {}
668 });
669
670 let mut server = Server::new(
671 openapi_spec,
672 url::Url::parse("http://example.com").unwrap(),
673 None,
674 None,
675 None,
676 false,
677 false,
678 );
679
680 server.name = Some("Custom Server".to_string());
682 server.instructions = Some("Custom instructions".to_string());
683
684 let result = server.get_info();
685
686 assert_eq!(result.server_info.name, "Custom Server");
688 assert_eq!(result.server_info.version, "2.5.0");
690 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
692 }
693}