1use bon::Builder;
2use rmcp::{
3 RoleServer, ServerHandler,
4 model::{
5 CallToolRequestParam, CallToolResult, ErrorData, Implementation, InitializeResult,
6 ListToolsResult, PaginatedRequestParam, ProtocolVersion, ServerCapabilities,
7 ToolsCapability,
8 },
9 service::RequestContext,
10};
11use rmcp_actix_web::transport::AuthorizationHeader;
12use serde_json::Value;
13
14use reqwest::header::HeaderMap;
15use url::Url;
16
17use crate::config::{Authorization, AuthorizationMode};
18use crate::error::Error;
19use crate::tool::{Tool, ToolCollection, ToolMetadata};
20use tracing::{debug, info, info_span, warn};
21
22#[derive(Clone, Builder)]
23pub struct Server {
24 pub openapi_spec: serde_json::Value,
25 #[builder(default)]
26 pub tool_collection: ToolCollection,
27 pub base_url: Url,
28 pub default_headers: Option<HeaderMap>,
29 pub tag_filter: Option<Vec<String>>,
30 pub method_filter: Option<Vec<reqwest::Method>>,
31 #[builder(default)]
32 pub authorization_mode: AuthorizationMode,
33 pub server_name: Option<String>,
34 pub server_version: Option<String>,
35 pub server_instructions: Option<String>,
36}
37
38impl Server {
39 pub fn new(
41 openapi_spec: serde_json::Value,
42 base_url: Url,
43 default_headers: Option<HeaderMap>,
44 tag_filter: Option<Vec<String>>,
45 method_filter: Option<Vec<reqwest::Method>>,
46 ) -> Self {
47 Self {
48 openapi_spec,
49 tool_collection: ToolCollection::new(),
50 base_url,
51 default_headers,
52 tag_filter,
53 method_filter,
54 authorization_mode: AuthorizationMode::default(),
55 server_name: None,
56 server_version: None,
57 server_instructions: None,
58 }
59 }
60
61 pub fn load_openapi_spec(&mut self) -> Result<(), Error> {
67 let span = info_span!("tool_registration");
68 let _enter = span.enter();
69
70 let spec = crate::spec::Spec::from_value(self.openapi_spec.clone())?;
72
73 let tools = spec.to_openapi_tools(
75 self.tag_filter.as_deref(),
76 self.method_filter.as_deref(),
77 Some(self.base_url.clone()),
78 self.default_headers.clone(),
79 )?;
80
81 self.tool_collection = ToolCollection::from_tools(tools);
82
83 info!(
84 tool_count = self.tool_collection.len(),
85 "Loaded tools from OpenAPI spec"
86 );
87
88 Ok(())
89 }
90
91 #[must_use]
93 pub fn tool_count(&self) -> usize {
94 self.tool_collection.len()
95 }
96
97 #[must_use]
99 pub fn get_tool_names(&self) -> Vec<String> {
100 self.tool_collection.get_tool_names()
101 }
102
103 #[must_use]
105 pub fn has_tool(&self, name: &str) -> bool {
106 self.tool_collection.has_tool(name)
107 }
108
109 #[must_use]
111 pub fn get_tool(&self, name: &str) -> Option<&Tool> {
112 self.tool_collection.get_tool(name)
113 }
114
115 #[must_use]
117 pub fn get_tool_metadata(&self, name: &str) -> Option<&ToolMetadata> {
118 self.get_tool(name).map(|tool| &tool.metadata)
119 }
120
121 pub fn set_authorization_mode(&mut self, mode: AuthorizationMode) {
123 self.authorization_mode = mode;
124 }
125
126 pub fn authorization_mode(&self) -> AuthorizationMode {
128 self.authorization_mode
129 }
130
131 #[must_use]
133 pub fn get_tool_stats(&self) -> String {
134 self.tool_collection.get_stats()
135 }
136
137 pub fn validate_registry(&self) -> Result<(), Error> {
143 if self.tool_collection.is_empty() {
144 return Err(Error::McpError("No tools loaded".to_string()));
145 }
146 Ok(())
147 }
148
149 fn extract_openapi_title(&self) -> Option<String> {
151 self.openapi_spec
152 .get("info")?
153 .get("title")?
154 .as_str()
155 .map(|s| s.to_string())
156 }
157
158 fn extract_openapi_version(&self) -> Option<String> {
160 self.openapi_spec
161 .get("info")?
162 .get("version")?
163 .as_str()
164 .map(|s| s.to_string())
165 }
166
167 fn extract_openapi_description(&self) -> Option<String> {
169 self.openapi_spec
170 .get("info")?
171 .get("description")?
172 .as_str()
173 .map(|s| s.to_string())
174 }
175}
176
177impl ServerHandler for Server {
178 fn get_info(&self) -> InitializeResult {
179 let server_name = self
181 .server_name
182 .clone()
183 .or_else(|| self.extract_openapi_title())
184 .unwrap_or_else(|| "OpenAPI MCP Server".to_string());
185
186 let server_version = self
188 .server_version
189 .clone()
190 .or_else(|| self.extract_openapi_version())
191 .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string());
192
193 let instructions = self
195 .server_instructions
196 .clone()
197 .or_else(|| self.extract_openapi_description())
198 .or_else(|| Some("Exposes OpenAPI endpoints as MCP tools".to_string()));
199
200 InitializeResult {
201 protocol_version: ProtocolVersion::V_2024_11_05,
202 server_info: Implementation {
203 name: server_name,
204 version: server_version,
205 },
206 capabilities: ServerCapabilities {
207 tools: Some(ToolsCapability {
208 list_changed: Some(false),
209 }),
210 ..Default::default()
211 },
212 instructions,
213 }
214 }
215
216 async fn list_tools(
217 &self,
218 _request: Option<PaginatedRequestParam>,
219 _context: RequestContext<RoleServer>,
220 ) -> Result<ListToolsResult, ErrorData> {
221 let span = info_span!("list_tools", tool_count = self.tool_collection.len());
222 let _enter = span.enter();
223
224 debug!("Processing MCP list_tools request");
225
226 let tools = self.tool_collection.to_mcp_tools();
228
229 info!(
230 returned_tools = tools.len(),
231 "MCP list_tools request completed successfully"
232 );
233
234 Ok(ListToolsResult {
235 tools,
236 next_cursor: None,
237 })
238 }
239
240 async fn call_tool(
241 &self,
242 request: CallToolRequestParam,
243 context: RequestContext<RoleServer>,
244 ) -> Result<CallToolResult, ErrorData> {
245 let span = info_span!(
246 "call_tool",
247 tool_name = %request.name
248 );
249 let _enter = span.enter();
250
251 debug!(
252 tool_name = %request.name,
253 has_arguments = !request.arguments.as_ref().unwrap_or(&serde_json::Map::new()).is_empty(),
254 "Processing MCP call_tool request"
255 );
256
257 let arguments = request.arguments.unwrap_or_default();
258 let arguments_value = Value::Object(arguments);
259
260 let auth_header = context.extensions.get::<AuthorizationHeader>().cloned();
262
263 if auth_header.is_some() {
264 debug!("Authorization header is present");
265 }
266
267 let authorization = Authorization::from_mode(self.authorization_mode, auth_header);
269
270 match self
272 .tool_collection
273 .call_tool(&request.name, &arguments_value, authorization)
274 .await
275 {
276 Ok(result) => {
277 info!(
278 tool_name = %request.name,
279 success = true,
280 "MCP call_tool request completed successfully"
281 );
282 Ok(result)
283 }
284 Err(e) => {
285 warn!(
286 tool_name = %request.name,
287 success = false,
288 error = %e,
289 "MCP call_tool request failed"
290 );
291 Err(e.into())
293 }
294 }
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use crate::error::ToolCallValidationError;
302 use crate::{ToolCallError, ToolMetadata};
303 use serde_json::json;
304
305 #[test]
306 fn test_tool_not_found_error_with_suggestions() {
307 let tool1_metadata = ToolMetadata {
309 name: "getPetById".to_string(),
310 title: Some("Get Pet by ID".to_string()),
311 description: "Find pet by ID".to_string(),
312 parameters: json!({
313 "type": "object",
314 "properties": {
315 "petId": {
316 "type": "integer"
317 }
318 },
319 "required": ["petId"]
320 }),
321 output_schema: None,
322 method: "GET".to_string(),
323 path: "/pet/{petId}".to_string(),
324 security: None,
325 };
326
327 let tool2_metadata = ToolMetadata {
328 name: "getPetsByStatus".to_string(),
329 title: Some("Find Pets by Status".to_string()),
330 description: "Find pets by status".to_string(),
331 parameters: json!({
332 "type": "object",
333 "properties": {
334 "status": {
335 "type": "array",
336 "items": {
337 "type": "string"
338 }
339 }
340 },
341 "required": ["status"]
342 }),
343 output_schema: None,
344 method: "GET".to_string(),
345 path: "/pet/findByStatus".to_string(),
346 security: None,
347 };
348
349 let tool1 = Tool::new(tool1_metadata, None, None).unwrap();
351 let tool2 = Tool::new(tool2_metadata, None, None).unwrap();
352
353 let mut server = Server::new(
355 serde_json::Value::Null,
356 url::Url::parse("http://example.com").unwrap(),
357 None,
358 None,
359 None,
360 );
361 server.tool_collection = ToolCollection::from_tools(vec![tool1, tool2]);
362
363 let tool_names = server.get_tool_names();
365 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
366
367 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
368 "getPetByID".to_string(),
369 &tool_name_refs,
370 ));
371 let error_data: ErrorData = error.into();
372 let error_json = serde_json::to_value(&error_data).unwrap();
373
374 insta::assert_json_snapshot!(error_json);
376 }
377
378 #[test]
379 fn test_tool_not_found_error_no_suggestions() {
380 let tool_metadata = ToolMetadata {
382 name: "getPetById".to_string(),
383 title: Some("Get Pet by ID".to_string()),
384 description: "Find pet by ID".to_string(),
385 parameters: json!({
386 "type": "object",
387 "properties": {
388 "petId": {
389 "type": "integer"
390 }
391 },
392 "required": ["petId"]
393 }),
394 output_schema: None,
395 method: "GET".to_string(),
396 path: "/pet/{petId}".to_string(),
397 security: None,
398 };
399
400 let tool = Tool::new(tool_metadata, None, None).unwrap();
402
403 let mut server = Server::new(
405 serde_json::Value::Null,
406 url::Url::parse("http://example.com").unwrap(),
407 None,
408 None,
409 None,
410 );
411 server.tool_collection = ToolCollection::from_tools(vec![tool]);
412
413 let tool_names = server.get_tool_names();
415 let tool_name_refs: Vec<&str> = tool_names.iter().map(|s| s.as_str()).collect();
416
417 let error = ToolCallError::Validation(ToolCallValidationError::tool_not_found(
418 "completelyUnrelatedToolName".to_string(),
419 &tool_name_refs,
420 ));
421 let error_data: ErrorData = error.into();
422 let error_json = serde_json::to_value(&error_data).unwrap();
423
424 insta::assert_json_snapshot!(error_json);
426 }
427
428 #[test]
429 fn test_validation_error_converted_to_error_data() {
430 let error = ToolCallError::Validation(ToolCallValidationError::InvalidParameters {
432 violations: vec![crate::error::ValidationError::invalid_parameter(
433 "page".to_string(),
434 &["page_number".to_string(), "page_size".to_string()],
435 )],
436 });
437
438 let error_data: ErrorData = error.into();
439 let error_json = serde_json::to_value(&error_data).unwrap();
440
441 assert_eq!(error_json["code"], -32602); insta::assert_json_snapshot!(error_json);
446 }
447
448 #[test]
449 fn test_extract_openapi_info_with_full_spec() {
450 let openapi_spec = json!({
451 "openapi": "3.0.0",
452 "info": {
453 "title": "Pet Store API",
454 "version": "2.1.0",
455 "description": "A sample API for managing pets"
456 },
457 "paths": {}
458 });
459
460 let server = Server::new(
461 openapi_spec,
462 url::Url::parse("http://example.com").unwrap(),
463 None,
464 None,
465 None,
466 );
467
468 assert_eq!(
469 server.extract_openapi_title(),
470 Some("Pet Store API".to_string())
471 );
472 assert_eq!(server.extract_openapi_version(), Some("2.1.0".to_string()));
473 assert_eq!(
474 server.extract_openapi_description(),
475 Some("A sample API for managing pets".to_string())
476 );
477 }
478
479 #[test]
480 fn test_extract_openapi_info_with_minimal_spec() {
481 let openapi_spec = json!({
482 "openapi": "3.0.0",
483 "info": {
484 "title": "My API",
485 "version": "1.0.0"
486 },
487 "paths": {}
488 });
489
490 let server = Server::new(
491 openapi_spec,
492 url::Url::parse("http://example.com").unwrap(),
493 None,
494 None,
495 None,
496 );
497
498 assert_eq!(server.extract_openapi_title(), Some("My API".to_string()));
499 assert_eq!(server.extract_openapi_version(), Some("1.0.0".to_string()));
500 assert_eq!(server.extract_openapi_description(), None);
501 }
502
503 #[test]
504 fn test_extract_openapi_info_with_invalid_spec() {
505 let openapi_spec = json!({
506 "invalid": "spec"
507 });
508
509 let server = Server::new(
510 openapi_spec,
511 url::Url::parse("http://example.com").unwrap(),
512 None,
513 None,
514 None,
515 );
516
517 assert_eq!(server.extract_openapi_title(), None);
518 assert_eq!(server.extract_openapi_version(), None);
519 assert_eq!(server.extract_openapi_description(), None);
520 }
521
522 #[test]
523 fn test_get_info_fallback_hierarchy_custom_metadata() {
524 let server = Server::new(
525 serde_json::Value::Null,
526 url::Url::parse("http://example.com").unwrap(),
527 None,
528 None,
529 None,
530 );
531
532 let mut server = server;
534 server.server_name = Some("Custom Server".to_string());
535 server.server_version = Some("3.0.0".to_string());
536 server.server_instructions = Some("Custom instructions".to_string());
537
538 let result = server.get_info();
539
540 assert_eq!(result.server_info.name, "Custom Server");
541 assert_eq!(result.server_info.version, "3.0.0");
542 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
543 }
544
545 #[test]
546 fn test_get_info_fallback_hierarchy_openapi_spec() {
547 let openapi_spec = json!({
548 "openapi": "3.0.0",
549 "info": {
550 "title": "OpenAPI Server",
551 "version": "1.5.0",
552 "description": "Server from OpenAPI spec"
553 },
554 "paths": {}
555 });
556
557 let server = Server::new(
558 openapi_spec,
559 url::Url::parse("http://example.com").unwrap(),
560 None,
561 None,
562 None,
563 );
564
565 let result = server.get_info();
566
567 assert_eq!(result.server_info.name, "OpenAPI Server");
568 assert_eq!(result.server_info.version, "1.5.0");
569 assert_eq!(
570 result.instructions,
571 Some("Server from OpenAPI spec".to_string())
572 );
573 }
574
575 #[test]
576 fn test_get_info_fallback_hierarchy_defaults() {
577 let server = Server::new(
578 serde_json::Value::Null,
579 url::Url::parse("http://example.com").unwrap(),
580 None,
581 None,
582 None,
583 );
584
585 let result = server.get_info();
586
587 assert_eq!(result.server_info.name, "OpenAPI MCP Server");
588 assert_eq!(result.server_info.version, env!("CARGO_PKG_VERSION"));
589 assert_eq!(
590 result.instructions,
591 Some("Exposes OpenAPI endpoints as MCP tools".to_string())
592 );
593 }
594
595 #[test]
596 fn test_get_info_fallback_hierarchy_mixed() {
597 let openapi_spec = json!({
598 "openapi": "3.0.0",
599 "info": {
600 "title": "OpenAPI Server",
601 "version": "2.5.0",
602 "description": "Server from OpenAPI spec"
603 },
604 "paths": {}
605 });
606
607 let mut server = Server::new(
608 openapi_spec,
609 url::Url::parse("http://example.com").unwrap(),
610 None,
611 None,
612 None,
613 );
614
615 server.server_name = Some("Custom Server".to_string());
617 server.server_instructions = Some("Custom instructions".to_string());
618
619 let result = server.get_info();
620
621 assert_eq!(result.server_info.name, "Custom Server");
623 assert_eq!(result.server_info.version, "2.5.0");
625 assert_eq!(result.instructions, Some("Custom instructions".to_string()));
627 }
628}