1use bon::Builder;
2use rmcp::{
3 handler::server::ServerHandler,
4 model::{
5 CallToolRequestParams, CallToolResult, ErrorData, Implementation, InitializeResult,
6 ListToolsResult, PaginatedRequestParams, ProtocolVersion, ServerCapabilities,
7 ToolsCapability,
8 },
9 service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13use std::sync::Arc;
14
15use reqwest::header::HeaderMap;
16use url::Url;
17
18use crate::error::Error;
19use crate::filter::ToolFilter;
20use crate::tool::{Tool, ToolCollection, ToolMetadata};
21use crate::transformer::ResponseTransformer;
22use crate::{
23 config::{Authorization, AuthorizationMode},
24 spec::Filters,
25};
26use tracing::{debug, info, info_span, warn};
27
28#[derive(Clone, Builder)]
29pub struct Server {
30 pub openapi_spec: serde_json::Value,
31 #[builder(default)]
32 pub tool_collection: ToolCollection,
33 pub base_url: Url,
34 pub default_headers: Option<HeaderMap>,
35 pub filters: Option<Filters>,
36 #[builder(default)]
37 pub authorization_mode: AuthorizationMode,
38 pub name: Option<String>,
39 pub version: Option<String>,
40 pub title: Option<String>,
41 pub instructions: Option<String>,
42 #[builder(default)]
43 pub skip_tool_descriptions: bool,
44 #[builder(default)]
45 pub skip_parameter_descriptions: bool,
46 pub response_transformer: Option<Arc<dyn ResponseTransformer>>,
53 pub tool_filter: Option<Arc<dyn ToolFilter>>,
56}
57
58impl Server {
59 pub fn new(
61 openapi_spec: serde_json::Value,
62 base_url: Url,
63 default_headers: Option<HeaderMap>,
64 filters: Option<Filters>,
65 skip_tool_descriptions: bool,
66 skip_parameter_descriptions: bool,
67 ) -> Self {
68 Self {
69 openapi_spec,
70 tool_collection: ToolCollection::new(),
71 base_url,
72 default_headers,
73 filters,
74 authorization_mode: AuthorizationMode::default(),
75 name: None,
76 version: None,
77 title: None,
78 instructions: None,
79 skip_tool_descriptions,
80 skip_parameter_descriptions,
81 response_transformer: None,
82 tool_filter: None,
83 }
84 }
85
86 pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
92 let span = info_span!("tool_registration");
93 let _enter = span.enter();
94
95 let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
97
98 let tools = spec.to_openapi_tools(
100 self.filters.as_ref(),
101 Some(self.base_url.clone()),
102 self.default_headers.clone(),
103 self.skip_tool_descriptions,
104 self.skip_parameter_descriptions,
105 )?;
106
107 let tools = if let Some(ref transformer) = self.response_transformer {
109 tools
110 .into_iter()
111 .map(|mut tool| {
112 if let Some(schema) = tool.metadata.output_schema.take() {
113 tool.metadata.output_schema = Some(transformer.transform_schema(schema));
114 }
115 tool
116 })
117 .collect()
118 } else {
119 tools
120 };
121
122 self.tool_collection = ToolCollection::from_tools(tools);
123
124 info!(
125 tool_count = self.tool_collection.len(),
126 "Loaded tools from OpenAPI spec"
127 );
128
129 Ok(())
130 }
131
132 pub fn set_tool_transformer(
142 &mut self,
143 tool_name: &str,
144 transformer: Arc<dyn ResponseTransformer>,
145 ) -> Result<(), Error> {
146 self.tool_collection
147 .set_tool_transformer(tool_name, transformer)
148 }
149
150 pub fn set_tool_filter(&mut self, filter: Arc<dyn ToolFilter>) {
152 self.tool_filter = Some(filter);
153 }
154
155 #[must_use]
157 pub fn tool_count(&self) -> usize {
158 self.tool_collection.len()
159 }
160
161 #[must_use]
163 pub fn get_tool_names(&self) -> Vec<String> {
164 self.tool_collection.get_tool_names()
165 }
166
167 #[must_use]
169 pub fn has_tool(&self, name: &str) -> bool {
170 self.tool_collection.has_tool(name)
171 }
172
173 #[must_use]
175 pub fn get_tool(&self, name: &str) -> Option<&Tool> {
176 self.tool_collection.get_tool(name)
177 }
178
179 #[must_use]
181 pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
182 self.get_tool(name).map(|tool| &tool.metadata)
183 }
184
185 pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
187 self.authorization_mode = mode;
188 }
189
190 pub fn authorization_mode(&self) -> AuthorizationMode {
192 self.authorization_mode
193 }
194
195 #[must_use]
197 pub fn get_tool_stats(&self) -> String {
198 self.tool_collection.get_stats()
199 }
200
201 pub fn validate_registry(&self) -> Result<(), Error> {
207 if self.tool_collection.is_empty() {
208 return Err(Error::McpError("No tools loaded".to_string()));
209 }
210 Ok(())
211 }
212
213 fn extract_openapi_title(&self) -> Option<String> {
215 self.openapi_spec
216 .get("info")?
217 .get("title")?
218 .as_str()
219 .map(|s| s.to_string())
220 }
221
222 fn extract_openapi_version(&self) -> Option<String> {
224 self.openapi_spec
225 .get("info")?
226 .get("version")?
227 .as_str()
228 .map(|s| s.to_string())
229 }
230
231 fn extract_openapi_description(&self) -> Option<String> {
233 self.openapi_spec
234 .get("info")?
235 .get("description")?
236 .as_str()
237 .map(|s| s.to_string())
238 }
239
240 fn extract_openapi_display_title(&self) -> Option<String> {
243 if let Some(display_title) = self
245 .openapi_spec
246 .get("info")
247 .and_then(|info| info.get("x-display-title"))
248 .and_then(|t| t.as_str())
249 {
250 return Some(display_title.to_string());
251 }
252
253 self.extract_openapi_title().map(|title| {
255 if title.to_lowercase().contains("server") {
256 title
257 } else {
258 format!("{} Server", title)
259 }
260 })
261 }
262}
263
264impl ServerHandler for Server {
265 fn get_info(&self) -> InitializeResult {
266 let server_name = self
268 .name
269 .clone()
270 .or_else(|| self.extract_openapi_title())
271 .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
272
273 let server_version = self
275 .version
276 .clone()
277 .or_else(|| self.extract_openapi_version())
278 .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
279
280 let server_title = self
282 .title
283 .clone()
284 .or_else(|| self.extract_openapi_display_title());
285
286 let instructions = self
288 .instructions
289 .clone()
290 .or_else(|| self.extract_openapi_description())
291 .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
292
293 let mut server_info = Implementation::new(server_name, server_version);
294 server_info.title = server_title;
295 server_info.description = self.extract_openapi_description();
296
297 let mut capabilities = ServerCapabilities::default();
298 capabilities.tools = Some(ToolsCapability {
299 list_changed: Some(false),
300 });
301
302 let mut result = InitializeResult::new(capabilities)
303 .with_protocol_version(ProtocolVersion::V_2024_11_05)
304 .with_server_info(server_info);
305 result.instructions = instructions;
306 result
307 }
308
309 async fn list_tools(
310 &self,
311 _request: Option<PaginatedRequestParams>,
312 context: RequestContext<RoleServer>,
313 ) -> Result<ListToolsResult, ErrorData> {
314 let span = info_span!("list_tools", tool_count = self.tool_collection.len());
315 let _enter = span.enter();
316
317 debug!("Processing MCP list_tools request");
318
319 let mut tools = self.tool_collection.to_mcp_tools();
321
322 if let Some(filter) = &self.tool_filter {
324 let mut filtered = Vec::with_capacity(tools.len());
325 for mcp_tool in tools {
326 if let Some(tool) = self.tool_collection.get_tool(&mcp_tool.name)
327 && filter.allow(tool, &context).await
328 {
329 filtered.push(mcp_tool);
330 }
331 }
332 tools = filtered;
333 }
334
335 info!(
336 returned_tools = tools.len(),
337 "MCP list_tools request completed successfully"
338 );
339
340 Ok(ListToolsResult {
341 meta: None,
342 tools,
343 next_cursor: None,
344 })
345 }
346
347 async fn call_tool(
348 &self,
349 request: CallToolRequestParams,
350 context: RequestContext<RoleServer>,
351 ) -> Result<CallToolResult, ErrorData> {
352 use crate::error::{ToolCallError, ToolCallValidationError};
353
354 let span = info_span!(
355 "call_tool",
356 tool_name = %request.name
357 );
358 let _enter = span.enter();
359
360 debug!(
361 tool_name = %request.name,
362 has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
363 "Processing MCP call_tool request"
364 );
365
366 let allowed_tools: Vec<&Tool> = match &self.tool_filter {
368 None => self.tool_collection.iter().collect(),
369 Some(filter) => {
370 let mut allowed = Vec::new();
371 for tool in self.tool_collection.iter() {
372 if filter.allow(tool, &context).await {
373 allowed.push(tool);
374 }
375 }
376 allowed
377 }
378 };
379
380 let tool = allowed_tools
382 .iter()
383 .find(|t| t.metadata.name == request.name);
384
385 let tool = match tool {
386 Some(t) => *t,
387 None => {
388 let available_names: Vec<&str> = allowed_tools
389 .iter()
390 .map(|t| t.metadata.name.as_str())
391 .collect();
392
393 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
395 request.name.to_string(),
396 &available_names,
397 ));
398
399 warn!(
400 tool_name = %request.name,
401 success = false,
402 error = %error,
403 "MCP call_tool request failed - tool not found or filtered"
404 );
405
406 return Err(error.into());
407 }
408 };
409
410 let arguments = request.arguments.unwrap_or_default();
411 let arguments_value = Value::Object(arguments);
412
413 let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
415
416 if auth_header.is_some() {
417 debug!("Authorization header is present");
418 }
419
420 let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
422
423 let server_transformer = self
425 .response_transformer
426 .as_ref()
427 .map(|t| t.as_ref() as &dyn ResponseTransformer);
428
429 match tool
431 .call(&arguments_value, authorization, server_transformer)
432 .await
433 {
434 Ok(result) => {
435 info!(
436 tool_name = %request.name,
437 success = true,
438 "MCP call_tool request completed successfully"
439 );
440 Ok(result)
441 }
442 Err(e) => {
443 warn!(
444 tool_name = %request.name,
445 success = false,
446 error = %e,
447 "MCP call_tool request failed"
448 );
449 Err(e.into())
451 }
452 }
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use crate::error::ToolCallValidationError;
460 use crate::{HttpClient, ToolCallError, ToolMetadata};
461 use serde_json::json;
462
463 #[test]
464 fn test_tool_not_found_error_with_suggestions() {
465 let tool1_metadata = ToolMetadata {
467 name: "getPetById".to_string(),
468 title: Some("Get Pet by ID".to_string()),
469 description: Some("Find pet by ID".to_string()),
470 parameters: json!({
471 "type": "object",
472 "properties": {
473 "petId": {
474 "type": "integer"
475 }
476 },
477 "required": ["petId"]
478 }),
479 output_schema: None,
480 method: "GET".to_string(),
481 path: "/pet/{petId}".to_string(),
482 security: None,
483 parameter_mappings: std::collections::HashMap::new(),
484 };
485
486 let tool2_metadata = ToolMetadata {
487 name: "getPetsByStatus".to_string(),
488 title: Some("Find Pets by Status".to_string()),
489 description: Some("Find pets by status".to_string()),
490 parameters: json!({
491 "type": "object",
492 "properties": {
493 "status": {
494 "type": "array",
495 "items": {
496 "type": "string"
497 }
498 }
499 },
500 "required": ["status"]
501 }),
502 output_schema: None,
503 method: "GET".to_string(),
504 path: "/pet/findByStatus".to_string(),
505 security: None,
506 parameter_mappings: std::collections::HashMap::new(),
507 };
508
509 let http_client = HttpClient::new();
511 let tool1 = Tool::new(tool1_metadata, http_client.clone()).unwrap();
512 let tool2 = Tool::new(tool2_metadata, http_client.clone()).unwrap();
513
514 let mut server = Server::new(
516 serde_json::Value::Null,
517 url::Url::parse("http://example.com").unwrap(),
518 None,
519 None,
520 false,
521 false,
522 );
523 server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
524
525 let tool_names = server.get_tool_names();
527 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
528
529 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
530 "getPetByID".to_string(),
531 &tool_name_refs,
532 ));
533 let error_data: ErrorData = error.into();
534 let error_json = serde_json::to_value(&error_data).unwrap();
535
536 insta::assert_json_snapshot!(error_json);
538 }
539
540 #[test]
541 fn test_tool_not_found_error_no_suggestions() {
542 let tool_metadata = ToolMetadata {
544 name: "getPetById".to_string(),
545 title: Some("Get Pet by ID".to_string()),
546 description: Some("Find pet by ID".to_string()),
547 parameters: json!({
548 "type": "object",
549 "properties": {
550 "petId": {
551 "type": "integer"
552 }
553 },
554 "required": ["petId"]
555 }),
556 output_schema: None,
557 method: "GET".to_string(),
558 path: "/pet/{petId}".to_string(),
559 security: None,
560 parameter_mappings: std::collections::HashMap::new(),
561 };
562
563 let tool = Tool::new(tool_metadata, HttpClient::new()).unwrap();
565
566 let mut server = Server::new(
568 serde_json::Value::Null,
569 url::Url::parse("http://example.com").unwrap(),
570 None,
571 None,
572 false,
573 false,
574 );
575 server.tool_collection = ToolCollection::from_tools(vec![tool]);
576
577 let tool_names = server.get_tool_names();
579 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
580
581 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
582 "completelyUnrelatedToolName".to_string(),
583 &tool_name_refs,
584 ));
585 let error_data: ErrorData = error.into();
586 let error_json = serde_json::to_value(&error_data).unwrap();
587
588 insta::assert_json_snapshot!(error_json);
590 }
591
592 #[test]
593 fn test_validation_error_converted_to_error_data() {
594 let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
596 violations: vec![crate::error::ValidationError::invalid_parameter(
597 "page".to_string(),
598 &["page_number".to_string(), "page_size".to_string()],
599 )],
600 });
601
602 let error_data: ErrorData = error.into();
603 let error_json = serde_json::to_value(&error_data).unwrap();
604
605 assert_eq!(error_json["code"], -32602); insta::assert_json_snapshot!(error_json);
610 }
611
612 #[test]
613 fn test_extract_openapi_info_with_full_spec() {
614 let openapi_spec = json!({
615 "openapi": "3.0.0",
616 "info": {
617 "title": "Pet Store API",
618 "version": "2.1.0",
619 "description": "A sample API for managing pets"
620 },
621 "paths": {}
622 });
623
624 let server = Server::new(
625 openapi_spec,
626 url::Url::parse("http://example.com").unwrap(),
627 None,
628 None,
629 false,
630 false,
631 );
632
633 assert_eq!(
634 server.extract_openapi_title(),
635 Some("Pet Store API".to_string())
636 );
637 assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
638 assert_eq!(
639 server.extract_openapi_description(),
640 Some("A sample API for managing pets".to_string())
641 );
642 }
643
644 #[test]
645 fn test_extract_openapi_info_with_minimal_spec() {
646 let openapi_spec = json!({
647 "openapi": "3.0.0",
648 "info": {
649 "title": "My API",
650 "version": "1.0.0"
651 },
652 "paths": {}
653 });
654
655 let server = Server::new(
656 openapi_spec,
657 url::Url::parse("http://example.com").unwrap(),
658 None,
659 None,
660 false,
661 false,
662 );
663
664 assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
665 assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
666 assert_eq!(server.extract_openapi_description(), None);
667 }
668
669 #[test]
670 fn test_extract_openapi_info_with_invalid_spec() {
671 let openapi_spec = json!({
672 "invalid": "spec"
673 });
674
675 let server = Server::new(
676 openapi_spec,
677 url::Url::parse("http://example.com").unwrap(),
678 None,
679 None,
680 false,
681 false,
682 );
683
684 assert_eq!(server.extract_openapi_title(), None);
685 assert_eq!(server.extract_openapi_version(), None);
686 assert_eq!(server.extract_openapi_description(), None);
687 }
688
689 #[test]
690 fn test_get_info_fallback_hierarchy_custom_metadata() {
691 let server = Server::new(
692 serde_json::Value::Null,
693 url::Url::parse("http://example.com").unwrap(),
694 None,
695 None,
696 false,
697 false,
698 );
699
700 let mut server = server;
702 server.name = Some("Custom Server".to_string());
703 server.version = Some("3.0.0".to_string());
704 server.instructions = Some("Custom instructions".to_string());
705
706 let result = server.get_info();
707
708 assert_eq!(result.server_info.name, "Custom Server");
709 assert_eq!(result.server_info.version, "3.0.0");
710 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
711 }
712
713 #[test]
714 fn test_get_info_fallback_hierarchy_openapi_spec() {
715 let openapi_spec = json!({
716 "openapi": "3.0.0",
717 "info": {
718 "title": "OpenAPI Server",
719 "version": "1.5.0",
720 "description": "Server from OpenAPI spec"
721 },
722 "paths": {}
723 });
724
725 let server = Server::new(
726 openapi_spec,
727 url::Url::parse("http://example.com").unwrap(),
728 None,
729 None,
730 false,
731 false,
732 );
733
734 let result = server.get_info();
735
736 assert_eq!(result.server_info.name, "OpenAPI Server");
737 assert_eq!(result.server_info.version, "1.5.0");
738 assert_eq!(
739 result.instructions,
740 Some("Server from OpenAPI spec".to_string())
741 );
742 }
743
744 #[test]
745 fn test_get_info_fallback_hierarchy_defaults() {
746 let server = Server::new(
747 serde_json::Value::Null,
748 url::Url::parse("http://example.com").unwrap(),
749 None,
750 None,
751 false,
752 false,
753 );
754
755 let result = server.get_info();
756
757 assert_eq!(result.server_info.name, "OpenAPI MCP Server");
758 assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
759 assert_eq!(
760 result.instructions,
761 Some("Exposes OpenAPI endpoints as MCP tools".to_string())
762 );
763 }
764
765 #[test]
766 fn test_get_info_fallback_hierarchy_mixed() {
767 let openapi_spec = json!({
768 "openapi": "3.0.0",
769 "info": {
770 "title": "OpenAPI Server",
771 "version": "2.5.0",
772 "description": "Server from OpenAPI spec"
773 },
774 "paths": {}
775 });
776
777 let mut server = Server::new(
778 openapi_spec,
779 url::Url::parse("http://example.com").unwrap(),
780 None,
781 None,
782 false,
783 false,
784 );
785
786 server.name = Some("Custom Server".to_string());
788 server.instructions = Some("Custom instructions".to_string());
789
790 let result = server.get_info();
791
792 assert_eq!(result.server_info.name, "Custom Server");
794 assert_eq!(result.server_info.version, "2.5.0");
796 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
798 }
799}