mcpkit_server/capability/
prompts.rs1use crate::context::Context;
7use crate::handler::PromptHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::prompt::{GetPromptResult, Prompt, PromptArgument, PromptMessage};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14
15pub type BoxedPromptFn = Box<
17 dyn for<'a> Fn(
18 Option<Value>,
19 &'a Context<'a>,
20 )
21 -> Pin<Box<dyn Future<Output = Result<GetPromptResult, McpError>> + Send + 'a>>
22 + Send
23 + Sync,
24>;
25
26pub struct RegisteredPrompt {
28 pub prompt: Prompt,
30 pub handler: BoxedPromptFn,
32}
33
34pub struct PromptService {
39 prompts: HashMap<String, RegisteredPrompt>,
40}
41
42impl Default for PromptService {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl PromptService {
49 #[must_use]
51 pub fn new() -> Self {
52 Self {
53 prompts: HashMap::new(),
54 }
55 }
56
57 pub fn register<F, Fut>(&mut self, prompt: Prompt, handler: F)
59 where
60 F: Fn(Option<Value>, &Context<'_>) -> Fut + Send + Sync + 'static,
61 Fut: Future<Output = Result<GetPromptResult, McpError>> + Send + 'static,
62 {
63 let name = prompt.name.clone();
64 let boxed: BoxedPromptFn = Box::new(move |args, ctx| Box::pin(handler(args, ctx)));
65 self.prompts.insert(
66 name,
67 RegisteredPrompt {
68 prompt,
69 handler: boxed,
70 },
71 );
72 }
73
74 #[must_use]
76 pub fn get(&self, name: &str) -> Option<&RegisteredPrompt> {
77 self.prompts.get(name)
78 }
79
80 #[must_use]
82 pub fn contains(&self, name: &str) -> bool {
83 self.prompts.contains_key(name)
84 }
85
86 #[must_use]
88 pub fn list(&self) -> Vec<&Prompt> {
89 self.prompts.values().map(|r| &r.prompt).collect()
90 }
91
92 #[must_use]
94 pub fn len(&self) -> usize {
95 self.prompts.len()
96 }
97
98 #[must_use]
100 pub fn is_empty(&self) -> bool {
101 self.prompts.is_empty()
102 }
103
104 pub async fn render(
106 &self,
107 name: &str,
108 arguments: Option<Value>,
109 ctx: &Context<'_>,
110 ) -> Result<GetPromptResult, McpError> {
111 let registered = self.prompts.get(name).ok_or_else(|| {
112 McpError::invalid_params("prompts/get", format!("Unknown prompt: {name}"))
113 })?;
114
115 (registered.handler)(arguments, ctx).await
116 }
117}
118
119impl PromptHandler for PromptService {
120 async fn list_prompts(&self, _ctx: &Context<'_>) -> Result<Vec<Prompt>, McpError> {
121 Ok(self.list().into_iter().cloned().collect())
122 }
123
124 async fn get_prompt(
125 &self,
126 name: &str,
127 arguments: Option<serde_json::Map<String, Value>>,
128 ctx: &Context<'_>,
129 ) -> Result<GetPromptResult, McpError> {
130 let args = arguments.map(Value::Object);
131 self.render(name, args, ctx).await
132 }
133}
134
135pub struct PromptBuilder {
137 name: String,
138 description: Option<String>,
139 arguments: Vec<PromptArgument>,
140}
141
142impl PromptBuilder {
143 pub fn new(name: impl Into<String>) -> Self {
145 Self {
146 name: name.into(),
147 description: None,
148 arguments: Vec::new(),
149 }
150 }
151
152 pub fn description(mut self, desc: impl Into<String>) -> Self {
154 self.description = Some(desc.into());
155 self
156 }
157
158 pub fn required_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
160 self.arguments.push(PromptArgument {
161 name: name.into(),
162 description: Some(description.into()),
163 required: Some(true),
164 });
165 self
166 }
167
168 pub fn optional_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
170 self.arguments.push(PromptArgument {
171 name: name.into(),
172 description: Some(description.into()),
173 required: Some(false),
174 });
175 self
176 }
177
178 #[must_use]
180 pub fn argument(mut self, arg: PromptArgument) -> Self {
181 self.arguments.push(arg);
182 self
183 }
184
185 #[must_use]
187 pub fn build(self) -> Prompt {
188 Prompt {
189 name: self.name,
190 description: self.description,
191 arguments: if self.arguments.is_empty() {
192 None
193 } else {
194 Some(self.arguments)
195 },
196 }
197 }
198}
199
200pub struct PromptResultBuilder {
202 description: Option<String>,
203 messages: Vec<PromptMessage>,
204}
205
206impl Default for PromptResultBuilder {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212impl PromptResultBuilder {
213 #[must_use]
215 pub const fn new() -> Self {
216 Self {
217 description: None,
218 messages: Vec::new(),
219 }
220 }
221
222 pub fn description(mut self, desc: impl Into<String>) -> Self {
224 self.description = Some(desc.into());
225 self
226 }
227
228 pub fn user_text(mut self, text: impl Into<String>) -> Self {
230 self.messages.push(PromptMessage::user(text.into()));
231 self
232 }
233
234 pub fn assistant_text(mut self, text: impl Into<String>) -> Self {
236 self.messages.push(PromptMessage::assistant(text.into()));
237 self
238 }
239
240 #[must_use]
242 pub fn message(mut self, msg: PromptMessage) -> Self {
243 self.messages.push(msg);
244 self
245 }
246
247 #[must_use]
249 pub fn build(self) -> GetPromptResult {
250 GetPromptResult {
251 description: self.description,
252 messages: self.messages,
253 }
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::context::{Context, NoOpPeer};
261 use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
262 use mcpkit_core::protocol::RequestId;
263
264 fn make_context() -> (RequestId, ClientCapabilities, ServerCapabilities, NoOpPeer) {
265 (
266 RequestId::Number(1),
267 ClientCapabilities::default(),
268 ServerCapabilities::default(),
269 NoOpPeer,
270 )
271 }
272
273 #[test]
274 fn test_prompt_builder() {
275 let prompt = PromptBuilder::new("code-review")
276 .description("Review code for issues")
277 .required_arg("code", "The code to review")
278 .optional_arg("language", "Programming language")
279 .build();
280
281 assert_eq!(prompt.name, "code-review");
282 assert_eq!(
283 prompt.description.as_deref(),
284 Some("Review code for issues")
285 );
286 assert_eq!(prompt.arguments.as_ref().map(std::vec::Vec::len), Some(2));
287 }
288
289 #[test]
290 fn test_prompt_result_builder() {
291 let result = PromptResultBuilder::new()
292 .description("Generated review")
293 .user_text("Please review this code")
294 .assistant_text("I'll analyze the code...")
295 .build();
296
297 assert_eq!(result.description.as_deref(), Some("Generated review"));
298 assert_eq!(result.messages.len(), 2);
299 }
300
301 #[tokio::test]
302 async fn test_prompt_service() {
303 let mut service = PromptService::new();
304
305 let prompt = PromptBuilder::new("greeting")
306 .description("Generate a greeting")
307 .required_arg("name", "Name to greet")
308 .build();
309
310 service.register(prompt, |args, _ctx| async move {
311 let name = args
312 .and_then(|v| v.get("name").and_then(|n| n.as_str()).map(String::from))
313 .unwrap_or_else(|| "World".to_string());
314
315 Ok(PromptResultBuilder::new()
316 .user_text(format!("Generate a greeting for {name}"))
317 .build())
318 });
319
320 assert!(service.contains("greeting"));
321 assert_eq!(service.len(), 1);
322
323 let (req_id, client_caps, server_caps, peer) = make_context();
324 let ctx = Context::new(&req_id, None, &client_caps, &server_caps, &peer);
325
326 let result = service
327 .render("greeting", Some(serde_json::json!({"name": "Alice"})), &ctx)
328 .await
329 .unwrap();
330
331 assert!(!result.messages.is_empty());
332 }
333}