1use bon::Builder;
2use rmcp::{
3 handler::server::ServerHandler,
4 model::{
5 CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6 ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7 ToolsCapability,
8 },
9 service::{RequestContext, RoleServer},
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::error::Error;
18use crate::tool::{Tool, ToolCollection, ToolMetadata};
19use crate::{
20 config::{Authorization, AuthorizationMode},
21 spec::Filters,
22};
23use tracing::{debug, info, info_span, warn};
24
25#[derive(Clone, Builder)]
26pub struct Server {
27 pub openapi_spec: serde_json::Value,
28 #[builder(default)]
29 pub tool_collection: ToolCollection,
30 pub base_url: Url,
31 pub default_headers: Option<HeaderMap>,
32 pub filters: Option<Filters>,
33 #[builder(default)]
34 pub authorization_mode: AuthorizationMode,
35 pub name: Option<String>,
36 pub version: Option<String>,
37 pub title: Option<String>,
38 pub instructions: Option<String>,
39 #[builder(default)]
40 pub skip_tool_descriptions: bool,
41 #[builder(default)]
42 pub skip_parameter_descriptions: bool,
43}
44
45impl Server {
46 pub fn new(
48 openapi_spec: serde_json::Value,
49 base_url: Url,
50 default_headers: Option<HeaderMap>,
51 filters: Option<Filters>,
52 skip_tool_descriptions: bool,
53 skip_parameter_descriptions: bool,
54 ) -> Self {
55 Self {
56 openapi_spec,
57 tool_collection: ToolCollection::new(),
58 base_url,
59 default_headers,
60 filters,
61 authorization_mode: AuthorizationMode::default(),
62 name: None,
63 version: None,
64 title: None,
65 instructions: None,
66 skip_tool_descriptions,
67 skip_parameter_descriptions,
68 }
69 }
70
71 pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
77 let span = info_span!("tool_registration");
78 let _enter = span.enter();
79
80 let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
82
83 let tools = spec.to_openapi_tools(
85 self.filters.as_ref(),
86 Some(self.base_url.clone()),
87 self.default_headers.clone(),
88 self.skip_tool_descriptions,
89 self.skip_parameter_descriptions,
90 )?;
91
92 self.tool_collection = ToolCollection::from_tools(tools);
93
94 info!(
95 tool_count = self.tool_collection.len(),
96 "Loaded tools from OpenAPI spec"
97 );
98
99 Ok(())
100 }
101
102 #[must_use]
104 pub fn tool_count(&self) -> usize {
105 self.tool_collection.len()
106 }
107
108 #[must_use]
110 pub fn get_tool_names(&self) -> Vec<String> {
111 self.tool_collection.get_tool_names()
112 }
113
114 #[must_use]
116 pub fn has_tool(&self, name: &str) -> bool {
117 self.tool_collection.has_tool(name)
118 }
119
120 #[must_use]
122 pub fn get_tool(&self, name: &str) -> Option<&Tool> {
123 self.tool_collection.get_tool(name)
124 }
125
126 #[must_use]
128 pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
129 self.get_tool(name).map(|tool| &tool.metadata)
130 }
131
132 pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
134 self.authorization_mode = mode;
135 }
136
137 pub fn authorization_mode(&self) -> AuthorizationMode {
139 self.authorization_mode
140 }
141
142 #[must_use]
144 pub fn get_tool_stats(&self) -> String {
145 self.tool_collection.get_stats()
146 }
147
148 pub fn validate_registry(&self) -> Result<(), Error> {
154 if self.tool_collection.is_empty() {
155 return Err(Error::McpError("No tools loaded".to_string()));
156 }
157 Ok(())
158 }
159
160 fn extract_openapi_title(&self) -> Option<String> {
162 self.openapi_spec
163 .get("info")?
164 .get("title")?
165 .as_str()
166 .map(|s| s.to_string())
167 }
168
169 fn extract_openapi_version(&self) -> Option<String> {
171 self.openapi_spec
172 .get("info")?
173 .get("version")?
174 .as_str()
175 .map(|s| s.to_string())
176 }
177
178 fn extract_openapi_description(&self) -> Option<String> {
180 self.openapi_spec
181 .get("info")?
182 .get("description")?
183 .as_str()
184 .map(|s| s.to_string())
185 }
186
187 fn extract_openapi_display_title(&self) -> Option<String> {
190 if let Some(display_title) = self
192 .openapi_spec
193 .get("info")
194 .and_then(|info| info.get("x-display-title"))
195 .and_then(|t| t.as_str())
196 {
197 return Some(display_title.to_string());
198 }
199
200 self.extract_openapi_title().map(|title| {
202 if title.to_lowercase().contains("server") {
203 title
204 } else {
205 format!("{} Server", title)
206 }
207 })
208 }
209}
210
211impl ServerHandler for Server {
212 fn get_info(&self) -> InitializeResult {
213 let server_name = self
215 .name
216 .clone()
217 .or_else(|| self.extract_openapi_title())
218 .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
219
220 let server_version = self
222 .version
223 .clone()
224 .or_else(|| self.extract_openapi_version())
225 .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
226
227 let server_title = self
229 .title
230 .clone()
231 .or_else(|| self.extract_openapi_display_title());
232
233 let instructions = self
235 .instructions
236 .clone()
237 .or_else(|| self.extract_openapi_description())
238 .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
239
240 InitializeResult {
241 protocol_version: ProtocolVersion::V_2024_11_05,
242 server_info: Implementation {
243 name: server_name,
244 version: server_version,
245 title: server_title,
246 icons: None,
247 website_url: None,
248 },
249 capabilities: ServerCapabilities {
250 tools: Some(ToolsCapability {
251 list_changed: Some(false),
252 }),
253 ..Default::default()
254 },
255 instructions,
256 }
257 }
258
259 async fn list_tools(
260 &self,
261 _request: Option<PaginatedRequestParam>,
262 _context: RequestContext<RoleServer>,
263 ) -> Result<ListToolsResult, ErrorData> {
264 let span = info_span!("list_tools", tool_count = self.tool_collection.len());
265 let _enter = span.enter();
266
267 debug!("Processing MCP list_tools request");
268
269 let tools = self.tool_collection.to_mcp_tools();
271
272 info!(
273 returned_tools = tools.len(),
274 "MCP list_tools request completed successfully"
275 );
276
277 Ok(ListToolsResult {
278 meta: None,
279 tools,
280 next_cursor: None,
281 })
282 }
283
284 async fn call_tool(
285 &self,
286 request: CallToolRequestParam,
287 context: RequestContext<RoleServer>,
288 ) -> Result<CallToolResult, ErrorData> {
289 let span = info_span!(
290 "call_tool",
291 tool_name = %request.name
292 );
293 let _enter = span.enter();
294
295 debug!(
296 tool_name = %request.name,
297 has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
298 "Processing MCP call_tool request"
299 );
300
301 let arguments = request.arguments.unwrap_or_default();
302 let arguments_value = Value::Object(arguments);
303
304 let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
306
307 if auth_header.is_some() {
308 debug!("Authorization header is present");
309 }
310
311 let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
313
314 match self
316 .tool_collection
317 .call_tool(&request.name, &arguments_value, authorization)
318 .await
319 {
320 Ok(result) => {
321 info!(
322 tool_name = %request.name,
323 success = true,
324 "MCP call_tool request completed successfully"
325 );
326 Ok(result)
327 }
328 Err(e) => {
329 warn!(
330 tool_name = %request.name,
331 success = false,
332 error = %e,
333 "MCP call_tool request failed"
334 );
335 Err(e.into())
337 }
338 }
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::error::ToolCallValidationError;
346 use crate::{HttpClient, ToolCallError, ToolMetadata};
347 use serde_json::json;
348
349 #[test]
350 fn test_tool_not_found_error_with_suggestions() {
351 let tool1_metadata = ToolMetadata {
353 name: "getPetById".to_string(),
354 title: Some("Get Pet by ID".to_string()),
355 description: Some("Find pet by ID".to_string()),
356 parameters: json!({
357 "type": "object",
358 "properties": {
359 "petId": {
360 "type": "integer"
361 }
362 },
363 "required": ["petId"]
364 }),
365 output_schema: None,
366 method: "GET".to_string(),
367 path: "/pet/{petId}".to_string(),
368 security: None,
369 parameter_mappings: std::collections::HashMap::new(),
370 };
371
372 let tool2_metadata = ToolMetadata {
373 name: "getPetsByStatus".to_string(),
374 title: Some("Find Pets by Status".to_string()),
375 description: Some("Find pets by status".to_string()),
376 parameters: json!({
377 "type": "object",
378 "properties": {
379 "status": {
380 "type": "array",
381 "items": {
382 "type": "string"
383 }
384 }
385 },
386 "required": ["status"]
387 }),
388 output_schema: None,
389 method: "GET".to_string(),
390 path: "/pet/findByStatus".to_string(),
391 security: None,
392 parameter_mappings: std::collections::HashMap::new(),
393 };
394
395 let http_client = HttpClient::new();
397 let tool1 = Tool::new(tool1_metadata, http_client.clone()).unwrap();
398 let tool2 = Tool::new(tool2_metadata, http_client.clone()).unwrap();
399
400 let mut server = Server::new(
402 serde_json::Value::Null,
403 url::Url::parse("http://example.com").unwrap(),
404 None,
405 None,
406 false,
407 false,
408 );
409 server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
410
411 let tool_names = server.get_tool_names();
413 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
414
415 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
416 "getPetByID".to_string(),
417 &tool_name_refs,
418 ));
419 let error_data: ErrorData = error.into();
420 let error_json = serde_json::to_value(&error_data).unwrap();
421
422 insta::assert_json_snapshot!(error_json);
424 }
425
426 #[test]
427 fn test_tool_not_found_error_no_suggestions() {
428 let tool_metadata = ToolMetadata {
430 name: "getPetById".to_string(),
431 title: Some("Get Pet by ID".to_string()),
432 description: Some("Find pet by ID".to_string()),
433 parameters: json!({
434 "type": "object",
435 "properties": {
436 "petId": {
437 "type": "integer"
438 }
439 },
440 "required": ["petId"]
441 }),
442 output_schema: None,
443 method: "GET".to_string(),
444 path: "/pet/{petId}".to_string(),
445 security: None,
446 parameter_mappings: std::collections::HashMap::new(),
447 };
448
449 let tool = Tool::new(tool_metadata, HttpClient::new()).unwrap();
451
452 let mut server = Server::new(
454 serde_json::Value::Null,
455 url::Url::parse("http://example.com").unwrap(),
456 None,
457 None,
458 false,
459 false,
460 );
461 server.tool_collection = ToolCollection::from_tools(vec![tool]);
462
463 let tool_names = server.get_tool_names();
465 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
466
467 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
468 "completelyUnrelatedToolName".to_string(),
469 &tool_name_refs,
470 ));
471 let error_data: ErrorData = error.into();
472 let error_json = serde_json::to_value(&error_data).unwrap();
473
474 insta::assert_json_snapshot!(error_json);
476 }
477
478 #[test]
479 fn test_validation_error_converted_to_error_data() {
480 let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
482 violations: vec![crate::error::ValidationError::invalid_parameter(
483 "page".to_string(),
484 &["page_number".to_string(), "page_size".to_string()],
485 )],
486 });
487
488 let error_data: ErrorData = error.into();
489 let error_json = serde_json::to_value(&error_data).unwrap();
490
491 assert_eq!(error_json["code"], -32602); insta::assert_json_snapshot!(error_json);
496 }
497
498 #[test]
499 fn test_extract_openapi_info_with_full_spec() {
500 let openapi_spec = json!({
501 "openapi": "3.0.0",
502 "info": {
503 "title": "Pet Store API",
504 "version": "2.1.0",
505 "description": "A sample API for managing pets"
506 },
507 "paths": {}
508 });
509
510 let server = Server::new(
511 openapi_spec,
512 url::Url::parse("http://example.com").unwrap(),
513 None,
514 None,
515 false,
516 false,
517 );
518
519 assert_eq!(
520 server.extract_openapi_title(),
521 Some("Pet Store API".to_string())
522 );
523 assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
524 assert_eq!(
525 server.extract_openapi_description(),
526 Some("A sample API for managing pets".to_string())
527 );
528 }
529
530 #[test]
531 fn test_extract_openapi_info_with_minimal_spec() {
532 let openapi_spec = json!({
533 "openapi": "3.0.0",
534 "info": {
535 "title": "My API",
536 "version": "1.0.0"
537 },
538 "paths": {}
539 });
540
541 let server = Server::new(
542 openapi_spec,
543 url::Url::parse("http://example.com").unwrap(),
544 None,
545 None,
546 false,
547 false,
548 );
549
550 assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
551 assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
552 assert_eq!(server.extract_openapi_description(), None);
553 }
554
555 #[test]
556 fn test_extract_openapi_info_with_invalid_spec() {
557 let openapi_spec = json!({
558 "invalid": "spec"
559 });
560
561 let server = Server::new(
562 openapi_spec,
563 url::Url::parse("http://example.com").unwrap(),
564 None,
565 None,
566 false,
567 false,
568 );
569
570 assert_eq!(server.extract_openapi_title(), None);
571 assert_eq!(server.extract_openapi_version(), None);
572 assert_eq!(server.extract_openapi_description(), None);
573 }
574
575 #[test]
576 fn test_get_info_fallback_hierarchy_custom_metadata() {
577 let server = Server::new(
578 serde_json::Value::Null,
579 url::Url::parse("http://example.com").unwrap(),
580 None,
581 None,
582 false,
583 false,
584 );
585
586 let mut server = server;
588 server.name = Some("Custom Server".to_string());
589 server.version = Some("3.0.0".to_string());
590 server.instructions = Some("Custom instructions".to_string());
591
592 let result = server.get_info();
593
594 assert_eq!(result.server_info.name, "Custom Server");
595 assert_eq!(result.server_info.version, "3.0.0");
596 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
597 }
598
599 #[test]
600 fn test_get_info_fallback_hierarchy_openapi_spec() {
601 let openapi_spec = json!({
602 "openapi": "3.0.0",
603 "info": {
604 "title": "OpenAPI Server",
605 "version": "1.5.0",
606 "description": "Server from OpenAPI spec"
607 },
608 "paths": {}
609 });
610
611 let server = Server::new(
612 openapi_spec,
613 url::Url::parse("http://example.com").unwrap(),
614 None,
615 None,
616 false,
617 false,
618 );
619
620 let result = server.get_info();
621
622 assert_eq!(result.server_info.name, "OpenAPI Server");
623 assert_eq!(result.server_info.version, "1.5.0");
624 assert_eq!(
625 result.instructions,
626 Some("Server from OpenAPI spec".to_string())
627 );
628 }
629
630 #[test]
631 fn test_get_info_fallback_hierarchy_defaults() {
632 let server = Server::new(
633 serde_json::Value::Null,
634 url::Url::parse("http://example.com").unwrap(),
635 None,
636 None,
637 false,
638 false,
639 );
640
641 let result = server.get_info();
642
643 assert_eq!(result.server_info.name, "OpenAPI MCP Server");
644 assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
645 assert_eq!(
646 result.instructions,
647 Some("Exposes OpenAPI endpoints as MCP tools".to_string())
648 );
649 }
650
651 #[test]
652 fn test_get_info_fallback_hierarchy_mixed() {
653 let openapi_spec = json!({
654 "openapi": "3.0.0",
655 "info": {
656 "title": "OpenAPI Server",
657 "version": "2.5.0",
658 "description": "Server from OpenAPI spec"
659 },
660 "paths": {}
661 });
662
663 let mut server = Server::new(
664 openapi_spec,
665 url::Url::parse("http://example.com").unwrap(),
666 None,
667 None,
668 false,
669 false,
670 );
671
672 server.name = Some("Custom Server".to_string());
674 server.instructions = Some("Custom instructions".to_string());
675
676 let result = server.get_info();
677
678 assert_eq!(result.server_info.name, "Custom Server");
680 assert_eq!(result.server_info.version, "2.5.0");
682 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
684 }
685}