mcpkit_server/capability/
prompts.rs

1//! Prompt capability implementation.
2//!
3//! This module provides utilities for managing and rendering prompts
4//! in an MCP server.
5
6use crate::context::Context;
7use crate::handler::PromptHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::prompt::{GetPromptResult, Prompt, PromptArgument, PromptMessage};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14
15/// A boxed async function for prompt rendering.
16pub type BoxedPromptFn = Box<
17    dyn for<'a> Fn(
18            Option<Value>,
19            &'a Context<'a>,
20        ) -> Pin<Box<dyn Future<Output = Result<GetPromptResult, McpError>> + Send + 'a>>
21        + Send
22        + Sync,
23>;
24
25/// A registered prompt with metadata and handler.
26pub struct RegisteredPrompt {
27    /// Prompt metadata.
28    pub prompt: Prompt,
29    /// Handler function for rendering.
30    pub handler: BoxedPromptFn,
31}
32
33/// Service for managing prompts.
34///
35/// This provides a registry for prompts and handles rendering
36/// them with arguments.
37pub struct PromptService {
38    prompts: HashMap<String, RegisteredPrompt>,
39}
40
41impl Default for PromptService {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl PromptService {
48    /// Create a new empty prompt service.
49    pub fn new() -> Self {
50        Self {
51            prompts: HashMap::new(),
52        }
53    }
54
55    /// Register a prompt with a handler function.
56    pub fn register<F, Fut>(&mut self, prompt: Prompt, handler: F)
57    where
58        F: Fn(Option<Value>, &Context<'_>) -> Fut + Send + Sync + 'static,
59        Fut: Future<Output = Result<GetPromptResult, McpError>> + Send + 'static,
60    {
61        let name = prompt.name.clone();
62        let boxed: BoxedPromptFn = Box::new(move |args, ctx| Box::pin(handler(args, ctx)));
63        self.prompts.insert(
64            name,
65            RegisteredPrompt {
66                prompt,
67                handler: boxed,
68            },
69        );
70    }
71
72    /// Get a prompt by name.
73    pub fn get(&self, name: &str) -> Option<&RegisteredPrompt> {
74        self.prompts.get(name)
75    }
76
77    /// Check if a prompt exists.
78    pub fn contains(&self, name: &str) -> bool {
79        self.prompts.contains_key(name)
80    }
81
82    /// List all registered prompts.
83    pub fn list(&self) -> Vec<&Prompt> {
84        self.prompts.values().map(|r| &r.prompt).collect()
85    }
86
87    /// Get the number of registered prompts.
88    pub fn len(&self) -> usize {
89        self.prompts.len()
90    }
91
92    /// Check if the service has no prompts.
93    pub fn is_empty(&self) -> bool {
94        self.prompts.is_empty()
95    }
96
97    /// Render a prompt by name with arguments.
98    pub async fn render(
99        &self,
100        name: &str,
101        arguments: Option<Value>,
102        ctx: &Context<'_>,
103    ) -> Result<GetPromptResult, McpError> {
104        let registered = self.prompts.get(name).ok_or_else(|| {
105            McpError::invalid_params("prompts/get", format!("Unknown prompt: {name}"))
106        })?;
107
108        (registered.handler)(arguments, ctx).await
109    }
110}
111
112impl PromptHandler for PromptService {
113    async fn list_prompts(&self, _ctx: &Context<'_>) -> Result<Vec<Prompt>, McpError> {
114        Ok(self.list().into_iter().cloned().collect())
115    }
116
117    async fn get_prompt(
118        &self,
119        name: &str,
120        arguments: Option<serde_json::Map<String, Value>>,
121        ctx: &Context<'_>,
122    ) -> Result<GetPromptResult, McpError> {
123        let args = arguments.map(Value::Object);
124        self.render(name, args, ctx).await
125    }
126}
127
128/// Builder for creating prompts with a fluent API.
129pub struct PromptBuilder {
130    name: String,
131    description: Option<String>,
132    arguments: Vec<PromptArgument>,
133}
134
135impl PromptBuilder {
136    /// Create a new prompt builder.
137    pub fn new(name: impl Into<String>) -> Self {
138        Self {
139            name: name.into(),
140            description: None,
141            arguments: Vec::new(),
142        }
143    }
144
145    /// Set the prompt description.
146    pub fn description(mut self, desc: impl Into<String>) -> Self {
147        self.description = Some(desc.into());
148        self
149    }
150
151    /// Add a required argument.
152    pub fn required_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
153        self.arguments.push(PromptArgument {
154            name: name.into(),
155            description: Some(description.into()),
156            required: Some(true),
157        });
158        self
159    }
160
161    /// Add an optional argument.
162    pub fn optional_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
163        self.arguments.push(PromptArgument {
164            name: name.into(),
165            description: Some(description.into()),
166            required: Some(false),
167        });
168        self
169    }
170
171    /// Add a custom argument.
172    pub fn argument(mut self, arg: PromptArgument) -> Self {
173        self.arguments.push(arg);
174        self
175    }
176
177    /// Build the prompt.
178    pub fn build(self) -> Prompt {
179        Prompt {
180            name: self.name,
181            description: self.description,
182            arguments: if self.arguments.is_empty() {
183                None
184            } else {
185                Some(self.arguments)
186            },
187        }
188    }
189}
190
191/// Builder for creating prompt results.
192pub struct PromptResultBuilder {
193    description: Option<String>,
194    messages: Vec<PromptMessage>,
195}
196
197impl Default for PromptResultBuilder {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203impl PromptResultBuilder {
204    /// Create a new result builder.
205    pub fn new() -> Self {
206        Self {
207            description: None,
208            messages: Vec::new(),
209        }
210    }
211
212    /// Set the result description.
213    pub fn description(mut self, desc: impl Into<String>) -> Self {
214        self.description = Some(desc.into());
215        self
216    }
217
218    /// Add a user message with text content.
219    pub fn user_text(mut self, text: impl Into<String>) -> Self {
220        self.messages.push(PromptMessage::user(text.into()));
221        self
222    }
223
224    /// Add an assistant message with text content.
225    pub fn assistant_text(mut self, text: impl Into<String>) -> Self {
226        self.messages.push(PromptMessage::assistant(text.into()));
227        self
228    }
229
230    /// Add a custom message.
231    pub fn message(mut self, msg: PromptMessage) -> Self {
232        self.messages.push(msg);
233        self
234    }
235
236    /// Build the result.
237    pub fn build(self) -> GetPromptResult {
238        GetPromptResult {
239            description: self.description,
240            messages: self.messages,
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use crate::context::{NoOpPeer, Context};
249    use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
250    use mcpkit_core::protocol::RequestId;
251
252    fn make_context() -> (RequestId, ClientCapabilities, ServerCapabilities, NoOpPeer) {
253        (
254            RequestId::Number(1),
255            ClientCapabilities::default(),
256            ServerCapabilities::default(),
257            NoOpPeer,
258        )
259    }
260
261    #[test]
262    fn test_prompt_builder() {
263        let prompt = PromptBuilder::new("code-review")
264            .description("Review code for issues")
265            .required_arg("code", "The code to review")
266            .optional_arg("language", "Programming language")
267            .build();
268
269        assert_eq!(prompt.name, "code-review");
270        assert_eq!(prompt.description.as_deref(), Some("Review code for issues"));
271        assert_eq!(prompt.arguments.as_ref().map(|a| a.len()), Some(2));
272    }
273
274    #[test]
275    fn test_prompt_result_builder() {
276        let result = PromptResultBuilder::new()
277            .description("Generated review")
278            .user_text("Please review this code")
279            .assistant_text("I'll analyze the code...")
280            .build();
281
282        assert_eq!(result.description.as_deref(), Some("Generated review"));
283        assert_eq!(result.messages.len(), 2);
284    }
285
286    #[tokio::test]
287    async fn test_prompt_service() {
288        let mut service = PromptService::new();
289
290        let prompt = PromptBuilder::new("greeting")
291            .description("Generate a greeting")
292            .required_arg("name", "Name to greet")
293            .build();
294
295        service.register(prompt, |args, _ctx| async move {
296            let name = args
297                .and_then(|v| v.get("name").and_then(|n| n.as_str()).map(String::from))
298                .unwrap_or_else(|| "World".to_string());
299
300            Ok(PromptResultBuilder::new()
301                .user_text(format!("Generate a greeting for {name}"))
302                .build())
303        });
304
305        assert!(service.contains("greeting"));
306        assert_eq!(service.len(), 1);
307
308        let (req_id, client_caps, server_caps, peer) = make_context();
309        let ctx = Context::new(&req_id, None, &client_caps, &server_caps, &peer);
310
311        let result = service
312            .render("greeting", Some(serde_json::json!({"name": "Alice"})), &ctx)
313            .await
314            .unwrap();
315
316        assert!(!result.messages.is_empty());
317    }
318}