Skip to main content

mcpkit_rs/handler/server/
prompt.rs

1//! Prompt handling infrastructure for MCP servers
2//!
3//! This module provides the core types and traits for implementing prompt handlers
4//! in MCP servers. Prompts allow servers to provide reusable templates for LLM
5//! interactions with customizable arguments.
6
7use std::{future::Future, marker::PhantomData};
8
9use futures::future::{BoxFuture, FutureExt};
10use serde::de::DeserializeOwned;
11
12use super::common::{AsRequestContext, FromContextPart};
13pub use super::common::{Extension, RequestId};
14use crate::{
15    RoleServer,
16    handler::server::wrapper::Parameters,
17    model::{GetPromptResult, PromptMessage},
18    service::RequestContext,
19};
20
21/// Context for prompt retrieval operations
22pub struct PromptContext<'a, S> {
23    pub server: &'a S,
24    pub name: String,
25    pub arguments: Option<serde_json::Map<String, serde_json::Value>>,
26    pub context: RequestContext<RoleServer>,
27}
28
29impl<'a, S> PromptContext<'a, S> {
30    pub fn new(
31        server: &'a S,
32        name: String,
33        arguments: Option<serde_json::Map<String, serde_json::Value>>,
34        context: RequestContext<RoleServer>,
35    ) -> Self {
36        Self {
37            server,
38            name,
39            arguments,
40            context,
41        }
42    }
43}
44
45impl<S> AsRequestContext for PromptContext<'_, S> {
46    fn as_request_context(&self) -> &RequestContext<RoleServer> {
47        &self.context
48    }
49
50    fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer> {
51        &mut self.context
52    }
53}
54
55/// Trait for handling prompt retrieval
56pub trait GetPromptHandler<S, A> {
57    fn handle(
58        self,
59        context: PromptContext<'_, S>,
60    ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>;
61}
62
63/// Type alias for dynamic prompt handlers
64pub type DynGetPromptHandler<S> = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result<GetPromptResult, crate::ErrorData>>
65    + Send
66    + Sync;
67
68/// Adapter type for async methods that return `Vec<PromptMessage>`
69pub struct AsyncMethodAdapter<T>(PhantomData<T>);
70
71/// Adapter type for async methods with parameters that return `Vec<PromptMessage>`
72pub struct AsyncMethodWithArgsAdapter<T>(PhantomData<T>);
73
74/// Adapter types for macro-generated implementations
75#[allow(clippy::type_complexity)]
76pub struct AsyncPromptAdapter<P, Fut, R>(PhantomData<fn(P) -> fn(Fut) -> R>);
77pub struct SyncPromptAdapter<P, R>(PhantomData<fn(P) -> R>);
78pub struct AsyncPromptMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
79pub struct SyncPromptMethodAdapter<P, R>(PhantomData<fn(P) -> R>);
80
81/// Trait for types that can be converted into GetPromptResult
82pub trait IntoGetPromptResult {
83    fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData>;
84}
85
86impl IntoGetPromptResult for GetPromptResult {
87    fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData> {
88        Ok(self)
89    }
90}
91
92impl IntoGetPromptResult for Vec<PromptMessage> {
93    fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData> {
94        Ok(GetPromptResult {
95            description: None,
96            messages: self,
97        })
98    }
99}
100
101impl<T: IntoGetPromptResult> IntoGetPromptResult for Result<T, crate::ErrorData> {
102    fn into_get_prompt_result(self) -> Result<GetPromptResult, crate::ErrorData> {
103        self.and_then(|v| v.into_get_prompt_result())
104    }
105}
106
107// Future wrapper that automatically handles IntoGetPromptResult conversion
108pin_project_lite::pin_project! {
109    #[project = IntoGetPromptResultFutProj]
110    pub enum IntoGetPromptResultFut<F, R> {
111        Pending {
112            #[pin]
113            fut: F,
114            _marker: PhantomData<R>,
115        },
116        Ready {
117            #[pin]
118            result: futures::future::Ready<Result<GetPromptResult, crate::ErrorData>>,
119        }
120    }
121}
122
123impl<F, R> Future for IntoGetPromptResultFut<F, R>
124where
125    F: Future<Output = R>,
126    R: IntoGetPromptResult,
127{
128    type Output = Result<GetPromptResult, crate::ErrorData>;
129
130    fn poll(
131        self: std::pin::Pin<&mut Self>,
132        cx: &mut std::task::Context<'_>,
133    ) -> std::task::Poll<Self::Output> {
134        match self.project() {
135            IntoGetPromptResultFutProj::Pending { fut, _marker } => fut
136                .poll(cx)
137                .map(IntoGetPromptResult::into_get_prompt_result),
138            IntoGetPromptResultFutProj::Ready { result } => result.poll(cx),
139        }
140    }
141}
142
143// Prompt-specific extractor for prompt name
144pub struct PromptName(pub String);
145
146impl<S> FromContextPart<PromptContext<'_, S>> for PromptName {
147    fn from_context_part(context: &mut PromptContext<S>) -> Result<Self, crate::ErrorData> {
148        Ok(Self(context.name.clone()))
149    }
150}
151
152// Special implementation for Parameters that handles prompt arguments
153impl<S, P> FromContextPart<PromptContext<'_, S>> for Parameters<P>
154where
155    P: DeserializeOwned,
156{
157    fn from_context_part(context: &mut PromptContext<S>) -> Result<Self, crate::ErrorData> {
158        let params = if let Some(args_map) = context.arguments.take() {
159            let args_value = serde_json::Value::Object(args_map);
160            serde_json::from_value::<P>(args_value).map_err(|e| {
161                crate::ErrorData::invalid_params(format!("Failed to parse parameters: {}", e), None)
162            })?
163        } else {
164            // Try to deserialize from empty object for optional fields
165            serde_json::from_value::<P>(serde_json::json!({})).map_err(|e| {
166                crate::ErrorData::invalid_params(
167                    format!("Missing required parameters: {}", e),
168                    None,
169                )
170            })?
171        };
172        Ok(Parameters(params))
173    }
174}
175
176// Macro to generate GetPromptHandler implementations for various parameter combinations
177macro_rules! impl_prompt_handler_for {
178    ($($T: ident)*) => {
179        impl_prompt_handler_for!([] [$($T)*]);
180    };
181    // finished
182    ([$($Tn: ident)*] []) => {
183        impl_prompt_handler_for!(@impl $($Tn)*);
184    };
185    ([$($Tn: ident)*] [$Tn_1: ident $($Rest: ident)*]) => {
186        impl_prompt_handler_for!(@impl $($Tn)*);
187        impl_prompt_handler_for!([$($Tn)* $Tn_1] [$($Rest)*]);
188    };
189    (@impl $($Tn: ident)*) => {
190        // Implementation for async methods (transformed by #[prompt] macro)
191        impl<$($Tn,)* S, F, R> GetPromptHandler<S, ($($Tn,)*)> for F
192        where
193            $(
194                $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send,
195            )*
196            F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R> + Send,
197            R: IntoGetPromptResult + Send + 'static,
198            S: Send + Sync + 'static,
199        {
200            #[allow(unused_variables, non_snake_case, unused_mut)]
201            fn handle(
202                self,
203                mut context: PromptContext<'_, S>,
204            ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
205            {
206                $(
207                    let result = $Tn::from_context_part(&mut context);
208                    let $Tn = match result {
209                        Ok(value) => value,
210                        Err(e) => return std::future::ready(Err(e)).boxed(),
211                    };
212                )*
213                let service = context.server;
214                let fut = self(service, $($Tn,)*);
215                async move {
216                    let result = fut.await;
217                    result.into_get_prompt_result()
218                }.boxed()
219            }
220        }
221
222
223        // Implementation for sync methods
224        impl<$($Tn,)* S, F, R> GetPromptHandler<S, SyncPromptMethodAdapter<($($Tn,)*), R>> for F
225        where
226            $(
227                $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send,
228            )*
229            F: FnOnce(&S, $($Tn,)*) -> R + Send,
230            R: IntoGetPromptResult + Send,
231            S: Send + Sync,
232        {
233            #[allow(unused_variables, non_snake_case, unused_mut)]
234            fn handle(
235                self,
236                mut context: PromptContext<'_, S>,
237            ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
238            {
239                $(
240                    let result = $Tn::from_context_part(&mut context);
241                    let $Tn = match result {
242                        Ok(value) => value,
243                        Err(e) => return std::future::ready(Err(e)).boxed(),
244                    };
245                )*
246                let service = context.server;
247                let result = self(service, $($Tn,)*);
248                std::future::ready(result.into_get_prompt_result()).boxed()
249            }
250        }
251
252
253        // AsyncPromptAdapter - for standalone functions returning GetPromptResult
254        impl<$($Tn,)* S, F, Fut, R> GetPromptHandler<S, AsyncPromptAdapter<($($Tn,)*), Fut, R>> for F
255        where
256            $(
257                $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send + 'static,
258            )*
259            F: FnOnce($($Tn,)*) -> Fut + Send + 'static,
260            Fut: Future<Output = Result<R, crate::ErrorData>> + Send + 'static,
261            R: IntoGetPromptResult + Send + 'static,
262            S: Send + Sync + 'static,
263        {
264            #[allow(unused_variables, non_snake_case, unused_mut)]
265            fn handle(
266                self,
267                mut context: PromptContext<'_, S>,
268            ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
269            {
270                // Extract all parameters before moving into the async block
271                $(
272                    let result = $Tn::from_context_part(&mut context);
273                    let $Tn = match result {
274                        Ok(value) => value,
275                        Err(e) => return std::future::ready(Err(e)).boxed(),
276                    };
277                )*
278
279                // Since we're dealing with standalone functions that don't take &S,
280                // we can return a 'static future
281                Box::pin(async move {
282                    let result = self($($Tn,)*).await?;
283                    result.into_get_prompt_result()
284                })
285            }
286        }
287
288
289        // SyncPromptAdapter - for standalone sync functions returning Result
290        impl<$($Tn,)* S, F, R> GetPromptHandler<S, SyncPromptAdapter<($($Tn,)*), R>> for F
291        where
292            $(
293                $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send + 'static,
294            )*
295            F: FnOnce($($Tn,)*) -> Result<R, crate::ErrorData> + Send + 'static,
296            R: IntoGetPromptResult + Send + 'static,
297            S: Send + Sync,
298        {
299            #[allow(unused_variables, non_snake_case, unused_mut)]
300            fn handle(
301                self,
302                mut context: PromptContext<'_, S>,
303            ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>
304            {
305                $(
306                    let result = $Tn::from_context_part(&mut context);
307                    let $Tn = match result {
308                        Ok(value) => value,
309                        Err(e) => return std::future::ready(Err(e)).boxed(),
310                    };
311                )*
312                let result = self($($Tn,)*);
313                std::future::ready(result.and_then(|r| r.into_get_prompt_result())).boxed()
314            }
315        }
316
317    };
318}
319
320// Invoke the macro to generate implementations for up to 16 parameters
321impl_prompt_handler_for!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);
322
323/// Extract prompt arguments from a type's JSON schema
324/// This function analyzes the schema of a type and extracts the properties
325/// as PromptArgument entries with name, description, and required status
326#[cfg(feature = "schemars")]
327pub fn cached_arguments_from_schema<T: schemars::JsonSchema + std::any::Any>()
328-> Option<Vec<crate::model::PromptArgument>> {
329    let schema = super::common::schema_for_type::<T>();
330    let schema_value = serde_json::Value::Object((*schema).clone());
331
332    let properties = schema_value.get("properties").and_then(|p| p.as_object());
333
334    if let Some(props) = properties {
335        let required = schema_value
336            .get("required")
337            .and_then(|r| r.as_array())
338            .map(|arr| {
339                arr.iter()
340                    .filter_map(|v| v.as_str())
341                    .collect::<std::collections::HashSet<_>>()
342            })
343            .unwrap_or_default();
344
345        let mut arguments = Vec::new();
346        for (name, prop_schema) in props {
347            let description = prop_schema
348                .get("description")
349                .and_then(|d| d.as_str())
350                .map(|s| s.to_string());
351
352            arguments.push(crate::model::PromptArgument {
353                name: name.clone(),
354                title: None,
355                description,
356                required: Some(required.contains(name.as_str())),
357            });
358        }
359
360        if arguments.is_empty() {
361            None
362        } else {
363            Some(arguments)
364        }
365    } else {
366        None
367    }
368}