1use async_trait::async_trait;
27use chrono::{DateTime, Utc};
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30use std::sync::atomic::{AtomicU64, Ordering};
31use std::sync::Arc;
32use std::time::Duration;
33use tokio::sync::RwLock;
34use uuid::Uuid;
35
36use crate::mcp::connection_manager::ConnectionManager;
37use crate::mcp::error::{McpError, McpResult};
38use crate::mcp::transport::McpRequest;
39use crate::mcp::types::JsonObject;
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct McpTool {
47 pub name: String,
49 pub description: Option<String>,
51 pub input_schema: serde_json::Value,
53 pub server_name: String,
55}
56
57impl McpTool {
58 pub fn new(
60 name: impl Into<String>,
61 server_name: impl Into<String>,
62 input_schema: serde_json::Value,
63 ) -> Self {
64 Self {
65 name: name.into(),
66 description: None,
67 input_schema,
68 server_name: server_name.into(),
69 }
70 }
71
72 pub fn with_description(
74 name: impl Into<String>,
75 server_name: impl Into<String>,
76 description: impl Into<String>,
77 input_schema: serde_json::Value,
78 ) -> Self {
79 Self {
80 name: name.into(),
81 description: Some(description.into()),
82 input_schema,
83 server_name: server_name.into(),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
92#[serde(tag = "type", rename_all = "lowercase")]
93pub enum ToolResultContent {
94 Text {
96 text: String,
98 },
99 Image {
101 data: String,
103 #[serde(rename = "mimeType")]
105 mime_type: String,
106 },
107 Resource {
109 uri: String,
111 text: Option<String>,
113 #[serde(rename = "blob")]
115 data: Option<String>,
116 #[serde(rename = "mimeType")]
118 mime_type: Option<String>,
119 },
120}
121
122impl ToolResultContent {
123 pub fn text(text: impl Into<String>) -> Self {
125 Self::Text { text: text.into() }
126 }
127
128 pub fn image(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
130 Self::Image {
131 data: data.into(),
132 mime_type: mime_type.into(),
133 }
134 }
135
136 pub fn resource(uri: impl Into<String>) -> Self {
138 Self::Resource {
139 uri: uri.into(),
140 text: None,
141 data: None,
142 mime_type: None,
143 }
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ToolCallResult {
153 pub content: Vec<ToolResultContent>,
155 #[serde(rename = "isError", default)]
157 pub is_error: bool,
158}
159
160impl ToolCallResult {
161 pub fn success_text(text: impl Into<String>) -> Self {
163 Self {
164 content: vec![ToolResultContent::text(text)],
165 is_error: false,
166 }
167 }
168
169 pub fn success(content: Vec<ToolResultContent>) -> Self {
171 Self {
172 content,
173 is_error: false,
174 }
175 }
176
177 pub fn error(message: impl Into<String>) -> Self {
179 Self {
180 content: vec![ToolResultContent::text(message)],
181 is_error: true,
182 }
183 }
184
185 pub fn is_empty(&self) -> bool {
187 self.content.is_empty()
188 }
189
190 pub fn first_text(&self) -> Option<&str> {
192 self.content.iter().find_map(|c| match c {
193 ToolResultContent::Text { text } => Some(text.as_str()),
194 _ => None,
195 })
196 }
197}
198
199#[derive(Debug, Clone, Default)]
203pub struct ArgValidationResult {
204 pub valid: bool,
206 pub errors: Vec<String>,
208}
209
210impl ArgValidationResult {
211 pub fn valid() -> Self {
213 Self {
214 valid: true,
215 errors: Vec::new(),
216 }
217 }
218
219 pub fn invalid(errors: Vec<String>) -> Self {
221 Self {
222 valid: false,
223 errors,
224 }
225 }
226
227 pub fn add_error(&mut self, error: impl Into<String>) {
229 self.valid = false;
230 self.errors.push(error.into());
231 }
232}
233
234#[derive(Debug, Clone)]
239pub struct CallInfo {
240 pub call_id: String,
242 pub server_name: String,
244 pub tool_name: String,
246 pub args: JsonObject,
248 pub start_time: DateTime<Utc>,
250 pub completed: bool,
252 pub cancelled: bool,
254}
255
256impl CallInfo {
257 pub fn new(
259 call_id: impl Into<String>,
260 server_name: impl Into<String>,
261 tool_name: impl Into<String>,
262 args: JsonObject,
263 ) -> Self {
264 Self {
265 call_id: call_id.into(),
266 server_name: server_name.into(),
267 tool_name: tool_name.into(),
268 args,
269 start_time: Utc::now(),
270 completed: false,
271 cancelled: false,
272 }
273 }
274
275 pub fn mark_completed(&mut self) {
277 self.completed = true;
278 }
279
280 pub fn mark_cancelled(&mut self) {
282 self.cancelled = true;
283 }
284
285 pub fn elapsed(&self) -> chrono::Duration {
287 Utc::now() - self.start_time
288 }
289}
290
291#[derive(Debug, Clone)]
295pub struct ToolCall {
296 pub server_name: String,
298 pub tool_name: String,
300 pub args: JsonObject,
302}
303
304impl ToolCall {
305 pub fn new(
307 server_name: impl Into<String>,
308 tool_name: impl Into<String>,
309 args: JsonObject,
310 ) -> Self {
311 Self {
312 server_name: server_name.into(),
313 tool_name: tool_name.into(),
314 args,
315 }
316 }
317}
318
319#[async_trait]
324pub trait ToolManager: Send + Sync {
325 async fn list_tools(&self, server_name: Option<&str>) -> McpResult<Vec<McpTool>>;
330
331 async fn get_tool(&self, server_name: &str, tool_name: &str) -> McpResult<Option<McpTool>>;
335
336 fn clear_cache(&self, server_name: Option<&str>);
340
341 async fn call_tool(
345 &self,
346 server_name: &str,
347 tool_name: &str,
348 args: JsonObject,
349 ) -> McpResult<ToolCallResult>;
350
351 async fn call_tool_with_timeout(
355 &self,
356 server_name: &str,
357 tool_name: &str,
358 args: JsonObject,
359 timeout: Duration,
360 ) -> McpResult<ToolCallResult>;
361
362 fn validate_args(&self, tool: &McpTool, args: &JsonObject) -> ArgValidationResult;
366
367 fn cancel_call(&self, call_id: &str);
371
372 fn get_pending_calls(&self) -> Vec<CallInfo>;
374
375 async fn call_tools_batch(&self, calls: Vec<ToolCall>) -> Vec<McpResult<ToolCallResult>>;
379}
380
381struct ToolCacheEntry {
383 tools: Vec<McpTool>,
385 cached_at: DateTime<Utc>,
387}
388
389pub struct McpToolManager<C: ConnectionManager> {
391 connection_manager: Arc<C>,
393 tool_cache: Arc<RwLock<HashMap<String, ToolCacheEntry>>>,
395 pending_calls: Arc<RwLock<HashMap<String, CallInfo>>>,
397 call_counter: AtomicU64,
399 default_timeout: Duration,
401 cache_ttl: Duration,
403}
404
405impl<C: ConnectionManager> McpToolManager<C> {
406 pub fn new(connection_manager: Arc<C>) -> Self {
408 Self {
409 connection_manager,
410 tool_cache: Arc::new(RwLock::new(HashMap::new())),
411 pending_calls: Arc::new(RwLock::new(HashMap::new())),
412 call_counter: AtomicU64::new(1),
413 default_timeout: Duration::from_secs(30),
414 cache_ttl: Duration::from_secs(300), }
416 }
417
418 pub fn with_settings(
420 connection_manager: Arc<C>,
421 default_timeout: Duration,
422 cache_ttl: Duration,
423 ) -> Self {
424 Self {
425 connection_manager,
426 tool_cache: Arc::new(RwLock::new(HashMap::new())),
427 pending_calls: Arc::new(RwLock::new(HashMap::new())),
428 call_counter: AtomicU64::new(1),
429 default_timeout,
430 cache_ttl,
431 }
432 }
433
434 pub fn generate_call_id(&self) -> String {
436 let counter = self.call_counter.fetch_add(1, Ordering::SeqCst);
437 format!("call-{}-{}", Uuid::new_v4(), counter)
438 }
439
440 fn is_cache_valid(&self, entry: &ToolCacheEntry) -> bool {
442 let age = Utc::now() - entry.cached_at;
443 age.num_seconds() < self.cache_ttl.as_secs() as i64
444 }
445
446 async fn fetch_tools_from_server(&self, server_name: &str) -> McpResult<Vec<McpTool>> {
448 let connection = self
450 .connection_manager
451 .get_connection_by_server(server_name)
452 .ok_or_else(|| {
453 McpError::connection(format!("No connection found for server: {}", server_name))
454 })?;
455
456 let request = McpRequest::new(
458 serde_json::json!(format!("tools-list-{}", Uuid::new_v4())),
459 "tools/list",
460 );
461
462 let response = self
463 .connection_manager
464 .send(&connection.id, request)
465 .await?;
466
467 let result = response.into_result()?;
469
470 let tools_value = result
472 .get("tools")
473 .ok_or_else(|| McpError::protocol("Response missing 'tools' field"))?;
474
475 let raw_tools: Vec<serde_json::Value> = serde_json::from_value(tools_value.clone())
476 .map_err(|e| McpError::protocol(format!("Failed to parse tools: {}", e)))?;
477
478 let tools: Vec<McpTool> = raw_tools
480 .into_iter()
481 .filter_map(|t| {
482 let name = t.get("name")?.as_str()?.to_string();
483 let description = t
484 .get("description")
485 .and_then(|d| d.as_str())
486 .map(String::from);
487 let input_schema = t
488 .get("inputSchema")
489 .cloned()
490 .unwrap_or(serde_json::json!({}));
491
492 Some(McpTool {
493 name,
494 description,
495 input_schema,
496 server_name: server_name.to_string(),
497 })
498 })
499 .collect();
500
501 Ok(tools)
502 }
503
504 async fn register_call(&self, call_info: CallInfo) {
506 let mut calls = self.pending_calls.write().await;
507 calls.insert(call_info.call_id.clone(), call_info);
508 }
509
510 async fn complete_call(&self, call_id: &str) {
512 let mut calls = self.pending_calls.write().await;
513 if let Some(info) = calls.get_mut(call_id) {
514 info.mark_completed();
515 }
516 calls.remove(call_id);
517 }
518
519 fn convert_result(&self, result: serde_json::Value) -> McpResult<ToolCallResult> {
523 if let Some(content) = result.get("content") {
525 let content_items: Vec<ToolResultContent> = serde_json::from_value(content.clone())
526 .map_err(|e| {
527 McpError::protocol(format!("Failed to parse tool result content: {}", e))
528 })?;
529
530 let is_error = result
531 .get("isError")
532 .and_then(|v| v.as_bool())
533 .unwrap_or(false);
534
535 return Ok(ToolCallResult {
536 content: content_items,
537 is_error,
538 });
539 }
540
541 if let Some(text) = result.as_str() {
543 return Ok(ToolCallResult::success_text(text));
544 }
545
546 Ok(ToolCallResult::success_text(result.to_string()))
548 }
549}
550
551#[async_trait]
552impl<C: ConnectionManager + 'static> ToolManager for McpToolManager<C> {
553 async fn list_tools(&self, server_name: Option<&str>) -> McpResult<Vec<McpTool>> {
554 match server_name {
555 Some(name) => {
556 {
558 let cache = self.tool_cache.read().await;
559 if let Some(entry) = cache.get(name) {
560 if self.is_cache_valid(entry) {
561 return Ok(entry.tools.clone());
562 }
563 }
564 }
565
566 let tools = self.fetch_tools_from_server(name).await?;
568
569 {
571 let mut cache = self.tool_cache.write().await;
572 cache.insert(
573 name.to_string(),
574 ToolCacheEntry {
575 tools: tools.clone(),
576 cached_at: Utc::now(),
577 },
578 );
579 }
580
581 Ok(tools)
582 }
583 None => {
584 let connections = self.connection_manager.get_all_connections();
586 let mut all_tools = Vec::new();
587
588 for conn in connections {
589 match self.list_tools(Some(&conn.server_name)).await {
590 Ok(tools) => all_tools.extend(tools),
591 Err(e) => {
592 tracing::warn!(
593 "Failed to list tools from server {}: {}",
594 conn.server_name,
595 e
596 );
597 }
598 }
599 }
600
601 Ok(all_tools)
602 }
603 }
604 }
605
606 async fn get_tool(&self, server_name: &str, tool_name: &str) -> McpResult<Option<McpTool>> {
607 let tools = self.list_tools(Some(server_name)).await?;
608 Ok(tools.into_iter().find(|t| t.name == tool_name))
609 }
610
611 fn clear_cache(&self, server_name: Option<&str>) {
612 let server_name_owned = server_name.map(|s| s.to_string());
614 let cache = self.tool_cache.clone();
615 tokio::spawn(async move {
616 let mut cache = cache.write().await;
617 match server_name_owned {
618 Some(name) => {
619 cache.remove(&name);
620 }
621 None => {
622 cache.clear();
623 }
624 }
625 });
626 }
627
628 async fn call_tool(
629 &self,
630 server_name: &str,
631 tool_name: &str,
632 args: JsonObject,
633 ) -> McpResult<ToolCallResult> {
634 self.call_tool_with_timeout(server_name, tool_name, args, self.default_timeout)
635 .await
636 }
637
638 async fn call_tool_with_timeout(
639 &self,
640 server_name: &str,
641 tool_name: &str,
642 args: JsonObject,
643 timeout: Duration,
644 ) -> McpResult<ToolCallResult> {
645 let tool = self
647 .get_tool(server_name, tool_name)
648 .await?
649 .ok_or_else(|| {
650 McpError::tool(
651 format!("Tool not found: {}/{}", server_name, tool_name),
652 Some(tool_name.to_string()),
653 )
654 })?;
655
656 let validation = self.validate_args(&tool, &args);
658 if !validation.valid {
659 return Err(McpError::validation(
660 format!(
661 "Invalid arguments for tool {}: {}",
662 tool_name,
663 validation.errors.join(", ")
664 ),
665 validation.errors,
666 ));
667 }
668
669 let connection = self
671 .connection_manager
672 .get_connection_by_server(server_name)
673 .ok_or_else(|| {
674 McpError::connection(format!("No connection found for server: {}", server_name))
675 })?;
676
677 let call_id = self.generate_call_id();
679 let call_info = CallInfo::new(&call_id, server_name, tool_name, args.clone());
680 self.register_call(call_info).await;
681
682 let request = McpRequest::with_params(
684 serde_json::json!(call_id.clone()),
685 "tools/call",
686 serde_json::json!({
687 "name": tool_name,
688 "arguments": args
689 }),
690 );
691
692 let result = self
694 .connection_manager
695 .send_with_timeout(&connection.id, request, timeout)
696 .await;
697
698 self.complete_call(&call_id).await;
700
701 match result {
703 Ok(response) => {
704 let result_value = response.into_result()?;
705 self.convert_result(result_value)
706 }
707 Err(e) => Err(e),
708 }
709 }
710
711 fn validate_args(&self, tool: &McpTool, args: &JsonObject) -> ArgValidationResult {
712 let schema = &tool.input_schema;
713
714 if schema.is_null()
716 || (schema.is_object() && schema.as_object().is_none_or(|o| o.is_empty()))
717 {
718 return ArgValidationResult::valid();
719 }
720
721 let mut result = ArgValidationResult::valid();
722
723 if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
725 for req in required {
726 if let Some(field_name) = req.as_str() {
727 if !args.contains_key(field_name) {
728 result.add_error(format!("Missing required field: {}", field_name));
729 }
730 }
731 }
732 }
733
734 if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) {
736 for (key, value) in args.iter() {
737 if let Some(prop_schema) = properties.get(key) {
738 if let Some(expected_type) = prop_schema.get("type").and_then(|t| t.as_str()) {
740 let actual_type = get_json_type(value);
741 if !types_compatible(expected_type, &actual_type) {
742 result.add_error(format!(
743 "Field '{}' has wrong type: expected {}, got {}",
744 key, expected_type, actual_type
745 ));
746 }
747 }
748 }
749 }
750 }
751
752 if let Some(additional) = schema.get("additionalProperties") {
754 if additional == &serde_json::Value::Bool(false) {
755 if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) {
756 for key in args.keys() {
757 if !properties.contains_key(key) {
758 result.add_error(format!("Unknown field: {}", key));
759 }
760 }
761 }
762 }
763 }
764
765 result
766 }
767
768 fn cancel_call(&self, call_id: &str) {
769 let pending_calls = self.pending_calls.clone();
770 let connection_manager = self.connection_manager.clone();
771 let call_id = call_id.to_string();
772
773 tokio::spawn(async move {
774 let mut calls = pending_calls.write().await;
775 if let Some(info) = calls.get_mut(&call_id) {
776 info.mark_cancelled();
777
778 if let Some(conn) = connection_manager.get_connection_by_server(&info.server_name) {
780 let _ = connection_manager.cancel_request(&conn.id, &call_id).await;
781 }
782 }
783 });
784 }
785
786 fn get_pending_calls(&self) -> Vec<CallInfo> {
787 self.pending_calls
789 .try_read()
790 .map(|calls| calls.values().cloned().collect())
791 .unwrap_or_default()
792 }
793
794 async fn call_tools_batch(&self, calls: Vec<ToolCall>) -> Vec<McpResult<ToolCallResult>> {
795 use futures::future::join_all;
796
797 let futures: Vec<_> = calls
798 .into_iter()
799 .map(|call| {
800 let server_name = call.server_name.clone();
801 let tool_name = call.tool_name.clone();
802 let args = call.args;
803 async move { self.call_tool(&server_name, &tool_name, args).await }
804 })
805 .collect();
806
807 join_all(futures).await
808 }
809}
810
811fn get_json_type(value: &serde_json::Value) -> String {
813 match value {
814 serde_json::Value::Null => "null".to_string(),
815 serde_json::Value::Bool(_) => "boolean".to_string(),
816 serde_json::Value::Number(n) => {
817 if n.is_i64() || n.is_u64() {
818 "integer".to_string()
819 } else {
820 "number".to_string()
821 }
822 }
823 serde_json::Value::String(_) => "string".to_string(),
824 serde_json::Value::Array(_) => "array".to_string(),
825 serde_json::Value::Object(_) => "object".to_string(),
826 }
827}
828
829fn types_compatible(expected: &str, actual: &str) -> bool {
831 if expected == actual {
832 return true;
833 }
834 if expected == "number" && actual == "integer" {
836 return true;
837 }
838 false
839}
840
841#[cfg(test)]
842mod tests {
843 use super::*;
844
845 #[test]
846 fn test_mcp_tool_new() {
847 let tool = McpTool::new("test_tool", "test_server", serde_json::json!({}));
848 assert_eq!(tool.name, "test_tool");
849 assert_eq!(tool.server_name, "test_server");
850 assert!(tool.description.is_none());
851 }
852
853 #[test]
854 fn test_mcp_tool_with_description() {
855 let tool = McpTool::with_description(
856 "test_tool",
857 "test_server",
858 "A test tool",
859 serde_json::json!({}),
860 );
861 assert_eq!(tool.description, Some("A test tool".to_string()));
862 }
863
864 #[test]
865 fn test_tool_result_content_text() {
866 let content = ToolResultContent::text("Hello, world!");
867 match content {
868 ToolResultContent::Text { text } => assert_eq!(text, "Hello, world!"),
869 _ => panic!("Expected Text content"),
870 }
871 }
872
873 #[test]
874 fn test_tool_result_content_image() {
875 let content = ToolResultContent::image("base64data", "image/png");
876 match content {
877 ToolResultContent::Image { data, mime_type } => {
878 assert_eq!(data, "base64data");
879 assert_eq!(mime_type, "image/png");
880 }
881 _ => panic!("Expected Image content"),
882 }
883 }
884
885 #[test]
886 fn test_tool_call_result_success() {
887 let result = ToolCallResult::success_text("Success!");
888 assert!(!result.is_error);
889 assert_eq!(result.first_text(), Some("Success!"));
890 }
891
892 #[test]
893 fn test_tool_call_result_error() {
894 let result = ToolCallResult::error("Something went wrong");
895 assert!(result.is_error);
896 assert_eq!(result.first_text(), Some("Something went wrong"));
897 }
898
899 #[test]
900 fn test_arg_validation_result_valid() {
901 let result = ArgValidationResult::valid();
902 assert!(result.valid);
903 assert!(result.errors.is_empty());
904 }
905
906 #[test]
907 fn test_arg_validation_result_invalid() {
908 let result = ArgValidationResult::invalid(vec!["Missing field".to_string()]);
909 assert!(!result.valid);
910 assert_eq!(result.errors.len(), 1);
911 }
912
913 #[test]
914 fn test_call_info_new() {
915 let args = serde_json::Map::new();
916 let info = CallInfo::new("call-1", "server", "tool", args);
917 assert_eq!(info.call_id, "call-1");
918 assert_eq!(info.server_name, "server");
919 assert_eq!(info.tool_name, "tool");
920 assert!(!info.completed);
921 assert!(!info.cancelled);
922 }
923
924 #[test]
925 fn test_call_info_mark_completed() {
926 let args = serde_json::Map::new();
927 let mut info = CallInfo::new("call-1", "server", "tool", args);
928 info.mark_completed();
929 assert!(info.completed);
930 }
931
932 #[test]
933 fn test_call_info_mark_cancelled() {
934 let args = serde_json::Map::new();
935 let mut info = CallInfo::new("call-1", "server", "tool", args);
936 info.mark_cancelled();
937 assert!(info.cancelled);
938 }
939
940 #[test]
941 fn test_tool_call_new() {
942 let args = serde_json::Map::new();
943 let call = ToolCall::new("server", "tool", args);
944 assert_eq!(call.server_name, "server");
945 assert_eq!(call.tool_name, "tool");
946 }
947
948 #[test]
949 fn test_get_json_type() {
950 assert_eq!(get_json_type(&serde_json::Value::Null), "null");
951 assert_eq!(get_json_type(&serde_json::json!(true)), "boolean");
952 assert_eq!(get_json_type(&serde_json::json!(42)), "integer");
953 assert_eq!(get_json_type(&serde_json::json!(3.15)), "number");
954 assert_eq!(get_json_type(&serde_json::json!("hello")), "string");
955 assert_eq!(get_json_type(&serde_json::json!([1, 2, 3])), "array");
956 assert_eq!(
957 get_json_type(&serde_json::json!({"key": "value"})),
958 "object"
959 );
960 }
961
962 #[test]
963 fn test_types_compatible() {
964 assert!(types_compatible("string", "string"));
965 assert!(types_compatible("number", "integer"));
966 assert!(!types_compatible("string", "number"));
967 assert!(!types_compatible("integer", "number"));
968 }
969}