Skip to main content

tower_mcp/
prompt.rs

1//! Prompt definition and builder API
2//!
3//! Provides ergonomic ways to define MCP prompts:
4//!
5//! 1. **Builder pattern** - Fluent API for defining prompts
6//! 2. **Trait-based** - Implement `McpPrompt` for full control
7//! 3. **Per-prompt middleware** - Apply tower middleware layers to individual prompts
8//!
9//! # Per-Prompt Middleware
10//!
11//! The `.layer()` method on `PromptBuilder` (after `.handler()`) allows applying
12//! tower middleware to a single prompt. This is useful for prompt-specific concerns
13//! like timeouts, rate limiting, or caching.
14//!
15//! ```rust
16//! use std::collections::HashMap;
17//! use std::time::Duration;
18//! use tower::timeout::TimeoutLayer;
19//! use tower_mcp::prompt::PromptBuilder;
20//! use tower_mcp::protocol::{GetPromptResult, PromptMessage, PromptRole, Content};
21//!
22//! let prompt = PromptBuilder::new("slow_prompt")
23//!     .description("A prompt that might take a while")
24//!     .handler(|args: HashMap<String, String>| async move {
25//!         // Slow prompt generation logic...
26//!         Ok(GetPromptResult {
27//!             description: Some("Generated prompt".to_string()),
28//!             messages: vec![PromptMessage {
29//!                 role: PromptRole::User,
30//!                 content: Content::Text {
31//!                     text: "Hello!".to_string(),
32//!                     annotations: None,
33//!                     meta: None,
34//!                 },
35//!                 meta: None,
36//!             }],
37//!             meta: None,
38//!         })
39//!     })
40//!     .layer(TimeoutLayer::new(Duration::from_secs(5)));
41//!
42//! assert_eq!(prompt.name, "slow_prompt");
43//! ```
44
45use std::collections::HashMap;
46use std::convert::Infallible;
47use std::fmt;
48use std::future::Future;
49use std::pin::Pin;
50use std::sync::Arc;
51use std::task::{Context, Poll};
52
53use pin_project_lite::pin_project;
54
55use tokio::sync::Mutex;
56use tower::util::BoxCloneService;
57use tower::{Layer, ServiceExt};
58use tower_service::Service;
59
60use crate::context::RequestContext;
61use crate::error::{Error, Result};
62use crate::protocol::{
63    Content, GetPromptResult, PromptArgument, PromptDefinition, PromptMessage, PromptRole,
64    RequestId, ToolIcon,
65};
66
67/// A boxed future for prompt handlers
68pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
69
70// =============================================================================
71// Per-Prompt Middleware Types
72// =============================================================================
73
74/// Request type for prompt middleware.
75///
76/// Contains the request context and prompt arguments, allowing middleware
77/// to access and modify the request before it reaches the prompt handler.
78#[derive(Debug, Clone)]
79pub struct PromptRequest {
80    /// The request context with progress reporting, cancellation, etc.
81    pub context: RequestContext,
82    /// The prompt arguments (name -> value)
83    pub arguments: HashMap<String, String>,
84}
85
86impl PromptRequest {
87    /// Create a new prompt request with the given context and arguments.
88    pub fn new(context: RequestContext, arguments: HashMap<String, String>) -> Self {
89        Self { context, arguments }
90    }
91
92    /// Create a prompt request with a default context (for testing or simple use cases).
93    pub fn with_arguments(arguments: HashMap<String, String>) -> Self {
94        Self {
95            context: RequestContext::new(RequestId::Number(0)),
96            arguments,
97        }
98    }
99}
100
101/// A boxed, cloneable prompt service with `Error = Infallible`.
102///
103/// This is the service type used internally after applying middleware layers.
104/// It wraps any `Service<PromptRequest>` implementation so that the prompt
105/// handler can consume it without knowing the concrete middleware stack.
106pub type BoxPromptService = BoxCloneService<PromptRequest, GetPromptResult, Infallible>;
107
108/// A service wrapper that catches errors from middleware and converts them
109/// into prompt errors, maintaining the `Error = Infallible` contract.
110///
111/// When a middleware layer (e.g., `TimeoutLayer`) produces an error, this
112/// wrapper converts it into a prompt error. This allows error information to
113/// flow through the normal response path rather than requiring special
114/// error handling.
115#[doc(hidden)]
116pub struct PromptCatchError<S> {
117    inner: S,
118}
119
120impl<S> PromptCatchError<S> {
121    /// Create a new `PromptCatchError` wrapping the given service.
122    pub fn new(inner: S) -> Self {
123        Self { inner }
124    }
125}
126
127impl<S: Clone> Clone for PromptCatchError<S> {
128    fn clone(&self) -> Self {
129        Self {
130            inner: self.inner.clone(),
131        }
132    }
133}
134
135impl<S: fmt::Debug> fmt::Debug for PromptCatchError<S> {
136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137        f.debug_struct("PromptCatchError")
138            .field("inner", &self.inner)
139            .finish()
140    }
141}
142
143pin_project! {
144    /// Future for [`PromptCatchError`].
145    #[doc(hidden)]
146    pub struct PromptCatchErrorFuture<F> {
147        #[pin]
148        inner: F,
149    }
150}
151
152impl<F, E> Future for PromptCatchErrorFuture<F>
153where
154    F: Future<Output = std::result::Result<GetPromptResult, E>>,
155    E: fmt::Display,
156{
157    type Output = std::result::Result<GetPromptResult, Infallible>;
158
159    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
160        match self.project().inner.poll(cx) {
161            Poll::Pending => Poll::Pending,
162            Poll::Ready(Ok(response)) => Poll::Ready(Ok(response)),
163            Poll::Ready(Err(err)) => Poll::Ready(Ok(GetPromptResult {
164                description: Some(format!("Prompt error: {}", err)),
165                messages: vec![PromptMessage {
166                    role: PromptRole::Assistant,
167                    content: Content::Text {
168                        text: format!("Error generating prompt: {}", err),
169                        annotations: None,
170                        meta: None,
171                    },
172                    meta: None,
173                }],
174                meta: None,
175            })),
176        }
177    }
178}
179
180impl<S> Service<PromptRequest> for PromptCatchError<S>
181where
182    S: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
183    S::Error: fmt::Display + Send,
184    S::Future: Send,
185{
186    type Response = GetPromptResult;
187    type Error = Infallible;
188    type Future = PromptCatchErrorFuture<S::Future>;
189
190    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
191        self.inner.poll_ready(cx).map_err(|_| unreachable!())
192    }
193
194    fn call(&mut self, req: PromptRequest) -> Self::Future {
195        PromptCatchErrorFuture {
196            inner: self.inner.call(req),
197        }
198    }
199}
200
201/// Adapts a prompt handler function into a `Service<PromptRequest>`.
202///
203/// This allows the handler to be wrapped with tower middleware layers.
204/// Used by `.layer()` on `PromptBuilderWithHandler`.
205#[doc(hidden)]
206pub struct PromptHandlerService<F> {
207    handler: F,
208}
209
210impl<F> Clone for PromptHandlerService<F>
211where
212    F: Clone,
213{
214    fn clone(&self) -> Self {
215        Self {
216            handler: self.handler.clone(),
217        }
218    }
219}
220
221impl<F, Fut> Service<PromptRequest> for PromptHandlerService<F>
222where
223    F: Fn(HashMap<String, String>) -> Fut + Clone + Send + Sync + 'static,
224    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
225{
226    type Response = GetPromptResult;
227    type Error = Error;
228    type Future = Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Error>> + Send>>;
229
230    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
231        Poll::Ready(Ok(()))
232    }
233
234    fn call(&mut self, req: PromptRequest) -> Self::Future {
235        let handler = self.handler.clone();
236        Box::pin(async move { handler(req.arguments).await })
237    }
238}
239
240/// Adapts a context-aware prompt handler function into a `Service<PromptRequest>`.
241///
242/// Used by `.layer()` on `PromptBuilderWithContextHandler`.
243#[doc(hidden)]
244pub struct PromptContextHandlerService<F> {
245    handler: F,
246}
247
248impl<F> Clone for PromptContextHandlerService<F>
249where
250    F: Clone,
251{
252    fn clone(&self) -> Self {
253        Self {
254            handler: self.handler.clone(),
255        }
256    }
257}
258
259impl<F, Fut> Service<PromptRequest> for PromptContextHandlerService<F>
260where
261    F: Fn(RequestContext, HashMap<String, String>) -> Fut + Clone + Send + Sync + 'static,
262    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
263{
264    type Response = GetPromptResult;
265    type Error = Error;
266    type Future = Pin<Box<dyn Future<Output = std::result::Result<GetPromptResult, Error>> + Send>>;
267
268    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
269        Poll::Ready(Ok(()))
270    }
271
272    fn call(&mut self, req: PromptRequest) -> Self::Future {
273        let handler = self.handler.clone();
274        Box::pin(async move { handler(req.context, req.arguments).await })
275    }
276}
277
278/// Prompt handler trait - the core abstraction for prompt generation
279pub trait PromptHandler: Send + Sync {
280    /// Get the prompt with the given arguments
281    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>>;
282
283    /// Get the prompt with request context
284    ///
285    /// The default implementation ignores the context and calls `get`.
286    /// Override this to receive context for progress reporting, cancellation, etc.
287    fn get_with_context(
288        &self,
289        _ctx: RequestContext,
290        arguments: HashMap<String, String>,
291    ) -> BoxFuture<'_, Result<GetPromptResult>> {
292        self.get(arguments)
293    }
294
295    /// Returns true if this handler uses context (for optimization)
296    fn uses_context(&self) -> bool {
297        false
298    }
299}
300
301/// A complete prompt definition with handler
302pub struct Prompt {
303    /// The prompt name (must be unique within the router).
304    pub name: String,
305    /// Optional human-readable title.
306    pub title: Option<String>,
307    /// Optional description of the prompt.
308    pub description: Option<String>,
309    /// Optional icons for the prompt.
310    pub icons: Option<Vec<ToolIcon>>,
311    /// The arguments this prompt accepts.
312    pub arguments: Vec<PromptArgument>,
313    handler: Arc<dyn PromptHandler>,
314}
315
316impl Clone for Prompt {
317    fn clone(&self) -> Self {
318        Self {
319            name: self.name.clone(),
320            title: self.title.clone(),
321            description: self.description.clone(),
322            icons: self.icons.clone(),
323            arguments: self.arguments.clone(),
324            handler: self.handler.clone(),
325        }
326    }
327}
328
329impl std::fmt::Debug for Prompt {
330    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331        f.debug_struct("Prompt")
332            .field("name", &self.name)
333            .field("title", &self.title)
334            .field("description", &self.description)
335            .field("icons", &self.icons)
336            .field("arguments", &self.arguments)
337            .finish_non_exhaustive()
338    }
339}
340
341impl Prompt {
342    /// Create a new prompt builder
343    pub fn builder(name: impl Into<String>) -> PromptBuilder {
344        PromptBuilder::new(name)
345    }
346
347    /// Get the prompt definition for prompts/list
348    pub fn definition(&self) -> PromptDefinition {
349        PromptDefinition {
350            name: self.name.clone(),
351            title: self.title.clone(),
352            description: self.description.clone(),
353            icons: self.icons.clone(),
354            arguments: self.arguments.clone(),
355            meta: None,
356        }
357    }
358
359    /// Get the prompt with arguments
360    pub fn get(
361        &self,
362        arguments: HashMap<String, String>,
363    ) -> BoxFuture<'_, Result<GetPromptResult>> {
364        self.handler.get(arguments)
365    }
366
367    /// Get the prompt with request context
368    ///
369    /// Use this when you have a RequestContext available for progress/cancellation.
370    pub fn get_with_context(
371        &self,
372        ctx: RequestContext,
373        arguments: HashMap<String, String>,
374    ) -> BoxFuture<'_, Result<GetPromptResult>> {
375        self.handler.get_with_context(ctx, arguments)
376    }
377
378    /// Returns true if this prompt uses context
379    pub fn uses_context(&self) -> bool {
380        self.handler.uses_context()
381    }
382}
383
384// =============================================================================
385// Builder API
386// =============================================================================
387
388/// Builder for creating prompts with a fluent API
389///
390/// # Example
391///
392/// ```rust
393/// use tower_mcp::prompt::PromptBuilder;
394/// use tower_mcp::protocol::{GetPromptResult, PromptMessage, PromptRole, Content};
395///
396/// let prompt = PromptBuilder::new("greet")
397///     .description("Generate a greeting")
398///     .required_arg("name", "The name to greet")
399///     .handler(|args| async move {
400///         let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
401///         Ok(GetPromptResult {
402///             description: Some("A greeting prompt".to_string()),
403///             messages: vec![PromptMessage {
404///                 role: PromptRole::User,
405///                 content: Content::Text {
406///                     text: format!("Please greet {}", name),
407///                     annotations: None,
408///                     meta: None,
409///                 },
410///                 meta: None,
411///             }],
412///             meta: None,
413///         })
414///     })
415///     .build();
416///
417/// assert_eq!(prompt.name, "greet");
418/// ```
419pub struct PromptBuilder {
420    name: String,
421    title: Option<String>,
422    description: Option<String>,
423    icons: Option<Vec<ToolIcon>>,
424    arguments: Vec<PromptArgument>,
425}
426
427impl PromptBuilder {
428    /// Create a new prompt builder with the given name.
429    pub fn new(name: impl Into<String>) -> Self {
430        Self {
431            name: name.into(),
432            title: None,
433            description: None,
434            icons: None,
435            arguments: Vec::new(),
436        }
437    }
438
439    /// Set a human-readable title for the prompt
440    pub fn title(mut self, title: impl Into<String>) -> Self {
441        self.title = Some(title.into());
442        self
443    }
444
445    /// Set the prompt description
446    pub fn description(mut self, description: impl Into<String>) -> Self {
447        self.description = Some(description.into());
448        self
449    }
450
451    /// Add an icon for the prompt
452    pub fn icon(mut self, src: impl Into<String>) -> Self {
453        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
454            src: src.into(),
455            mime_type: None,
456            sizes: None,
457            theme: None,
458        });
459        self
460    }
461
462    /// Add an icon with metadata
463    pub fn icon_with_meta(
464        mut self,
465        src: impl Into<String>,
466        mime_type: Option<String>,
467        sizes: Option<Vec<String>>,
468    ) -> Self {
469        self.icons.get_or_insert_with(Vec::new).push(ToolIcon {
470            src: src.into(),
471            mime_type,
472            sizes,
473            theme: None,
474        });
475        self
476    }
477
478    /// Add a required argument
479    pub fn required_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
480        self.arguments.push(PromptArgument {
481            name: name.into(),
482            description: Some(description.into()),
483            required: true,
484        });
485        self
486    }
487
488    /// Add an optional argument
489    pub fn optional_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
490        self.arguments.push(PromptArgument {
491            name: name.into(),
492            description: Some(description.into()),
493            required: false,
494        });
495        self
496    }
497
498    /// Add an argument with full control
499    pub fn argument(mut self, arg: PromptArgument) -> Self {
500        self.arguments.push(arg);
501        self
502    }
503
504    /// Set the handler function for getting the prompt.
505    ///
506    /// Returns a `PromptBuilderWithHandler` which can be finalized with `.build()`
507    /// or have middleware applied with `.layer()`.
508    ///
509    /// # Sharing State
510    ///
511    /// Capture an [`Arc`] in the closure to share state across handler
512    /// invocations or with other parts of your application:
513    ///
514    /// ```rust
515    /// use std::collections::HashMap;
516    /// use std::sync::Arc;
517    /// use tokio::sync::RwLock;
518    /// use tower_mcp::prompt::PromptBuilder;
519    /// use tower_mcp::protocol::{GetPromptResult, PromptMessage, PromptRole, Content};
520    ///
521    /// let templates = Arc::new(RwLock::new(HashMap::from([
522    ///     ("greeting".to_string(), "Hello, {name}!".to_string()),
523    /// ])));
524    ///
525    /// let tpl = Arc::clone(&templates);
526    /// let prompt = PromptBuilder::new("greet")
527    ///     .description("Greet a user by name")
528    ///     .required_arg("name", "The user's name")
529    ///     .handler(move |args: HashMap<String, String>| {
530    ///         let tpl = Arc::clone(&tpl);
531    ///         async move {
532    ///             let templates = tpl.read().await;
533    ///             let greeting = templates.get("greeting").unwrap();
534    ///             let name = args.get("name").unwrap();
535    ///             let text = greeting.replace("{name}", name);
536    ///             Ok(GetPromptResult {
537    ///                 description: Some("A greeting".to_string()),
538    ///                 messages: vec![PromptMessage {
539    ///                     role: PromptRole::User,
540    ///                     content: Content::text(text),
541    ///                     meta: None,
542    ///                 }],
543    ///                 meta: None,
544    ///             })
545    ///         }
546    ///     })
547    ///     .build();
548    /// ```
549    ///
550    /// [`Arc`]: std::sync::Arc
551    pub fn handler<F, Fut>(self, handler: F) -> PromptBuilderWithHandler<F>
552    where
553        F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
554        Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
555    {
556        PromptBuilderWithHandler {
557            name: self.name,
558            title: self.title,
559            description: self.description,
560            icons: self.icons,
561            arguments: self.arguments,
562            handler,
563        }
564    }
565
566    /// Set a context-aware handler function for getting the prompt
567    ///
568    /// The handler receives a `RequestContext` for progress reporting and
569    /// cancellation checking, along with the prompt arguments.
570    pub fn handler_with_context<F, Fut>(self, handler: F) -> PromptBuilderWithContextHandler<F>
571    where
572        F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
573        Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
574    {
575        PromptBuilderWithContextHandler {
576            name: self.name,
577            title: self.title,
578            description: self.description,
579            icons: self.icons,
580            arguments: self.arguments,
581            handler,
582        }
583    }
584
585    /// Create a static prompt (no arguments needed)
586    pub fn static_prompt(self, messages: Vec<PromptMessage>) -> Prompt {
587        let description = self.description.clone();
588        self.handler(move |_| {
589            let messages = messages.clone();
590            let description = description.clone();
591            async move {
592                Ok(GetPromptResult {
593                    description,
594                    messages,
595                    meta: None,
596                })
597            }
598        })
599        .build()
600    }
601
602    /// Create a simple text prompt with a user message
603    pub fn user_message(self, text: impl Into<String>) -> Prompt {
604        let text = text.into();
605        self.static_prompt(vec![PromptMessage {
606            role: PromptRole::User,
607            content: Content::Text {
608                text,
609                annotations: None,
610                meta: None,
611            },
612            meta: None,
613        }])
614    }
615
616    /// Finalize the builder into a Prompt
617    ///
618    /// This is an alias for `handler(...).build()` for when you want to
619    /// explicitly mark the build step.
620    pub fn build<F, Fut>(self, handler: F) -> Prompt
621    where
622        F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
623        Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
624    {
625        self.handler(handler).build()
626    }
627}
628
629/// Builder state after handler is specified
630///
631/// This allows either calling `.build()` to create the prompt directly,
632/// or `.layer()` to apply middleware before building.
633#[doc(hidden)]
634pub struct PromptBuilderWithHandler<F> {
635    name: String,
636    title: Option<String>,
637    description: Option<String>,
638    icons: Option<Vec<ToolIcon>>,
639    arguments: Vec<PromptArgument>,
640    handler: F,
641}
642
643impl<F, Fut> PromptBuilderWithHandler<F>
644where
645    F: Fn(HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
646    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
647{
648    /// Build the prompt without any middleware
649    pub fn build(self) -> Prompt {
650        Prompt {
651            name: self.name,
652            title: self.title,
653            description: self.description,
654            icons: self.icons,
655            arguments: self.arguments,
656            handler: Arc::new(FnHandler {
657                handler: self.handler,
658            }),
659        }
660    }
661
662    /// Apply a tower middleware layer to this prompt
663    ///
664    /// The layer wraps the prompt handler, allowing middleware like timeouts,
665    /// rate limiting, or retries to be applied to this specific prompt.
666    ///
667    /// # Example
668    ///
669    /// ```rust
670    /// use std::collections::HashMap;
671    /// use std::time::Duration;
672    /// use tower::timeout::TimeoutLayer;
673    /// use tower_mcp::prompt::PromptBuilder;
674    /// use tower_mcp::protocol::{GetPromptResult, PromptMessage, PromptRole, Content};
675    ///
676    /// let prompt = PromptBuilder::new("slow_prompt")
677    ///     .description("A prompt that might take a while")
678    ///     .handler(|_args: HashMap<String, String>| async move {
679    ///         Ok(GetPromptResult {
680    ///             description: Some("Generated prompt".to_string()),
681    ///             messages: vec![PromptMessage {
682    ///                 role: PromptRole::User,
683    ///                 content: Content::Text {
684    ///                     text: "Hello!".to_string(),
685    ///                     annotations: None,
686    ///                     meta: None,
687    ///                 },
688    ///                 meta: None,
689    ///             }],
690    ///             meta: None,
691    ///         })
692    ///     })
693    ///     .layer(TimeoutLayer::new(Duration::from_secs(5)));
694    /// ```
695    pub fn layer<L>(self, layer: L) -> Prompt
696    where
697        L: Layer<PromptHandlerService<F>> + Send + Sync + 'static,
698        L::Service: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
699        <L::Service as Service<PromptRequest>>::Error: fmt::Display + Send,
700        <L::Service as Service<PromptRequest>>::Future: Send,
701    {
702        let service = PromptHandlerService {
703            handler: self.handler,
704        };
705        let wrapped = layer.layer(service);
706        let boxed = BoxCloneService::new(PromptCatchError::new(wrapped));
707
708        Prompt {
709            name: self.name,
710            title: self.title,
711            description: self.description,
712            icons: self.icons,
713            arguments: self.arguments,
714            handler: Arc::new(ServiceHandler {
715                service: Mutex::new(boxed),
716            }),
717        }
718    }
719}
720
721/// Builder state after context-aware handler is specified
722#[doc(hidden)]
723pub struct PromptBuilderWithContextHandler<F> {
724    name: String,
725    title: Option<String>,
726    description: Option<String>,
727    icons: Option<Vec<ToolIcon>>,
728    arguments: Vec<PromptArgument>,
729    handler: F,
730}
731
732impl<F, Fut> PromptBuilderWithContextHandler<F>
733where
734    F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + Clone + 'static,
735    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
736{
737    /// Build the prompt without any middleware
738    pub fn build(self) -> Prompt {
739        Prompt {
740            name: self.name,
741            title: self.title,
742            description: self.description,
743            icons: self.icons,
744            arguments: self.arguments,
745            handler: Arc::new(ContextAwareHandler {
746                handler: self.handler,
747            }),
748        }
749    }
750
751    /// Apply a tower middleware layer to this prompt
752    pub fn layer<L>(self, layer: L) -> Prompt
753    where
754        L: Layer<PromptContextHandlerService<F>> + Send + Sync + 'static,
755        L::Service: Service<PromptRequest, Response = GetPromptResult> + Clone + Send + 'static,
756        <L::Service as Service<PromptRequest>>::Error: fmt::Display + Send,
757        <L::Service as Service<PromptRequest>>::Future: Send,
758    {
759        let service = PromptContextHandlerService {
760            handler: self.handler,
761        };
762        let wrapped = layer.layer(service);
763        let boxed = BoxCloneService::new(PromptCatchError::new(wrapped));
764
765        Prompt {
766            name: self.name,
767            title: self.title,
768            description: self.description,
769            icons: self.icons,
770            arguments: self.arguments,
771            handler: Arc::new(ServiceContextHandler {
772                service: Mutex::new(boxed),
773            }),
774        }
775    }
776}
777
778// =============================================================================
779// Handler implementations
780// =============================================================================
781
782/// Handler wrapping a function
783struct FnHandler<F> {
784    handler: F,
785}
786
787impl<F, Fut> PromptHandler for FnHandler<F>
788where
789    F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
790    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
791{
792    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
793        Box::pin((self.handler)(arguments))
794    }
795}
796
797/// Handler that receives request context
798struct ContextAwareHandler<F> {
799    handler: F,
800}
801
802impl<F, Fut> PromptHandler for ContextAwareHandler<F>
803where
804    F: Fn(RequestContext, HashMap<String, String>) -> Fut + Send + Sync + 'static,
805    Fut: Future<Output = Result<GetPromptResult>> + Send + 'static,
806{
807    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
808        // When called without context, create a dummy context
809        let ctx = RequestContext::new(RequestId::Number(0));
810        self.get_with_context(ctx, arguments)
811    }
812
813    fn get_with_context(
814        &self,
815        ctx: RequestContext,
816        arguments: HashMap<String, String>,
817    ) -> BoxFuture<'_, Result<GetPromptResult>> {
818        Box::pin((self.handler)(ctx, arguments))
819    }
820
821    fn uses_context(&self) -> bool {
822        true
823    }
824}
825
826/// Handler wrapping a boxed service (used when middleware is applied)
827///
828/// Uses a Mutex to make the BoxCloneService (which is Send but not Sync) safe
829/// for use in a Sync context. Since we clone the service before each call,
830/// the lock is only held briefly during the clone.
831struct ServiceHandler {
832    service: Mutex<BoxPromptService>,
833}
834
835impl PromptHandler for ServiceHandler {
836    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
837        Box::pin(async move {
838            let req = PromptRequest::with_arguments(arguments);
839            let mut service = self.service.lock().await.clone();
840            match service.ready().await {
841                Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
842                Err(e) => match e {},
843            }
844        })
845    }
846
847    fn get_with_context(
848        &self,
849        ctx: RequestContext,
850        arguments: HashMap<String, String>,
851    ) -> BoxFuture<'_, Result<GetPromptResult>> {
852        Box::pin(async move {
853            let req = PromptRequest::new(ctx, arguments);
854            let mut service = self.service.lock().await.clone();
855            match service.ready().await {
856                Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
857                Err(e) => match e {},
858            }
859        })
860    }
861}
862
863/// Handler wrapping a boxed service for context-aware prompts
864struct ServiceContextHandler {
865    service: Mutex<BoxPromptService>,
866}
867
868impl PromptHandler for ServiceContextHandler {
869    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
870        let ctx = RequestContext::new(RequestId::Number(0));
871        self.get_with_context(ctx, arguments)
872    }
873
874    fn get_with_context(
875        &self,
876        ctx: RequestContext,
877        arguments: HashMap<String, String>,
878    ) -> BoxFuture<'_, Result<GetPromptResult>> {
879        Box::pin(async move {
880            let req = PromptRequest::new(ctx, arguments);
881            let mut service = self.service.lock().await.clone();
882            match service.ready().await {
883                Ok(svc) => svc.call(req).await.map_err(|e| match e {}),
884                Err(e) => match e {},
885            }
886        })
887    }
888
889    fn uses_context(&self) -> bool {
890        true
891    }
892}
893
894// =============================================================================
895// Trait-based prompt definition
896// =============================================================================
897
898/// Trait for defining prompts with full control
899///
900/// Implement this trait when you need more control than the builder provides,
901/// or when you want to define prompts as standalone types.
902///
903/// # Example
904///
905/// ```rust
906/// use std::collections::HashMap;
907/// use tower_mcp::prompt::McpPrompt;
908/// use tower_mcp::protocol::{GetPromptResult, PromptArgument, PromptMessage, PromptRole, Content};
909/// use tower_mcp::error::Result;
910///
911/// struct CodeReviewPrompt;
912///
913/// impl McpPrompt for CodeReviewPrompt {
914///     const NAME: &'static str = "code_review";
915///     const DESCRIPTION: &'static str = "Review code for issues";
916///
917///     fn arguments(&self) -> Vec<PromptArgument> {
918///         vec![
919///             PromptArgument {
920///                 name: "code".to_string(),
921///                 description: Some("The code to review".to_string()),
922///                 required: true,
923///             },
924///             PromptArgument {
925///                 name: "language".to_string(),
926///                 description: Some("Programming language".to_string()),
927///                 required: false,
928///             },
929///         ]
930///     }
931///
932///     async fn get(&self, args: HashMap<String, String>) -> Result<GetPromptResult> {
933///         let code = args.get("code").map(|s| s.as_str()).unwrap_or("");
934///         let lang = args.get("language").map(|s| s.as_str()).unwrap_or("unknown");
935///
936///         Ok(GetPromptResult {
937///             description: Some("Code review prompt".to_string()),
938///             messages: vec![PromptMessage {
939///                 role: PromptRole::User,
940///                 content: Content::Text {
941///                     text: format!("Please review this {} code:\n\n```{}\n{}\n```", lang, lang, code),
942///                     annotations: None,
943///                     meta: None,
944///                 },
945///                 meta: None,
946///             }],
947///             meta: None,
948///         })
949///     }
950/// }
951///
952/// let prompt = CodeReviewPrompt.into_prompt();
953/// assert_eq!(prompt.name, "code_review");
954/// ```
955pub trait McpPrompt: Send + Sync + 'static {
956    /// The prompt name (must be unique within the router).
957    const NAME: &'static str;
958    /// A human-readable description of the prompt.
959    const DESCRIPTION: &'static str;
960
961    /// Define the arguments for this prompt
962    fn arguments(&self) -> Vec<PromptArgument> {
963        Vec::new()
964    }
965
966    /// Generate the prompt messages for the given arguments.
967    fn get(
968        &self,
969        arguments: HashMap<String, String>,
970    ) -> impl Future<Output = Result<GetPromptResult>> + Send;
971
972    /// Convert to a Prompt instance
973    fn into_prompt(self) -> Prompt
974    where
975        Self: Sized,
976    {
977        let arguments = self.arguments();
978        let prompt = Arc::new(self);
979        Prompt {
980            name: Self::NAME.to_string(),
981            title: None,
982            description: Some(Self::DESCRIPTION.to_string()),
983            icons: None,
984            arguments,
985            handler: Arc::new(McpPromptHandler { prompt }),
986        }
987    }
988}
989
990/// Wrapper to make McpPrompt implement PromptHandler
991struct McpPromptHandler<T: McpPrompt> {
992    prompt: Arc<T>,
993}
994
995impl<T: McpPrompt> PromptHandler for McpPromptHandler<T> {
996    fn get(&self, arguments: HashMap<String, String>) -> BoxFuture<'_, Result<GetPromptResult>> {
997        let prompt = self.prompt.clone();
998        Box::pin(async move { prompt.get(arguments).await })
999    }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004    use super::*;
1005
1006    #[tokio::test]
1007    async fn test_builder_prompt() {
1008        let prompt = PromptBuilder::new("greet")
1009            .description("A greeting prompt")
1010            .required_arg("name", "Name to greet")
1011            .handler(|args| async move {
1012                let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1013                Ok(GetPromptResult {
1014                    description: Some("Greeting".to_string()),
1015                    messages: vec![PromptMessage {
1016                        role: PromptRole::User,
1017                        content: Content::Text {
1018                            text: format!("Hello, {}!", name),
1019                            annotations: None,
1020                            meta: None,
1021                        },
1022                        meta: None,
1023                    }],
1024                    meta: None,
1025                })
1026            })
1027            .build();
1028
1029        assert_eq!(prompt.name, "greet");
1030        assert_eq!(prompt.description.as_deref(), Some("A greeting prompt"));
1031        assert_eq!(prompt.arguments.len(), 1);
1032        assert!(prompt.arguments[0].required);
1033
1034        let mut args = HashMap::new();
1035        args.insert("name".to_string(), "Alice".to_string());
1036        let result = prompt.get(args).await.unwrap();
1037
1038        assert_eq!(result.messages.len(), 1);
1039        match &result.messages[0].content {
1040            Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
1041            _ => panic!("Expected text content"),
1042        }
1043    }
1044
1045    #[tokio::test]
1046    async fn test_static_prompt() {
1047        let prompt = PromptBuilder::new("help")
1048            .description("Help prompt")
1049            .user_message("How can I help you today?");
1050
1051        let result = prompt.get(HashMap::new()).await.unwrap();
1052        assert_eq!(result.messages.len(), 1);
1053        match &result.messages[0].content {
1054            Content::Text { text, .. } => assert_eq!(text, "How can I help you today?"),
1055            _ => panic!("Expected text content"),
1056        }
1057    }
1058
1059    #[tokio::test]
1060    async fn test_trait_prompt() {
1061        struct TestPrompt;
1062
1063        impl McpPrompt for TestPrompt {
1064            const NAME: &'static str = "test";
1065            const DESCRIPTION: &'static str = "A test prompt";
1066
1067            fn arguments(&self) -> Vec<PromptArgument> {
1068                vec![PromptArgument {
1069                    name: "input".to_string(),
1070                    description: Some("Test input".to_string()),
1071                    required: true,
1072                }]
1073            }
1074
1075            async fn get(&self, args: HashMap<String, String>) -> Result<GetPromptResult> {
1076                let input = args.get("input").map(|s| s.as_str()).unwrap_or("default");
1077                Ok(GetPromptResult {
1078                    description: Some("Test".to_string()),
1079                    messages: vec![PromptMessage {
1080                        role: PromptRole::User,
1081                        content: Content::Text {
1082                            text: format!("Input: {}", input),
1083                            annotations: None,
1084                            meta: None,
1085                        },
1086                        meta: None,
1087                    }],
1088                    meta: None,
1089                })
1090            }
1091        }
1092
1093        let prompt = TestPrompt.into_prompt();
1094        assert_eq!(prompt.name, "test");
1095        assert_eq!(prompt.arguments.len(), 1);
1096
1097        let mut args = HashMap::new();
1098        args.insert("input".to_string(), "hello".to_string());
1099        let result = prompt.get(args).await.unwrap();
1100
1101        match &result.messages[0].content {
1102            Content::Text { text, .. } => assert_eq!(text, "Input: hello"),
1103            _ => panic!("Expected text content"),
1104        }
1105    }
1106
1107    #[test]
1108    fn test_prompt_definition() {
1109        let prompt = PromptBuilder::new("test")
1110            .description("Test description")
1111            .required_arg("arg1", "First arg")
1112            .optional_arg("arg2", "Second arg")
1113            .user_message("Test");
1114
1115        let def = prompt.definition();
1116        assert_eq!(def.name, "test");
1117        assert_eq!(def.description.as_deref(), Some("Test description"));
1118        assert_eq!(def.arguments.len(), 2);
1119        assert!(def.arguments[0].required);
1120        assert!(!def.arguments[1].required);
1121    }
1122
1123    #[tokio::test]
1124    async fn test_handler_with_context() {
1125        let prompt = PromptBuilder::new("context_prompt")
1126            .description("A prompt with context")
1127            .handler_with_context(|ctx: RequestContext, args| async move {
1128                // Verify we have access to the context
1129                let _ = ctx.is_cancelled();
1130                let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1131                Ok(GetPromptResult {
1132                    description: Some("Context prompt".to_string()),
1133                    messages: vec![PromptMessage {
1134                        role: PromptRole::User,
1135                        content: Content::Text {
1136                            text: format!("Hello, {}!", name),
1137                            annotations: None,
1138                            meta: None,
1139                        },
1140                        meta: None,
1141                    }],
1142                    meta: None,
1143                })
1144            })
1145            .build();
1146
1147        assert_eq!(prompt.name, "context_prompt");
1148        assert!(prompt.uses_context());
1149
1150        let ctx = RequestContext::new(RequestId::Number(1));
1151        let mut args = HashMap::new();
1152        args.insert("name".to_string(), "Alice".to_string());
1153        let result = prompt.get_with_context(ctx, args).await.unwrap();
1154
1155        match &result.messages[0].content {
1156            Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
1157            _ => panic!("Expected text content"),
1158        }
1159    }
1160
1161    #[tokio::test]
1162    async fn test_prompt_with_timeout_layer() {
1163        use std::time::Duration;
1164        use tower::timeout::TimeoutLayer;
1165
1166        let prompt = PromptBuilder::new("timeout_prompt")
1167            .description("A prompt with timeout")
1168            .handler(|args: HashMap<String, String>| async move {
1169                let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1170                Ok(GetPromptResult {
1171                    description: Some("Timeout prompt".to_string()),
1172                    messages: vec![PromptMessage {
1173                        role: PromptRole::User,
1174                        content: Content::Text {
1175                            text: format!("Hello, {}!", name),
1176                            annotations: None,
1177                            meta: None,
1178                        },
1179                        meta: None,
1180                    }],
1181                    meta: None,
1182                })
1183            })
1184            .layer(TimeoutLayer::new(Duration::from_secs(5)));
1185
1186        assert_eq!(prompt.name, "timeout_prompt");
1187
1188        let mut args = HashMap::new();
1189        args.insert("name".to_string(), "Alice".to_string());
1190        let result = prompt.get(args).await.unwrap();
1191
1192        match &result.messages[0].content {
1193            Content::Text { text, .. } => assert_eq!(text, "Hello, Alice!"),
1194            _ => panic!("Expected text content"),
1195        }
1196    }
1197
1198    #[tokio::test]
1199    async fn test_prompt_timeout_expires() {
1200        use std::time::Duration;
1201        use tower::timeout::TimeoutLayer;
1202
1203        let prompt = PromptBuilder::new("slow_prompt")
1204            .description("A slow prompt")
1205            .handler(|_args: HashMap<String, String>| async move {
1206                // Sleep much longer than timeout to ensure timeout fires reliably in CI
1207                tokio::time::sleep(Duration::from_secs(1)).await;
1208                Ok(GetPromptResult {
1209                    description: Some("Slow prompt".to_string()),
1210                    messages: vec![PromptMessage {
1211                        role: PromptRole::User,
1212                        content: Content::Text {
1213                            text: "This should not appear".to_string(),
1214                            annotations: None,
1215                            meta: None,
1216                        },
1217                        meta: None,
1218                    }],
1219                    meta: None,
1220                })
1221            })
1222            .layer(TimeoutLayer::new(Duration::from_millis(50)));
1223
1224        let result = prompt.get(HashMap::new()).await.unwrap();
1225
1226        // Should get an error message due to timeout
1227        assert!(result.description.as_ref().unwrap().contains("error"));
1228        match &result.messages[0].content {
1229            Content::Text { text, .. } => {
1230                assert!(text.contains("Error generating prompt"));
1231            }
1232            _ => panic!("Expected text content"),
1233        }
1234    }
1235
1236    #[tokio::test]
1237    async fn test_context_handler_with_layer() {
1238        use std::time::Duration;
1239        use tower::timeout::TimeoutLayer;
1240
1241        let prompt = PromptBuilder::new("context_timeout")
1242            .description("Context prompt with timeout")
1243            .handler_with_context(
1244                |_ctx: RequestContext, args: HashMap<String, String>| async move {
1245                    let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
1246                    Ok(GetPromptResult {
1247                        description: Some("Context timeout".to_string()),
1248                        messages: vec![PromptMessage {
1249                            role: PromptRole::User,
1250                            content: Content::Text {
1251                                text: format!("Hello, {}!", name),
1252                                annotations: None,
1253                                meta: None,
1254                            },
1255                            meta: None,
1256                        }],
1257                        meta: None,
1258                    })
1259                },
1260            )
1261            .layer(TimeoutLayer::new(Duration::from_secs(5)));
1262
1263        assert_eq!(prompt.name, "context_timeout");
1264        assert!(prompt.uses_context());
1265
1266        let ctx = RequestContext::new(RequestId::Number(1));
1267        let mut args = HashMap::new();
1268        args.insert("name".to_string(), "Bob".to_string());
1269        let result = prompt.get_with_context(ctx, args).await.unwrap();
1270
1271        match &result.messages[0].content {
1272            Content::Text { text, .. } => assert_eq!(text, "Hello, Bob!"),
1273            _ => panic!("Expected text content"),
1274        }
1275    }
1276
1277    #[test]
1278    fn test_prompt_request_construction() {
1279        let args: HashMap<String, String> = [("key".to_string(), "value".to_string())]
1280            .into_iter()
1281            .collect();
1282
1283        let req = PromptRequest::with_arguments(args.clone());
1284        assert_eq!(req.arguments.get("key"), Some(&"value".to_string()));
1285
1286        let ctx = RequestContext::new(RequestId::Number(42));
1287        let req2 = PromptRequest::new(ctx, args);
1288        assert_eq!(req2.arguments.get("key"), Some(&"value".to_string()));
1289    }
1290
1291    #[test]
1292    fn test_prompt_catch_error_clone() {
1293        // Just verify the type can be constructed and cloned
1294        let handler = PromptHandlerService {
1295            handler: |_args: HashMap<String, String>| async {
1296                Ok::<GetPromptResult, Error>(GetPromptResult {
1297                    description: None,
1298                    messages: vec![],
1299                    meta: None,
1300                })
1301            },
1302        };
1303        let catch_error = PromptCatchError::new(handler);
1304        let _clone = catch_error.clone();
1305        // PromptCatchError with PromptHandlerService doesn't implement Debug
1306        // because the handler function doesn't implement Debug
1307    }
1308}