1use bon::Builder;
2use rmcp::{
3 handler::server::ServerHandler,
4 model::{
5 CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6 ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7 ToolsCapability,
8 },
9 service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13use std::sync::Arc;
14
15use reqwest::header::HeaderMap;
16use url::Url;
17
18use crate::error::Error;
19use crate::tool::{Tool, ToolCollection, ToolMetadata};
20use crate::transformer::ResponseTransformer;
21use crate::{
22 config::{Authorization, AuthorizationMode},
23 spec::Filters,
24};
25use tracing::{debug, info, info_span, warn};
26
27#[derive(Clone, Builder)]
28pub struct Server {
29 pub openapi_spec: serde_json::Value,
30 #[builder(default)]
31 pub tool_collection: ToolCollection,
32 pub base_url: Url,
33 pub default_headers: Option<HeaderMap>,
34 pub filters: Option<Filters>,
35 #[builder(default)]
36 pub authorization_mode: AuthorizationMode,
37 pub name: Option<String>,
38 pub version: Option<String>,
39 pub title: Option<String>,
40 pub instructions: Option<String>,
41 #[builder(default)]
42 pub skip_tool_descriptions: bool,
43 #[builder(default)]
44 pub skip_parameter_descriptions: bool,
45 pub response_transformer: Option<Arc<dyn ResponseTransformer>>,
52}
53
54impl Server {
55 pub fn new(
57 openapi_spec: serde_json::Value,
58 base_url: Url,
59 default_headers: Option<HeaderMap>,
60 filters: Option<Filters>,
61 skip_tool_descriptions: bool,
62 skip_parameter_descriptions: bool,
63 ) -> Self {
64 Self {
65 openapi_spec,
66 tool_collection: ToolCollection::new(),
67 base_url,
68 default_headers,
69 filters,
70 authorization_mode: AuthorizationMode::default(),
71 name: None,
72 version: None,
73 title: None,
74 instructions: None,
75 skip_tool_descriptions,
76 skip_parameter_descriptions,
77 response_transformer: None,
78 }
79 }
80
81 pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
87 let span = info_span!("tool_registration");
88 let _enter = span.enter();
89
90 let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
92
93 let tools = spec.to_openapi_tools(
95 self.filters.as_ref(),
96 Some(self.base_url.clone()),
97 self.default_headers.clone(),
98 self.skip_tool_descriptions,
99 self.skip_parameter_descriptions,
100 )?;
101
102 let tools = if let Some(ref transformer) = self.response_transformer {
104 tools
105 .into_iter()
106 .map(|mut tool| {
107 if let Some(schema) = tool.metadata.output_schema.take() {
108 tool.metadata.output_schema = Some(transformer.transform_schema(schema));
109 }
110 tool
111 })
112 .collect()
113 } else {
114 tools
115 };
116
117 self.tool_collection = ToolCollection::from_tools(tools);
118
119 info!(
120 tool_count = self.tool_collection.len(),
121 "Loaded tools from OpenAPI spec"
122 );
123
124 Ok(())
125 }
126
127 pub fn set_tool_transformer(
137 &mut self,
138 tool_name: &str,
139 transformer: Arc<dyn ResponseTransformer>,
140 ) -> Result<(), Error> {
141 self.tool_collection
142 .set_tool_transformer(tool_name, transformer)
143 }
144
145 #[must_use]
147 pub fn tool_count(&self) -> usize {
148 self.tool_collection.len()
149 }
150
151 #[must_use]
153 pub fn get_tool_names(&self) -> Vec<String> {
154 self.tool_collection.get_tool_names()
155 }
156
157 #[must_use]
159 pub fn has_tool(&self, name: &str) -> bool {
160 self.tool_collection.has_tool(name)
161 }
162
163 #[must_use]
165 pub fn get_tool(&self, name: &str) -> Option<&Tool> {
166 self.tool_collection.get_tool(name)
167 }
168
169 #[must_use]
171 pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
172 self.get_tool(name).map(|tool| &tool.metadata)
173 }
174
175 pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
177 self.authorization_mode = mode;
178 }
179
180 pub fn authorization_mode(&self) -> AuthorizationMode {
182 self.authorization_mode
183 }
184
185 #[must_use]
187 pub fn get_tool_stats(&self) -> String {
188 self.tool_collection.get_stats()
189 }
190
191 pub fn validate_registry(&self) -> Result<(), Error> {
197 if self.tool_collection.is_empty() {
198 return Err(Error::McpError("No tools loaded".to_string()));
199 }
200 Ok(())
201 }
202
203 fn extract_openapi_title(&self) -> Option<String> {
205 self.openapi_spec
206 .get("info")?
207 .get("title")?
208 .as_str()
209 .map(|s| s.to_string())
210 }
211
212 fn extract_openapi_version(&self) -> Option<String> {
214 self.openapi_spec
215 .get("info")?
216 .get("version")?
217 .as_str()
218 .map(|s| s.to_string())
219 }
220
221 fn extract_openapi_description(&self) -> Option<String> {
223 self.openapi_spec
224 .get("info")?
225 .get("description")?
226 .as_str()
227 .map(|s| s.to_string())
228 }
229
230 fn extract_openapi_display_title(&self) -> Option<String> {
233 if let Some(display_title) = self
235 .openapi_spec
236 .get("info")
237 .and_then(|info| info.get("x-display-title"))
238 .and_then(|t| t.as_str())
239 {
240 return Some(display_title.to_string());
241 }
242
243 self.extract_openapi_title().map(|title| {
245 if title.to_lowercase().contains("server") {
246 title
247 } else {
248 format!("{} Server", title)
249 }
250 })
251 }
252}
253
254impl ServerHandler for Server {
255 fn get_info(&self) -> InitializeResult {
256 let server_name = self
258 .name
259 .clone()
260 .or_else(|| self.extract_openapi_title())
261 .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
262
263 let server_version = self
265 .version
266 .clone()
267 .or_else(|| self.extract_openapi_version())
268 .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
269
270 let server_title = self
272 .title
273 .clone()
274 .or_else(|| self.extract_openapi_display_title());
275
276 let instructions = self
278 .instructions
279 .clone()
280 .or_else(|| self.extract_openapi_description())
281 .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
282
283 InitializeResult {
284 protocol_version: ProtocolVersion::V_2024_11_05,
285 server_info: Implementation {
286 name: server_name,
287 version: server_version,
288 title: server_title,
289 icons: None,
290 website_url: None,
291 },
292 capabilities: ServerCapabilities {
293 tools: Some(ToolsCapability {
294 list_changed: Some(false),
295 }),
296 ..Default::default()
297 },
298 instructions,
299 }
300 }
301
302 async fn list_tools(
303 &self,
304 _request: Option<PaginatedRequestParam>,
305 _context: RequestContext<RoleServer>,
306 ) -> Result<ListToolsResult, ErrorData> {
307 let span = info_span!("list_tools", tool_count = self.tool_collection.len());
308 let _enter = span.enter();
309
310 debug!("Processing MCP list_tools request");
311
312 let tools = self.tool_collection.to_mcp_tools();
314
315 info!(
316 returned_tools = tools.len(),
317 "MCP list_tools request completed successfully"
318 );
319
320 Ok(ListToolsResult {
321 meta: None,
322 tools,
323 next_cursor: None,
324 })
325 }
326
327 async fn call_tool(
328 &self,
329 request: CallToolRequestParam,
330 context: RequestContext<RoleServer>,
331 ) -> Result<CallToolResult, ErrorData> {
332 let span = info_span!(
333 "call_tool",
334 tool_name = %request.name
335 );
336 let _enter = span.enter();
337
338 debug!(
339 tool_name = %request.name,
340 has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
341 "Processing MCP call_tool request"
342 );
343
344 let arguments = request.arguments.unwrap_or_default();
345 let arguments_value = Value::Object(arguments);
346
347 let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
349
350 if auth_header.is_some() {
351 debug!("Authorization header is present");
352 }
353
354 let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
356
357 let server_transformer = self
359 .response_transformer
360 .as_ref()
361 .map(|t| t.as_ref() as &dyn ResponseTransformer);
362
363 match self
365 .tool_collection
366 .call_tool(
367 &request.name,
368 &arguments_value,
369 authorization,
370 server_transformer,
371 )
372 .await
373 {
374 Ok(result) => {
375 info!(
376 tool_name = %request.name,
377 success = true,
378 "MCP call_tool request completed successfully"
379 );
380 Ok(result)
381 }
382 Err(e) => {
383 warn!(
384 tool_name = %request.name,
385 success = false,
386 error = %e,
387 "MCP call_tool request failed"
388 );
389 Err(e.into())
391 }
392 }
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use crate::error::ToolCallValidationError;
400 use crate::{HttpClient, ToolCallError, ToolMetadata};
401 use serde_json::json;
402
403 #[test]
404 fn test_tool_not_found_error_with_suggestions() {
405 let tool1_metadata = ToolMetadata {
407 name: "getPetById".to_string(),
408 title: Some("Get Pet by ID".to_string()),
409 description: Some("Find pet by ID".to_string()),
410 parameters: json!({
411 "type": "object",
412 "properties": {
413 "petId": {
414 "type": "integer"
415 }
416 },
417 "required": ["petId"]
418 }),
419 output_schema: None,
420 method: "GET".to_string(),
421 path: "/pet/{petId}".to_string(),
422 security: None,
423 parameter_mappings: std::collections::HashMap::new(),
424 };
425
426 let tool2_metadata = ToolMetadata {
427 name: "getPetsByStatus".to_string(),
428 title: Some("Find Pets by Status".to_string()),
429 description: Some("Find pets by status".to_string()),
430 parameters: json!({
431 "type": "object",
432 "properties": {
433 "status": {
434 "type": "array",
435 "items": {
436 "type": "string"
437 }
438 }
439 },
440 "required": ["status"]
441 }),
442 output_schema: None,
443 method: "GET".to_string(),
444 path: "/pet/findByStatus".to_string(),
445 security: None,
446 parameter_mappings: std::collections::HashMap::new(),
447 };
448
449 let http_client = HttpClient::new();
451 let tool1 = Tool::new(tool1_metadata, http_client.clone()).unwrap();
452 let tool2 = Tool::new(tool2_metadata, http_client.clone()).unwrap();
453
454 let mut server = Server::new(
456 serde_json::Value::Null,
457 url::Url::parse("http://example.com").unwrap(),
458 None,
459 None,
460 false,
461 false,
462 );
463 server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
464
465 let tool_names = server.get_tool_names();
467 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
468
469 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
470 "getPetByID".to_string(),
471 &tool_name_refs,
472 ));
473 let error_data: ErrorData = error.into();
474 let error_json = serde_json::to_value(&error_data).unwrap();
475
476 insta::assert_json_snapshot!(error_json);
478 }
479
480 #[test]
481 fn test_tool_not_found_error_no_suggestions() {
482 let tool_metadata = ToolMetadata {
484 name: "getPetById".to_string(),
485 title: Some("Get Pet by ID".to_string()),
486 description: Some("Find pet by ID".to_string()),
487 parameters: json!({
488 "type": "object",
489 "properties": {
490 "petId": {
491 "type": "integer"
492 }
493 },
494 "required": ["petId"]
495 }),
496 output_schema: None,
497 method: "GET".to_string(),
498 path: "/pet/{petId}".to_string(),
499 security: None,
500 parameter_mappings: std::collections::HashMap::new(),
501 };
502
503 let tool = Tool::new(tool_metadata, HttpClient::new()).unwrap();
505
506 let mut server = Server::new(
508 serde_json::Value::Null,
509 url::Url::parse("http://example.com").unwrap(),
510 None,
511 None,
512 false,
513 false,
514 );
515 server.tool_collection = ToolCollection::from_tools(vec![tool]);
516
517 let tool_names = server.get_tool_names();
519 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
520
521 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
522 "completelyUnrelatedToolName".to_string(),
523 &tool_name_refs,
524 ));
525 let error_data: ErrorData = error.into();
526 let error_json = serde_json::to_value(&error_data).unwrap();
527
528 insta::assert_json_snapshot!(error_json);
530 }
531
532 #[test]
533 fn test_validation_error_converted_to_error_data() {
534 let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
536 violations: vec![crate::error::ValidationError::invalid_parameter(
537 "page".to_string(),
538 &["page_number".to_string(), "page_size".to_string()],
539 )],
540 });
541
542 let error_data: ErrorData = error.into();
543 let error_json = serde_json::to_value(&error_data).unwrap();
544
545 assert_eq!(error_json["code"], -32602); insta::assert_json_snapshot!(error_json);
550 }
551
552 #[test]
553 fn test_extract_openapi_info_with_full_spec() {
554 let openapi_spec = json!({
555 "openapi": "3.0.0",
556 "info": {
557 "title": "Pet Store API",
558 "version": "2.1.0",
559 "description": "A sample API for managing pets"
560 },
561 "paths": {}
562 });
563
564 let server = Server::new(
565 openapi_spec,
566 url::Url::parse("http://example.com").unwrap(),
567 None,
568 None,
569 false,
570 false,
571 );
572
573 assert_eq!(
574 server.extract_openapi_title(),
575 Some("Pet Store API".to_string())
576 );
577 assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
578 assert_eq!(
579 server.extract_openapi_description(),
580 Some("A sample API for managing pets".to_string())
581 );
582 }
583
584 #[test]
585 fn test_extract_openapi_info_with_minimal_spec() {
586 let openapi_spec = json!({
587 "openapi": "3.0.0",
588 "info": {
589 "title": "My API",
590 "version": "1.0.0"
591 },
592 "paths": {}
593 });
594
595 let server = Server::new(
596 openapi_spec,
597 url::Url::parse("http://example.com").unwrap(),
598 None,
599 None,
600 false,
601 false,
602 );
603
604 assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
605 assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
606 assert_eq!(server.extract_openapi_description(), None);
607 }
608
609 #[test]
610 fn test_extract_openapi_info_with_invalid_spec() {
611 let openapi_spec = json!({
612 "invalid": "spec"
613 });
614
615 let server = Server::new(
616 openapi_spec,
617 url::Url::parse("http://example.com").unwrap(),
618 None,
619 None,
620 false,
621 false,
622 );
623
624 assert_eq!(server.extract_openapi_title(), None);
625 assert_eq!(server.extract_openapi_version(), None);
626 assert_eq!(server.extract_openapi_description(), None);
627 }
628
629 #[test]
630 fn test_get_info_fallback_hierarchy_custom_metadata() {
631 let server = Server::new(
632 serde_json::Value::Null,
633 url::Url::parse("http://example.com").unwrap(),
634 None,
635 None,
636 false,
637 false,
638 );
639
640 let mut server = server;
642 server.name = Some("Custom Server".to_string());
643 server.version = Some("3.0.0".to_string());
644 server.instructions = Some("Custom instructions".to_string());
645
646 let result = server.get_info();
647
648 assert_eq!(result.server_info.name, "Custom Server");
649 assert_eq!(result.server_info.version, "3.0.0");
650 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
651 }
652
653 #[test]
654 fn test_get_info_fallback_hierarchy_openapi_spec() {
655 let openapi_spec = json!({
656 "openapi": "3.0.0",
657 "info": {
658 "title": "OpenAPI Server",
659 "version": "1.5.0",
660 "description": "Server from OpenAPI spec"
661 },
662 "paths": {}
663 });
664
665 let server = Server::new(
666 openapi_spec,
667 url::Url::parse("http://example.com").unwrap(),
668 None,
669 None,
670 false,
671 false,
672 );
673
674 let result = server.get_info();
675
676 assert_eq!(result.server_info.name, "OpenAPI Server");
677 assert_eq!(result.server_info.version, "1.5.0");
678 assert_eq!(
679 result.instructions,
680 Some("Server from OpenAPI spec".to_string())
681 );
682 }
683
684 #[test]
685 fn test_get_info_fallback_hierarchy_defaults() {
686 let server = Server::new(
687 serde_json::Value::Null,
688 url::Url::parse("http://example.com").unwrap(),
689 None,
690 None,
691 false,
692 false,
693 );
694
695 let result = server.get_info();
696
697 assert_eq!(result.server_info.name, "OpenAPI MCP Server");
698 assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
699 assert_eq!(
700 result.instructions,
701 Some("Exposes OpenAPI endpoints as MCP tools".to_string())
702 );
703 }
704
705 #[test]
706 fn test_get_info_fallback_hierarchy_mixed() {
707 let openapi_spec = json!({
708 "openapi": "3.0.0",
709 "info": {
710 "title": "OpenAPI Server",
711 "version": "2.5.0",
712 "description": "Server from OpenAPI spec"
713 },
714 "paths": {}
715 });
716
717 let mut server = Server::new(
718 openapi_spec,
719 url::Url::parse("http://example.com").unwrap(),
720 None,
721 None,
722 false,
723 false,
724 );
725
726 server.name = Some("Custom Server".to_string());
728 server.instructions = Some("Custom instructions".to_string());
729
730 let result = server.get_info();
731
732 assert_eq!(result.server_info.name, "Custom Server");
734 assert_eq!(result.server_info.version, "2.5.0");
736 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
738 }
739}