1use crate::{error::Error, peer::Peer};
2use aws_sdk_bedrockruntime::types::{ToolInputSchema, ToolSpecification};
3use aws_smithy_types::Document;
4pub use mcp_sdk_rs::{MessageContent, ResourceContent, Tool as McpTool, ToolResult};
5use serde_json::Value;
6use std::{fmt, future::Future, pin::Pin, sync::Arc};
7
8type AsyncExecutorFn = Arc<
10 dyn Fn(Option<Value>) -> Pin<Box<dyn Future<Output = Result<MessageContent, Error>> + Send>>
11 + Send
12 + Sync,
13>;
14
15#[derive(Clone)]
16pub enum Executor {
17 Fn(AsyncExecutorFn),
18 Cmd {
19 cmd: String,
20 args: Option<Vec<String>>,
21 },
22}
23
24impl fmt::Debug for Executor {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match self {
27 Executor::Fn(_) => write!(f, "Executor::Fn(<async function>)"),
28 Executor::Cmd { cmd, args } => f
29 .debug_struct("Cmd")
30 .field("cmd", cmd)
31 .field("args", args)
32 .finish(),
33 }
34 }
35}
36
37pub fn sync_fn_executor(func: fn(&Option<Value>) -> Result<MessageContent, Error>) -> Executor {
39 let async_fn = Arc::new(move |params: Option<Value>| {
40 let result = func(¶ms);
41 Box::pin(async move { result })
42 as Pin<Box<dyn Future<Output = Result<MessageContent, Error>> + Send>>
43 });
44 Executor::Fn(async_fn)
45}
46
47pub fn async_fn_executor<F, Fut>(func: F) -> Executor
49where
50 F: Fn(Option<Value>) -> Fut + Send + Sync + 'static,
51 Fut: Future<Output = Result<MessageContent, Error>> + Send + 'static,
52{
53 let async_fn = Arc::new(move |params: Option<Value>| {
54 Box::pin(func(params))
55 as Pin<Box<dyn Future<Output = Result<MessageContent, Error>> + Send>>
56 });
57 Executor::Fn(async_fn)
58}
59
60#[derive(Clone, Debug)]
61pub enum Tool {
62 Local { executor: Executor, tool: McpTool },
63 Remote { peer: Peer, tool: McpTool },
64}
65impl fmt::Display for Tool {
66 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67 match self {
68 Tool::Remote { peer: _, tool } => write!(f, "{}", tool.description),
69 Tool::Local { executor: _, tool } => write!(f, "{}", tool.description),
70 }
71 }
72}
73impl From<Tool> for ToolSpecification {
74 fn from(tool: Tool) -> ToolSpecification {
75 let name: String;
76 let description: String;
77 let mut input_schema: std::collections::HashMap<String, Document> =
78 std::collections::HashMap::default();
79 input_schema.insert(
80 "properties".to_string(),
81 Document::Object(std::collections::HashMap::default()),
82 );
83 input_schema.insert("required".to_string(), Document::Array(vec![]));
84 match tool {
85 Tool::Remote { peer: _, tool } => {
86 name = tool.name.clone();
87 description = tool.description.clone();
88 if let Some(schema) = &tool.input_schema {
89 if let Some(props) = &schema.properties {
90 let props_val =
91 serde_json::to_value(props).expect("a serializable tool schema");
92 let props_doc: Document =
93 serde_json::from_value(props_val).expect("a valid tool schema");
94 input_schema.insert("properties".to_string(), props_doc);
95 }
96 if let Some(req) = &schema.required {
97 let required_val =
98 serde_json::to_value(req).expect("serializable required params");
99 let required_doc: Document = serde_json::from_value(required_val)
100 .expect("valid required parameters");
101 input_schema.insert("required".to_string(), required_doc);
102 }
103 }
104 }
105 Tool::Local { executor: _, tool } => {
106 name = tool.name.clone();
107 description = tool.description.clone();
108 if let Some(schema) = &tool.input_schema {
109 if let Some(props) = &schema.properties {
110 let props_val =
111 serde_json::to_value(props).expect("a serializable tool schema");
112 let props_doc: Document =
113 serde_json::from_value(props_val).expect("a valid tool schema");
114 input_schema.insert("properties".to_string(), props_doc);
115 }
116 if let Some(req) = &schema.required {
117 let required_val =
118 serde_json::to_value(req).expect("serializable required params");
119 let required_doc: Document = serde_json::from_value(required_val)
120 .expect("valid required parameters");
121 input_schema.insert("required".to_string(), required_doc);
122 }
123 }
124 }
125 }
126 input_schema.insert("type".to_string(), Document::String("object".to_string()));
127 let input_schema_doc = Document::from(input_schema);
128 ToolSpecification::builder()
129 .set_name(Some(name))
130 .set_description(Some(description))
131 .set_input_schema(Some(ToolInputSchema::Json(input_schema_doc)))
132 .build()
133 .expect("a valid tool specification")
134 }
135}