mcpkit_server/capability/
prompts.rs1use crate::context::Context;
7use crate::handler::PromptHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::prompt::{GetPromptResult, Prompt, PromptArgument, PromptMessage};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14
15pub type BoxedPromptFn = Box<
17 dyn for<'a> Fn(
18 Option<Value>,
19 &'a Context<'a>,
20 )
21 -> Pin<Box<dyn Future<Output = Result<GetPromptResult, McpError>> + Send + 'a>>
22 + Send
23 + Sync,
24>;
25
26pub struct RegisteredPrompt {
28 pub prompt: Prompt,
30 pub handler: BoxedPromptFn,
32}
33
34pub struct PromptService {
39 prompts: HashMap<String, RegisteredPrompt>,
40}
41
42impl Default for PromptService {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl PromptService {
49 #[must_use]
51 pub fn new() -> Self {
52 Self {
53 prompts: HashMap::new(),
54 }
55 }
56
57 pub fn register<F, Fut>(&mut self, prompt: Prompt, handler: F)
59 where
60 F: Fn(Option<Value>, &Context<'_>) -> Fut + Send + Sync + 'static,
61 Fut: Future<Output = Result<GetPromptResult, McpError>> + Send + 'static,
62 {
63 let name = prompt.name.clone();
64 let boxed: BoxedPromptFn = Box::new(move |args, ctx| Box::pin(handler(args, ctx)));
65 self.prompts.insert(
66 name,
67 RegisteredPrompt {
68 prompt,
69 handler: boxed,
70 },
71 );
72 }
73
74 #[must_use]
76 pub fn get(&self, name: &str) -> Option<&RegisteredPrompt> {
77 self.prompts.get(name)
78 }
79
80 #[must_use]
82 pub fn contains(&self, name: &str) -> bool {
83 self.prompts.contains_key(name)
84 }
85
86 #[must_use]
88 pub fn list(&self) -> Vec<&Prompt> {
89 self.prompts.values().map(|r| &r.prompt).collect()
90 }
91
92 #[must_use]
94 pub fn len(&self) -> usize {
95 self.prompts.len()
96 }
97
98 #[must_use]
100 pub fn is_empty(&self) -> bool {
101 self.prompts.is_empty()
102 }
103
104 pub async fn render(
106 &self,
107 name: &str,
108 arguments: Option<Value>,
109 ctx: &Context<'_>,
110 ) -> Result<GetPromptResult, McpError> {
111 let registered = self.prompts.get(name).ok_or_else(|| {
112 McpError::invalid_params("prompts/get", format!("Unknown prompt: {name}"))
113 })?;
114
115 (registered.handler)(arguments, ctx).await
116 }
117}
118
119impl PromptHandler for PromptService {
120 async fn list_prompts(&self, _ctx: &Context<'_>) -> Result<Vec<Prompt>, McpError> {
121 Ok(self.list().into_iter().cloned().collect())
122 }
123
124 async fn get_prompt(
125 &self,
126 name: &str,
127 arguments: Option<serde_json::Map<String, Value>>,
128 ctx: &Context<'_>,
129 ) -> Result<GetPromptResult, McpError> {
130 let args = arguments.map(Value::Object);
131 self.render(name, args, ctx).await
132 }
133}
134
135pub struct PromptBuilder {
137 name: String,
138 description: Option<String>,
139 arguments: Vec<PromptArgument>,
140}
141
142impl PromptBuilder {
143 pub fn new(name: impl Into<String>) -> Self {
145 Self {
146 name: name.into(),
147 description: None,
148 arguments: Vec::new(),
149 }
150 }
151
152 pub fn description(mut self, desc: impl Into<String>) -> Self {
154 self.description = Some(desc.into());
155 self
156 }
157
158 pub fn required_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
160 self.arguments.push(PromptArgument {
161 name: name.into(),
162 description: Some(description.into()),
163 required: Some(true),
164 });
165 self
166 }
167
168 pub fn optional_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
170 self.arguments.push(PromptArgument {
171 name: name.into(),
172 description: Some(description.into()),
173 required: Some(false),
174 });
175 self
176 }
177
178 #[must_use]
180 pub fn argument(mut self, arg: PromptArgument) -> Self {
181 self.arguments.push(arg);
182 self
183 }
184
185 #[must_use]
187 pub fn build(self) -> Prompt {
188 Prompt {
189 name: self.name,
190 description: self.description,
191 arguments: if self.arguments.is_empty() {
192 None
193 } else {
194 Some(self.arguments)
195 },
196 }
197 }
198}
199
200pub struct PromptResultBuilder {
202 description: Option<String>,
203 messages: Vec<PromptMessage>,
204}
205
206impl Default for PromptResultBuilder {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212impl PromptResultBuilder {
213 #[must_use]
215 pub const fn new() -> Self {
216 Self {
217 description: None,
218 messages: Vec::new(),
219 }
220 }
221
222 pub fn description(mut self, desc: impl Into<String>) -> Self {
224 self.description = Some(desc.into());
225 self
226 }
227
228 pub fn user_text(mut self, text: impl Into<String>) -> Self {
230 self.messages.push(PromptMessage::user(text.into()));
231 self
232 }
233
234 pub fn assistant_text(mut self, text: impl Into<String>) -> Self {
236 self.messages.push(PromptMessage::assistant(text.into()));
237 self
238 }
239
240 #[must_use]
242 pub fn message(mut self, msg: PromptMessage) -> Self {
243 self.messages.push(msg);
244 self
245 }
246
247 #[must_use]
249 pub fn build(self) -> GetPromptResult {
250 GetPromptResult {
251 description: self.description,
252 messages: self.messages,
253 }
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::context::{Context, NoOpPeer};
261 use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
262 use mcpkit_core::protocol::RequestId;
263 use mcpkit_core::protocol_version::ProtocolVersion;
264
265 fn make_context() -> (
266 RequestId,
267 ClientCapabilities,
268 ServerCapabilities,
269 ProtocolVersion,
270 NoOpPeer,
271 ) {
272 (
273 RequestId::Number(1),
274 ClientCapabilities::default(),
275 ServerCapabilities::default(),
276 ProtocolVersion::LATEST,
277 NoOpPeer,
278 )
279 }
280
281 #[test]
282 fn test_prompt_builder() {
283 let prompt = PromptBuilder::new("code-review")
284 .description("Review code for issues")
285 .required_arg("code", "The code to review")
286 .optional_arg("language", "Programming language")
287 .build();
288
289 assert_eq!(prompt.name, "code-review");
290 assert_eq!(
291 prompt.description.as_deref(),
292 Some("Review code for issues")
293 );
294 assert_eq!(prompt.arguments.as_ref().map(std::vec::Vec::len), Some(2));
295 }
296
297 #[test]
298 fn test_prompt_result_builder() {
299 let result = PromptResultBuilder::new()
300 .description("Generated review")
301 .user_text("Please review this code")
302 .assistant_text("I'll analyze the code...")
303 .build();
304
305 assert_eq!(result.description.as_deref(), Some("Generated review"));
306 assert_eq!(result.messages.len(), 2);
307 }
308
309 #[tokio::test]
310 async fn test_prompt_service() {
311 let mut service = PromptService::new();
312
313 let prompt = PromptBuilder::new("greeting")
314 .description("Generate a greeting")
315 .required_arg("name", "Name to greet")
316 .build();
317
318 service.register(prompt, |args, _ctx| async move {
319 let name = args
320 .and_then(|v| v.get("name").and_then(|n| n.as_str()).map(String::from))
321 .unwrap_or_else(|| "World".to_string());
322
323 Ok(PromptResultBuilder::new()
324 .user_text(format!("Generate a greeting for {name}"))
325 .build())
326 });
327
328 assert!(service.contains("greeting"));
329 assert_eq!(service.len(), 1);
330
331 let (req_id, client_caps, server_caps, protocol_version, peer) = make_context();
332 let ctx = Context::new(
333 &req_id,
334 None,
335 &client_caps,
336 &server_caps,
337 protocol_version,
338 &peer,
339 );
340
341 let result = service
342 .render("greeting", Some(serde_json::json!({"name": "Alice"})), &ctx)
343 .await
344 .unwrap();
345
346 assert!(!result.messages.is_empty());
347 }
348}