mcpkit_server/capability/
prompts.rs1use crate::context::Context;
7use crate::handler::PromptHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::prompt::{GetPromptResult, Prompt, PromptArgument, PromptMessage};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::pin::Pin;
14
15pub type BoxedPromptFn = Box<
17 dyn for<'a> Fn(
18 Option<Value>,
19 &'a Context<'a>,
20 ) -> Pin<Box<dyn Future<Output = Result<GetPromptResult, McpError>> + Send + 'a>>
21 + Send
22 + Sync,
23>;
24
25pub struct RegisteredPrompt {
27 pub prompt: Prompt,
29 pub handler: BoxedPromptFn,
31}
32
33pub struct PromptService {
38 prompts: HashMap<String, RegisteredPrompt>,
39}
40
41impl Default for PromptService {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl PromptService {
48 pub fn new() -> Self {
50 Self {
51 prompts: HashMap::new(),
52 }
53 }
54
55 pub fn register<F, Fut>(&mut self, prompt: Prompt, handler: F)
57 where
58 F: Fn(Option<Value>, &Context<'_>) -> Fut + Send + Sync + 'static,
59 Fut: Future<Output = Result<GetPromptResult, McpError>> + Send + 'static,
60 {
61 let name = prompt.name.clone();
62 let boxed: BoxedPromptFn = Box::new(move |args, ctx| Box::pin(handler(args, ctx)));
63 self.prompts.insert(
64 name,
65 RegisteredPrompt {
66 prompt,
67 handler: boxed,
68 },
69 );
70 }
71
72 pub fn get(&self, name: &str) -> Option<&RegisteredPrompt> {
74 self.prompts.get(name)
75 }
76
77 pub fn contains(&self, name: &str) -> bool {
79 self.prompts.contains_key(name)
80 }
81
82 pub fn list(&self) -> Vec<&Prompt> {
84 self.prompts.values().map(|r| &r.prompt).collect()
85 }
86
87 pub fn len(&self) -> usize {
89 self.prompts.len()
90 }
91
92 pub fn is_empty(&self) -> bool {
94 self.prompts.is_empty()
95 }
96
97 pub async fn render(
99 &self,
100 name: &str,
101 arguments: Option<Value>,
102 ctx: &Context<'_>,
103 ) -> Result<GetPromptResult, McpError> {
104 let registered = self.prompts.get(name).ok_or_else(|| {
105 McpError::invalid_params("prompts/get", format!("Unknown prompt: {name}"))
106 })?;
107
108 (registered.handler)(arguments, ctx).await
109 }
110}
111
112impl PromptHandler for PromptService {
113 async fn list_prompts(&self, _ctx: &Context<'_>) -> Result<Vec<Prompt>, McpError> {
114 Ok(self.list().into_iter().cloned().collect())
115 }
116
117 async fn get_prompt(
118 &self,
119 name: &str,
120 arguments: Option<serde_json::Map<String, Value>>,
121 ctx: &Context<'_>,
122 ) -> Result<GetPromptResult, McpError> {
123 let args = arguments.map(Value::Object);
124 self.render(name, args, ctx).await
125 }
126}
127
128pub struct PromptBuilder {
130 name: String,
131 description: Option<String>,
132 arguments: Vec<PromptArgument>,
133}
134
135impl PromptBuilder {
136 pub fn new(name: impl Into<String>) -> Self {
138 Self {
139 name: name.into(),
140 description: None,
141 arguments: Vec::new(),
142 }
143 }
144
145 pub fn description(mut self, desc: impl Into<String>) -> Self {
147 self.description = Some(desc.into());
148 self
149 }
150
151 pub fn required_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
153 self.arguments.push(PromptArgument {
154 name: name.into(),
155 description: Some(description.into()),
156 required: Some(true),
157 });
158 self
159 }
160
161 pub fn optional_arg(mut self, name: impl Into<String>, description: impl Into<String>) -> Self {
163 self.arguments.push(PromptArgument {
164 name: name.into(),
165 description: Some(description.into()),
166 required: Some(false),
167 });
168 self
169 }
170
171 pub fn argument(mut self, arg: PromptArgument) -> Self {
173 self.arguments.push(arg);
174 self
175 }
176
177 pub fn build(self) -> Prompt {
179 Prompt {
180 name: self.name,
181 description: self.description,
182 arguments: if self.arguments.is_empty() {
183 None
184 } else {
185 Some(self.arguments)
186 },
187 }
188 }
189}
190
191pub struct PromptResultBuilder {
193 description: Option<String>,
194 messages: Vec<PromptMessage>,
195}
196
197impl Default for PromptResultBuilder {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203impl PromptResultBuilder {
204 pub fn new() -> Self {
206 Self {
207 description: None,
208 messages: Vec::new(),
209 }
210 }
211
212 pub fn description(mut self, desc: impl Into<String>) -> Self {
214 self.description = Some(desc.into());
215 self
216 }
217
218 pub fn user_text(mut self, text: impl Into<String>) -> Self {
220 self.messages.push(PromptMessage::user(text.into()));
221 self
222 }
223
224 pub fn assistant_text(mut self, text: impl Into<String>) -> Self {
226 self.messages.push(PromptMessage::assistant(text.into()));
227 self
228 }
229
230 pub fn message(mut self, msg: PromptMessage) -> Self {
232 self.messages.push(msg);
233 self
234 }
235
236 pub fn build(self) -> GetPromptResult {
238 GetPromptResult {
239 description: self.description,
240 messages: self.messages,
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use crate::context::{NoOpPeer, Context};
249 use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
250 use mcpkit_core::protocol::RequestId;
251
252 fn make_context() -> (RequestId, ClientCapabilities, ServerCapabilities, NoOpPeer) {
253 (
254 RequestId::Number(1),
255 ClientCapabilities::default(),
256 ServerCapabilities::default(),
257 NoOpPeer,
258 )
259 }
260
261 #[test]
262 fn test_prompt_builder() {
263 let prompt = PromptBuilder::new("code-review")
264 .description("Review code for issues")
265 .required_arg("code", "The code to review")
266 .optional_arg("language", "Programming language")
267 .build();
268
269 assert_eq!(prompt.name, "code-review");
270 assert_eq!(prompt.description.as_deref(), Some("Review code for issues"));
271 assert_eq!(prompt.arguments.as_ref().map(|a| a.len()), Some(2));
272 }
273
274 #[test]
275 fn test_prompt_result_builder() {
276 let result = PromptResultBuilder::new()
277 .description("Generated review")
278 .user_text("Please review this code")
279 .assistant_text("I'll analyze the code...")
280 .build();
281
282 assert_eq!(result.description.as_deref(), Some("Generated review"));
283 assert_eq!(result.messages.len(), 2);
284 }
285
286 #[tokio::test]
287 async fn test_prompt_service() {
288 let mut service = PromptService::new();
289
290 let prompt = PromptBuilder::new("greeting")
291 .description("Generate a greeting")
292 .required_arg("name", "Name to greet")
293 .build();
294
295 service.register(prompt, |args, _ctx| async move {
296 let name = args
297 .and_then(|v| v.get("name").and_then(|n| n.as_str()).map(String::from))
298 .unwrap_or_else(|| "World".to_string());
299
300 Ok(PromptResultBuilder::new()
301 .user_text(format!("Generate a greeting for {name}"))
302 .build())
303 });
304
305 assert!(service.contains("greeting"));
306 assert_eq!(service.len(), 1);
307
308 let (req_id, client_caps, server_caps, peer) = make_context();
309 let ctx = Context::new(&req_id, None, &client_caps, &server_caps, &peer);
310
311 let result = service
312 .render("greeting", Some(serde_json::json!({"name": "Alice"})), &ctx)
313 .await
314 .unwrap();
315
316 assert!(!result.messages.is_empty());
317 }
318}