1use std::fmt;
2
3use async_trait::async_trait;
4use rust_mcp_sdk::schema::{CallToolResult, TextContent, schema_utils::CallToolError};
5use serde::Serialize;
6
7pub trait TextTool {
8 type Output: IntoTextToolResult;
9
10 fn call(&self) -> Self::Output;
11}
12
13#[async_trait]
14pub trait AsyncTextTool {
15 type Output: IntoTextToolResult;
16
17 async fn call(&self) -> Self::Output;
18}
19
20pub trait IntoTextToolResult {
21 fn result(self) -> Result<String, ToolError>;
22}
23
24impl IntoTextToolResult for String {
25 fn result(self) -> Result<String, ToolError> {
26 Ok(self)
27 }
28}
29
30impl IntoTextToolResult for &String {
31 fn result(self) -> Result<String, ToolError> {
32 Ok(self.clone())
33 }
34}
35
36impl IntoTextToolResult for &str {
37 fn result(self) -> Result<String, ToolError> {
38 Ok(self.to_string())
39 }
40}
41
42impl<T, E> IntoTextToolResult for Result<T, E>
43where
44 T: Into<String>,
45 E: Into<ToolError>,
46{
47 fn result(self) -> Result<String, ToolError> {
48 self.map(|value| value.into()).map_err(|err| err.into())
49 }
50}
51
52pub trait IntoStructuredToolResult {
53 fn result(self) -> Result<serde_json::Value, ToolError>;
54}
55
56impl<T> IntoStructuredToolResult for T
57where
58 T: Serialize,
59{
60 fn result(self) -> Result<serde_json::Value, ToolError> {
61 serde_json::to_value(self).map_err(|e| ToolError::from(e.to_string()))
62 }
63}
64
65pub trait StructuredTool {
66 type Output: IntoStructuredToolResult;
67
68 fn call(&self) -> Self::Output;
69}
70
71#[async_trait]
72pub trait AsyncStructuredTool {
73 type Output: IntoStructuredToolResult;
74
75 async fn call(&self) -> Self::Output;
76}
77
78#[derive(Debug)]
79pub struct ToolError {
80 display: String,
81}
82
83impl fmt::Display for ToolError {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 write!(f, "{}", self.display)
86 }
87}
88
89impl From<String> for ToolError {
90 fn from(value: String) -> Self {
91 Self { display: value }
92 }
93}
94
95impl From<&str> for ToolError {
96 fn from(value: &str) -> Self {
97 Self {
98 display: value.to_owned(),
99 }
100 }
101}
102
103impl From<&String> for ToolError {
104 fn from(value: &String) -> Self {
105 Self {
106 display: value.clone(),
107 }
108 }
109}
110
111impl std::error::Error for ToolError {}
112
113#[async_trait]
114trait CustomTextTool {
115 async fn call(&self) -> Result<CallToolResult, CallToolError>;
116}
117
118#[async_trait]
119trait CustomStructuredTool {
120 async fn call(&self) -> Result<CallToolResult, CallToolError>;
121}
122
123#[async_trait]
124trait AsyncCustomTextTool {
125 async fn call(&self) -> Result<CallToolResult, CallToolError>;
126}
127
128#[async_trait]
129trait AsyncCustomStructuredTool {
130 async fn call(&self) -> Result<CallToolResult, CallToolError>;
131}
132
133#[async_trait]
134impl<T, O> CustomTextTool for T
135where
136 T: TextTool<Output = O> + Send + Sync,
137 O: IntoTextToolResult,
138{
139 async fn call(&self) -> Result<CallToolResult, CallToolError> {
140 let result = TextTool::call(self).result().map_err(CallToolError::new)?;
141 Ok(CallToolResult::text_content(vec![TextContent::new(
142 result, None, None,
143 )]))
144 }
145}
146
147#[async_trait]
148impl<T, O> AsyncCustomTextTool for T
149where
150 T: AsyncTextTool<Output = O> + Send + Sync,
151 O: IntoTextToolResult,
152{
153 async fn call(&self) -> Result<CallToolResult, CallToolError> {
154 let result = AsyncTextTool::call(self)
155 .await
156 .result()
157 .map_err(CallToolError::new)?;
158 Ok(CallToolResult::text_content(vec![TextContent::new(
159 result, None, None,
160 )]))
161 }
162}
163
164#[async_trait]
165impl<T> CustomStructuredTool for T
166where
167 T: StructuredTool + Send + Sync,
168 T::Output: IntoStructuredToolResult,
169{
170 async fn call(&self) -> Result<CallToolResult, CallToolError> {
171 let value = StructuredTool::call(self)
172 .result()
173 .map_err(CallToolError::new)?;
174 Ok(
175 CallToolResult::text_content(vec![]).with_structured_content(match value {
176 serde_json::Value::Object(map) => map,
177 value => {
178 let mut map = serde_json::Map::new();
179 map.insert("result".to_string(), value);
180 map
181 }
182 }),
183 )
184 }
185}
186
187#[async_trait]
188impl<T> AsyncCustomStructuredTool for T
189where
190 T: AsyncStructuredTool + Send + Sync,
191 T::Output: IntoStructuredToolResult,
192{
193 async fn call(&self) -> Result<CallToolResult, CallToolError> {
194 let value = AsyncStructuredTool::call(self)
195 .await
196 .result()
197 .map_err(CallToolError::new)?;
198 Ok(
199 CallToolResult::text_content(vec![]).with_structured_content(match value {
200 serde_json::Value::Object(map) => map,
201 value => {
202 let mut map = serde_json::Map::new();
203 map.insert("result".to_string(), value);
204 map
205 }
206 }),
207 )
208 }
209}
210
211enum CustomToolInner<'a> {
212 Text(&'a (dyn CustomTextTool + Send + Sync)),
213 Structured(&'a (dyn CustomStructuredTool + Send + Sync)),
214 AsyncText(&'a (dyn AsyncCustomTextTool + Send + Sync)),
215 AsyncStructured(&'a (dyn AsyncCustomStructuredTool + Send + Sync)),
216}
217
218pub struct CustomTool<'a> {
219 inner: CustomToolInner<'a>,
220}
221
222impl<'a> CustomTool<'a> {
223 pub fn text<T, O>(tool: &'a T) -> Self
224 where
225 T: TextTool<Output = O> + Send + Sync,
226 O: IntoTextToolResult,
227 {
228 Self {
229 inner: CustomToolInner::Text(tool),
230 }
231 }
232
233 pub fn structured<T>(tool: &'a T) -> Self
234 where
235 T: StructuredTool + Send + Sync,
236 T::Output: IntoStructuredToolResult,
237 {
238 Self {
239 inner: CustomToolInner::Structured(tool),
240 }
241 }
242
243 pub fn async_text<T, O>(tool: &'a T) -> Self
244 where
245 T: AsyncTextTool<Output = O> + Send + Sync,
246 O: IntoTextToolResult,
247 {
248 Self {
249 inner: CustomToolInner::AsyncText(tool),
250 }
251 }
252
253 pub fn async_structured<T>(tool: &'a T) -> Self
254 where
255 T: AsyncStructuredTool + Send + Sync,
256 T::Output: IntoStructuredToolResult,
257 {
258 Self {
259 inner: CustomToolInner::AsyncStructured(tool),
260 }
261 }
262
263 pub async fn call(&self) -> Result<CallToolResult, CallToolError> {
264 match self.inner {
265 CustomToolInner::Text(tool) => tool.call().await,
266 CustomToolInner::Structured(tool) => tool.call().await,
267 CustomToolInner::AsyncText(tool) => tool.call().await,
268 CustomToolInner::AsyncStructured(tool) => tool.call().await,
269 }
270 }
271}