Skip to main content

rmcp/handler/server/
tool.rs

1use std::{
2    borrow::Cow,
3    future::{Future, Ready},
4    marker::PhantomData,
5};
6
7use futures::future::{BoxFuture, FutureExt};
8use serde::de::DeserializeOwned;
9
10use super::common::{AsRequestContext, FromContextPart};
11pub use super::{
12    common::{Extension, RequestId, schema_for_output, schema_for_type},
13    router::tool::{ToolRoute, ToolRouter},
14};
15use crate::{
16    RoleServer,
17    handler::server::wrapper::Parameters,
18    model::{CallToolRequestParams, CallToolResult, IntoContents, JsonObject},
19    service::RequestContext,
20};
21
22/// Deserialize a JSON object into a type
23pub fn parse_json_object<T: DeserializeOwned>(input: JsonObject) -> Result<T, crate::ErrorData> {
24    serde_json::from_value(serde_json::Value::Object(input)).map_err(|e| {
25        crate::ErrorData::invalid_params(
26            format!("failed to deserialize parameters: {error}", error = e),
27            None,
28        )
29    })
30}
31pub struct ToolCallContext<'s, S> {
32    pub request_context: RequestContext<RoleServer>,
33    pub service: &'s S,
34    pub name: Cow<'static, str>,
35    pub arguments: Option<JsonObject>,
36    pub task: Option<JsonObject>,
37}
38
39impl<'s, S> ToolCallContext<'s, S> {
40    pub fn new(
41        service: &'s S,
42        CallToolRequestParams {
43            meta: _,
44            name,
45            arguments,
46            task,
47        }: CallToolRequestParams,
48        request_context: RequestContext<RoleServer>,
49    ) -> Self {
50        Self {
51            request_context,
52            service,
53            name,
54            arguments,
55            task,
56        }
57    }
58    pub fn name(&self) -> &str {
59        &self.name
60    }
61    pub fn request_context(&self) -> &RequestContext<RoleServer> {
62        &self.request_context
63    }
64}
65
66impl<S> AsRequestContext for ToolCallContext<'_, S> {
67    fn as_request_context(&self) -> &RequestContext<RoleServer> {
68        &self.request_context
69    }
70
71    fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer> {
72        &mut self.request_context
73    }
74}
75
76pub trait IntoCallToolResult {
77    fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData>;
78}
79
80impl<T: IntoContents> IntoCallToolResult for T {
81    fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
82        Ok(CallToolResult::success(self.into_contents()))
83    }
84}
85
86impl<T: IntoContents, E: IntoContents> IntoCallToolResult for Result<T, E> {
87    fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
88        match self {
89            Ok(value) => Ok(CallToolResult::success(value.into_contents())),
90            Err(error) => Ok(CallToolResult::error(error.into_contents())),
91        }
92    }
93}
94
95impl<T: IntoCallToolResult> IntoCallToolResult for Result<T, crate::ErrorData> {
96    fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
97        match self {
98            Ok(value) => value.into_call_tool_result(),
99            Err(error) => Err(error),
100        }
101    }
102}
103
104pin_project_lite::pin_project! {
105    #[project = IntoCallToolResultFutProj]
106    pub enum IntoCallToolResultFut<F, R> {
107        Pending {
108            #[pin]
109            fut: F,
110            _marker: PhantomData<R>,
111        },
112        Ready {
113            #[pin]
114            result: Ready<Result<CallToolResult, crate::ErrorData>>,
115        }
116    }
117}
118
119impl<F, R> Future for IntoCallToolResultFut<F, R>
120where
121    F: Future<Output = R>,
122    R: IntoCallToolResult,
123{
124    type Output = Result<CallToolResult, crate::ErrorData>;
125
126    fn poll(
127        self: std::pin::Pin<&mut Self>,
128        cx: &mut std::task::Context<'_>,
129    ) -> std::task::Poll<Self::Output> {
130        match self.project() {
131            IntoCallToolResultFutProj::Pending { fut, _marker } => {
132                fut.poll(cx).map(IntoCallToolResult::into_call_tool_result)
133            }
134            IntoCallToolResultFutProj::Ready { result } => result.poll(cx),
135        }
136    }
137}
138
139impl IntoCallToolResult for Result<CallToolResult, crate::ErrorData> {
140    fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
141        self
142    }
143}
144
145pub trait CallToolHandler<S, A> {
146    fn call(
147        self,
148        context: ToolCallContext<'_, S>,
149    ) -> BoxFuture<'_, Result<CallToolResult, crate::ErrorData>>;
150}
151
152pub type DynCallToolHandler<S> = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result<CallToolResult, crate::ErrorData>>
153    + Send
154    + Sync;
155
156// Tool-specific extractor for tool name
157pub struct ToolName(pub Cow<'static, str>);
158
159impl<S> FromContextPart<ToolCallContext<'_, S>> for ToolName {
160    fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
161        Ok(Self(context.name.clone()))
162    }
163}
164
165// Special implementation for Parameters that handles tool arguments
166impl<S, P> FromContextPart<ToolCallContext<'_, S>> for Parameters<P>
167where
168    P: DeserializeOwned,
169{
170    fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
171        let arguments = context.arguments.take().unwrap_or_default();
172        let value: P =
173            serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| {
174                crate::ErrorData::invalid_params(
175                    format!("failed to deserialize parameters: {error}", error = e),
176                    None,
177                )
178            })?;
179        Ok(Parameters(value))
180    }
181}
182
183// Special implementation for JsonObject that takes tool arguments
184impl<S> FromContextPart<ToolCallContext<'_, S>> for JsonObject {
185    fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
186        let object = context.arguments.take().unwrap_or_default();
187        Ok(object)
188    }
189}
190
191impl<'s, S> ToolCallContext<'s, S> {
192    pub fn invoke<H, A>(self, h: H) -> BoxFuture<'s, Result<CallToolResult, crate::ErrorData>>
193    where
194        H: CallToolHandler<S, A>,
195    {
196        h.call(self)
197    }
198}
199#[allow(clippy::type_complexity)]
200pub struct AsyncAdapter<P, Fut, R>(PhantomData<fn(P) -> fn(Fut) -> R>);
201pub struct SyncAdapter<P, R>(PhantomData<fn(P) -> R>);
202// #[allow(clippy::type_complexity)]
203pub struct AsyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
204pub struct SyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
205
206macro_rules! impl_for {
207    ($($T: ident)*) => {
208        impl_for!([] [$($T)*]);
209    };
210    // finished
211    ([$($Tn: ident)*] []) => {
212        impl_for!(@impl $($Tn)*);
213    };
214    ([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => {
215        impl_for!(@impl $($Tn)*);
216        impl_for!([$($Tn)* $Tn_1] [$($Rest)*]);
217    };
218    (@impl $($Tn: ident)*) => {
219        impl<$($Tn,)* S, F,  R> CallToolHandler<S, AsyncMethodAdapter<($($Tn,)*), R>> for F
220        where
221            $(
222                $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> ,
223            )*
224            F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R>,
225
226            // Need RTN support here(I guess), https://github.com/rust-lang/rust/pull/138424
227            // Fut: Future<Output = R> + Send + 'a,
228            R: IntoCallToolResult + Send + 'static,
229            S: Send + Sync + 'static,
230        {
231            #[allow(unused_variables, non_snake_case, unused_mut)]
232            fn call(
233                self,
234                mut context: ToolCallContext<'_, S>,
235            ) -> BoxFuture<'_, Result<CallToolResult, crate::ErrorData>>{
236                $(
237                    let result = $Tn::from_context_part(&mut context);
238                    let $Tn = match result {
239                        Ok(value) => value,
240                        Err(e) => return std::future::ready(Err(e)).boxed(),
241                    };
242                )*
243                let service = context.service;
244                let fut = self(service, $($Tn,)*);
245                async move {
246                    let result = fut.await;
247                    result.into_call_tool_result()
248                }.boxed()
249            }
250        }
251
252        impl<$($Tn,)* S, F, Fut, R> CallToolHandler<S, AsyncAdapter<($($Tn,)*), Fut, R>> for F
253        where
254            $(
255                $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> ,
256            )*
257            F: FnOnce($($Tn,)*) -> Fut + Send + ,
258            Fut: Future<Output = R> + Send + 'static,
259            R: IntoCallToolResult + Send + 'static,
260            S: Send + Sync,
261        {
262            #[allow(unused_variables, non_snake_case, unused_mut)]
263            fn call(
264                self,
265                mut context: ToolCallContext<S>,
266            ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>>{
267                $(
268                    let result = $Tn::from_context_part(&mut context);
269                    let $Tn = match result {
270                        Ok(value) => value,
271                        Err(e) => return std::future::ready(Err(e)).boxed(),
272                    };
273                )*
274                let fut = self($($Tn,)*);
275                async move {
276                    let result = fut.await;
277                    result.into_call_tool_result()
278                }.boxed()
279            }
280        }
281
282        impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncMethodAdapter<($($Tn,)*), R>> for F
283        where
284            $(
285                $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> + ,
286            )*
287            F: FnOnce(&S, $($Tn,)*) -> R + Send + ,
288            R: IntoCallToolResult + Send + ,
289            S: Send + Sync,
290        {
291            #[allow(unused_variables, non_snake_case, unused_mut)]
292            fn call(
293                self,
294                mut context: ToolCallContext<S>,
295            ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>> {
296                $(
297                    let result = $Tn::from_context_part(&mut context);
298                    let $Tn = match result {
299                        Ok(value) => value,
300                        Err(e) => return std::future::ready(Err(e)).boxed(),
301                    };
302                )*
303                std::future::ready(self(context.service, $($Tn,)*).into_call_tool_result()).boxed()
304            }
305        }
306
307        impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncAdapter<($($Tn,)*), R>> for F
308        where
309            $(
310                $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> + ,
311            )*
312            F: FnOnce($($Tn,)*) -> R + Send + ,
313            R: IntoCallToolResult + Send + ,
314            S: Send + Sync,
315        {
316            #[allow(unused_variables, non_snake_case, unused_mut)]
317            fn call(
318                self,
319                mut context: ToolCallContext<S>,
320            ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>>  {
321                $(
322                    let result = $Tn::from_context_part(&mut context);
323                    let $Tn = match result {
324                        Ok(value) => value,
325                        Err(e) => return std::future::ready(Err(e)).boxed(),
326                    };
327                )*
328                std::future::ready(self($($Tn,)*).into_call_tool_result()).boxed()
329            }
330        }
331    };
332}
333impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);