1use std::{
2 borrow::Cow,
3 future::{Future, Ready},
4 marker::PhantomData,
5};
6
7use futures::future::{BoxFuture, FutureExt};
8use serde::de::DeserializeOwned;
9
10use super::common::{AsRequestContext, FromContextPart};
11pub use super::{
12 common::{Extension, RequestId, schema_for_output, schema_for_type},
13 router::tool::{ToolRoute, ToolRouter},
14};
15use crate::{
16 RoleServer,
17 handler::server::wrapper::Parameters,
18 model::{CallToolRequestParams, CallToolResult, IntoContents, JsonObject},
19 service::RequestContext,
20};
21
22pub fn parse_json_object<T: DeserializeOwned>(input: JsonObject) -> Result<T, crate::ErrorData> {
24 serde_json::from_value(serde_json::Value::Object(input)).map_err(|e| {
25 crate::ErrorData::invalid_params(
26 format!("failed to deserialize parameters: {error}", error = e),
27 None,
28 )
29 })
30}
31pub struct ToolCallContext<'s, S> {
32 pub request_context: RequestContext<RoleServer>,
33 pub service: &'s S,
34 pub name: Cow<'static, str>,
35 pub arguments: Option<JsonObject>,
36 pub task: Option<JsonObject>,
37}
38
39impl<'s, S> ToolCallContext<'s, S> {
40 pub fn new(
41 service: &'s S,
42 CallToolRequestParams {
43 meta: _,
44 name,
45 arguments,
46 task,
47 }: CallToolRequestParams,
48 request_context: RequestContext<RoleServer>,
49 ) -> Self {
50 Self {
51 request_context,
52 service,
53 name,
54 arguments,
55 task,
56 }
57 }
58 pub fn name(&self) -> &str {
59 &self.name
60 }
61 pub fn request_context(&self) -> &RequestContext<RoleServer> {
62 &self.request_context
63 }
64}
65
66impl<S> AsRequestContext for ToolCallContext<'_, S> {
67 fn as_request_context(&self) -> &RequestContext<RoleServer> {
68 &self.request_context
69 }
70
71 fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer> {
72 &mut self.request_context
73 }
74}
75
76pub trait IntoCallToolResult {
77 fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData>;
78}
79
80impl<T: IntoContents> IntoCallToolResult for T {
81 fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
82 Ok(CallToolResult::success(self.into_contents()))
83 }
84}
85
86impl<T: IntoContents, E: IntoContents> IntoCallToolResult for Result<T, E> {
87 fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
88 match self {
89 Ok(value) => Ok(CallToolResult::success(value.into_contents())),
90 Err(error) => Ok(CallToolResult::error(error.into_contents())),
91 }
92 }
93}
94
95impl<T: IntoCallToolResult> IntoCallToolResult for Result<T, crate::ErrorData> {
96 fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
97 match self {
98 Ok(value) => value.into_call_tool_result(),
99 Err(error) => Err(error),
100 }
101 }
102}
103
104pin_project_lite::pin_project! {
105 #[project = IntoCallToolResultFutProj]
106 pub enum IntoCallToolResultFut<F, R> {
107 Pending {
108 #[pin]
109 fut: F,
110 _marker: PhantomData<R>,
111 },
112 Ready {
113 #[pin]
114 result: Ready<Result<CallToolResult, crate::ErrorData>>,
115 }
116 }
117}
118
119impl<F, R> Future for IntoCallToolResultFut<F, R>
120where
121 F: Future<Output = R>,
122 R: IntoCallToolResult,
123{
124 type Output = Result<CallToolResult, crate::ErrorData>;
125
126 fn poll(
127 self: std::pin::Pin<&mut Self>,
128 cx: &mut std::task::Context<'_>,
129 ) -> std::task::Poll<Self::Output> {
130 match self.project() {
131 IntoCallToolResultFutProj::Pending { fut, _marker } => {
132 fut.poll(cx).map(IntoCallToolResult::into_call_tool_result)
133 }
134 IntoCallToolResultFutProj::Ready { result } => result.poll(cx),
135 }
136 }
137}
138
139impl IntoCallToolResult for Result<CallToolResult, crate::ErrorData> {
140 fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
141 self
142 }
143}
144
145pub trait CallToolHandler<S, A> {
146 fn call(
147 self,
148 context: ToolCallContext<'_, S>,
149 ) -> BoxFuture<'_, Result<CallToolResult, crate::ErrorData>>;
150}
151
152pub type DynCallToolHandler<S> = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result<CallToolResult, crate::ErrorData>>
153 + Send
154 + Sync;
155
156pub struct ToolName(pub Cow<'static, str>);
158
159impl<S> FromContextPart<ToolCallContext<'_, S>> for ToolName {
160 fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
161 Ok(Self(context.name.clone()))
162 }
163}
164
165impl<S, P> FromContextPart<ToolCallContext<'_, S>> for Parameters<P>
167where
168 P: DeserializeOwned,
169{
170 fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
171 let arguments = context.arguments.take().unwrap_or_default();
172 let value: P =
173 serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| {
174 crate::ErrorData::invalid_params(
175 format!("failed to deserialize parameters: {error}", error = e),
176 None,
177 )
178 })?;
179 Ok(Parameters(value))
180 }
181}
182
183impl<S> FromContextPart<ToolCallContext<'_, S>> for JsonObject {
185 fn from_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::ErrorData> {
186 let object = context.arguments.take().unwrap_or_default();
187 Ok(object)
188 }
189}
190
191impl<'s, S> ToolCallContext<'s, S> {
192 pub fn invoke<H, A>(self, h: H) -> BoxFuture<'s, Result<CallToolResult, crate::ErrorData>>
193 where
194 H: CallToolHandler<S, A>,
195 {
196 h.call(self)
197 }
198}
199#[allow(clippy::type_complexity)]
200pub struct AsyncAdapter<P, Fut, R>(PhantomData<fn(P) -> fn(Fut) -> R>);
201pub struct SyncAdapter<P, R>(PhantomData<fn(P) -> R>);
202pub struct AsyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
204pub struct SyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
205
206macro_rules! impl_for {
207 ($($T: ident)*) => {
208 impl_for!([] [$($T)*]);
209 };
210 ([$($Tn: ident)*] []) => {
212 impl_for!(@impl $($Tn)*);
213 };
214 ([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => {
215 impl_for!(@impl $($Tn)*);
216 impl_for!([$($Tn)* $Tn_1] [$($Rest)*]);
217 };
218 (@impl $($Tn: ident)*) => {
219 impl<$($Tn,)* S, F, R> CallToolHandler<S, AsyncMethodAdapter<($($Tn,)*), R>> for F
220 where
221 $(
222 $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> ,
223 )*
224 F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R>,
225
226 R: IntoCallToolResult + Send + 'static,
229 S: Send + Sync + 'static,
230 {
231 #[allow(unused_variables, non_snake_case, unused_mut)]
232 fn call(
233 self,
234 mut context: ToolCallContext<'_, S>,
235 ) -> BoxFuture<'_, Result<CallToolResult, crate::ErrorData>>{
236 $(
237 let result = $Tn::from_context_part(&mut context);
238 let $Tn = match result {
239 Ok(value) => value,
240 Err(e) => return std::future::ready(Err(e)).boxed(),
241 };
242 )*
243 let service = context.service;
244 let fut = self(service, $($Tn,)*);
245 async move {
246 let result = fut.await;
247 result.into_call_tool_result()
248 }.boxed()
249 }
250 }
251
252 impl<$($Tn,)* S, F, Fut, R> CallToolHandler<S, AsyncAdapter<($($Tn,)*), Fut, R>> for F
253 where
254 $(
255 $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> ,
256 )*
257 F: FnOnce($($Tn,)*) -> Fut + Send + ,
258 Fut: Future<Output = R> + Send + 'static,
259 R: IntoCallToolResult + Send + 'static,
260 S: Send + Sync,
261 {
262 #[allow(unused_variables, non_snake_case, unused_mut)]
263 fn call(
264 self,
265 mut context: ToolCallContext<S>,
266 ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>>{
267 $(
268 let result = $Tn::from_context_part(&mut context);
269 let $Tn = match result {
270 Ok(value) => value,
271 Err(e) => return std::future::ready(Err(e)).boxed(),
272 };
273 )*
274 let fut = self($($Tn,)*);
275 async move {
276 let result = fut.await;
277 result.into_call_tool_result()
278 }.boxed()
279 }
280 }
281
282 impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncMethodAdapter<($($Tn,)*), R>> for F
283 where
284 $(
285 $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> + ,
286 )*
287 F: FnOnce(&S, $($Tn,)*) -> R + Send + ,
288 R: IntoCallToolResult + Send + ,
289 S: Send + Sync,
290 {
291 #[allow(unused_variables, non_snake_case, unused_mut)]
292 fn call(
293 self,
294 mut context: ToolCallContext<S>,
295 ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>> {
296 $(
297 let result = $Tn::from_context_part(&mut context);
298 let $Tn = match result {
299 Ok(value) => value,
300 Err(e) => return std::future::ready(Err(e)).boxed(),
301 };
302 )*
303 std::future::ready(self(context.service, $($Tn,)*).into_call_tool_result()).boxed()
304 }
305 }
306
307 impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncAdapter<($($Tn,)*), R>> for F
308 where
309 $(
310 $Tn: for<'a> FromContextPart<ToolCallContext<'a, S>> + ,
311 )*
312 F: FnOnce($($Tn,)*) -> R + Send + ,
313 R: IntoCallToolResult + Send + ,
314 S: Send + Sync,
315 {
316 #[allow(unused_variables, non_snake_case, unused_mut)]
317 fn call(
318 self,
319 mut context: ToolCallContext<S>,
320 ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>> {
321 $(
322 let result = $Tn::from_context_part(&mut context);
323 let $Tn = match result {
324 Ok(value) => value,
325 Err(e) => return std::future::ready(Err(e)).boxed(),
326 };
327 )*
328 std::future::ready(self($($Tn,)*).into_call_tool_result()).boxed()
329 }
330 }
331 };
332}
333impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);