Skip to main content

agenterra_rmcp/handler/server/
tool.rs

1use std::{
2    any::TypeId, borrow::Cow, collections::HashMap, future::Ready, marker::PhantomData, sync::Arc,
3};
4
5use futures::future::{BoxFuture, FutureExt};
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8use tokio_util::sync::CancellationToken;
9
10pub use super::router::tool::{ToolRoute, ToolRouter};
11use crate::{
12    RoleServer,
13    model::{CallToolRequestParam, CallToolResult, IntoContents, JsonObject},
14    service::RequestContext,
15};
16
17/// A shortcut for generating a JSON schema for a type.
18pub fn schema_for_type<T: JsonSchema>() -> JsonObject {
19    // explicitly to align json schema version to official specifications.
20    // https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-03-26/schema.json
21    let mut settings = schemars::r#gen::SchemaSettings::draft07();
22    settings.option_nullable = true;
23    settings.option_add_null_type = false;
24    settings.visitors = Vec::default();
25    let generator = settings.into_generator();
26    let schema = generator.into_root_schema_for::<T>();
27    let object = serde_json::to_value(schema).expect("failed to serialize schema");
28    match object {
29        serde_json::Value::Object(object) => object,
30        _ => panic!("unexpected schema value"),
31    }
32}
33
34/// Call [`schema_for_type`] with a cache
35pub fn cached_schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
36    thread_local! {
37        static CACHE_FOR_TYPE: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
38    };
39    CACHE_FOR_TYPE.with(|cache| {
40        if let Some(x) = cache
41            .read()
42            .expect("schema cache lock poisoned")
43            .get(&TypeId::of::<T>())
44        {
45            x.clone()
46        } else {
47            let schema = schema_for_type::<T>();
48            let schema = Arc::new(schema);
49            cache
50                .write()
51                .expect("schema cache lock poisoned")
52                .insert(TypeId::of::<T>(), schema.clone());
53            schema
54        }
55    })
56}
57
58/// Deserialize a JSON object into a type
59pub fn parse_json_object<T: DeserializeOwned>(input: JsonObject) -> Result<T, crate::Error> {
60    serde_json::from_value(serde_json::Value::Object(input)).map_err(|e| {
61        crate::Error::invalid_params(
62            format!("failed to deserialize parameters: {error}", error = e),
63            None,
64        )
65    })
66}
67pub struct ToolCallContext<'s, S> {
68    pub request_context: RequestContext<RoleServer>,
69    pub service: &'s S,
70    pub name: Cow<'static, str>,
71    pub arguments: Option<JsonObject>,
72}
73
74impl<'s, S> ToolCallContext<'s, S> {
75    pub fn new(
76        service: &'s S,
77        CallToolRequestParam { name, arguments }: CallToolRequestParam,
78        request_context: RequestContext<RoleServer>,
79    ) -> Self {
80        Self {
81            request_context,
82            service,
83            name,
84            arguments,
85        }
86    }
87    pub fn name(&self) -> &str {
88        &self.name
89    }
90    pub fn request_context(&self) -> &RequestContext<RoleServer> {
91        &self.request_context
92    }
93}
94
95pub trait FromToolCallContextPart<S>: Sized {
96    fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error>;
97}
98
99pub trait IntoCallToolResult {
100    fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error>;
101}
102impl IntoCallToolResult for () {
103    fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
104        Ok(CallToolResult::success(vec![]))
105    }
106}
107
108impl<T: IntoContents> IntoCallToolResult for T {
109    fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
110        Ok(CallToolResult::success(self.into_contents()))
111    }
112}
113
114impl<T: IntoContents, E: IntoContents> IntoCallToolResult for Result<T, E> {
115    fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
116        match self {
117            Ok(value) => Ok(CallToolResult::success(value.into_contents())),
118            Err(error) => Ok(CallToolResult::error(error.into_contents())),
119        }
120    }
121}
122
123pin_project_lite::pin_project! {
124    #[project = IntoCallToolResultFutProj]
125    pub enum IntoCallToolResultFut<F, R> {
126        Pending {
127            #[pin]
128            fut: F,
129            _marker: PhantomData<R>,
130        },
131        Ready {
132            #[pin]
133            result: Ready<Result<CallToolResult, crate::Error>>,
134        }
135    }
136}
137
138impl<F, R> Future for IntoCallToolResultFut<F, R>
139where
140    F: Future<Output = R>,
141    R: IntoCallToolResult,
142{
143    type Output = Result<CallToolResult, crate::Error>;
144
145    fn poll(
146        self: std::pin::Pin<&mut Self>,
147        cx: &mut std::task::Context<'_>,
148    ) -> std::task::Poll<Self::Output> {
149        match self.project() {
150            IntoCallToolResultFutProj::Pending { fut, _marker } => {
151                fut.poll(cx).map(IntoCallToolResult::into_call_tool_result)
152            }
153            IntoCallToolResultFutProj::Ready { result } => result.poll(cx),
154        }
155    }
156}
157
158impl IntoCallToolResult for Result<CallToolResult, crate::Error> {
159    fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
160        self
161    }
162}
163
164pub trait CallToolHandler<S, A> {
165    fn call(
166        self,
167        context: ToolCallContext<'_, S>,
168    ) -> BoxFuture<'_, Result<CallToolResult, crate::Error>>;
169}
170
171pub type DynCallToolHandler<S> = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result<CallToolResult, crate::Error>>
172    + Send
173    + Sync;
174
175/// Parameter Extractor
176///
177#[derive(Debug, Clone, Serialize, Deserialize)]
178#[serde(transparent)]
179pub struct Parameters<P>(pub P);
180
181impl<P: JsonSchema> JsonSchema for Parameters<P> {
182    fn schema_name() -> String {
183        P::schema_name()
184    }
185
186    fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
187        P::json_schema(generator)
188    }
189}
190
191impl<S> FromToolCallContextPart<S> for CancellationToken {
192    fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
193        Ok(context.request_context.ct.clone())
194    }
195}
196
197pub struct ToolName(pub Cow<'static, str>);
198
199impl<S> FromToolCallContextPart<S> for ToolName {
200    fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
201        Ok(Self(context.name.clone()))
202    }
203}
204
205impl<S, P> FromToolCallContextPart<S> for Parameters<P>
206where
207    P: DeserializeOwned,
208{
209    fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
210        let arguments = context.arguments.take().unwrap_or_default();
211        let value: P =
212            serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| {
213                crate::Error::invalid_params(
214                    format!("failed to deserialize parameters: {error}", error = e),
215                    None,
216                )
217            })?;
218        Ok(Parameters(value))
219    }
220}
221
222impl<S> FromToolCallContextPart<S> for JsonObject {
223    fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
224        let object = context.arguments.take().unwrap_or_default();
225        Ok(object)
226    }
227}
228
229impl<S> FromToolCallContextPart<S> for crate::model::Extensions {
230    fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
231        let extensions = context.request_context.extensions.clone();
232        Ok(extensions)
233    }
234}
235
236pub struct Extension<T>(pub T);
237
238impl<S, T> FromToolCallContextPart<S> for Extension<T>
239where
240    T: Send + Sync + 'static + Clone,
241{
242    fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
243        let extension = context
244            .request_context
245            .extensions
246            .get::<T>()
247            .cloned()
248            .ok_or_else(|| {
249                crate::Error::invalid_params(
250                    format!("missing extension {}", std::any::type_name::<T>()),
251                    None,
252                )
253            })?;
254        Ok(Extension(extension))
255    }
256}
257
258impl<S> FromToolCallContextPart<S> for crate::Peer<RoleServer> {
259    fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
260        let peer = context.request_context.peer.clone();
261        Ok(peer)
262    }
263}
264
265impl<S> FromToolCallContextPart<S> for crate::model::Meta {
266    fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
267        let mut meta = crate::model::Meta::default();
268        std::mem::swap(&mut meta, &mut context.request_context.meta);
269        Ok(meta)
270    }
271}
272
273pub struct RequestId(pub crate::model::RequestId);
274impl<S> FromToolCallContextPart<S> for RequestId {
275    fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
276        Ok(RequestId(context.request_context.id.clone()))
277    }
278}
279
280impl<S> FromToolCallContextPart<S> for RequestContext<RoleServer> {
281    fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
282        Ok(context.request_context.clone())
283    }
284}
285
286impl<'s, S> ToolCallContext<'s, S> {
287    pub fn invoke<H, A>(self, h: H) -> BoxFuture<'s, Result<CallToolResult, crate::Error>>
288    where
289        H: CallToolHandler<S, A>,
290    {
291        h.call(self)
292    }
293}
294#[allow(clippy::type_complexity)]
295pub struct AsyncAdapter<P, Fut, R>(PhantomData<fn(P) -> fn(Fut) -> R>);
296pub struct SyncAdapter<P, R>(PhantomData<fn(P) -> R>);
297// #[allow(clippy::type_complexity)]
298pub struct AsyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
299pub struct SyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
300
301macro_rules! impl_for {
302    ($($T: ident)*) => {
303        impl_for!([] [$($T)*]);
304    };
305    // finished
306    ([$($Tn: ident)*] []) => {
307        impl_for!(@impl $($Tn)*);
308    };
309    ([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => {
310        impl_for!(@impl $($Tn)*);
311        impl_for!([$($Tn)* $Tn_1] [$($Rest)*]);
312    };
313    (@impl $($Tn: ident)*) => {
314        impl<$($Tn,)* S, F,  R> CallToolHandler<S, AsyncMethodAdapter<($($Tn,)*), R>> for F
315        where
316            $(
317                $Tn: FromToolCallContextPart<S> ,
318            )*
319            F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R>,
320
321            // Need RTN support here(I guess), https://github.com/rust-lang/rust/pull/138424
322            // Fut: Future<Output = R> + Send + 'a,
323            R: IntoCallToolResult + Send + 'static,
324            S: Send + Sync + 'static,
325        {
326            #[allow(unused_variables, non_snake_case, unused_mut)]
327            fn call(
328                self,
329                mut context: ToolCallContext<'_, S>,
330            ) -> BoxFuture<'_, Result<CallToolResult, crate::Error>>{
331                $(
332                    let result = $Tn::from_tool_call_context_part(&mut context);
333                    let $Tn = match result {
334                        Ok(value) => value,
335                        Err(e) => return std::future::ready(Err(e)).boxed(),
336                    };
337                )*
338                let service = context.service;
339                let fut = self(service, $($Tn,)*);
340                async move {
341                    let result = fut.await;
342                    result.into_call_tool_result()
343                }.boxed()
344            }
345        }
346
347        impl<$($Tn,)* S, F, Fut, R> CallToolHandler<S, AsyncAdapter<($($Tn,)*), Fut, R>> for F
348        where
349            $(
350                $Tn: FromToolCallContextPart<S> ,
351            )*
352            F: FnOnce($($Tn,)*) -> Fut + Send + ,
353            Fut: Future<Output = R> + Send + 'static,
354            R: IntoCallToolResult + Send + 'static,
355            S: Send + Sync,
356        {
357            #[allow(unused_variables, non_snake_case, unused_mut)]
358            fn call(
359                self,
360                mut context: ToolCallContext<S>,
361            ) -> BoxFuture<'static, Result<CallToolResult, crate::Error>>{
362                $(
363                    let result = $Tn::from_tool_call_context_part(&mut context);
364                    let $Tn = match result {
365                        Ok(value) => value,
366                        Err(e) => return std::future::ready(Err(e)).boxed(),
367                    };
368                )*
369                let fut = self($($Tn,)*);
370                async move {
371                    let result = fut.await;
372                    result.into_call_tool_result()
373                }.boxed()
374            }
375        }
376
377        impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncMethodAdapter<($($Tn,)*), R>> for F
378        where
379            $(
380                $Tn: FromToolCallContextPart<S> + ,
381            )*
382            F: FnOnce(&S, $($Tn,)*) -> R + Send + ,
383            R: IntoCallToolResult + Send + ,
384            S: Send + Sync,
385        {
386            #[allow(unused_variables, non_snake_case, unused_mut)]
387            fn call(
388                self,
389                mut context: ToolCallContext<S>,
390            ) -> BoxFuture<'static, Result<CallToolResult, crate::Error>> {
391                $(
392                    let result = $Tn::from_tool_call_context_part(&mut context);
393                    let $Tn = match result {
394                        Ok(value) => value,
395                        Err(e) => return std::future::ready(Err(e)).boxed(),
396                    };
397                )*
398                std::future::ready(self(context.service, $($Tn,)*).into_call_tool_result()).boxed()
399            }
400        }
401
402        impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncAdapter<($($Tn,)*), R>> for F
403        where
404            $(
405                $Tn: FromToolCallContextPart<S> + ,
406            )*
407            F: FnOnce($($Tn,)*) -> R + Send + ,
408            R: IntoCallToolResult + Send + ,
409            S: Send + Sync,
410        {
411            #[allow(unused_variables, non_snake_case, unused_mut)]
412            fn call(
413                self,
414                mut context: ToolCallContext<S>,
415            ) -> BoxFuture<'static, Result<CallToolResult, crate::Error>>  {
416                $(
417                    let result = $Tn::from_tool_call_context_part(&mut context);
418                    let $Tn = match result {
419                        Ok(value) => value,
420                        Err(e) => return std::future::ready(Err(e)).boxed(),
421                    };
422                )*
423                std::future::ready(self($($Tn,)*).into_call_tool_result()).boxed()
424            }
425        }
426    };
427}
428impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);