mcp_utils/
tool.rs

1use std::fmt;
2
3use async_trait::async_trait;
4use rust_mcp_sdk::schema::{CallToolResult, TextContent, schema_utils::CallToolError};
5use serde::Serialize;
6
7pub trait TextTool {
8    type Output: IntoTextToolResult;
9
10    fn call(&self) -> Self::Output;
11}
12
13#[async_trait]
14pub trait AsyncTextTool {
15    type Output: IntoTextToolResult;
16
17    async fn call(&self) -> Self::Output;
18}
19
20pub trait IntoTextToolResult {
21    fn result(self) -> Result<String, ToolError>;
22}
23
24impl IntoTextToolResult for String {
25    fn result(self) -> Result<String, ToolError> {
26        Ok(self)
27    }
28}
29
30impl IntoTextToolResult for &String {
31    fn result(self) -> Result<String, ToolError> {
32        Ok(self.clone())
33    }
34}
35
36impl IntoTextToolResult for &str {
37    fn result(self) -> Result<String, ToolError> {
38        Ok(self.to_string())
39    }
40}
41
42impl<T, E> IntoTextToolResult for Result<T, E>
43where
44    T: Into<String>,
45    E: Into<ToolError>,
46{
47    fn result(self) -> Result<String, ToolError> {
48        self.map(|value| value.into()).map_err(|err| err.into())
49    }
50}
51
52pub trait IntoStructuredToolResult {
53    fn result(self) -> Result<serde_json::Value, ToolError>;
54}
55
56impl<T> IntoStructuredToolResult for T
57where
58    T: Serialize,
59{
60    fn result(self) -> Result<serde_json::Value, ToolError> {
61        serde_json::to_value(self).map_err(|e| ToolError::from(e.to_string()))
62    }
63}
64
65pub trait StructuredTool {
66    type Output: IntoStructuredToolResult;
67
68    fn call(&self) -> Self::Output;
69}
70
71#[async_trait]
72pub trait AsyncStructuredTool {
73    type Output: IntoStructuredToolResult;
74
75    async fn call(&self) -> Self::Output;
76}
77
78#[derive(Debug)]
79pub struct ToolError {
80    display: String,
81}
82
83impl fmt::Display for ToolError {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        write!(f, "{}", self.display)
86    }
87}
88
89impl From<String> for ToolError {
90    fn from(value: String) -> Self {
91        Self { display: value }
92    }
93}
94
95impl From<&str> for ToolError {
96    fn from(value: &str) -> Self {
97        Self {
98            display: value.to_owned(),
99        }
100    }
101}
102
103impl From<&String> for ToolError {
104    fn from(value: &String) -> Self {
105        Self {
106            display: value.clone(),
107        }
108    }
109}
110
111impl std::error::Error for ToolError {}
112
113#[async_trait]
114trait CustomTextTool {
115    async fn call(&self) -> Result<CallToolResult, CallToolError>;
116}
117
118#[async_trait]
119trait CustomStructuredTool {
120    async fn call(&self) -> Result<CallToolResult, CallToolError>;
121}
122
123#[async_trait]
124trait AsyncCustomTextTool {
125    async fn call(&self) -> Result<CallToolResult, CallToolError>;
126}
127
128#[async_trait]
129trait AsyncCustomStructuredTool {
130    async fn call(&self) -> Result<CallToolResult, CallToolError>;
131}
132
133#[async_trait]
134impl<T, O> CustomTextTool for T
135where
136    T: TextTool<Output = O> + Send + Sync,
137    O: IntoTextToolResult,
138{
139    async fn call(&self) -> Result<CallToolResult, CallToolError> {
140        let result = TextTool::call(self).result().map_err(CallToolError::new)?;
141        Ok(CallToolResult::text_content(vec![TextContent::new(
142            result, None, None,
143        )]))
144    }
145}
146
147#[async_trait]
148impl<T, O> AsyncCustomTextTool for T
149where
150    T: AsyncTextTool<Output = O> + Send + Sync,
151    O: IntoTextToolResult,
152{
153    async fn call(&self) -> Result<CallToolResult, CallToolError> {
154        let result = AsyncTextTool::call(self)
155            .await
156            .result()
157            .map_err(CallToolError::new)?;
158        Ok(CallToolResult::text_content(vec![TextContent::new(
159            result, None, None,
160        )]))
161    }
162}
163
164#[async_trait]
165impl<T> CustomStructuredTool for T
166where
167    T: StructuredTool + Send + Sync,
168    T::Output: IntoStructuredToolResult,
169{
170    async fn call(&self) -> Result<CallToolResult, CallToolError> {
171        let value = StructuredTool::call(self)
172            .result()
173            .map_err(CallToolError::new)?;
174        Ok(
175            CallToolResult::text_content(vec![]).with_structured_content(match value {
176                serde_json::Value::Object(map) => map,
177                value => {
178                    let mut map = serde_json::Map::new();
179                    map.insert("result".to_string(), value);
180                    map
181                }
182            }),
183        )
184    }
185}
186
187#[async_trait]
188impl<T> AsyncCustomStructuredTool for T
189where
190    T: AsyncStructuredTool + Send + Sync,
191    T::Output: IntoStructuredToolResult,
192{
193    async fn call(&self) -> Result<CallToolResult, CallToolError> {
194        let value = AsyncStructuredTool::call(self)
195            .await
196            .result()
197            .map_err(CallToolError::new)?;
198        Ok(
199            CallToolResult::text_content(vec![]).with_structured_content(match value {
200                serde_json::Value::Object(map) => map,
201                value => {
202                    let mut map = serde_json::Map::new();
203                    map.insert("result".to_string(), value);
204                    map
205                }
206            }),
207        )
208    }
209}
210
211enum CustomToolInner<'a> {
212    Text(&'a (dyn CustomTextTool + Send + Sync)),
213    Structured(&'a (dyn CustomStructuredTool + Send + Sync)),
214    AsyncText(&'a (dyn AsyncCustomTextTool + Send + Sync)),
215    AsyncStructured(&'a (dyn AsyncCustomStructuredTool + Send + Sync)),
216}
217
218pub struct CustomTool<'a> {
219    inner: CustomToolInner<'a>,
220}
221
222impl<'a> CustomTool<'a> {
223    pub fn text<T, O>(tool: &'a T) -> Self
224    where
225        T: TextTool<Output = O> + Send + Sync,
226        O: IntoTextToolResult,
227    {
228        Self {
229            inner: CustomToolInner::Text(tool),
230        }
231    }
232
233    pub fn structured<T>(tool: &'a T) -> Self
234    where
235        T: StructuredTool + Send + Sync,
236        T::Output: IntoStructuredToolResult,
237    {
238        Self {
239            inner: CustomToolInner::Structured(tool),
240        }
241    }
242
243    pub fn async_text<T, O>(tool: &'a T) -> Self
244    where
245        T: AsyncTextTool<Output = O> + Send + Sync,
246        O: IntoTextToolResult,
247    {
248        Self {
249            inner: CustomToolInner::AsyncText(tool),
250        }
251    }
252
253    pub fn async_structured<T>(tool: &'a T) -> Self
254    where
255        T: AsyncStructuredTool + Send + Sync,
256        T::Output: IntoStructuredToolResult,
257    {
258        Self {
259            inner: CustomToolInner::AsyncStructured(tool),
260        }
261    }
262
263    pub async fn call(&self) -> Result<CallToolResult, CallToolError> {
264        match self.inner {
265            CustomToolInner::Text(tool) => tool.call().await,
266            CustomToolInner::Structured(tool) => tool.call().await,
267            CustomToolInner::AsyncText(tool) => tool.call().await,
268            CustomToolInner::AsyncStructured(tool) => tool.call().await,
269        }
270    }
271}