mcpkit_server/capability/
prompts.rs

1//! Prompt capability implementation.
2//!
3//! This module provides utilities for managing and rendering prompts
4//! in an MCP server.
5
6use crate::context::Context;
7use crate::handler::PromptHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::prompt::{GetPromptResult, Prompt, PromptArgument, PromptMessage};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14
15/// A boxed async function for prompt rendering.
16pub type BoxedPromptFn = Box<
17    dyn for<'a> Fn(
18            Option<Value>,
19            &'a Context<'a>,
20        )
21            -> Pin<Box<dyn Future<Output = Result<GetPromptResult, McpError>> + Send + 'a>>
22        + Send
23        + Sync,
24>;
25
26/// A registered prompt with metadata and handler.
27pub struct RegisteredPrompt {
28    /// Prompt metadata.
29    pub prompt: Prompt,
30    /// Handler function for rendering.
31    pub handler: BoxedPromptFn,
32}
33
34/// Service for managing prompts.
35///
36/// This provides a registry for prompts and handles rendering
37/// them with arguments.
38pub struct PromptService {
39    prompts: HashMap<String, RegisteredPrompt>,
40}
41
42impl Default for PromptService {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl PromptService {
49    /// Create a new empty prompt service.
50    #[must_use]
51    pub fn new() -> Self {
52        Self {
53            prompts: HashMap::new(),
54        }
55    }
56
57    /// Register a prompt with a handler function.
58    pub fn register<F, Fut>(&mut self, prompt: Prompt, handler: F)
59    where
60        F: Fn(Option<Value>, &Context<'_>) -> Fut + Send + Sync + 'static,
61        Fut: Future<Output = Result<GetPromptResult, McpError>> + Send + 'static,
62    {
63        let name = prompt.name.clone();
64        let boxed: BoxedPromptFn = Box::new(move |args, ctx| Box::pin(handler(args, ctx)));
65        self.prompts.insert(
66            name,
67            RegisteredPrompt {
68                prompt,
69                handler: boxed,
70            },
71        );
72    }
73
74    /// Get a prompt by name.
75    #[must_use]
76    pub fn get(&self, name: &str) -> Option<&RegisteredPrompt> {
77        self.prompts.get(name)
78    }
79
80    /// Check if a prompt exists.
81    #[must_use]
82    pub fn contains(&self, name: &str) -> bool {
83        self.prompts.contains_key(name)
84    }
85
86    /// List all registered prompts.
87    #[must_use]
88    pub fn list(&self) -> Vec<&Prompt> {
89        self.prompts.values().map(|r| &r.prompt).collect()
90    }
91
92    /// Get the number of registered prompts.
93    #[must_use]
94    pub fn len(&self) -> usize {
95        self.prompts.len()
96    }
97
98    /// Check if the service has no prompts.
99    #[must_use]
100    pub fn is_empty(&self) -> bool {
101        self.prompts.is_empty()
102    }
103
104    /// Render a prompt by name with arguments.
105    pub async fn render(
106        &self,
107        name: &str,
108        arguments: Option<Value>,
109        ctx: &Context<'_>,
110    ) -> Result<GetPromptResult, McpError> {
111        let registered = self.prompts.get(name).ok_or_else(|| {
112            McpError::invalid_params("prompts/get", format!("Unknown prompt: {name}"))
113        })?;
114
115        (registered.handler)(arguments, ctx).await
116    }
117}
118
119impl PromptHandler for PromptService {
120    async fn list_prompts(&self, _ctx: &Context<'_>) -> Result<Vec<Prompt>, McpError> {
121        Ok(self.list().into_iter().cloned().collect())
122    }
123
124    async fn get_prompt(
125        &self,
126        name: &str,
127        arguments: Option<serde_json::Map<String, Value>>,
128        ctx: &Context<'_>,
129    ) -> Result<GetPromptResult, McpError> {
130        let args = arguments.map(Value::Object);
131        self.render(name, args, ctx).await
132    }
133}
134
135/// Builder for creating prompts with a fluent API.
136pub struct PromptBuilder {
137    name: String,
138    description: Option<String>,
139    arguments: Vec<PromptArgument>,
140}
141
142impl PromptBuilder {
143    /// Create a new prompt builder.
144    pub fn new(name: impl Into<String>) -> Self {
145        Self {
146            name: name.into(),
147            description: None,
148            arguments: Vec::new(),
149        }
150    }
151
152    /// Set the prompt description.
153    pub fn description(mut self, desc: impl Into<String>) -> Self {
154        self.description = Some(desc.into());
155        self
156    }
157
158    /// Add a required argument.
159    pub fn required_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
160        self.arguments.push(PromptArgument {
161            name: name.into(),
162            description: Some(description.into()),
163            required: Some(true),
164        });
165        self
166    }
167
168    /// Add an optional argument.
169    pub fn optional_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
170        self.arguments.push(PromptArgument {
171            name: name.into(),
172            description: Some(description.into()),
173            required: Some(false),
174        });
175        self
176    }
177
178    /// Add a custom argument.
179    #[must_use]
180    pub fn argument(mut self, arg: PromptArgument) -> Self {
181        self.arguments.push(arg);
182        self
183    }
184
185    /// Build the prompt.
186    #[must_use]
187    pub fn build(self) -> Prompt {
188        Prompt {
189            name: self.name,
190            description: self.description,
191            arguments: if self.arguments.is_empty() {
192                None
193            } else {
194                Some(self.arguments)
195            },
196        }
197    }
198}
199
200/// Builder for creating prompt results.
201pub struct PromptResultBuilder {
202    description: Option<String>,
203    messages: Vec<PromptMessage>,
204}
205
206impl Default for PromptResultBuilder {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212impl PromptResultBuilder {
213    /// Create a new result builder.
214    #[must_use]
215    pub const fn new() -> Self {
216        Self {
217            description: None,
218            messages: Vec::new(),
219        }
220    }
221
222    /// Set the result description.
223    pub fn description(mut self, desc: impl Into<String>) -> Self {
224        self.description = Some(desc.into());
225        self
226    }
227
228    /// Add a user message with text content.
229    pub fn user_text(mut self, text: impl Into<String>) -> Self {
230        self.messages.push(PromptMessage::user(text.into()));
231        self
232    }
233
234    /// Add an assistant message with text content.
235    pub fn assistant_text(mut self, text: impl Into<String>) -> Self {
236        self.messages.push(PromptMessage::assistant(text.into()));
237        self
238    }
239
240    /// Add a custom message.
241    #[must_use]
242    pub fn message(mut self, msg: PromptMessage) -> Self {
243        self.messages.push(msg);
244        self
245    }
246
247    /// Build the result.
248    #[must_use]
249    pub fn build(self) -> GetPromptResult {
250        GetPromptResult {
251            description: self.description,
252            messages: self.messages,
253        }
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::context::{Context, NoOpPeer};
261    use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
262    use mcpkit_core::protocol::RequestId;
263    use mcpkit_core::protocol_version::ProtocolVersion;
264
265    fn make_context() -> (
266        RequestId,
267        ClientCapabilities,
268        ServerCapabilities,
269        ProtocolVersion,
270        NoOpPeer,
271    ) {
272        (
273            RequestId::Number(1),
274            ClientCapabilities::default(),
275            ServerCapabilities::default(),
276            ProtocolVersion::LATEST,
277            NoOpPeer,
278        )
279    }
280
281    #[test]
282    fn test_prompt_builder() {
283        let prompt = PromptBuilder::new("code-review")
284            .description("Review code for issues")
285            .required_arg("code", "The code to review")
286            .optional_arg("language", "Programming language")
287            .build();
288
289        assert_eq!(prompt.name, "code-review");
290        assert_eq!(
291            prompt.description.as_deref(),
292            Some("Review code for issues")
293        );
294        assert_eq!(prompt.arguments.as_ref().map(std::vec::Vec::len), Some(2));
295    }
296
297    #[test]
298    fn test_prompt_result_builder() {
299        let result = PromptResultBuilder::new()
300            .description("Generated review")
301            .user_text("Please review this code")
302            .assistant_text("I'll analyze the code...")
303            .build();
304
305        assert_eq!(result.description.as_deref(), Some("Generated review"));
306        assert_eq!(result.messages.len(), 2);
307    }
308
309    #[tokio::test]
310    async fn test_prompt_service() {
311        let mut service = PromptService::new();
312
313        let prompt = PromptBuilder::new("greeting")
314            .description("Generate a greeting")
315            .required_arg("name", "Name to greet")
316            .build();
317
318        service.register(prompt, |args, _ctx| async move {
319            let name = args
320                .and_then(|v| v.get("name").and_then(|n| n.as_str()).map(String::from))
321                .unwrap_or_else(|| "World".to_string());
322
323            Ok(PromptResultBuilder::new()
324                .user_text(format!("Generate a greeting for {name}"))
325                .build())
326        });
327
328        assert!(service.contains("greeting"));
329        assert_eq!(service.len(), 1);
330
331        let (req_id, client_caps, server_caps, protocol_version, peer) = make_context();
332        let ctx = Context::new(
333            &req_id,
334            None,
335            &client_caps,
336            &server_caps,
337            protocol_version,
338            &peer,
339        );
340
341        let result = service
342            .render("greeting", Some(serde_json::json!({"name": "Alice"})), &ctx)
343            .await
344            .unwrap();
345
346        assert!(!result.messages.is_empty());
347    }
348}