Skip to main content

mcpkit_rs/handler/server/
tool.rs

1use std::{
2    borrow::Cow,
3    future::{Future, Ready},
4    marker::PhantomData,
5};
6
7use futures::future::{BoxFuture, FutureExt};
8use serde::de::DeserializeOwned;
9
10use super::common::{AsRequestContext, FromContextPart};
11#[cfg(feature = "schemars")]
12pub use super::common::{schema_for_output, schema_for_type};
13pub use super::{
14    common::{Extension, RequestId},
15    router::tool::{ToolRoute, ToolRouter},
16};
17use crate::{
18    RoleServer,
19    handler::server::wrapper::Parameters,
20    model::{CallToolRequestParams, CallToolResult, IntoContents, JsonObject},
21    service::RequestContext,
22};
23
24/// Deserialize a JSON object into a type
25pub fn parse_json_object<T: DeserializeOwned>(input: JsonObject) -> Result<T, crate::ErrorData> {
26    serde_json::from_value(serde_json::Value::Object(input)).map_err(|e| {
27        crate::ErrorData::invalid_params(
28            format!("failed to deserialize parameters: {error}", error = e),
29            None,
30        )
31    })
32}
33pub struct ToolCallContext<'s, S> {
34    pub request_context: RequestContext<RoleServer>,
35    pub service: &'s S,
36    pub name: Cow<'static, str>,
37    pub arguments: Option<JsonObject>,
38    pub task: Option<JsonObject>,
39}
40
41impl<'s, S> ToolCallContext<'s, S> {
42    pub fn new(
43        service: &'s S,
44        CallToolRequestParams {
45            meta: _,
46            name,
47            arguments,
48            task,
49        }: CallToolRequestParams,
50        request_context: RequestContext<RoleServer>,
51    ) -> Self {
52        Self {
53            request_context,
54            service,
55            name,
56            arguments,
57            task,
58        }
59    }
60    pub fn name(&self) -> &str {
61        &self.name
62    }
63    pub fn request_context(&self) -> &RequestContext<RoleServer> {
64        &self.request_context
65    }
66}
67
68impl<S> AsRequestContext for ToolCallContext<'_, S> {
69    fn as_request_context(&self) -> &RequestContext<RoleServer> {
70        &self.request_context
71    }
72
73    fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer> {
74        &mut self.request_context
75    }
76}
77
78pub trait IntoCallToolResult {
79    fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData>;
80}
81
82impl<T: IntoContents> IntoCallToolResult for T {
83    fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
84        Ok(CallToolResult::success(self.into_contents()))
85    }
86}
87
88impl<T: IntoContents, E: IntoContents> IntoCallToolResult for Result<T, E> {
89    fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
90        match self {
91            Ok(value) => Ok(CallToolResult::success(value.into_contents())),
92            Err(error) => Ok(CallToolResult::error(error.into_contents())),
93        }
94    }
95}
96
97impl<T: IntoCallToolResult> IntoCallToolResult for Result<T, crate::ErrorData> {
98    fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
99        match self {
100            Ok(value) => value.into_call_tool_result(),
101            Err(error) => Err(error),
102        }
103    }
104}
105
106pin_project_lite::pin_project! {
107    #[project = IntoCallToolResultFutProj]
108    pub enum IntoCallToolResultFut<F, R> {
109        Pending {
110            #[pin]
111            fut: F,
112            _marker: PhantomData<R>,
113        },
114        Ready {
115            #[pin]
116            result: Ready<Result<CallToolResult, crate::ErrorData>>,
117        }
118    }
119}
120
121impl<F, R> Future for IntoCallToolResultFut<F, R>
122where
123    F: Future<Output = R>,
124    R: IntoCallToolResult,
125{
126    type Output = Result<CallToolResult, crate::ErrorData>;
127
128    fn poll(
129        self: std::pin::Pin<&mut Self>,
130        cx: &mut std::task::Context<'_>,
131    ) -> std::task::Poll<Self::Output> {
132        match self.project() {
133            IntoCallToolResultFutProj::Pending { fut, _marker } => {
134                fut.poll(cx).map(IntoCallToolResult::into_call_tool_result)
135            }
136            IntoCallToolResultFutProj::Ready { result } => result.poll(cx),
137        }
138    }
139}
140
141impl IntoCallToolResult for Result<CallToolResult, crate::ErrorData> {
142    fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
143        self
144    }
145}
146
147pub trait CallToolHandler<S, A> {
148    fn call(
149        self,
150        context: ToolCallContext<'_, S>,
151    ) -> BoxFuture<'_, Result<CallToolResult, crate::ErrorData>>;
152}
153
154pub type DynCallToolHandler<S> = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result<CallToolResult, crate::ErrorData>>
155    + Send
156    + Sync;
157
158// Tool-specific extractor for tool name
159pub struct ToolName(pub Cow<'static, str>);
160
161impl<S> FromContextPart<ToolCallContext<'_, S>> for ToolName {
162    fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
163        Ok(Self(context.name.clone()))
164    }
165}
166
167// Special implementation for Parameters that handles tool arguments
168impl<S, P> FromContextPart<ToolCallContext<'_, S>> for Parameters<P>
169where
170    P: DeserializeOwned,
171{
172    fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
173        let arguments = context.arguments.take().unwrap_or_default();
174        let value: P =
175            serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| {
176                crate::ErrorData::invalid_params(
177                    format!("failed to deserialize parameters: {error}", error = e),
178                    None,
179                )
180            })?;
181        Ok(Parameters(value))
182    }
183}
184
185// Special implementation for JsonObject that takes tool arguments
186impl<S> FromContextPart<ToolCallContext<'_, S>> for JsonObject {
187    fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
188        let object = context.arguments.take().unwrap_or_default();
189        Ok(object)
190    }
191}
192
193impl<'s, S> ToolCallContext<'s, S> {
194    pub fn invoke<H, A>(self, h: H) -> BoxFuture<'s, Result<CallToolResult, crate::ErrorData>>
195    where
196        H: CallToolHandler<S, A>,
197    {
198        h.call(self)
199    }
200}
201#[allow(clippy::type_complexity)]
202pub struct AsyncAdapter<P, Fut, R>(PhantomData<fn(P) -> fn(Fut) -> R>);
203pub struct SyncAdapter<P, R>(PhantomData<fn(P) -> R>);
204// #[allow(clippy::type_complexity)]
205pub struct AsyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
206pub struct SyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
207
208macro_rules! impl_for {
209    ($($T: ident)*) => {
210        impl_for!([] [$($T)*]);
211    };
212    // finished
213    ([$($Tn: ident)*] []) => {
214        impl_for!(@impl $($Tn)*);
215    };
216    ([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => {
217        impl_for!(@impl $($Tn)*);
218        impl_for!([$($Tn)* $Tn_1] [$($Rest)*]);
219    };
220    (@impl $($Tn: ident)*) => {
221        impl<$($Tn,)* S, F,  R> CallToolHandler<S, AsyncMethodAdapter<($($Tn,)*), R>> for F
222        where
223            $(
224                $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> ,
225            )*
226            F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R>,
227
228            // Need RTN support here(I guess), https://github.com/rust-lang/rust/pull/138424
229            // Fut: Future<Output = R> + Send + 'a,
230            R: IntoCallToolResult + Send + 'static,
231            S: Send + Sync + 'static,
232        {
233            #[allow(unused_variables, non_snake_case, unused_mut)]
234            fn call(
235                self,
236                mut context: ToolCallContext<'_, S>,
237            ) -> BoxFuture<'_, Result<CallToolResult, crate::ErrorData>>{
238                $(
239                    let result = $Tn::from_context_part(&mut context);
240                    let $Tn = match result {
241                        Ok(value) => value,
242                        Err(e) => return std::future::ready(Err(e)).boxed(),
243                    };
244                )*
245                let service = context.service;
246                let fut = self(service, $($Tn,)*);
247                async move {
248                    let result = fut.await;
249                    result.into_call_tool_result()
250                }.boxed()
251            }
252        }
253
254        impl<$($Tn,)* S, F, Fut, R> CallToolHandler<S, AsyncAdapter<($($Tn,)*), Fut, R>> for F
255        where
256            $(
257                $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> ,
258            )*
259            F: FnOnce($($Tn,)*) -> Fut + Send + ,
260            Fut: Future<Output = R> + Send + 'static,
261            R: IntoCallToolResult + Send + 'static,
262            S: Send + Sync,
263        {
264            #[allow(unused_variables, non_snake_case, unused_mut)]
265            fn call(
266                self,
267                mut context: ToolCallContext<S>,
268            ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>>{
269                $(
270                    let result = $Tn::from_context_part(&mut context);
271                    let $Tn = match result {
272                        Ok(value) => value,
273                        Err(e) => return std::future::ready(Err(e)).boxed(),
274                    };
275                )*
276                let fut = self($($Tn,)*);
277                async move {
278                    let result = fut.await;
279                    result.into_call_tool_result()
280                }.boxed()
281            }
282        }
283
284        impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncMethodAdapter<($($Tn,)*), R>> for F
285        where
286            $(
287                $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> + ,
288            )*
289            F: FnOnce(&S, $($Tn,)*) -> R + Send + ,
290            R: IntoCallToolResult + Send + ,
291            S: Send + Sync,
292        {
293            #[allow(unused_variables, non_snake_case, unused_mut)]
294            fn call(
295                self,
296                mut context: ToolCallContext<S>,
297            ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>> {
298                $(
299                    let result = $Tn::from_context_part(&mut context);
300                    let $Tn = match result {
301                        Ok(value) => value,
302                        Err(e) => return std::future::ready(Err(e)).boxed(),
303                    };
304                )*
305                std::future::ready(self(context.service, $($Tn,)*).into_call_tool_result()).boxed()
306            }
307        }
308
309        impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncAdapter<($($Tn,)*), R>> for F
310        where
311            $(
312                $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> + ,
313            )*
314            F: FnOnce($($Tn,)*) -> R + Send + ,
315            R: IntoCallToolResult + Send + ,
316            S: Send + Sync,
317        {
318            #[allow(unused_variables, non_snake_case, unused_mut)]
319            fn call(
320                self,
321                mut context: ToolCallContext<S>,
322            ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>>  {
323                $(
324                    let result = $Tn::from_context_part(&mut context);
325                    let $Tn = match result {
326                        Ok(value) => value,
327                        Err(e) => return std::future::ready(Err(e)).boxed(),
328                    };
329                )*
330                std::future::ready(self($($Tn,)*).into_call_tool_result()).boxed()
331            }
332        }
333    };
334}
335impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);
336
337#[cfg(test)]
338mod tests {
339    use serde::{Deserialize, Serialize};
340    use serde_json::json;
341    use tokio_util::sync::CancellationToken;
342
343    use super::*;
344    use crate::model::NumberOrString;
345
346    #[derive(Debug, Clone)]
347    struct TestService {
348        #[allow(dead_code)]
349        value: String,
350    }
351
352    #[derive(Debug, Deserialize, Serialize)]
353    struct TestParams {
354        message: String,
355        count: i32,
356    }
357
358    #[tokio::test]
359    async fn test_parse_json_object_valid() {
360        let mut json = JsonObject::new();
361        json.insert("message".to_string(), json!("hello"));
362        json.insert("count".to_string(), json!(42));
363
364        let result: Result<TestParams, _> = parse_json_object(json);
365        assert!(result.is_ok());
366        let params = result.unwrap();
367        assert_eq!(params.message, "hello");
368        assert_eq!(params.count, 42);
369    }
370
371    #[tokio::test]
372    async fn test_parse_json_object_invalid() {
373        let mut json = JsonObject::new();
374        json.insert("message".to_string(), json!("hello"));
375        // Missing required field 'count'
376
377        let result: Result<TestParams, _> = parse_json_object(json);
378        assert!(result.is_err());
379        let err = result.unwrap_err();
380        assert!(err.message.contains("failed to deserialize"));
381    }
382
383    #[tokio::test]
384    async fn test_parse_json_object_type_mismatch() {
385        let mut json = JsonObject::new();
386        json.insert("message".to_string(), json!("hello"));
387        json.insert("count".to_string(), json!("not a number")); // Wrong type
388
389        let result: Result<TestParams, _> = parse_json_object(json);
390        assert!(result.is_err());
391    }
392
393    #[tokio::test]
394    async fn test_into_call_tool_result_string() {
395        let result = "success".to_string().into_call_tool_result();
396        assert!(result.is_ok());
397        let tool_result = result.unwrap();
398        assert_eq!(tool_result.is_error, Some(false));
399        assert_eq!(tool_result.content.len(), 1);
400        if let Some(text) = tool_result.content[0].as_text() {
401            assert_eq!(text.text, "success");
402        } else {
403            panic!("Expected text content");
404        }
405    }
406
407    #[tokio::test]
408    async fn test_into_call_tool_result_ok_variant() {
409        let result: Result<String, String> = Ok("success".to_string());
410        let tool_result = result.into_call_tool_result().unwrap();
411        assert_eq!(tool_result.is_error, Some(false));
412        assert_eq!(tool_result.content.len(), 1);
413    }
414
415    #[tokio::test]
416    async fn test_into_call_tool_result_err_variant() {
417        let result: Result<String, String> = Err("error".to_string());
418        let tool_result = result.into_call_tool_result().unwrap();
419        assert_eq!(tool_result.is_error, Some(true));
420        assert_eq!(tool_result.content.len(), 1);
421        if let Some(text) = tool_result.content[0].as_text() {
422            assert_eq!(text.text, "error");
423        } else {
424            panic!("Expected text content");
425        }
426    }
427
428    #[tokio::test]
429    async fn test_into_call_tool_result_error_data() {
430        let error = crate::ErrorData::invalid_params("bad params".to_string(), None);
431        let result: Result<String, crate::ErrorData> = Err(error);
432        let tool_result = result.into_call_tool_result();
433        assert!(tool_result.is_err());
434        assert!(tool_result.unwrap_err().message.contains("bad params"));
435    }
436
437    #[tokio::test]
438    async fn test_tool_name_extraction() {
439        let service = TestService {
440            value: "test".to_string(),
441        };
442        let request_context = RequestContext {
443            peer: crate::service::Peer::new(
444                std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
445                None,
446            )
447            .0,
448            ct: CancellationToken::new(),
449            id: NumberOrString::Number(1),
450            meta: Default::default(),
451            extensions: Default::default(),
452        };
453
454        let mut context = ToolCallContext::new(
455            &service,
456            CallToolRequestParams {
457                meta: None,
458                name: "test_tool".into(),
459                arguments: None,
460                task: None,
461            },
462            request_context,
463        );
464
465        let tool_name = ToolName::from_context_part(&mut context).unwrap();
466        assert_eq!(tool_name.0, "test_tool");
467    }
468
469    #[tokio::test]
470    async fn test_parameters_extraction() {
471        let service = TestService {
472            value: "test".to_string(),
473        };
474        let mut args = JsonObject::new();
475        args.insert("message".to_string(), json!("hello"));
476        args.insert("count".to_string(), json!(42));
477
478        let request_context = RequestContext {
479            peer: crate::service::Peer::new(
480                std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
481                None,
482            )
483            .0,
484            ct: CancellationToken::new(),
485            id: NumberOrString::Number(1),
486            meta: Default::default(),
487            extensions: Default::default(),
488        };
489
490        let mut context = ToolCallContext::new(
491            &service,
492            CallToolRequestParams {
493                meta: None,
494                name: "test_tool".into(),
495                arguments: Some(args),
496                task: None,
497            },
498            request_context,
499        );
500
501        let params: Parameters<TestParams> = Parameters::from_context_part(&mut context).unwrap();
502        assert_eq!(params.0.message, "hello");
503        assert_eq!(params.0.count, 42);
504        // Arguments should be consumed
505        assert!(context.arguments.is_none());
506    }
507
508    #[tokio::test]
509    async fn test_parameters_extraction_empty() {
510        let service = TestService {
511            value: "test".to_string(),
512        };
513
514        let request_context = RequestContext {
515            peer: crate::service::Peer::new(
516                std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
517                None,
518            )
519            .0,
520            ct: CancellationToken::new(),
521            id: NumberOrString::Number(1),
522            meta: Default::default(),
523            extensions: Default::default(),
524        };
525
526        let mut context = ToolCallContext::new(
527            &service,
528            CallToolRequestParams {
529                meta: None,
530                name: "test_tool".into(),
531                arguments: None,
532                task: None,
533            },
534            request_context,
535        );
536
537        // Should use empty object when no arguments
538        let json_obj: JsonObject = JsonObject::from_context_part(&mut context).unwrap();
539        assert!(json_obj.is_empty());
540    }
541
542    #[tokio::test]
543    async fn test_async_handler_success() {
544        async fn async_tool(params: Parameters<TestParams>) -> String {
545            format!("{} x {}", params.0.message, params.0.count)
546        }
547
548        let service = TestService {
549            value: "test".to_string(),
550        };
551        let mut args = JsonObject::new();
552        args.insert("message".to_string(), json!("hello"));
553        args.insert("count".to_string(), json!(3));
554
555        let request_context = RequestContext {
556            peer: crate::service::Peer::new(
557                std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
558                None,
559            )
560            .0,
561            ct: CancellationToken::new(),
562            id: NumberOrString::Number(1),
563            meta: Default::default(),
564            extensions: Default::default(),
565        };
566
567        let context = ToolCallContext::new(
568            &service,
569            CallToolRequestParams {
570                meta: None,
571                name: "async_tool".into(),
572                arguments: Some(args),
573                task: None,
574            },
575            request_context,
576        );
577
578        let result = context.invoke(async_tool).await;
579        assert!(result.is_ok());
580        let tool_result = result.unwrap();
581        assert_eq!(tool_result.is_error, Some(false));
582        if let Some(text) = tool_result.content[0].as_text() {
583            assert_eq!(text.text, "hello x 3");
584        } else {
585            panic!("Expected text content");
586        }
587    }
588
589    #[tokio::test]
590    async fn test_sync_handler_success() {
591        fn sync_tool(params: Parameters<TestParams>) -> String {
592            format!("{} x {}", params.0.message, params.0.count)
593        }
594
595        let service = TestService {
596            value: "test".to_string(),
597        };
598        let mut args = JsonObject::new();
599        args.insert("message".to_string(), json!("test"));
600        args.insert("count".to_string(), json!(5));
601
602        let request_context = RequestContext {
603            peer: crate::service::Peer::new(
604                std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
605                None,
606            )
607            .0,
608            ct: CancellationToken::new(),
609            id: NumberOrString::Number(1),
610            meta: Default::default(),
611            extensions: Default::default(),
612        };
613
614        let context = ToolCallContext::new(
615            &service,
616            CallToolRequestParams {
617                meta: None,
618                name: "sync_tool".into(),
619                arguments: Some(args),
620                task: None,
621            },
622            request_context,
623        );
624
625        let result = context.invoke(sync_tool).await;
626        assert!(result.is_ok());
627        let tool_result = result.unwrap();
628        assert_eq!(tool_result.is_error, Some(false));
629        if let Some(text) = tool_result.content[0].as_text() {
630            assert_eq!(text.text, "test x 5");
631        } else {
632            panic!("Expected text content");
633        }
634    }
635
636    #[tokio::test]
637    async fn test_handler_with_result_error() {
638        async fn failing_tool(_params: Parameters<TestParams>) -> Result<String, String> {
639            Err("Tool execution failed".to_string())
640        }
641
642        let service = TestService {
643            value: "test".to_string(),
644        };
645        let mut args = JsonObject::new();
646        args.insert("message".to_string(), json!("test"));
647        args.insert("count".to_string(), json!(1));
648
649        let request_context = RequestContext {
650            peer: crate::service::Peer::new(
651                std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
652                None,
653            )
654            .0,
655            ct: CancellationToken::new(),
656            id: NumberOrString::Number(1),
657            meta: Default::default(),
658            extensions: Default::default(),
659        };
660
661        let context = ToolCallContext::new(
662            &service,
663            CallToolRequestParams {
664                meta: None,
665                name: "failing_tool".into(),
666                arguments: Some(args),
667                task: None,
668            },
669            request_context,
670        );
671
672        let result = context.invoke(failing_tool).await;
673        assert!(result.is_ok());
674        let tool_result = result.unwrap();
675        assert_eq!(tool_result.is_error, Some(true));
676        if let Some(text) = tool_result.content[0].as_text() {
677            assert_eq!(text.text, "Tool execution failed");
678        } else {
679            panic!("Expected text content");
680        }
681    }
682
683    #[tokio::test]
684    async fn test_handler_with_json_string_output() {
685        async fn json_tool(params: Parameters<TestParams>) -> String {
686            let result = json!({
687                "message": params.0.message,
688                "count": params.0.count,
689                "computed": params.0.count * 2
690            });
691            result.to_string()
692        }
693
694        let service = TestService {
695            value: "test".to_string(),
696        };
697        let mut args = JsonObject::new();
698        args.insert("message".to_string(), json!("hello"));
699        args.insert("count".to_string(), json!(10));
700
701        let request_context = RequestContext {
702            peer: crate::service::Peer::new(
703                std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
704                None,
705            )
706            .0,
707            ct: CancellationToken::new(),
708            id: NumberOrString::Number(1),
709            meta: Default::default(),
710            extensions: Default::default(),
711        };
712
713        let context = ToolCallContext::new(
714            &service,
715            CallToolRequestParams {
716                meta: None,
717                name: "json_tool".into(),
718                arguments: Some(args),
719                task: None,
720            },
721            request_context,
722        );
723
724        let result = context.invoke(json_tool).await;
725        assert!(result.is_ok());
726        let tool_result = result.unwrap();
727        assert_eq!(tool_result.is_error, Some(false));
728        if let Some(text) = tool_result.content[0].as_text() {
729            let parsed: serde_json::Value = serde_json::from_str(&text.text).unwrap();
730            assert_eq!(parsed["message"], "hello");
731            assert_eq!(parsed["count"], 10);
732            assert_eq!(parsed["computed"], 20);
733        } else {
734            panic!("Expected text content");
735        }
736    }
737}