1use std::{
2 borrow::Cow,
3 future::{Future, Ready},
4 marker::PhantomData,
5};
6
7use futures::future::{BoxFuture, FutureExt};
8use serde::de::DeserializeOwned;
9
10use super::common::{AsRequestContext, FromContextPart};
11#[cfg(feature = "schemars")]
12pub use super::common::{schema_for_output, schema_for_type};
13pub use super::{
14 common::{Extension, RequestId},
15 router::tool::{ToolRoute, ToolRouter},
16};
17use crate::{
18 RoleServer,
19 handler::server::wrapper::Parameters,
20 model::{CallToolRequestParams, CallToolResult, IntoContents, JsonObject},
21 service::RequestContext,
22};
23
24pub fn parse_json_object<T: DeserializeOwned>(input: JsonObject) -> Result<T, crate::ErrorData> {
26 serde_json::from_value(serde_json::Value::Object(input)).map_err(|e| {
27 crate::ErrorData::invalid_params(
28 format!("failed to deserialize parameters: {error}", error = e),
29 None,
30 )
31 })
32}
33pub struct ToolCallContext<'s, S> {
34 pub request_context: RequestContext<RoleServer>,
35 pub service: &'s S,
36 pub name: Cow<'static, str>,
37 pub arguments: Option<JsonObject>,
38 pub task: Option<JsonObject>,
39}
40
41impl<'s, S> ToolCallContext<'s, S> {
42 pub fn new(
43 service: &'s S,
44 CallToolRequestParams {
45 meta: _,
46 name,
47 arguments,
48 task,
49 }: CallToolRequestParams,
50 request_context: RequestContext<RoleServer>,
51 ) -> Self {
52 Self {
53 request_context,
54 service,
55 name,
56 arguments,
57 task,
58 }
59 }
60 pub fn name(&self) -> &str {
61 &self.name
62 }
63 pub fn request_context(&self) -> &RequestContext<RoleServer> {
64 &self.request_context
65 }
66}
67
68impl<S> AsRequestContext for ToolCallContext<'_, S> {
69 fn as_request_context(&self) -> &RequestContext<RoleServer> {
70 &self.request_context
71 }
72
73 fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer> {
74 &mut self.request_context
75 }
76}
77
78pub trait IntoCallToolResult {
79 fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData>;
80}
81
82impl<T: IntoContents> IntoCallToolResult for T {
83 fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
84 Ok(CallToolResult::success(self.into_contents()))
85 }
86}
87
88impl<T: IntoContents, E: IntoContents> IntoCallToolResult for Result<T, E> {
89 fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
90 match self {
91 Ok(value) => Ok(CallToolResult::success(value.into_contents())),
92 Err(error) => Ok(CallToolResult::error(error.into_contents())),
93 }
94 }
95}
96
97impl<T: IntoCallToolResult> IntoCallToolResult for Result<T, crate::ErrorData> {
98 fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
99 match self {
100 Ok(value) => value.into_call_tool_result(),
101 Err(error) => Err(error),
102 }
103 }
104}
105
106pin_project_lite::pin_project! {
107 #[project = IntoCallToolResultFutProj]
108 pub enum IntoCallToolResultFut<F, R> {
109 Pending {
110 #[pin]
111 fut: F,
112 _marker: PhantomData<R>,
113 },
114 Ready {
115 #[pin]
116 result: Ready<Result<CallToolResult, crate::ErrorData>>,
117 }
118 }
119}
120
121impl<F, R> Future for IntoCallToolResultFut<F, R>
122where
123 F: Future<Output = R>,
124 R: IntoCallToolResult,
125{
126 type Output = Result<CallToolResult, crate::ErrorData>;
127
128 fn poll(
129 self: std::pin::Pin<&mut Self>,
130 cx: &mut std::task::Context<'_>,
131 ) -> std::task::Poll<Self::Output> {
132 match self.project() {
133 IntoCallToolResultFutProj::Pending { fut, _marker } => {
134 fut.poll(cx).map(IntoCallToolResult::into_call_tool_result)
135 }
136 IntoCallToolResultFutProj::Ready { result } => result.poll(cx),
137 }
138 }
139}
140
141impl IntoCallToolResult for Result<CallToolResult, crate::ErrorData> {
142 fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
143 self
144 }
145}
146
147pub trait CallToolHandler<S, A> {
148 fn call(
149 self,
150 context: ToolCallContext<'_, S>,
151 ) -> BoxFuture<'_, Result<CallToolResult, crate::ErrorData>>;
152}
153
154pub type DynCallToolHandler<S> = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result<CallToolResult, crate::ErrorData>>
155 + Send
156 + Sync;
157
158pub struct ToolName(pub Cow<'static, str>);
160
161impl<S> FromContextPart<ToolCallContext<'_, S>> for ToolName {
162 fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
163 Ok(Self(context.name.clone()))
164 }
165}
166
167impl<S, P> FromContextPart<ToolCallContext<'_, S>> for Parameters<P>
169where
170 P: DeserializeOwned,
171{
172 fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
173 let arguments = context.arguments.take().unwrap_or_default();
174 let value: P =
175 serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| {
176 crate::ErrorData::invalid_params(
177 format!("failed to deserialize parameters: {error}", error = e),
178 None,
179 )
180 })?;
181 Ok(Parameters(value))
182 }
183}
184
185impl<S> FromContextPart<ToolCallContext<'_, S>> for JsonObject {
187 fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
188 let object = context.arguments.take().unwrap_or_default();
189 Ok(object)
190 }
191}
192
193impl<'s, S> ToolCallContext<'s, S> {
194 pub fn invoke<H, A>(self, h: H) -> BoxFuture<'s, Result<CallToolResult, crate::ErrorData>>
195 where
196 H: CallToolHandler<S, A>,
197 {
198 h.call(self)
199 }
200}
201#[allow(clippy::type_complexity)]
202pub struct AsyncAdapter<P, Fut, R>(PhantomData<fn(P) -> fn(Fut) -> R>);
203pub struct SyncAdapter<P, R>(PhantomData<fn(P) -> R>);
204pub struct AsyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
206pub struct SyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
207
208macro_rules! impl_for {
209 ($($T: ident)*) => {
210 impl_for!([] [$($T)*]);
211 };
212 ([$($Tn: ident)*] []) => {
214 impl_for!(@impl $($Tn)*);
215 };
216 ([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => {
217 impl_for!(@impl $($Tn)*);
218 impl_for!([$($Tn)* $Tn_1] [$($Rest)*]);
219 };
220 (@impl $($Tn: ident)*) => {
221 impl<$($Tn,)* S, F, R> CallToolHandler<S, AsyncMethodAdapter<($($Tn,)*), R>> for F
222 where
223 $(
224 $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> ,
225 )*
226 F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R>,
227
228 R: IntoCallToolResult + Send + 'static,
231 S: Send + Sync + 'static,
232 {
233 #[allow(unused_variables, non_snake_case, unused_mut)]
234 fn call(
235 self,
236 mut context: ToolCallContext<'_, S>,
237 ) -> BoxFuture<'_, Result<CallToolResult, crate::ErrorData>>{
238 $(
239 let result = $Tn::from_context_part(&mut context);
240 let $Tn = match result {
241 Ok(value) => value,
242 Err(e) => return std::future::ready(Err(e)).boxed(),
243 };
244 )*
245 let service = context.service;
246 let fut = self(service, $($Tn,)*);
247 async move {
248 let result = fut.await;
249 result.into_call_tool_result()
250 }.boxed()
251 }
252 }
253
254 impl<$($Tn,)* S, F, Fut, R> CallToolHandler<S, AsyncAdapter<($($Tn,)*), Fut, R>> for F
255 where
256 $(
257 $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> ,
258 )*
259 F: FnOnce($($Tn,)*) -> Fut + Send + ,
260 Fut: Future<Output = R> + Send + 'static,
261 R: IntoCallToolResult + Send + 'static,
262 S: Send + Sync,
263 {
264 #[allow(unused_variables, non_snake_case, unused_mut)]
265 fn call(
266 self,
267 mut context: ToolCallContext<S>,
268 ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>>{
269 $(
270 let result = $Tn::from_context_part(&mut context);
271 let $Tn = match result {
272 Ok(value) => value,
273 Err(e) => return std::future::ready(Err(e)).boxed(),
274 };
275 )*
276 let fut = self($($Tn,)*);
277 async move {
278 let result = fut.await;
279 result.into_call_tool_result()
280 }.boxed()
281 }
282 }
283
284 impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncMethodAdapter<($($Tn,)*), R>> for F
285 where
286 $(
287 $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> + ,
288 )*
289 F: FnOnce(&S, $($Tn,)*) -> R + Send + ,
290 R: IntoCallToolResult + Send + ,
291 S: Send + Sync,
292 {
293 #[allow(unused_variables, non_snake_case, unused_mut)]
294 fn call(
295 self,
296 mut context: ToolCallContext<S>,
297 ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>> {
298 $(
299 let result = $Tn::from_context_part(&mut context);
300 let $Tn = match result {
301 Ok(value) => value,
302 Err(e) => return std::future::ready(Err(e)).boxed(),
303 };
304 )*
305 std::future::ready(self(context.service, $($Tn,)*).into_call_tool_result()).boxed()
306 }
307 }
308
309 impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncAdapter<($($Tn,)*), R>> for F
310 where
311 $(
312 $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> + ,
313 )*
314 F: FnOnce($($Tn,)*) -> R + Send + ,
315 R: IntoCallToolResult + Send + ,
316 S: Send + Sync,
317 {
318 #[allow(unused_variables, non_snake_case, unused_mut)]
319 fn call(
320 self,
321 mut context: ToolCallContext<S>,
322 ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>> {
323 $(
324 let result = $Tn::from_context_part(&mut context);
325 let $Tn = match result {
326 Ok(value) => value,
327 Err(e) => return std::future::ready(Err(e)).boxed(),
328 };
329 )*
330 std::future::ready(self($($Tn,)*).into_call_tool_result()).boxed()
331 }
332 }
333 };
334}
335impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);
336
337#[cfg(test)]
338mod tests {
339 use serde::{Deserialize, Serialize};
340 use serde_json::json;
341 use tokio_util::sync::CancellationToken;
342
343 use super::*;
344 use crate::model::NumberOrString;
345
346 #[derive(Debug, Clone)]
347 struct TestService {
348 #[allow(dead_code)]
349 value: String,
350 }
351
352 #[derive(Debug, Deserialize, Serialize)]
353 struct TestParams {
354 message: String,
355 count: i32,
356 }
357
358 #[tokio::test]
359 async fn test_parse_json_object_valid() {
360 let mut json = JsonObject::new();
361 json.insert("message".to_string(), json!("hello"));
362 json.insert("count".to_string(), json!(42));
363
364 let result: Result<TestParams, _> = parse_json_object(json);
365 assert!(result.is_ok());
366 let params = result.unwrap();
367 assert_eq!(params.message, "hello");
368 assert_eq!(params.count, 42);
369 }
370
371 #[tokio::test]
372 async fn test_parse_json_object_invalid() {
373 let mut json = JsonObject::new();
374 json.insert("message".to_string(), json!("hello"));
375 let result: Result<TestParams, _> = parse_json_object(json);
378 assert!(result.is_err());
379 let err = result.unwrap_err();
380 assert!(err.message.contains("failed to deserialize"));
381 }
382
383 #[tokio::test]
384 async fn test_parse_json_object_type_mismatch() {
385 let mut json = JsonObject::new();
386 json.insert("message".to_string(), json!("hello"));
387 json.insert("count".to_string(), json!("not a number")); let result: Result<TestParams, _> = parse_json_object(json);
390 assert!(result.is_err());
391 }
392
393 #[tokio::test]
394 async fn test_into_call_tool_result_string() {
395 let result = "success".to_string().into_call_tool_result();
396 assert!(result.is_ok());
397 let tool_result = result.unwrap();
398 assert_eq!(tool_result.is_error, Some(false));
399 assert_eq!(tool_result.content.len(), 1);
400 if let Some(text) = tool_result.content[0].as_text() {
401 assert_eq!(text.text, "success");
402 } else {
403 panic!("Expected text content");
404 }
405 }
406
407 #[tokio::test]
408 async fn test_into_call_tool_result_ok_variant() {
409 let result: Result<String, String> = Ok("success".to_string());
410 let tool_result = result.into_call_tool_result().unwrap();
411 assert_eq!(tool_result.is_error, Some(false));
412 assert_eq!(tool_result.content.len(), 1);
413 }
414
415 #[tokio::test]
416 async fn test_into_call_tool_result_err_variant() {
417 let result: Result<String, String> = Err("error".to_string());
418 let tool_result = result.into_call_tool_result().unwrap();
419 assert_eq!(tool_result.is_error, Some(true));
420 assert_eq!(tool_result.content.len(), 1);
421 if let Some(text) = tool_result.content[0].as_text() {
422 assert_eq!(text.text, "error");
423 } else {
424 panic!("Expected text content");
425 }
426 }
427
428 #[tokio::test]
429 async fn test_into_call_tool_result_error_data() {
430 let error = crate::ErrorData::invalid_params("bad params".to_string(), None);
431 let result: Result<String, crate::ErrorData> = Err(error);
432 let tool_result = result.into_call_tool_result();
433 assert!(tool_result.is_err());
434 assert!(tool_result.unwrap_err().message.contains("bad params"));
435 }
436
437 #[tokio::test]
438 async fn test_tool_name_extraction() {
439 let service = TestService {
440 value: "test".to_string(),
441 };
442 let request_context = RequestContext {
443 peer: crate::service::Peer::new(
444 std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
445 None,
446 )
447 .0,
448 ct: CancellationToken::new(),
449 id: NumberOrString::Number(1),
450 meta: Default::default(),
451 extensions: Default::default(),
452 };
453
454 let mut context = ToolCallContext::new(
455 &service,
456 CallToolRequestParams {
457 meta: None,
458 name: "test_tool".into(),
459 arguments: None,
460 task: None,
461 },
462 request_context,
463 );
464
465 let tool_name = ToolName::from_context_part(&mut context).unwrap();
466 assert_eq!(tool_name.0, "test_tool");
467 }
468
469 #[tokio::test]
470 async fn test_parameters_extraction() {
471 let service = TestService {
472 value: "test".to_string(),
473 };
474 let mut args = JsonObject::new();
475 args.insert("message".to_string(), json!("hello"));
476 args.insert("count".to_string(), json!(42));
477
478 let request_context = RequestContext {
479 peer: crate::service::Peer::new(
480 std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
481 None,
482 )
483 .0,
484 ct: CancellationToken::new(),
485 id: NumberOrString::Number(1),
486 meta: Default::default(),
487 extensions: Default::default(),
488 };
489
490 let mut context = ToolCallContext::new(
491 &service,
492 CallToolRequestParams {
493 meta: None,
494 name: "test_tool".into(),
495 arguments: Some(args),
496 task: None,
497 },
498 request_context,
499 );
500
501 let params: Parameters<TestParams> = Parameters::from_context_part(&mut context).unwrap();
502 assert_eq!(params.0.message, "hello");
503 assert_eq!(params.0.count, 42);
504 assert!(context.arguments.is_none());
506 }
507
508 #[tokio::test]
509 async fn test_parameters_extraction_empty() {
510 let service = TestService {
511 value: "test".to_string(),
512 };
513
514 let request_context = RequestContext {
515 peer: crate::service::Peer::new(
516 std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
517 None,
518 )
519 .0,
520 ct: CancellationToken::new(),
521 id: NumberOrString::Number(1),
522 meta: Default::default(),
523 extensions: Default::default(),
524 };
525
526 let mut context = ToolCallContext::new(
527 &service,
528 CallToolRequestParams {
529 meta: None,
530 name: "test_tool".into(),
531 arguments: None,
532 task: None,
533 },
534 request_context,
535 );
536
537 let json_obj: JsonObject = JsonObject::from_context_part(&mut context).unwrap();
539 assert!(json_obj.is_empty());
540 }
541
542 #[tokio::test]
543 async fn test_async_handler_success() {
544 async fn async_tool(params: Parameters<TestParams>) -> String {
545 format!("{} x {}", params.0.message, params.0.count)
546 }
547
548 let service = TestService {
549 value: "test".to_string(),
550 };
551 let mut args = JsonObject::new();
552 args.insert("message".to_string(), json!("hello"));
553 args.insert("count".to_string(), json!(3));
554
555 let request_context = RequestContext {
556 peer: crate::service::Peer::new(
557 std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
558 None,
559 )
560 .0,
561 ct: CancellationToken::new(),
562 id: NumberOrString::Number(1),
563 meta: Default::default(),
564 extensions: Default::default(),
565 };
566
567 let context = ToolCallContext::new(
568 &service,
569 CallToolRequestParams {
570 meta: None,
571 name: "async_tool".into(),
572 arguments: Some(args),
573 task: None,
574 },
575 request_context,
576 );
577
578 let result = context.invoke(async_tool).await;
579 assert!(result.is_ok());
580 let tool_result = result.unwrap();
581 assert_eq!(tool_result.is_error, Some(false));
582 if let Some(text) = tool_result.content[0].as_text() {
583 assert_eq!(text.text, "hello x 3");
584 } else {
585 panic!("Expected text content");
586 }
587 }
588
589 #[tokio::test]
590 async fn test_sync_handler_success() {
591 fn sync_tool(params: Parameters<TestParams>) -> String {
592 format!("{} x {}", params.0.message, params.0.count)
593 }
594
595 let service = TestService {
596 value: "test".to_string(),
597 };
598 let mut args = JsonObject::new();
599 args.insert("message".to_string(), json!("test"));
600 args.insert("count".to_string(), json!(5));
601
602 let request_context = RequestContext {
603 peer: crate::service::Peer::new(
604 std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
605 None,
606 )
607 .0,
608 ct: CancellationToken::new(),
609 id: NumberOrString::Number(1),
610 meta: Default::default(),
611 extensions: Default::default(),
612 };
613
614 let context = ToolCallContext::new(
615 &service,
616 CallToolRequestParams {
617 meta: None,
618 name: "sync_tool".into(),
619 arguments: Some(args),
620 task: None,
621 },
622 request_context,
623 );
624
625 let result = context.invoke(sync_tool).await;
626 assert!(result.is_ok());
627 let tool_result = result.unwrap();
628 assert_eq!(tool_result.is_error, Some(false));
629 if let Some(text) = tool_result.content[0].as_text() {
630 assert_eq!(text.text, "test x 5");
631 } else {
632 panic!("Expected text content");
633 }
634 }
635
636 #[tokio::test]
637 async fn test_handler_with_result_error() {
638 async fn failing_tool(_params: Parameters<TestParams>) -> Result<String, String> {
639 Err("Tool execution failed".to_string())
640 }
641
642 let service = TestService {
643 value: "test".to_string(),
644 };
645 let mut args = JsonObject::new();
646 args.insert("message".to_string(), json!("test"));
647 args.insert("count".to_string(), json!(1));
648
649 let request_context = RequestContext {
650 peer: crate::service::Peer::new(
651 std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
652 None,
653 )
654 .0,
655 ct: CancellationToken::new(),
656 id: NumberOrString::Number(1),
657 meta: Default::default(),
658 extensions: Default::default(),
659 };
660
661 let context = ToolCallContext::new(
662 &service,
663 CallToolRequestParams {
664 meta: None,
665 name: "failing_tool".into(),
666 arguments: Some(args),
667 task: None,
668 },
669 request_context,
670 );
671
672 let result = context.invoke(failing_tool).await;
673 assert!(result.is_ok());
674 let tool_result = result.unwrap();
675 assert_eq!(tool_result.is_error, Some(true));
676 if let Some(text) = tool_result.content[0].as_text() {
677 assert_eq!(text.text, "Tool execution failed");
678 } else {
679 panic!("Expected text content");
680 }
681 }
682
683 #[tokio::test]
684 async fn test_handler_with_json_string_output() {
685 async fn json_tool(params: Parameters<TestParams>) -> String {
686 let result = json!({
687 "message": params.0.message,
688 "count": params.0.count,
689 "computed": params.0.count * 2
690 });
691 result.to_string()
692 }
693
694 let service = TestService {
695 value: "test".to_string(),
696 };
697 let mut args = JsonObject::new();
698 args.insert("message".to_string(), json!("hello"));
699 args.insert("count".to_string(), json!(10));
700
701 let request_context = RequestContext {
702 peer: crate::service::Peer::new(
703 std::sync::Arc::new(crate::service::AtomicU32RequestIdProvider::default()),
704 None,
705 )
706 .0,
707 ct: CancellationToken::new(),
708 id: NumberOrString::Number(1),
709 meta: Default::default(),
710 extensions: Default::default(),
711 };
712
713 let context = ToolCallContext::new(
714 &service,
715 CallToolRequestParams {
716 meta: None,
717 name: "json_tool".into(),
718 arguments: Some(args),
719 task: None,
720 },
721 request_context,
722 );
723
724 let result = context.invoke(json_tool).await;
725 assert!(result.is_ok());
726 let tool_result = result.unwrap();
727 assert_eq!(tool_result.is_error, Some(false));
728 if let Some(text) = tool_result.content[0].as_text() {
729 let parsed: serde_json::Value = serde_json::from_str(&text.text).unwrap();
730 assert_eq!(parsed["message"], "hello");
731 assert_eq!(parsed["count"], 10);
732 assert_eq!(parsed["computed"], 20);
733 } else {
734 panic!("Expected text content");
735 }
736 }
737}