1use bon::Builder;
2use rmcp::{
3 RoleServer, ServerHandler,
4 model::{
5 CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6 ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7 ToolsCapability,
8 },
9 service::RequestContext,
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::config::{Authorization, AuthorizationMode};
18use crate::error::Error;
19use crate::tool::{Tool, ToolCollection, ToolMetadata};
20use tracing::{debug, info, info_span, warn};
21
22#[derive(Clone, Builder)]
23pub struct Server {
24 pub openapi_spec: serde_json::Value,
25 #[builder(default)]
26 pub tool_collection: ToolCollection,
27 pub base_url: Url,
28 pub default_headers: Option<HeaderMap>,
29 pub tag_filter: Option<Vec<String>>,
30 pub method_filter: Option<Vec<reqwest::Method>>,
31 #[builder(default)]
32 pub authorization_mode: AuthorizationMode,
33 pub name: Option<String>,
34 pub version: Option<String>,
35 pub title: Option<String>,
36 pub instructions: Option<String>,
37}
38
39impl Server {
40 pub fn new(
42 openapi_spec: serde_json::Value,
43 base_url: Url,
44 default_headers: Option<HeaderMap>,
45 tag_filter: Option<Vec<String>>,
46 method_filter: Option<Vec<reqwest::Method>>,
47 ) -> Self {
48 Self {
49 openapi_spec,
50 tool_collection: ToolCollection::new(),
51 base_url,
52 default_headers,
53 tag_filter,
54 method_filter,
55 authorization_mode: AuthorizationMode::default(),
56 name: None,
57 version: None,
58 title: None,
59 instructions: None,
60 }
61 }
62
63 pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
69 let span = info_span!("tool_registration");
70 let _enter = span.enter();
71
72 let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
74
75 let tools = spec.to_openapi_tools(
77 self.tag_filter.as_deref(),
78 self.method_filter.as_deref(),
79 Some(self.base_url.clone()),
80 self.default_headers.clone(),
81 )?;
82
83 self.tool_collection = ToolCollection::from_tools(tools);
84
85 info!(
86 tool_count = self.tool_collection.len(),
87 "Loaded tools from OpenAPI spec"
88 );
89
90 Ok(())
91 }
92
93 #[must_use]
95 pub fn tool_count(&self) -> usize {
96 self.tool_collection.len()
97 }
98
99 #[must_use]
101 pub fn get_tool_names(&self) -> Vec<String> {
102 self.tool_collection.get_tool_names()
103 }
104
105 #[must_use]
107 pub fn has_tool(&self, name: &str) -> bool {
108 self.tool_collection.has_tool(name)
109 }
110
111 #[must_use]
113 pub fn get_tool(&self, name: &str) -> Option<&Tool> {
114 self.tool_collection.get_tool(name)
115 }
116
117 #[must_use]
119 pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
120 self.get_tool(name).map(|tool| &tool.metadata)
121 }
122
123 pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
125 self.authorization_mode = mode;
126 }
127
128 pub fn authorization_mode(&self) -> AuthorizationMode {
130 self.authorization_mode
131 }
132
133 #[must_use]
135 pub fn get_tool_stats(&self) -> String {
136 self.tool_collection.get_stats()
137 }
138
139 pub fn validate_registry(&self) -> Result<(), Error> {
145 if self.tool_collection.is_empty() {
146 return Err(Error::McpError("No tools loaded".to_string()));
147 }
148 Ok(())
149 }
150
151 fn extract_openapi_title(&self) -> Option<String> {
153 self.openapi_spec
154 .get("info")?
155 .get("title")?
156 .as_str()
157 .map(|s| s.to_string())
158 }
159
160 fn extract_openapi_version(&self) -> Option<String> {
162 self.openapi_spec
163 .get("info")?
164 .get("version")?
165 .as_str()
166 .map(|s| s.to_string())
167 }
168
169 fn extract_openapi_description(&self) -> Option<String> {
171 self.openapi_spec
172 .get("info")?
173 .get("description")?
174 .as_str()
175 .map(|s| s.to_string())
176 }
177
178 fn extract_openapi_display_title(&self) -> Option<String> {
181 if let Some(display_title) = self
183 .openapi_spec
184 .get("info")
185 .and_then(|info| info.get("x-display-title"))
186 .and_then(|t| t.as_str())
187 {
188 return Some(display_title.to_string());
189 }
190
191 self.extract_openapi_title().map(|title| {
193 if title.to_lowercase().contains("server") {
194 title
195 } else {
196 format!("{} Server", title)
197 }
198 })
199 }
200}
201
202impl ServerHandler for Server {
203 fn get_info(&self) -> InitializeResult {
204 let server_name = self
206 .name
207 .clone()
208 .or_else(|| self.extract_openapi_title())
209 .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
210
211 let server_version = self
213 .version
214 .clone()
215 .or_else(|| self.extract_openapi_version())
216 .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
217
218 let server_title = self
220 .title
221 .clone()
222 .or_else(|| self.extract_openapi_display_title());
223
224 let instructions = self
226 .instructions
227 .clone()
228 .or_else(|| self.extract_openapi_description())
229 .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
230
231 InitializeResult {
232 protocol_version: ProtocolVersion::V_2024_11_05,
233 server_info: Implementation {
234 name: server_name,
235 version: server_version,
236 title: server_title,
237 icons: None,
238 website_url: None,
239 },
240 capabilities: ServerCapabilities {
241 tools: Some(ToolsCapability {
242 list_changed: Some(false),
243 }),
244 ..Default::default()
245 },
246 instructions,
247 }
248 }
249
250 async fn list_tools(
251 &self,
252 _request: Option<PaginatedRequestParam>,
253 _context: RequestContext<RoleServer>,
254 ) -> Result<ListToolsResult, ErrorData> {
255 let span = info_span!("list_tools", tool_count = self.tool_collection.len());
256 let _enter = span.enter();
257
258 debug!("Processing MCP list_tools request");
259
260 let tools = self.tool_collection.to_mcp_tools();
262
263 info!(
264 returned_tools = tools.len(),
265 "MCP list_tools request completed successfully"
266 );
267
268 Ok(ListToolsResult {
269 tools,
270 next_cursor: None,
271 })
272 }
273
274 async fn call_tool(
275 &self,
276 request: CallToolRequestParam,
277 context: RequestContext<RoleServer>,
278 ) -> Result<CallToolResult, ErrorData> {
279 let span = info_span!(
280 "call_tool",
281 tool_name = %request.name
282 );
283 let _enter = span.enter();
284
285 debug!(
286 tool_name = %request.name,
287 has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
288 "Processing MCP call_tool request"
289 );
290
291 let arguments = request.arguments.unwrap_or_default();
292 let arguments_value = Value::Object(arguments);
293
294 let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
296
297 if auth_header.is_some() {
298 debug!("Authorization header is present");
299 }
300
301 let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
303
304 match self
306 .tool_collection
307 .call_tool(&request.name, &arguments_value, authorization)
308 .await
309 {
310 Ok(result) => {
311 info!(
312 tool_name = %request.name,
313 success = true,
314 "MCP call_tool request completed successfully"
315 );
316 Ok(result)
317 }
318 Err(e) => {
319 warn!(
320 tool_name = %request.name,
321 success = false,
322 error = %e,
323 "MCP call_tool request failed"
324 );
325 Err(e.into())
327 }
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use crate::error::ToolCallValidationError;
336 use crate::{ToolCallError, ToolMetadata};
337 use serde_json::json;
338
339 #[test]
340 fn test_tool_not_found_error_with_suggestions() {
341 let tool1_metadata = ToolMetadata {
343 name: "getPetById".to_string(),
344 title: Some("Get Pet by ID".to_string()),
345 description: "Find pet by ID".to_string(),
346 parameters: json!({
347 "type": "object",
348 "properties": {
349 "petId": {
350 "type": "integer"
351 }
352 },
353 "required": ["petId"]
354 }),
355 output_schema: None,
356 method: "GET".to_string(),
357 path: "/pet/{petId}".to_string(),
358 security: None,
359 };
360
361 let tool2_metadata = ToolMetadata {
362 name: "getPetsByStatus".to_string(),
363 title: Some("Find Pets by Status".to_string()),
364 description: "Find pets by status".to_string(),
365 parameters: json!({
366 "type": "object",
367 "properties": {
368 "status": {
369 "type": "array",
370 "items": {
371 "type": "string"
372 }
373 }
374 },
375 "required": ["status"]
376 }),
377 output_schema: None,
378 method: "GET".to_string(),
379 path: "/pet/findByStatus".to_string(),
380 security: None,
381 };
382
383 let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
385 let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
386
387 let mut server = Server::new(
389 serde_json::Value::Null,
390 url::Url::parse("http://example.com").unwrap(),
391 None,
392 None,
393 None,
394 );
395 server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
396
397 let tool_names = server.get_tool_names();
399 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
400
401 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
402 "getPetByID".to_string(),
403 &tool_name_refs,
404 ));
405 let error_data: ErrorData = error.into();
406 let error_json = serde_json::to_value(&error_data).unwrap();
407
408 insta::assert_json_snapshot!(error_json);
410 }
411
412 #[test]
413 fn test_tool_not_found_error_no_suggestions() {
414 let tool_metadata = ToolMetadata {
416 name: "getPetById".to_string(),
417 title: Some("Get Pet by ID".to_string()),
418 description: "Find pet by ID".to_string(),
419 parameters: json!({
420 "type": "object",
421 "properties": {
422 "petId": {
423 "type": "integer"
424 }
425 },
426 "required": ["petId"]
427 }),
428 output_schema: None,
429 method: "GET".to_string(),
430 path: "/pet/{petId}".to_string(),
431 security: None,
432 };
433
434 let tool = Tool::new(tool_metadata, None, None).unwrap();
436
437 let mut server = Server::new(
439 serde_json::Value::Null,
440 url::Url::parse("http://example.com").unwrap(),
441 None,
442 None,
443 None,
444 );
445 server.tool_collection = ToolCollection::from_tools(vec![tool]);
446
447 let tool_names = server.get_tool_names();
449 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
450
451 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
452 "completelyUnrelatedToolName".to_string(),
453 &tool_name_refs,
454 ));
455 let error_data: ErrorData = error.into();
456 let error_json = serde_json::to_value(&error_data).unwrap();
457
458 insta::assert_json_snapshot!(error_json);
460 }
461
462 #[test]
463 fn test_validation_error_converted_to_error_data() {
464 let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
466 violations: vec![crate::error::ValidationError::invalid_parameter(
467 "page".to_string(),
468 &["page_number".to_string(), "page_size".to_string()],
469 )],
470 });
471
472 let error_data: ErrorData = error.into();
473 let error_json = serde_json::to_value(&error_data).unwrap();
474
475 assert_eq!(error_json["code"], -32602); insta::assert_json_snapshot!(error_json);
480 }
481
482 #[test]
483 fn test_extract_openapi_info_with_full_spec() {
484 let openapi_spec = json!({
485 "openapi": "3.0.0",
486 "info": {
487 "title": "Pet Store API",
488 "version": "2.1.0",
489 "description": "A sample API for managing pets"
490 },
491 "paths": {}
492 });
493
494 let server = Server::new(
495 openapi_spec,
496 url::Url::parse("http://example.com").unwrap(),
497 None,
498 None,
499 None,
500 );
501
502 assert_eq!(
503 server.extract_openapi_title(),
504 Some("Pet Store API".to_string())
505 );
506 assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
507 assert_eq!(
508 server.extract_openapi_description(),
509 Some("A sample API for managing pets".to_string())
510 );
511 }
512
513 #[test]
514 fn test_extract_openapi_info_with_minimal_spec() {
515 let openapi_spec = json!({
516 "openapi": "3.0.0",
517 "info": {
518 "title": "My API",
519 "version": "1.0.0"
520 },
521 "paths": {}
522 });
523
524 let server = Server::new(
525 openapi_spec,
526 url::Url::parse("http://example.com").unwrap(),
527 None,
528 None,
529 None,
530 );
531
532 assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
533 assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
534 assert_eq!(server.extract_openapi_description(), None);
535 }
536
537 #[test]
538 fn test_extract_openapi_info_with_invalid_spec() {
539 let openapi_spec = json!({
540 "invalid": "spec"
541 });
542
543 let server = Server::new(
544 openapi_spec,
545 url::Url::parse("http://example.com").unwrap(),
546 None,
547 None,
548 None,
549 );
550
551 assert_eq!(server.extract_openapi_title(), None);
552 assert_eq!(server.extract_openapi_version(), None);
553 assert_eq!(server.extract_openapi_description(), None);
554 }
555
556 #[test]
557 fn test_get_info_fallback_hierarchy_custom_metadata() {
558 let server = Server::new(
559 serde_json::Value::Null,
560 url::Url::parse("http://example.com").unwrap(),
561 None,
562 None,
563 None,
564 );
565
566 let mut server = server;
568 server.name = Some("Custom Server".to_string());
569 server.version = Some("3.0.0".to_string());
570 server.instructions = Some("Custom instructions".to_string());
571
572 let result = server.get_info();
573
574 assert_eq!(result.server_info.name, "Custom Server");
575 assert_eq!(result.server_info.version, "3.0.0");
576 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
577 }
578
579 #[test]
580 fn test_get_info_fallback_hierarchy_openapi_spec() {
581 let openapi_spec = json!({
582 "openapi": "3.0.0",
583 "info": {
584 "title": "OpenAPI Server",
585 "version": "1.5.0",
586 "description": "Server from OpenAPI spec"
587 },
588 "paths": {}
589 });
590
591 let server = Server::new(
592 openapi_spec,
593 url::Url::parse("http://example.com").unwrap(),
594 None,
595 None,
596 None,
597 );
598
599 let result = server.get_info();
600
601 assert_eq!(result.server_info.name, "OpenAPI Server");
602 assert_eq!(result.server_info.version, "1.5.0");
603 assert_eq!(
604 result.instructions,
605 Some("Server from OpenAPI spec".to_string())
606 );
607 }
608
609 #[test]
610 fn test_get_info_fallback_hierarchy_defaults() {
611 let server = Server::new(
612 serde_json::Value::Null,
613 url::Url::parse("http://example.com").unwrap(),
614 None,
615 None,
616 None,
617 );
618
619 let result = server.get_info();
620
621 assert_eq!(result.server_info.name, "OpenAPI MCP Server");
622 assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
623 assert_eq!(
624 result.instructions,
625 Some("Exposes OpenAPI endpoints as MCP tools".to_string())
626 );
627 }
628
629 #[test]
630 fn test_get_info_fallback_hierarchy_mixed() {
631 let openapi_spec = json!({
632 "openapi": "3.0.0",
633 "info": {
634 "title": "OpenAPI Server",
635 "version": "2.5.0",
636 "description": "Server from OpenAPI spec"
637 },
638 "paths": {}
639 });
640
641 let mut server = Server::new(
642 openapi_spec,
643 url::Url::parse("http://example.com").unwrap(),
644 None,
645 None,
646 None,
647 );
648
649 server.name = Some("Custom Server".to_string());
651 server.instructions = Some("Custom instructions".to_string());
652
653 let result = server.get_info();
654
655 assert_eq!(result.server_info.name, "Custom Server");
657 assert_eq!(result.server_info.version, "2.5.0");
659 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
661 }
662}