1use std::{
2 any::TypeId, borrow::Cow, collections::HashMap, future::Ready, marker::PhantomData, sync::Arc,
3};
4
5use futures::future::{BoxFuture, FutureExt};
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8use tokio_util::sync::CancellationToken;
9
10pub use super::router::tool::{ToolRoute, ToolRouter};
11use crate::{
12 RoleServer,
13 model::{CallToolRequestParam, CallToolResult, IntoContents, JsonObject},
14 service::RequestContext,
15};
16
17pub fn schema_for_type<T: JsonSchema>() -> JsonObject {
19 let mut settings = schemars::r#gen::SchemaSettings::draft07();
22 settings.option_nullable = true;
23 settings.option_add_null_type = false;
24 settings.visitors = Vec::default();
25 let generator = settings.into_generator();
26 let schema = generator.into_root_schema_for::<T>();
27 let object = serde_json::to_value(schema).expect("failed to serialize schema");
28 match object {
29 serde_json::Value::Object(object) => object,
30 _ => panic!("unexpected schema value"),
31 }
32}
33
34pub fn cached_schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
36 thread_local! {
37 static CACHE_FOR_TYPE: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
38 };
39 CACHE_FOR_TYPE.with(|cache| {
40 if let Some(x) = cache
41 .read()
42 .expect("schema cache lock poisoned")
43 .get(&TypeId::of::<T>())
44 {
45 x.clone()
46 } else {
47 let schema = schema_for_type::<T>();
48 let schema = Arc::new(schema);
49 cache
50 .write()
51 .expect("schema cache lock poisoned")
52 .insert(TypeId::of::<T>(), schema.clone());
53 schema
54 }
55 })
56}
57
58pub fn parse_json_object<T: DeserializeOwned>(input: JsonObject) -> Result<T, crate::Error> {
60 serde_json::from_value(serde_json::Value::Object(input)).map_err(|e| {
61 crate::Error::invalid_params(
62 format!("failed to deserialize parameters: {error}", error = e),
63 None,
64 )
65 })
66}
67pub struct ToolCallContext<'s, S> {
68 pub request_context: RequestContext<RoleServer>,
69 pub service: &'s S,
70 pub name: Cow<'static, str>,
71 pub arguments: Option<JsonObject>,
72}
73
74impl<'s, S> ToolCallContext<'s, S> {
75 pub fn new(
76 service: &'s S,
77 CallToolRequestParam { name, arguments }: CallToolRequestParam,
78 request_context: RequestContext<RoleServer>,
79 ) -> Self {
80 Self {
81 request_context,
82 service,
83 name,
84 arguments,
85 }
86 }
87 pub fn name(&self) -> &str {
88 &self.name
89 }
90 pub fn request_context(&self) -> &RequestContext<RoleServer> {
91 &self.request_context
92 }
93}
94
95pub trait FromToolCallContextPart<S>: Sized {
96 fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error>;
97}
98
99pub trait IntoCallToolResult {
100 fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error>;
101}
102impl IntoCallToolResult for () {
103 fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
104 Ok(CallToolResult::success(vec![]))
105 }
106}
107
108impl<T: IntoContents> IntoCallToolResult for T {
109 fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
110 Ok(CallToolResult::success(self.into_contents()))
111 }
112}
113
114impl<T: IntoContents, E: IntoContents> IntoCallToolResult for Result<T, E> {
115 fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
116 match self {
117 Ok(value) => Ok(CallToolResult::success(value.into_contents())),
118 Err(error) => Ok(CallToolResult::error(error.into_contents())),
119 }
120 }
121}
122
123pin_project_lite::pin_project! {
124 #[project = IntoCallToolResultFutProj]
125 pub enum IntoCallToolResultFut<F, R> {
126 Pending {
127 #[pin]
128 fut: F,
129 _marker: PhantomData<R>,
130 },
131 Ready {
132 #[pin]
133 result: Ready<Result<CallToolResult, crate::Error>>,
134 }
135 }
136}
137
138impl<F, R> Future for IntoCallToolResultFut<F, R>
139where
140 F: Future<Output = R>,
141 R: IntoCallToolResult,
142{
143 type Output = Result<CallToolResult, crate::Error>;
144
145 fn poll(
146 self: std::pin::Pin<&mut Self>,
147 cx: &mut std::task::Context<'_>,
148 ) -> std::task::Poll<Self::Output> {
149 match self.project() {
150 IntoCallToolResultFutProj::Pending { fut, _marker } => {
151 fut.poll(cx).map(IntoCallToolResult::into_call_tool_result)
152 }
153 IntoCallToolResultFutProj::Ready { result } => result.poll(cx),
154 }
155 }
156}
157
158impl IntoCallToolResult for Result<CallToolResult, crate::Error> {
159 fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
160 self
161 }
162}
163
164pub trait CallToolHandler<S, A> {
165 fn call(
166 self,
167 context: ToolCallContext<'_, S>,
168 ) -> BoxFuture<'_, Result<CallToolResult, crate::Error>>;
169}
170
171pub type DynCallToolHandler<S> = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result<CallToolResult, crate::Error>>
172 + Send
173 + Sync;
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
178#[serde(transparent)]
179pub struct Parameters<P>(pub P);
180
181impl<P: JsonSchema> JsonSchema for Parameters<P> {
182 fn schema_name() -> String {
183 P::schema_name()
184 }
185
186 fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
187 P::json_schema(generator)
188 }
189}
190
191impl<S> FromToolCallContextPart<S> for CancellationToken {
192 fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
193 Ok(context.request_context.ct.clone())
194 }
195}
196
197pub struct ToolName(pub Cow<'static, str>);
198
199impl<S> FromToolCallContextPart<S> for ToolName {
200 fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
201 Ok(Self(context.name.clone()))
202 }
203}
204
205impl<S, P> FromToolCallContextPart<S> for Parameters<P>
206where
207 P: DeserializeOwned,
208{
209 fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
210 let arguments = context.arguments.take().unwrap_or_default();
211 let value: P =
212 serde_json::from_value(serde_json::Value::Object(arguments)).map_err(|e| {
213 crate::Error::invalid_params(
214 format!("failed to deserialize parameters: {error}", error = e),
215 None,
216 )
217 })?;
218 Ok(Parameters(value))
219 }
220}
221
222impl<S> FromToolCallContextPart<S> for JsonObject {
223 fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
224 let object = context.arguments.take().unwrap_or_default();
225 Ok(object)
226 }
227}
228
229impl<S> FromToolCallContextPart<S> for crate::model::Extensions {
230 fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
231 let extensions = context.request_context.extensions.clone();
232 Ok(extensions)
233 }
234}
235
236pub struct Extension<T>(pub T);
237
238impl<S, T> FromToolCallContextPart<S> for Extension<T>
239where
240 T: Send + Sync + 'static + Clone,
241{
242 fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
243 let extension = context
244 .request_context
245 .extensions
246 .get::<T>()
247 .cloned()
248 .ok_or_else(|| {
249 crate::Error::invalid_params(
250 format!("missing extension {}", std::any::type_name::<T>()),
251 None,
252 )
253 })?;
254 Ok(Extension(extension))
255 }
256}
257
258impl<S> FromToolCallContextPart<S> for crate::Peer<RoleServer> {
259 fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
260 let peer = context.request_context.peer.clone();
261 Ok(peer)
262 }
263}
264
265impl<S> FromToolCallContextPart<S> for crate::model::Meta {
266 fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
267 let mut meta = crate::model::Meta::default();
268 std::mem::swap(&mut meta, &mut context.request_context.meta);
269 Ok(meta)
270 }
271}
272
273pub struct RequestId(pub crate::model::RequestId);
274impl<S> FromToolCallContextPart<S> for RequestId {
275 fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
276 Ok(RequestId(context.request_context.id.clone()))
277 }
278}
279
280impl<S> FromToolCallContextPart<S> for RequestContext<RoleServer> {
281 fn from_tool_call_context_part(context: &mut ToolCallContext<S>) -> Result<Self, crate::Error> {
282 Ok(context.request_context.clone())
283 }
284}
285
286impl<'s, S> ToolCallContext<'s, S> {
287 pub fn invoke<H, A>(self, h: H) -> BoxFuture<'s, Result<CallToolResult, crate::Error>>
288 where
289 H: CallToolHandler<S, A>,
290 {
291 h.call(self)
292 }
293}
294#[allow(clippy::type_complexity)]
295pub struct AsyncAdapter<P, Fut, R>(PhantomData<fn(P) -> fn(Fut) -> R>);
296pub struct SyncAdapter<P, R>(PhantomData<fn(P) -> R>);
297pub struct AsyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
299pub struct SyncMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
300
301macro_rules! impl_for {
302 ($($T: ident)*) => {
303 impl_for!([] [$($T)*]);
304 };
305 ([$($Tn: ident)*] []) => {
307 impl_for!(@impl $($Tn)*);
308 };
309 ([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => {
310 impl_for!(@impl $($Tn)*);
311 impl_for!([$($Tn)* $Tn_1] [$($Rest)*]);
312 };
313 (@impl $($Tn: ident)*) => {
314 impl<$($Tn,)* S, F, R> CallToolHandler<S, AsyncMethodAdapter<($($Tn,)*), R>> for F
315 where
316 $(
317 $Tn: FromToolCallContextPart<S> ,
318 )*
319 F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R>,
320
321 R: IntoCallToolResult + Send + 'static,
324 S: Send + Sync + 'static,
325 {
326 #[allow(unused_variables, non_snake_case, unused_mut)]
327 fn call(
328 self,
329 mut context: ToolCallContext<'_, S>,
330 ) -> BoxFuture<'_, Result<CallToolResult, crate::Error>>{
331 $(
332 let result = $Tn::from_tool_call_context_part(&mut context);
333 let $Tn = match result {
334 Ok(value) => value,
335 Err(e) => return std::future::ready(Err(e)).boxed(),
336 };
337 )*
338 let service = context.service;
339 let fut = self(service, $($Tn,)*);
340 async move {
341 let result = fut.await;
342 result.into_call_tool_result()
343 }.boxed()
344 }
345 }
346
347 impl<$($Tn,)* S, F, Fut, R> CallToolHandler<S, AsyncAdapter<($($Tn,)*), Fut, R>> for F
348 where
349 $(
350 $Tn: FromToolCallContextPart<S> ,
351 )*
352 F: FnOnce($($Tn,)*) -> Fut + Send + ,
353 Fut: Future<Output = R> + Send + 'static,
354 R: IntoCallToolResult + Send + 'static,
355 S: Send + Sync,
356 {
357 #[allow(unused_variables, non_snake_case, unused_mut)]
358 fn call(
359 self,
360 mut context: ToolCallContext<S>,
361 ) -> BoxFuture<'static, Result<CallToolResult, crate::Error>>{
362 $(
363 let result = $Tn::from_tool_call_context_part(&mut context);
364 let $Tn = match result {
365 Ok(value) => value,
366 Err(e) => return std::future::ready(Err(e)).boxed(),
367 };
368 )*
369 let fut = self($($Tn,)*);
370 async move {
371 let result = fut.await;
372 result.into_call_tool_result()
373 }.boxed()
374 }
375 }
376
377 impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncMethodAdapter<($($Tn,)*), R>> for F
378 where
379 $(
380 $Tn: FromToolCallContextPart<S> + ,
381 )*
382 F: FnOnce(&S, $($Tn,)*) -> R + Send + ,
383 R: IntoCallToolResult + Send + ,
384 S: Send + Sync,
385 {
386 #[allow(unused_variables, non_snake_case, unused_mut)]
387 fn call(
388 self,
389 mut context: ToolCallContext<S>,
390 ) -> BoxFuture<'static, Result<CallToolResult, crate::Error>> {
391 $(
392 let result = $Tn::from_tool_call_context_part(&mut context);
393 let $Tn = match result {
394 Ok(value) => value,
395 Err(e) => return std::future::ready(Err(e)).boxed(),
396 };
397 )*
398 std::future::ready(self(context.service, $($Tn,)*).into_call_tool_result()).boxed()
399 }
400 }
401
402 impl<$($Tn,)* S, F, R> CallToolHandler<S, SyncAdapter<($($Tn,)*), R>> for F
403 where
404 $(
405 $Tn: FromToolCallContextPart<S> + ,
406 )*
407 F: FnOnce($($Tn,)*) -> R + Send + ,
408 R: IntoCallToolResult + Send + ,
409 S: Send + Sync,
410 {
411 #[allow(unused_variables, non_snake_case, unused_mut)]
412 fn call(
413 self,
414 mut context: ToolCallContext<S>,
415 ) -> BoxFuture<'static, Result<CallToolResult, crate::Error>> {
416 $(
417 let result = $Tn::from_tool_call_context_part(&mut context);
418 let $Tn = match result {
419 Ok(value) => value,
420 Err(e) => return std::future::ready(Err(e)).boxed(),
421 };
422 )*
423 std::future::ready(self($($Tn,)*).into_call_tool_result()).boxed()
424 }
425 }
426 };
427}
428impl_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);