1use bon::Builder;
2use rmcp::{
3 handler::server::ServerHandler,
4 model::{
5 CallToolRequestParams, CallToolResult, ErrorData, Implementation, InitializeResult,
6 ListToolsResult, PaginatedRequestParams, ProtocolVersion, ServerCapabilities,
7 ToolsCapability,
8 },
9 service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13use std::sync::Arc;
14
15use reqwest::header::HeaderMap;
16use url::Url;
17
18use crate::error::Error;
19use crate::filter::ToolFilter;
20use crate::tool::{Tool, ToolCollection, ToolMetadata};
21use crate::transformer::ResponseTransformer;
22use crate::{
23 config::{Authorization, AuthorizationMode},
24 spec::Filters,
25};
26use tracing::{debug, info, info_span, warn};
27
28#[derive(Clone, Builder)]
29pub struct Server {
30 pub openapi_spec: serde_json::Value,
31 #[builder(default)]
32 pub tool_collection: ToolCollection,
33 pub base_url: Url,
34 pub default_headers: Option<HeaderMap>,
35 pub filters: Option<Filters>,
36 #[builder(default)]
37 pub authorization_mode: AuthorizationMode,
38 pub name: Option<String>,
39 pub version: Option<String>,
40 pub title: Option<String>,
41 pub instructions: Option<String>,
42 #[builder(default)]
43 pub skip_tool_descriptions: bool,
44 #[builder(default)]
45 pub skip_parameter_descriptions: bool,
46 pub response_transformer: Option<Arc<dyn ResponseTransformer>>,
53 pub tool_filter: Option<Arc<dyn ToolFilter>>,
56}
57
58impl Server {
59 pub fn new(
61 openapi_spec: serde_json::Value,
62 base_url: Url,
63 default_headers: Option<HeaderMap>,
64 filters: Option<Filters>,
65 skip_tool_descriptions: bool,
66 skip_parameter_descriptions: bool,
67 ) -> Self {
68 Self {
69 openapi_spec,
70 tool_collection: ToolCollection::new(),
71 base_url,
72 default_headers,
73 filters,
74 authorization_mode: AuthorizationMode::default(),
75 name: None,
76 version: None,
77 title: None,
78 instructions: None,
79 skip_tool_descriptions,
80 skip_parameter_descriptions,
81 response_transformer: None,
82 tool_filter: None,
83 }
84 }
85
86 pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
92 let span = info_span!("tool_registration");
93 let _enter = span.enter();
94
95 let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
97
98 let tools = spec.to_openapi_tools(
100 self.filters.as_ref(),
101 Some(self.base_url.clone()),
102 self.default_headers.clone(),
103 self.skip_tool_descriptions,
104 self.skip_parameter_descriptions,
105 )?;
106
107 let tools = if let Some(ref transformer) = self.response_transformer {
109 tools
110 .into_iter()
111 .map(|mut tool| {
112 if let Some(schema) = tool.metadata.output_schema.take() {
113 tool.metadata.output_schema = Some(transformer.transform_schema(schema));
114 }
115 tool
116 })
117 .collect()
118 } else {
119 tools
120 };
121
122 self.tool_collection = ToolCollection::from_tools(tools);
123
124 info!(
125 tool_count = self.tool_collection.len(),
126 "Loaded tools from OpenAPI spec"
127 );
128
129 Ok(())
130 }
131
132 pub fn set_tool_transformer(
142 &mut self,
143 tool_name: &str,
144 transformer: Arc<dyn ResponseTransformer>,
145 ) -> Result<(), Error> {
146 self.tool_collection
147 .set_tool_transformer(tool_name, transformer)
148 }
149
150 pub fn set_tool_filter(&mut self, filter: Arc<dyn ToolFilter>) {
152 self.tool_filter = Some(filter);
153 }
154
155 #[must_use]
157 pub fn tool_count(&self) -> usize {
158 self.tool_collection.len()
159 }
160
161 #[must_use]
163 pub fn get_tool_names(&self) -> Vec<String> {
164 self.tool_collection.get_tool_names()
165 }
166
167 #[must_use]
169 pub fn has_tool(&self, name: &str) -> bool {
170 self.tool_collection.has_tool(name)
171 }
172
173 #[must_use]
175 pub fn get_tool(&self, name: &str) -> Option<&Tool> {
176 self.tool_collection.get_tool(name)
177 }
178
179 #[must_use]
181 pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
182 self.get_tool(name).map(|tool| &tool.metadata)
183 }
184
185 pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
187 self.authorization_mode = mode;
188 }
189
190 pub fn authorization_mode(&self) -> AuthorizationMode {
192 self.authorization_mode
193 }
194
195 #[must_use]
197 pub fn get_tool_stats(&self) -> String {
198 self.tool_collection.get_stats()
199 }
200
201 pub fn validate_registry(&self) -> Result<(), Error> {
207 if self.tool_collection.is_empty() {
208 return Err(Error::McpError("No tools loaded".to_string()));
209 }
210 Ok(())
211 }
212
213 fn extract_openapi_title(&self) -> Option<String> {
215 self.openapi_spec
216 .get("info")?
217 .get("title")?
218 .as_str()
219 .map(|s| s.to_string())
220 }
221
222 fn extract_openapi_version(&self) -> Option<String> {
224 self.openapi_spec
225 .get("info")?
226 .get("version")?
227 .as_str()
228 .map(|s| s.to_string())
229 }
230
231 fn extract_openapi_description(&self) -> Option<String> {
233 self.openapi_spec
234 .get("info")?
235 .get("description")?
236 .as_str()
237 .map(|s| s.to_string())
238 }
239
240 fn extract_openapi_display_title(&self) -> Option<String> {
243 if let Some(display_title) = self
245 .openapi_spec
246 .get("info")
247 .and_then(|info| info.get("x-display-title"))
248 .and_then(|t| t.as_str())
249 {
250 return Some(display_title.to_string());
251 }
252
253 self.extract_openapi_title().map(|title| {
255 if title.to_lowercase().contains("server") {
256 title
257 } else {
258 format!("{} Server", title)
259 }
260 })
261 }
262}
263
264impl ServerHandler for Server {
265 fn get_info(&self) -> InitializeResult {
266 let server_name = self
268 .name
269 .clone()
270 .or_else(|| self.extract_openapi_title())
271 .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
272
273 let server_version = self
275 .version
276 .clone()
277 .or_else(|| self.extract_openapi_version())
278 .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
279
280 let server_title = self
282 .title
283 .clone()
284 .or_else(|| self.extract_openapi_display_title());
285
286 let instructions = self
288 .instructions
289 .clone()
290 .or_else(|| self.extract_openapi_description())
291 .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
292
293 InitializeResult {
294 protocol_version: ProtocolVersion::V_2024_11_05,
295 server_info: Implementation {
296 name: server_name,
297 version: server_version,
298 title: server_title,
299 icons: None,
300 website_url: None,
301 },
302 capabilities: ServerCapabilities {
303 tools: Some(ToolsCapability {
304 list_changed: Some(false),
305 }),
306 ..Default::default()
307 },
308 instructions,
309 }
310 }
311
312 async fn list_tools(
313 &self,
314 _request: Option<PaginatedRequestParams>,
315 context: RequestContext<RoleServer>,
316 ) -> Result<ListToolsResult, ErrorData> {
317 let span = info_span!("list_tools", tool_count = self.tool_collection.len());
318 let _enter = span.enter();
319
320 debug!("Processing MCP list_tools request");
321
322 let mut tools = self.tool_collection.to_mcp_tools();
324
325 if let Some(filter) = &self.tool_filter {
327 let mut filtered = Vec::with_capacity(tools.len());
328 for mcp_tool in tools {
329 if let Some(tool) = self.tool_collection.get_tool(&mcp_tool.name)
330 && filter.allow(tool, &context).await
331 {
332 filtered.push(mcp_tool);
333 }
334 }
335 tools = filtered;
336 }
337
338 info!(
339 returned_tools = tools.len(),
340 "MCP list_tools request completed successfully"
341 );
342
343 Ok(ListToolsResult {
344 meta: None,
345 tools,
346 next_cursor: None,
347 })
348 }
349
350 async fn call_tool(
351 &self,
352 request: CallToolRequestParams,
353 context: RequestContext<RoleServer>,
354 ) -> Result<CallToolResult, ErrorData> {
355 use crate::error::{ToolCallError, ToolCallValidationError};
356
357 let span = info_span!(
358 "call_tool",
359 tool_name = %request.name
360 );
361 let _enter = span.enter();
362
363 debug!(
364 tool_name = %request.name,
365 has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
366 "Processing MCP call_tool request"
367 );
368
369 let allowed_tools: Vec<&Tool> = match &self.tool_filter {
371 None => self.tool_collection.iter().collect(),
372 Some(filter) => {
373 let mut allowed = Vec::new();
374 for tool in self.tool_collection.iter() {
375 if filter.allow(tool, &context).await {
376 allowed.push(tool);
377 }
378 }
379 allowed
380 }
381 };
382
383 let tool = allowed_tools
385 .iter()
386 .find(|t| t.metadata.name == request.name);
387
388 let tool = match tool {
389 Some(t) => *t,
390 None => {
391 let available_names: Vec<&str> = allowed_tools
392 .iter()
393 .map(|t| t.metadata.name.as_str())
394 .collect();
395
396 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
398 request.name.to_string(),
399 &available_names,
400 ));
401
402 warn!(
403 tool_name = %request.name,
404 success = false,
405 error = %error,
406 "MCP call_tool request failed - tool not found or filtered"
407 );
408
409 return Err(error.into());
410 }
411 };
412
413 let arguments = request.arguments.unwrap_or_default();
414 let arguments_value = Value::Object(arguments);
415
416 let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
418
419 if auth_header.is_some() {
420 debug!("Authorization header is present");
421 }
422
423 let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
425
426 let server_transformer = self
428 .response_transformer
429 .as_ref()
430 .map(|t| t.as_ref() as &dyn ResponseTransformer);
431
432 match tool
434 .call(&arguments_value, authorization, server_transformer)
435 .await
436 {
437 Ok(result) => {
438 info!(
439 tool_name = %request.name,
440 success = true,
441 "MCP call_tool request completed successfully"
442 );
443 Ok(result)
444 }
445 Err(e) => {
446 warn!(
447 tool_name = %request.name,
448 success = false,
449 error = %e,
450 "MCP call_tool request failed"
451 );
452 Err(e.into())
454 }
455 }
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use crate::error::ToolCallValidationError;
463 use crate::{HttpClient, ToolCallError, ToolMetadata};
464 use serde_json::json;
465
466 #[test]
467 fn test_tool_not_found_error_with_suggestions() {
468 let tool1_metadata = ToolMetadata {
470 name: "getPetById".to_string(),
471 title: Some("Get Pet by ID".to_string()),
472 description: Some("Find pet by ID".to_string()),
473 parameters: json!({
474 "type": "object",
475 "properties": {
476 "petId": {
477 "type": "integer"
478 }
479 },
480 "required": ["petId"]
481 }),
482 output_schema: None,
483 method: "GET".to_string(),
484 path: "/pet/{petId}".to_string(),
485 security: None,
486 parameter_mappings: std::collections::HashMap::new(),
487 };
488
489 let tool2_metadata = ToolMetadata {
490 name: "getPetsByStatus".to_string(),
491 title: Some("Find Pets by Status".to_string()),
492 description: Some("Find pets by status".to_string()),
493 parameters: json!({
494 "type": "object",
495 "properties": {
496 "status": {
497 "type": "array",
498 "items": {
499 "type": "string"
500 }
501 }
502 },
503 "required": ["status"]
504 }),
505 output_schema: None,
506 method: "GET".to_string(),
507 path: "/pet/findByStatus".to_string(),
508 security: None,
509 parameter_mappings: std::collections::HashMap::new(),
510 };
511
512 let http_client = HttpClient::new();
514 let tool1 = Tool::new(tool1_metadata, http_client.clone()).unwrap();
515 let tool2 = Tool::new(tool2_metadata, http_client.clone()).unwrap();
516
517 let mut server = Server::new(
519 serde_json::Value::Null,
520 url::Url::parse("http://example.com").unwrap(),
521 None,
522 None,
523 false,
524 false,
525 );
526 server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
527
528 let tool_names = server.get_tool_names();
530 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
531
532 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
533 "getPetByID".to_string(),
534 &tool_name_refs,
535 ));
536 let error_data: ErrorData = error.into();
537 let error_json = serde_json::to_value(&error_data).unwrap();
538
539 insta::assert_json_snapshot!(error_json);
541 }
542
543 #[test]
544 fn test_tool_not_found_error_no_suggestions() {
545 let tool_metadata = ToolMetadata {
547 name: "getPetById".to_string(),
548 title: Some("Get Pet by ID".to_string()),
549 description: Some("Find pet by ID".to_string()),
550 parameters: json!({
551 "type": "object",
552 "properties": {
553 "petId": {
554 "type": "integer"
555 }
556 },
557 "required": ["petId"]
558 }),
559 output_schema: None,
560 method: "GET".to_string(),
561 path: "/pet/{petId}".to_string(),
562 security: None,
563 parameter_mappings: std::collections::HashMap::new(),
564 };
565
566 let tool = Tool::new(tool_metadata, HttpClient::new()).unwrap();
568
569 let mut server = Server::new(
571 serde_json::Value::Null,
572 url::Url::parse("http://example.com").unwrap(),
573 None,
574 None,
575 false,
576 false,
577 );
578 server.tool_collection = ToolCollection::from_tools(vec![tool]);
579
580 let tool_names = server.get_tool_names();
582 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
583
584 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
585 "completelyUnrelatedToolName".to_string(),
586 &tool_name_refs,
587 ));
588 let error_data: ErrorData = error.into();
589 let error_json = serde_json::to_value(&error_data).unwrap();
590
591 insta::assert_json_snapshot!(error_json);
593 }
594
595 #[test]
596 fn test_validation_error_converted_to_error_data() {
597 let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
599 violations: vec![crate::error::ValidationError::invalid_parameter(
600 "page".to_string(),
601 &["page_number".to_string(), "page_size".to_string()],
602 )],
603 });
604
605 let error_data: ErrorData = error.into();
606 let error_json = serde_json::to_value(&error_data).unwrap();
607
608 assert_eq!(error_json["code"], -32602); insta::assert_json_snapshot!(error_json);
613 }
614
615 #[test]
616 fn test_extract_openapi_info_with_full_spec() {
617 let openapi_spec = json!({
618 "openapi": "3.0.0",
619 "info": {
620 "title": "Pet Store API",
621 "version": "2.1.0",
622 "description": "A sample API for managing pets"
623 },
624 "paths": {}
625 });
626
627 let server = Server::new(
628 openapi_spec,
629 url::Url::parse("http://example.com").unwrap(),
630 None,
631 None,
632 false,
633 false,
634 );
635
636 assert_eq!(
637 server.extract_openapi_title(),
638 Some("Pet Store API".to_string())
639 );
640 assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
641 assert_eq!(
642 server.extract_openapi_description(),
643 Some("A sample API for managing pets".to_string())
644 );
645 }
646
647 #[test]
648 fn test_extract_openapi_info_with_minimal_spec() {
649 let openapi_spec = json!({
650 "openapi": "3.0.0",
651 "info": {
652 "title": "My API",
653 "version": "1.0.0"
654 },
655 "paths": {}
656 });
657
658 let server = Server::new(
659 openapi_spec,
660 url::Url::parse("http://example.com").unwrap(),
661 None,
662 None,
663 false,
664 false,
665 );
666
667 assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
668 assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
669 assert_eq!(server.extract_openapi_description(), None);
670 }
671
672 #[test]
673 fn test_extract_openapi_info_with_invalid_spec() {
674 let openapi_spec = json!({
675 "invalid": "spec"
676 });
677
678 let server = Server::new(
679 openapi_spec,
680 url::Url::parse("http://example.com").unwrap(),
681 None,
682 None,
683 false,
684 false,
685 );
686
687 assert_eq!(server.extract_openapi_title(), None);
688 assert_eq!(server.extract_openapi_version(), None);
689 assert_eq!(server.extract_openapi_description(), None);
690 }
691
692 #[test]
693 fn test_get_info_fallback_hierarchy_custom_metadata() {
694 let server = Server::new(
695 serde_json::Value::Null,
696 url::Url::parse("http://example.com").unwrap(),
697 None,
698 None,
699 false,
700 false,
701 );
702
703 let mut server = server;
705 server.name = Some("Custom Server".to_string());
706 server.version = Some("3.0.0".to_string());
707 server.instructions = Some("Custom instructions".to_string());
708
709 let result = server.get_info();
710
711 assert_eq!(result.server_info.name, "Custom Server");
712 assert_eq!(result.server_info.version, "3.0.0");
713 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
714 }
715
716 #[test]
717 fn test_get_info_fallback_hierarchy_openapi_spec() {
718 let openapi_spec = json!({
719 "openapi": "3.0.0",
720 "info": {
721 "title": "OpenAPI Server",
722 "version": "1.5.0",
723 "description": "Server from OpenAPI spec"
724 },
725 "paths": {}
726 });
727
728 let server = Server::new(
729 openapi_spec,
730 url::Url::parse("http://example.com").unwrap(),
731 None,
732 None,
733 false,
734 false,
735 );
736
737 let result = server.get_info();
738
739 assert_eq!(result.server_info.name, "OpenAPI Server");
740 assert_eq!(result.server_info.version, "1.5.0");
741 assert_eq!(
742 result.instructions,
743 Some("Server from OpenAPI spec".to_string())
744 );
745 }
746
747 #[test]
748 fn test_get_info_fallback_hierarchy_defaults() {
749 let server = Server::new(
750 serde_json::Value::Null,
751 url::Url::parse("http://example.com").unwrap(),
752 None,
753 None,
754 false,
755 false,
756 );
757
758 let result = server.get_info();
759
760 assert_eq!(result.server_info.name, "OpenAPI MCP Server");
761 assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
762 assert_eq!(
763 result.instructions,
764 Some("Exposes OpenAPI endpoints as MCP tools".to_string())
765 );
766 }
767
768 #[test]
769 fn test_get_info_fallback_hierarchy_mixed() {
770 let openapi_spec = json!({
771 "openapi": "3.0.0",
772 "info": {
773 "title": "OpenAPI Server",
774 "version": "2.5.0",
775 "description": "Server from OpenAPI spec"
776 },
777 "paths": {}
778 });
779
780 let mut server = Server::new(
781 openapi_spec,
782 url::Url::parse("http://example.com").unwrap(),
783 None,
784 None,
785 false,
786 false,
787 );
788
789 server.name = Some("Custom Server".to_string());
791 server.instructions = Some("Custom instructions".to_string());
792
793 let result = server.get_info();
794
795 assert_eq!(result.server_info.name, "Custom Server");
797 assert_eq!(result.server_info.version, "2.5.0");
799 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
801 }
802}