1use std::{future::Future, marker::PhantomData};
8
9use futures::future::{BoxFuture, FutureExt};
10use serde::de::DeserializeOwned;
11
12use super::common::{AsRequestContext, FromContextPart};
13pub use super::common::{Extension, RequestId};
14use crate::{
15 RoleServer,
16 handler::server::wrapper::Parameters,
17 model::{GetPromptResult, PromptMessage},
18 service::RequestContext,
19};
20
21pub struct PromptContext<'a, S> {
23 pub server: &'a S,
24 pub name: String,
25 pub arguments: Option<serde_json::Map<String, serde_json::Value>>,
26 pub context: RequestContext<RoleServer>,
27}
28
29impl<'a, S> PromptContext<'a, S> {
30 pub fn new(
31 server: &'a S,
32 name: String,
33 arguments: Option<serde_json::Map<String, serde_json::Value>>,
34 context: RequestContext<RoleServer>,
35 ) -> Self {
36 Self {
37 server,
38 name,
39 arguments,
40 context,
41 }
42 }
43}
44
45impl<S> AsRequestContext for PromptContext<'_, S> {
46 fn as_request_context(&self) -> &RequestContext<RoleServer> {
47 &self.context
48 }
49
50 fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer> {
51 &mut self.context
52 }
53}
54
55pub trait GetPromptHandler<S, A> {
57 fn handle(
58 self,
59 context: PromptContext<'_, S>,
60 ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>;
61}
62
63pub type DynGetPromptHandler<S> = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result<GetPromptResult, crate::ErrorData>>
65 + Send
66 + Sync;
67
68pub struct AsyncMethodAdapter<T>(PhantomData<T>);
70
71pub struct AsyncMethodWithArgsAdapter<T>(PhantomData<T>);
73
74#[allow(clippy::type_complexity)]
76pub struct AsyncPromptAdapter<P, Fut, R>(PhantomData<fn(P) -> fn(Fut) -> R>);
77pub struct SyncPromptAdapter<P, R>(PhantomData<fn(P) -> R>);
78pub struct AsyncPromptMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
79pub struct SyncPromptMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
80
81pub trait IntoGetPromptResult {
83 fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData>;
84}
85
86impl IntoGetPromptResult for GetPromptResult {
87 fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData> {
88 Ok(self)
89 }
90}
91
92impl IntoGetPromptResult for Vec<PromptMessage> {
93 fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData> {
94 Ok(GetPromptResult {
95 description: None,
96 messages: self,
97 })
98 }
99}
100
101impl<T: IntoGetPromptResult> IntoGetPromptResult for Result<T, crate::ErrorData> {
102 fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData> {
103 self.and_then(|v| v.into_get_prompt_result())
104 }
105}
106
107pin_project_lite::pin_project! {
109 #[project = IntoGetPromptResultFutProj]
110 pub enum IntoGetPromptResultFut<F, R> {
111 Pending {
112 #[pin]
113 fut: F,
114 _marker: PhantomData<R>,
115 },
116 Ready {
117 #[pin]
118 result: futures::future::Ready<Result<GetPromptResult, crate::ErrorData>>,
119 }
120 }
121}
122
123impl<F, R> Future for IntoGetPromptResultFut<F, R>
124where
125 F: Future<Output = R>,
126 R: IntoGetPromptResult,
127{
128 type Output = Result<GetPromptResult, crate::ErrorData>;
129
130 fn poll(
131 self: std::pin::Pin<&mut Self>,
132 cx: &mut std::task::Context<'_>,
133 ) -> std::task::Poll<Self::Output> {
134 match self.project() {
135 IntoGetPromptResultFutProj::Pending { fut, _marker } => fut
136 .poll(cx)
137 .map(IntoGetPromptResult::into_get_prompt_result),
138 IntoGetPromptResultFutProj::Ready { result } => result.poll(cx),
139 }
140 }
141}
142
143pub struct PromptName(pub String);
145
146impl<S> FromContextPart<PromptContext<'_, S>> for PromptName {
147 fn from_context_part(context: &mut PromptContext<S>) -> Result<Self, crate::ErrorData> {
148 Ok(Self(context.name.clone()))
149 }
150}
151
152impl<S, P> FromContextPart<PromptContext<'_, S>> for Parameters<P>
154where
155 P: DeserializeOwned,
156{
157 fn from_context_part(context: &mut PromptContext<S>) -> Result<Self, crate::ErrorData> {
158 let params = if let Some(args_map) = context.arguments.take() {
159 let args_value = serde_json::Value::Object(args_map);
160 serde_json::from_value::<P>(args_value).map_err(|e| {
161 crate::ErrorData::invalid_params(format!("Failed to parse parameters: {}", e), None)
162 })?
163 } else {
164 serde_json::from_value::<P>(serde_json::json!({})).map_err(|e| {
166 crate::ErrorData::invalid_params(
167 format!("Missing required parameters: {}", e),
168 None,
169 )
170 })?
171 };
172 Ok(Parameters(params))
173 }
174}
175
176macro_rules! impl_prompt_handler_for {
178 ($($T: ident)*) => {
179 impl_prompt_handler_for!([] [$($T)*]);
180 };
181 ([$($Tn: ident)*] []) => {
183 impl_prompt_handler_for!(@impl $($Tn)*);
184 };
185 ([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => {
186 impl_prompt_handler_for!(@impl $($Tn)*);
187 impl_prompt_handler_for!([$($Tn)* $Tn_1] [$($Rest)*]);
188 };
189 (@impl $($Tn: ident)*) => {
190 impl<$($Tn,)* S, F, R> GetPromptHandler<S, ($($Tn,)*)> for F
192 where
193 $(
194 $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send,
195 )*
196 F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R> + Send,
197 R: IntoGetPromptResult + Send + 'static,
198 S: Send + Sync + 'static,
199 {
200 #[allow(unused_variables, non_snake_case, unused_mut)]
201 fn handle(
202 self,
203 mut context: PromptContext<'_, S>,
204 ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
205 {
206 $(
207 let result = $Tn::from_context_part(&mut context);
208 let $Tn = match result {
209 Ok(value) => value,
210 Err(e) => return std::future::ready(Err(e)).boxed(),
211 };
212 )*
213 let service = context.server;
214 let fut = self(service, $($Tn,)*);
215 async move {
216 let result = fut.await;
217 result.into_get_prompt_result()
218 }.boxed()
219 }
220 }
221
222
223 impl<$($Tn,)* S, F, R> GetPromptHandler<S, SyncPromptMethodAdapter<($($Tn,)*), R>> for F
225 where
226 $(
227 $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send,
228 )*
229 F: FnOnce(&S, $($Tn,)*) -> R + Send,
230 R: IntoGetPromptResult + Send,
231 S: Send + Sync,
232 {
233 #[allow(unused_variables, non_snake_case, unused_mut)]
234 fn handle(
235 self,
236 mut context: PromptContext<'_, S>,
237 ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
238 {
239 $(
240 let result = $Tn::from_context_part(&mut context);
241 let $Tn = match result {
242 Ok(value) => value,
243 Err(e) => return std::future::ready(Err(e)).boxed(),
244 };
245 )*
246 let service = context.server;
247 let result = self(service, $($Tn,)*);
248 std::future::ready(result.into_get_prompt_result()).boxed()
249 }
250 }
251
252
253 impl<$($Tn,)* S, F, Fut, R> GetPromptHandler<S, AsyncPromptAdapter<($($Tn,)*), Fut, R>> for F
255 where
256 $(
257 $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send + 'static,
258 )*
259 F: FnOnce($($Tn,)*) -> Fut + Send + 'static,
260 Fut: Future<Output = Result<R, crate::ErrorData>> + Send + 'static,
261 R: IntoGetPromptResult + Send + 'static,
262 S: Send + Sync + 'static,
263 {
264 #[allow(unused_variables, non_snake_case, unused_mut)]
265 fn handle(
266 self,
267 mut context: PromptContext<'_, S>,
268 ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
269 {
270 $(
272 let result = $Tn::from_context_part(&mut context);
273 let $Tn = match result {
274 Ok(value) => value,
275 Err(e) => return std::future::ready(Err(e)).boxed(),
276 };
277 )*
278
279 Box::pin(async move {
282 let result = self($($Tn,)*).await?;
283 result.into_get_prompt_result()
284 })
285 }
286 }
287
288
289 impl<$($Tn,)* S, F, R> GetPromptHandler<S, SyncPromptAdapter<($($Tn,)*), R>> for F
291 where
292 $(
293 $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send + 'static,
294 )*
295 F: FnOnce($($Tn,)*) -> Result<R, crate::ErrorData> + Send + 'static,
296 R: IntoGetPromptResult + Send + 'static,
297 S: Send + Sync,
298 {
299 #[allow(unused_variables, non_snake_case, unused_mut)]
300 fn handle(
301 self,
302 mut context: PromptContext<'_, S>,
303 ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
304 {
305 $(
306 let result = $Tn::from_context_part(&mut context);
307 let $Tn = match result {
308 Ok(value) => value,
309 Err(e) => return std::future::ready(Err(e)).boxed(),
310 };
311 )*
312 let result = self($($Tn,)*);
313 std::future::ready(result.and_then(|r| r.into_get_prompt_result())).boxed()
314 }
315 }
316
317 };
318}
319
320impl_prompt_handler_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);
322
323#[cfg(feature = "schemars")]
327pub fn cached_arguments_from_schema<T: schemars::JsonSchema + std::any::Any>()
328-> Option<Vec<crate::model::PromptArgument>> {
329 let schema = super::common::schema_for_type::<T>();
330 let schema_value = serde_json::Value::Object((*schema).clone());
331
332 let properties = schema_value.get("properties").and_then(|p| p.as_object());
333
334 if let Some(props) = properties {
335 let required = schema_value
336 .get("required")
337 .and_then(|r| r.as_array())
338 .map(|arr| {
339 arr.iter()
340 .filter_map(|v| v.as_str())
341 .collect::<std::collections::HashSet<_>>()
342 })
343 .unwrap_or_default();
344
345 let mut arguments = Vec::new();
346 for (name, prop_schema) in props {
347 let description = prop_schema
348 .get("description")
349 .and_then(|d| d.as_str())
350 .map(|s| s.to_string());
351
352 arguments.push(crate::model::PromptArgument {
353 name: name.clone(),
354 title: None,
355 description,
356 required: Some(required.contains(name.as_str())),
357 });
358 }
359
360 if arguments.is_empty() {
361 None
362 } else {
363 Some(arguments)
364 }
365 } else {
366 None
367 }
368}