1use serde_json::Value;
8use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11
12use crate::traits::{
14 HasPromptAnnotations, HasPromptArguments, HasPromptDescription,
15 HasPromptMeta, HasPromptMetadata,
16};
17use turul_mcp_protocol::prompts::{
19 ContentBlock, GetPromptResult, PromptArgument, PromptMessage,
20};
21
22pub type DynamicPromptFn = Box<
24 dyn Fn(
25 HashMap<String, String>,
26 ) -> Pin<Box<dyn Future<Output = Result<GetPromptResult, String>> + Send>>
27 + Send
28 + Sync,
29>;
30
31pub struct PromptBuilder {
33 name: String,
34 title: Option<String>,
35 description: Option<String>,
36 arguments: Vec<PromptArgument>,
37 messages: Vec<PromptMessage>,
38 meta: Option<HashMap<String, Value>>,
39 get_fn: Option<DynamicPromptFn>,
40}
41
42impl PromptBuilder {
43 pub fn new(name: impl Into<String>) -> Self {
45 Self {
46 name: name.into(),
47 title: None,
48 description: None,
49 arguments: Vec::new(),
50 messages: Vec::new(),
51 meta: None,
52 get_fn: None,
53 }
54 }
55
56 pub fn title(mut self, title: impl Into<String>) -> Self {
58 self.title = Some(title.into());
59 self
60 }
61
62 pub fn description(mut self, description: impl Into<String>) -> Self {
64 self.description = Some(description.into());
65 self
66 }
67
68 pub fn argument(mut self, argument: PromptArgument) -> Self {
70 self.arguments.push(argument);
71 self
72 }
73
74 pub fn string_argument(
76 mut self,
77 name: impl Into<String>,
78 description: impl Into<String>,
79 ) -> Self {
80 let arg = PromptArgument::new(name)
81 .with_description(description)
82 .required();
83 self.arguments.push(arg);
84 self
85 }
86
87 pub fn optional_string_argument(
89 mut self,
90 name: impl Into<String>,
91 description: impl Into<String>,
92 ) -> Self {
93 let arg = PromptArgument::new(name)
94 .with_description(description)
95 .optional();
96 self.arguments.push(arg);
97 self
98 }
99
100 pub fn message(mut self, message: PromptMessage) -> Self {
102 self.messages.push(message);
103 self
104 }
105
106 pub fn system_message(mut self, text: impl Into<String>) -> Self {
108 self.messages
110 .push(PromptMessage::user_text(format!("System: {}", text.into())));
111 self
112 }
113
114 pub fn user_message(mut self, text: impl Into<String>) -> Self {
116 self.messages.push(PromptMessage::user_text(text));
117 self
118 }
119
120 pub fn assistant_message(mut self, text: impl Into<String>) -> Self {
122 self.messages.push(PromptMessage::assistant_text(text));
123 self
124 }
125
126 pub fn user_image(mut self, data: impl Into<String>, mime_type: impl Into<String>) -> Self {
128 self.messages
129 .push(PromptMessage::user_image(data, mime_type));
130 self
131 }
132
133 pub fn template_user_message(mut self, template: impl Into<String>) -> Self {
135 self.messages.push(PromptMessage::user_text(template));
136 self
137 }
138
139 pub fn template_assistant_message(mut self, template: impl Into<String>) -> Self {
141 self.messages.push(PromptMessage::assistant_text(template));
142 self
143 }
144
145 pub fn meta(mut self, meta: HashMap<String, Value>) -> Self {
147 self.meta = Some(meta);
148 self
149 }
150
151 pub fn get<F, Fut>(mut self, f: F) -> Self
153 where
154 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
155 Fut: Future<Output = Result<GetPromptResult, String>> + Send + 'static,
156 {
157 self.get_fn = Some(Box::new(move |args| Box::pin(f(args))));
158 self
159 }
160
161 pub fn build(self) -> Result<DynamicPrompt, String> {
163 let get_fn = if let Some(f) = self.get_fn {
165 f
166 } else {
167 let messages = self.messages.clone();
169 let description = self.description.clone();
170 Box::new(move |args| {
171 let messages = messages.clone();
172 let description = description.clone();
173 Box::pin(async move {
174 let processed_messages = process_template_messages(messages, &args)?;
175 let mut result = GetPromptResult::new(processed_messages);
176 if let Some(desc) = description {
177 result = result.with_description(desc);
178 }
179 Ok(result)
180 })
181 as Pin<Box<dyn Future<Output = Result<GetPromptResult, String>> + Send>>
182 })
183 };
184
185 Ok(DynamicPrompt {
186 name: self.name,
187 title: self.title,
188 description: self.description,
189 arguments: self.arguments,
190 messages: self.messages,
191 meta: self.meta,
192 get_fn,
193 })
194 }
195}
196
197pub struct DynamicPrompt {
199 name: String,
200 title: Option<String>,
201 description: Option<String>,
202 arguments: Vec<PromptArgument>,
203 #[allow(dead_code)]
204 messages: Vec<PromptMessage>,
205 meta: Option<HashMap<String, Value>>,
206 get_fn: DynamicPromptFn,
207}
208
209impl DynamicPrompt {
210 pub async fn get(&self, args: HashMap<String, String>) -> Result<GetPromptResult, String> {
212 (self.get_fn)(args).await
213 }
214}
215
216impl HasPromptMetadata for DynamicPrompt {
218 fn name(&self) -> &str {
219 &self.name
220 }
221
222 fn title(&self) -> Option<&str> {
223 self.title.as_deref()
224 }
225}
226
227impl HasPromptDescription for DynamicPrompt {
228 fn description(&self) -> Option<&str> {
229 self.description.as_deref()
230 }
231}
232
233impl HasPromptArguments for DynamicPrompt {
234 fn arguments(&self) -> Option<&Vec<PromptArgument>> {
235 if self.arguments.is_empty() {
236 None
237 } else {
238 Some(&self.arguments)
239 }
240 }
241}
242
243impl HasPromptAnnotations for DynamicPrompt {
244 fn annotations(&self) -> Option<&turul_mcp_protocol::prompts::PromptAnnotations> {
245 None }
247}
248
249impl HasPromptMeta for DynamicPrompt {
250 fn prompt_meta(&self) -> Option<&HashMap<String, Value>> {
251 self.meta.as_ref()
252 }
253}
254
255fn process_template_messages(
259 messages: Vec<PromptMessage>,
260 args: &HashMap<String, String>,
261) -> Result<Vec<PromptMessage>, String> {
262 let mut processed = Vec::new();
263
264 for message in messages {
265 let processed_message = match message.content {
266 ContentBlock::Text { text, .. } => {
267 let processed_text = process_template_string(&text, args);
268 PromptMessage {
269 role: message.role,
270 content: ContentBlock::text(processed_text),
271 }
272 }
273 other_content => PromptMessage {
275 role: message.role,
276 content: other_content,
277 },
278 };
279 processed.push(processed_message);
280 }
281
282 Ok(processed)
283}
284
285fn process_template_string(template: &str, args: &HashMap<String, String>) -> String {
287 let mut result = template.to_string();
288
289 for (key, value) in args {
290 let placeholder = format!("{{{}}}", key);
291 result = result.replace(&placeholder, value);
292 }
293
294 result
295}
296
297#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_prompt_builder_basic() {
306 let prompt = PromptBuilder::new("greeting_prompt")
307 .title("Greeting Generator")
308 .description("Generate personalized greetings")
309 .string_argument("name", "The person's name")
310 .user_message("Hello {name}! How are you today?")
311 .build()
312 .expect("Failed to build prompt");
313
314 assert_eq!(prompt.name(), "greeting_prompt");
315 assert_eq!(prompt.title(), Some("Greeting Generator"));
316 assert_eq!(
317 prompt.description(),
318 Some("Generate personalized greetings")
319 );
320 assert_eq!(prompt.arguments().unwrap().len(), 1);
321 }
322
323 #[tokio::test]
324 async fn test_prompt_builder_template_processing() {
325 let prompt = PromptBuilder::new("conversation_starter")
326 .description("Start a conversation with someone")
327 .string_argument("name", "Person's name")
328 .optional_string_argument("topic", "Optional conversation topic")
329 .user_message("Hi {name}! Nice to meet you.")
330 .template_assistant_message("Hello! What would you like to talk about?")
331 .user_message("Let's discuss {topic}")
332 .build()
333 .expect("Failed to build prompt");
334
335 let mut args = HashMap::new();
336 args.insert("name".to_string(), "Alice".to_string());
337 args.insert("topic".to_string(), "Rust programming".to_string());
338
339 let result = prompt.get(args).await.expect("Failed to get prompt");
340
341 assert_eq!(result.messages.len(), 3);
342
343 if let ContentBlock::Text { text, .. } = &result.messages[0].content {
345 assert_eq!(text, "Hi Alice! Nice to meet you.");
346 } else {
347 panic!("Expected text content");
348 }
349
350 if let ContentBlock::Text { text, .. } = &result.messages[2].content {
351 assert_eq!(text, "Let's discuss Rust programming");
352 } else {
353 panic!("Expected text content");
354 }
355 }
356
357 #[tokio::test]
358 async fn test_prompt_builder_custom_get_function() {
359 let prompt = PromptBuilder::new("dynamic_prompt")
360 .description("Dynamic prompt with custom logic")
361 .string_argument("mood", "Current mood")
362 .get(|args| async move {
363 let default_mood = "neutral".to_string();
364 let mood = args.get("mood").unwrap_or(&default_mood);
365 let message_text = match mood.as_str() {
366 "happy" => "That's wonderful! Tell me more about what's making you happy.",
367 "sad" => "I'm sorry to hear that. Would you like to talk about it?",
368 _ => "How are you feeling today?",
369 };
370
371 let messages = vec![
372 PromptMessage::user_text(format!("I'm feeling {}", mood)),
373 PromptMessage::assistant_text(message_text),
374 ];
375
376 Ok(GetPromptResult::new(messages).with_description("Mood-based conversation"))
377 })
378 .build()
379 .expect("Failed to build prompt");
380
381 let mut args = HashMap::new();
382 args.insert("mood".to_string(), "happy".to_string());
383
384 let result = prompt.get(args).await.expect("Failed to get prompt");
385
386 assert_eq!(result.messages.len(), 2);
387 assert_eq!(
388 result.description,
389 Some("Mood-based conversation".to_string())
390 );
391
392 if let ContentBlock::Text { text, .. } = &result.messages[1].content {
393 assert!(text.contains("wonderful"));
394 } else {
395 panic!("Expected text content");
396 }
397 }
398
399 #[test]
400 fn test_prompt_builder_arguments() {
401 let prompt = PromptBuilder::new("complex_prompt")
402 .string_argument("required_arg", "This is required")
403 .optional_string_argument("optional_arg", "This is optional")
404 .argument(
405 PromptArgument::new("custom_arg")
406 .with_title("Custom Argument")
407 .with_description("A custom argument")
408 .required(),
409 )
410 .build()
411 .expect("Failed to build prompt");
412
413 let args = prompt.arguments().unwrap();
414 assert_eq!(args.len(), 3);
415 assert_eq!(args[0].name, "required_arg");
416 assert_eq!(args[0].required, Some(true));
417 assert_eq!(args[1].name, "optional_arg");
418 assert_eq!(args[1].required, Some(false));
419 assert_eq!(args[2].name, "custom_arg");
420 assert_eq!(args[2].title, Some("Custom Argument".to_string()));
421 }
422
423 #[test]
424 fn test_template_string_processing() {
425 let template = "Hello {name}, welcome to {place}!";
426 let mut args = HashMap::new();
427 args.insert("name".to_string(), "Alice".to_string());
428 args.insert("place".to_string(), "Wonderland".to_string());
429
430 let result = process_template_string(template, &args);
431 assert_eq!(result, "Hello Alice, welcome to Wonderland!");
432 }
433}