mcp_commune/
tool.rs

1use crate::{error::Error, peer::Peer};
2use aws_sdk_bedrockruntime::types::{ToolInputSchema, ToolSpecification};
3use aws_smithy_types::Document;
4pub use mcp_sdk_rs::{MessageContent, ResourceContent, Tool as McpTool, ToolResult};
5use serde_json::Value;
6use std::{fmt, future::Future, pin::Pin, sync::Arc};
7
8// Type alias for async function executor
9type AsyncExecutorFn = Arc<
10    dyn Fn(Option<Value>) -> Pin<Box<dyn Future<Output = Result<MessageContent, Error>> + Send>>
11        + Send
12        + Sync,
13>;
14
15#[derive(Clone)]
16pub enum Executor {
17    Fn(AsyncExecutorFn),
18    Cmd {
19        cmd: String,
20        args: Option<Vec<String>>,
21    },
22}
23
24impl fmt::Debug for Executor {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        match self {
27            Executor::Fn(_) => write!(f, "Executor::Fn(<async function>)"),
28            Executor::Cmd { cmd, args } => f
29                .debug_struct("Cmd")
30                .field("cmd", cmd)
31                .field("args", args)
32                .finish(),
33        }
34    }
35}
36
37/// Helper function to convert a synchronous function to an async executor
38pub fn sync_fn_executor(func: fn(&Option<Value>) -> Result<MessageContent, Error>) -> Executor {
39    let async_fn = Arc::new(move |params: Option<Value>| {
40        let result = func(&params);
41        Box::pin(async move { result })
42            as Pin<Box<dyn Future<Output = Result<MessageContent, Error>> + Send>>
43    });
44    Executor::Fn(async_fn)
45}
46
47/// Helper function to create an async function executor
48pub fn async_fn_executor<F, Fut>(func: F) -> Executor
49where
50    F: Fn(Option<Value>) -> Fut + Send + Sync + 'static,
51    Fut: Future<Output = Result<MessageContent, Error>> + Send + 'static,
52{
53    let async_fn = Arc::new(move |params: Option<Value>| {
54        Box::pin(func(params))
55            as Pin<Box<dyn Future<Output = Result<MessageContent, Error>> + Send>>
56    });
57    Executor::Fn(async_fn)
58}
59
60#[derive(Clone, Debug)]
61pub enum Tool {
62    Local { executor: Executor, tool: McpTool },
63    Remote { peer: Peer, tool: McpTool },
64}
65impl fmt::Display for Tool {
66    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67        match self {
68            Tool::Remote { peer: _, tool } => write!(f, "{}", tool.description),
69            Tool::Local { executor: _, tool } => write!(f, "{}", tool.description),
70        }
71    }
72}
73impl From<Tool> for ToolSpecification {
74    fn from(tool: Tool) -> ToolSpecification {
75        let name: String;
76        let description: String;
77        let mut input_schema: std::collections::HashMap<String, Document> =
78            std::collections::HashMap::default();
79        input_schema.insert(
80            "properties".to_string(),
81            Document::Object(std::collections::HashMap::default()),
82        );
83        input_schema.insert("required".to_string(), Document::Array(vec![]));
84        match tool {
85            Tool::Remote { peer: _, tool } => {
86                name = tool.name.clone();
87                description = tool.description.clone();
88                if let Some(schema) = &tool.input_schema {
89                    if let Some(props) = &schema.properties {
90                        let props_val =
91                            serde_json::to_value(props).expect("a serializable tool schema");
92                        let props_doc: Document =
93                            serde_json::from_value(props_val).expect("a valid tool schema");
94                        input_schema.insert("properties".to_string(), props_doc);
95                    }
96                    if let Some(req) = &schema.required {
97                        let required_val =
98                            serde_json::to_value(req).expect("serializable required params");
99                        let required_doc: Document = serde_json::from_value(required_val)
100                            .expect("valid required parameters");
101                        input_schema.insert("required".to_string(), required_doc);
102                    }
103                }
104            }
105            Tool::Local { executor: _, tool } => {
106                name = tool.name.clone();
107                description = tool.description.clone();
108                if let Some(schema) = &tool.input_schema {
109                    if let Some(props) = &schema.properties {
110                        let props_val =
111                            serde_json::to_value(props).expect("a serializable tool schema");
112                        let props_doc: Document =
113                            serde_json::from_value(props_val).expect("a valid tool schema");
114                        input_schema.insert("properties".to_string(), props_doc);
115                    }
116                    if let Some(req) = &schema.required {
117                        let required_val =
118                            serde_json::to_value(req).expect("serializable required params");
119                        let required_doc: Document = serde_json::from_value(required_val)
120                            .expect("valid required parameters");
121                        input_schema.insert("required".to_string(), required_doc);
122                    }
123                }
124            }
125        }
126        input_schema.insert("type".to_string(), Document::String("object".to_string()));
127        let input_schema_doc = Document::from(input_schema);
128        ToolSpecification::builder()
129            .set_name(Some(name))
130            .set_description(Some(description))
131            .set_input_schema(Some(ToolInputSchema::Json(input_schema_doc)))
132            .build()
133            .expect("a valid tool specification")
134    }
135}