1use bon::Builder;
2use rmcp::{
3 handler::server::ServerHandler,
4 model::{
5 CallToolRequestParams, CallToolResult, ErrorData, Implementation, InitializeResult,
6 ListToolsResult, PaginatedRequestParams, ProtocolVersion, ServerCapabilities,
7 ToolsCapability,
8 },
9 service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13use std::sync::Arc;
14
15use reqwest::header::HeaderMap;
16use url::Url;
17
18use crate::error::Error;
19use crate::filter::ToolFilter;
20use crate::tool::{Tool, ToolCollection, ToolMetadata};
21use crate::transformer::ResponseTransformer;
22use crate::{
23 config::{Authorization, AuthorizationMode},
24 spec::Filters,
25};
26use tracing::{debug, info, info_span, warn};
27
28#[derive(Clone, Builder)]
29pub struct Server {
30 pub openapi_spec: serde_json::Value,
31 #[builder(default)]
32 pub tool_collection: ToolCollection,
33 pub base_url: Url,
34 pub default_headers: Option<HeaderMap>,
35 pub filters: Option<Filters>,
36 #[builder(default)]
37 pub authorization_mode: AuthorizationMode,
38 pub name: Option<String>,
39 pub version: Option<String>,
40 pub title: Option<String>,
41 pub instructions: Option<String>,
42 #[builder(default)]
43 pub skip_tool_descriptions: bool,
44 #[builder(default)]
45 pub skip_parameter_descriptions: bool,
46 pub response_transformer: Option<Arc<dyn ResponseTransformer>>,
53 pub tool_filter: Option<Arc<dyn ToolFilter>>,
56}
57
58impl Server {
59 pub fn new(
61 openapi_spec: serde_json::Value,
62 base_url: Url,
63 default_headers: Option<HeaderMap>,
64 filters: Option<Filters>,
65 skip_tool_descriptions: bool,
66 skip_parameter_descriptions: bool,
67 ) -> Self {
68 Self {
69 openapi_spec,
70 tool_collection: ToolCollection::new(),
71 base_url,
72 default_headers,
73 filters,
74 authorization_mode: AuthorizationMode::default(),
75 name: None,
76 version: None,
77 title: None,
78 instructions: None,
79 skip_tool_descriptions,
80 skip_parameter_descriptions,
81 response_transformer: None,
82 tool_filter: None,
83 }
84 }
85
86 pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
92 let span = info_span!("tool_registration");
93 let _enter = span.enter();
94
95 let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
97
98 let tools = spec.to_openapi_tools(
100 self.filters.as_ref(),
101 Some(self.base_url.clone()),
102 self.default_headers.clone(),
103 self.skip_tool_descriptions,
104 self.skip_parameter_descriptions,
105 )?;
106
107 let tools = if let Some(ref transformer) = self.response_transformer {
109 tools
110 .into_iter()
111 .map(|mut tool| {
112 if let Some(schema) = tool.metadata.output_schema.take() {
113 tool.metadata.output_schema = Some(transformer.transform_schema(schema));
114 }
115 tool
116 })
117 .collect()
118 } else {
119 tools
120 };
121
122 self.tool_collection = ToolCollection::from_tools(tools);
123
124 info!(
125 tool_count = self.tool_collection.len(),
126 "Loaded tools from OpenAPI spec"
127 );
128
129 Ok(())
130 }
131
132 pub fn set_tool_transformer(
142 &mut self,
143 tool_name: &str,
144 transformer: Arc<dyn ResponseTransformer>,
145 ) -> Result<(), Error> {
146 self.tool_collection
147 .set_tool_transformer(tool_name, transformer)
148 }
149
150 pub fn set_tool_filter(&mut self, filter: Arc<dyn ToolFilter>) {
152 self.tool_filter = Some(filter);
153 }
154
155 #[must_use]
157 pub fn tool_count(&self) -> usize {
158 self.tool_collection.len()
159 }
160
161 #[must_use]
163 pub fn get_tool_names(&self) -> Vec<String> {
164 self.tool_collection.get_tool_names()
165 }
166
167 #[must_use]
169 pub fn has_tool(&self, name: &str) -> bool {
170 self.tool_collection.has_tool(name)
171 }
172
173 #[must_use]
175 pub fn get_tool(&self, name: &str) -> Option<&Tool> {
176 self.tool_collection.get_tool(name)
177 }
178
179 #[must_use]
181 pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
182 self.get_tool(name).map(|tool| &tool.metadata)
183 }
184
185 pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
187 self.authorization_mode = mode;
188 }
189
190 pub fn authorization_mode(&self) -> AuthorizationMode {
192 self.authorization_mode
193 }
194
195 #[must_use]
197 pub fn get_tool_stats(&self) -> String {
198 self.tool_collection.get_stats()
199 }
200
201 pub fn validate_registry(&self) -> Result<(), Error> {
207 if self.tool_collection.is_empty() {
208 return Err(Error::McpError("No tools loaded".to_string()));
209 }
210 Ok(())
211 }
212
213 fn extract_openapi_title(&self) -> Option<String> {
215 self.openapi_spec
216 .get("info")?
217 .get("title")?
218 .as_str()
219 .map(|s| s.to_string())
220 }
221
222 fn extract_openapi_version(&self) -> Option<String> {
224 self.openapi_spec
225 .get("info")?
226 .get("version")?
227 .as_str()
228 .map(|s| s.to_string())
229 }
230
231 fn extract_openapi_description(&self) -> Option<String> {
233 self.openapi_spec
234 .get("info")?
235 .get("description")?
236 .as_str()
237 .map(|s| s.to_string())
238 }
239
240 fn extract_openapi_display_title(&self) -> Option<String> {
243 if let Some(display_title) = self
245 .openapi_spec
246 .get("info")
247 .and_then(|info| info.get("x-display-title"))
248 .and_then(|t| t.as_str())
249 {
250 return Some(display_title.to_string());
251 }
252
253 self.extract_openapi_title().map(|title| {
255 if title.to_lowercase().contains("server") {
256 title
257 } else {
258 format!("{} Server", title)
259 }
260 })
261 }
262}
263
264impl ServerHandler for Server {
265 fn get_info(&self) -> InitializeResult {
266 let server_name = self
268 .name
269 .clone()
270 .or_else(|| self.extract_openapi_title())
271 .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
272
273 let server_version = self
275 .version
276 .clone()
277 .or_else(|| self.extract_openapi_version())
278 .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
279
280 let server_title = self
282 .title
283 .clone()
284 .or_else(|| self.extract_openapi_display_title());
285
286 let instructions = self
288 .instructions
289 .clone()
290 .or_else(|| self.extract_openapi_description())
291 .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
292
293 InitializeResult {
294 protocol_version: ProtocolVersion::V_2024_11_05,
295 server_info: Implementation {
296 name: server_name,
297 version: server_version,
298 title: server_title,
299 description: self.extract_openapi_description(),
300 icons: None,
301 website_url: None,
302 },
303 capabilities: ServerCapabilities {
304 tools: Some(ToolsCapability {
305 list_changed: Some(false),
306 }),
307 ..Default::default()
308 },
309 instructions,
310 }
311 }
312
313 async fn list_tools(
314 &self,
315 _request: Option<PaginatedRequestParams>,
316 context: RequestContext<RoleServer>,
317 ) -> Result<ListToolsResult, ErrorData> {
318 let span = info_span!("list_tools", tool_count = self.tool_collection.len());
319 let _enter = span.enter();
320
321 debug!("Processing MCP list_tools request");
322
323 let mut tools = self.tool_collection.to_mcp_tools();
325
326 if let Some(filter) = &self.tool_filter {
328 let mut filtered = Vec::with_capacity(tools.len());
329 for mcp_tool in tools {
330 if let Some(tool) = self.tool_collection.get_tool(&mcp_tool.name)
331 && filter.allow(tool, &context).await
332 {
333 filtered.push(mcp_tool);
334 }
335 }
336 tools = filtered;
337 }
338
339 info!(
340 returned_tools = tools.len(),
341 "MCP list_tools request completed successfully"
342 );
343
344 Ok(ListToolsResult {
345 meta: None,
346 tools,
347 next_cursor: None,
348 })
349 }
350
351 async fn call_tool(
352 &self,
353 request: CallToolRequestParams,
354 context: RequestContext<RoleServer>,
355 ) -> Result<CallToolResult, ErrorData> {
356 use crate::error::{ToolCallError, ToolCallValidationError};
357
358 let span = info_span!(
359 "call_tool",
360 tool_name = %request.name
361 );
362 let _enter = span.enter();
363
364 debug!(
365 tool_name = %request.name,
366 has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
367 "Processing MCP call_tool request"
368 );
369
370 let allowed_tools: Vec<&Tool> = match &self.tool_filter {
372 None => self.tool_collection.iter().collect(),
373 Some(filter) => {
374 let mut allowed = Vec::new();
375 for tool in self.tool_collection.iter() {
376 if filter.allow(tool, &context).await {
377 allowed.push(tool);
378 }
379 }
380 allowed
381 }
382 };
383
384 let tool = allowed_tools
386 .iter()
387 .find(|t| t.metadata.name == request.name);
388
389 let tool = match tool {
390 Some(t) => *t,
391 None => {
392 let available_names: Vec<&str> = allowed_tools
393 .iter()
394 .map(|t| t.metadata.name.as_str())
395 .collect();
396
397 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
399 request.name.to_string(),
400 &available_names,
401 ));
402
403 warn!(
404 tool_name = %request.name,
405 success = false,
406 error = %error,
407 "MCP call_tool request failed - tool not found or filtered"
408 );
409
410 return Err(error.into());
411 }
412 };
413
414 let arguments = request.arguments.unwrap_or_default();
415 let arguments_value = Value::Object(arguments);
416
417 let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
419
420 if auth_header.is_some() {
421 debug!("Authorization header is present");
422 }
423
424 let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
426
427 let server_transformer = self
429 .response_transformer
430 .as_ref()
431 .map(|t| t.as_ref() as &dyn ResponseTransformer);
432
433 match tool
435 .call(&arguments_value, authorization, server_transformer)
436 .await
437 {
438 Ok(result) => {
439 info!(
440 tool_name = %request.name,
441 success = true,
442 "MCP call_tool request completed successfully"
443 );
444 Ok(result)
445 }
446 Err(e) => {
447 warn!(
448 tool_name = %request.name,
449 success = false,
450 error = %e,
451 "MCP call_tool request failed"
452 );
453 Err(e.into())
455 }
456 }
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463 use crate::error::ToolCallValidationError;
464 use crate::{HttpClient, ToolCallError, ToolMetadata};
465 use serde_json::json;
466
467 #[test]
468 fn test_tool_not_found_error_with_suggestions() {
469 let tool1_metadata = ToolMetadata {
471 name: "getPetById".to_string(),
472 title: Some("Get Pet by ID".to_string()),
473 description: Some("Find pet by ID".to_string()),
474 parameters: json!({
475 "type": "object",
476 "properties": {
477 "petId": {
478 "type": "integer"
479 }
480 },
481 "required": ["petId"]
482 }),
483 output_schema: None,
484 method: "GET".to_string(),
485 path: "/pet/{petId}".to_string(),
486 security: None,
487 parameter_mappings: std::collections::HashMap::new(),
488 };
489
490 let tool2_metadata = ToolMetadata {
491 name: "getPetsByStatus".to_string(),
492 title: Some("Find Pets by Status".to_string()),
493 description: Some("Find pets by status".to_string()),
494 parameters: json!({
495 "type": "object",
496 "properties": {
497 "status": {
498 "type": "array",
499 "items": {
500 "type": "string"
501 }
502 }
503 },
504 "required": ["status"]
505 }),
506 output_schema: None,
507 method: "GET".to_string(),
508 path: "/pet/findByStatus".to_string(),
509 security: None,
510 parameter_mappings: std::collections::HashMap::new(),
511 };
512
513 let http_client = HttpClient::new();
515 let tool1 = Tool::new(tool1_metadata, http_client.clone()).unwrap();
516 let tool2 = Tool::new(tool2_metadata, http_client.clone()).unwrap();
517
518 let mut server = Server::new(
520 serde_json::Value::Null,
521 url::Url::parse("http://example.com").unwrap(),
522 None,
523 None,
524 false,
525 false,
526 );
527 server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
528
529 let tool_names = server.get_tool_names();
531 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
532
533 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
534 "getPetByID".to_string(),
535 &tool_name_refs,
536 ));
537 let error_data: ErrorData = error.into();
538 let error_json = serde_json::to_value(&error_data).unwrap();
539
540 insta::assert_json_snapshot!(error_json);
542 }
543
544 #[test]
545 fn test_tool_not_found_error_no_suggestions() {
546 let tool_metadata = ToolMetadata {
548 name: "getPetById".to_string(),
549 title: Some("Get Pet by ID".to_string()),
550 description: Some("Find pet by ID".to_string()),
551 parameters: json!({
552 "type": "object",
553 "properties": {
554 "petId": {
555 "type": "integer"
556 }
557 },
558 "required": ["petId"]
559 }),
560 output_schema: None,
561 method: "GET".to_string(),
562 path: "/pet/{petId}".to_string(),
563 security: None,
564 parameter_mappings: std::collections::HashMap::new(),
565 };
566
567 let tool = Tool::new(tool_metadata, HttpClient::new()).unwrap();
569
570 let mut server = Server::new(
572 serde_json::Value::Null,
573 url::Url::parse("http://example.com").unwrap(),
574 None,
575 None,
576 false,
577 false,
578 );
579 server.tool_collection = ToolCollection::from_tools(vec![tool]);
580
581 let tool_names = server.get_tool_names();
583 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
584
585 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
586 "completelyUnrelatedToolName".to_string(),
587 &tool_name_refs,
588 ));
589 let error_data: ErrorData = error.into();
590 let error_json = serde_json::to_value(&error_data).unwrap();
591
592 insta::assert_json_snapshot!(error_json);
594 }
595
596 #[test]
597 fn test_validation_error_converted_to_error_data() {
598 let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
600 violations: vec![crate::error::ValidationError::invalid_parameter(
601 "page".to_string(),
602 &["page_number".to_string(), "page_size".to_string()],
603 )],
604 });
605
606 let error_data: ErrorData = error.into();
607 let error_json = serde_json::to_value(&error_data).unwrap();
608
609 assert_eq!(error_json["code"], -32602); insta::assert_json_snapshot!(error_json);
614 }
615
616 #[test]
617 fn test_extract_openapi_info_with_full_spec() {
618 let openapi_spec = json!({
619 "openapi": "3.0.0",
620 "info": {
621 "title": "Pet Store API",
622 "version": "2.1.0",
623 "description": "A sample API for managing pets"
624 },
625 "paths": {}
626 });
627
628 let server = Server::new(
629 openapi_spec,
630 url::Url::parse("http://example.com").unwrap(),
631 None,
632 None,
633 false,
634 false,
635 );
636
637 assert_eq!(
638 server.extract_openapi_title(),
639 Some("Pet Store API".to_string())
640 );
641 assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
642 assert_eq!(
643 server.extract_openapi_description(),
644 Some("A sample API for managing pets".to_string())
645 );
646 }
647
648 #[test]
649 fn test_extract_openapi_info_with_minimal_spec() {
650 let openapi_spec = json!({
651 "openapi": "3.0.0",
652 "info": {
653 "title": "My API",
654 "version": "1.0.0"
655 },
656 "paths": {}
657 });
658
659 let server = Server::new(
660 openapi_spec,
661 url::Url::parse("http://example.com").unwrap(),
662 None,
663 None,
664 false,
665 false,
666 );
667
668 assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
669 assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
670 assert_eq!(server.extract_openapi_description(), None);
671 }
672
673 #[test]
674 fn test_extract_openapi_info_with_invalid_spec() {
675 let openapi_spec = json!({
676 "invalid": "spec"
677 });
678
679 let server = Server::new(
680 openapi_spec,
681 url::Url::parse("http://example.com").unwrap(),
682 None,
683 None,
684 false,
685 false,
686 );
687
688 assert_eq!(server.extract_openapi_title(), None);
689 assert_eq!(server.extract_openapi_version(), None);
690 assert_eq!(server.extract_openapi_description(), None);
691 }
692
693 #[test]
694 fn test_get_info_fallback_hierarchy_custom_metadata() {
695 let server = Server::new(
696 serde_json::Value::Null,
697 url::Url::parse("http://example.com").unwrap(),
698 None,
699 None,
700 false,
701 false,
702 );
703
704 let mut server = server;
706 server.name = Some("Custom Server".to_string());
707 server.version = Some("3.0.0".to_string());
708 server.instructions = Some("Custom instructions".to_string());
709
710 let result = server.get_info();
711
712 assert_eq!(result.server_info.name, "Custom Server");
713 assert_eq!(result.server_info.version, "3.0.0");
714 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
715 }
716
717 #[test]
718 fn test_get_info_fallback_hierarchy_openapi_spec() {
719 let openapi_spec = json!({
720 "openapi": "3.0.0",
721 "info": {
722 "title": "OpenAPI Server",
723 "version": "1.5.0",
724 "description": "Server from OpenAPI spec"
725 },
726 "paths": {}
727 });
728
729 let server = Server::new(
730 openapi_spec,
731 url::Url::parse("http://example.com").unwrap(),
732 None,
733 None,
734 false,
735 false,
736 );
737
738 let result = server.get_info();
739
740 assert_eq!(result.server_info.name, "OpenAPI Server");
741 assert_eq!(result.server_info.version, "1.5.0");
742 assert_eq!(
743 result.instructions,
744 Some("Server from OpenAPI spec".to_string())
745 );
746 }
747
748 #[test]
749 fn test_get_info_fallback_hierarchy_defaults() {
750 let server = Server::new(
751 serde_json::Value::Null,
752 url::Url::parse("http://example.com").unwrap(),
753 None,
754 None,
755 false,
756 false,
757 );
758
759 let result = server.get_info();
760
761 assert_eq!(result.server_info.name, "OpenAPI MCP Server");
762 assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
763 assert_eq!(
764 result.instructions,
765 Some("Exposes OpenAPI endpoints as MCP tools".to_string())
766 );
767 }
768
769 #[test]
770 fn test_get_info_fallback_hierarchy_mixed() {
771 let openapi_spec = json!({
772 "openapi": "3.0.0",
773 "info": {
774 "title": "OpenAPI Server",
775 "version": "2.5.0",
776 "description": "Server from OpenAPI spec"
777 },
778 "paths": {}
779 });
780
781 let mut server = Server::new(
782 openapi_spec,
783 url::Url::parse("http://example.com").unwrap(),
784 None,
785 None,
786 false,
787 false,
788 );
789
790 server.name = Some("Custom Server".to_string());
792 server.instructions = Some("Custom instructions".to_string());
793
794 let result = server.get_info();
795
796 assert_eq!(result.server_info.name, "Custom Server");
798 assert_eq!(result.server_info.version, "2.5.0");
800 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
802 }
803}