1use minijinja::Environment;
4use serde::Serialize;
5
6#[derive(Debug, Clone)]
11pub enum PromptPart {
12 Text(String),
14 Image {
16 media_type: String,
18 data: Vec<u8>,
20 },
21 }
23
24pub trait ToPrompt {
75 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<PromptPart> {
86 let _ = mode; self.to_prompt_parts()
89 }
90
91 fn to_prompt_with_mode(&self, mode: &str) -> String {
95 self.to_prompt_parts_with_mode(mode)
96 .iter()
97 .filter_map(|part| match part {
98 PromptPart::Text(text) => Some(text.as_str()),
99 _ => None,
100 })
101 .collect::<Vec<_>>()
102 .join("")
103 }
104
105 fn to_prompt_parts(&self) -> Vec<PromptPart> {
111 self.to_prompt_parts_with_mode("full")
112 }
113
114 fn to_prompt(&self) -> String {
119 self.to_prompt_with_mode("full")
120 }
121
122 fn prompt_schema() -> String {
140 String::new() }
142}
143
144impl ToPrompt for String {
147 fn to_prompt_parts(&self) -> Vec<PromptPart> {
148 vec![PromptPart::Text(self.clone())]
149 }
150
151 fn to_prompt(&self) -> String {
152 self.clone()
153 }
154}
155
156impl ToPrompt for &str {
157 fn to_prompt_parts(&self) -> Vec<PromptPart> {
158 vec![PromptPart::Text(self.to_string())]
159 }
160
161 fn to_prompt(&self) -> String {
162 self.to_string()
163 }
164}
165
166impl ToPrompt for bool {
167 fn to_prompt_parts(&self) -> Vec<PromptPart> {
168 vec![PromptPart::Text(self.to_string())]
169 }
170
171 fn to_prompt(&self) -> String {
172 self.to_string()
173 }
174}
175
176impl ToPrompt for char {
177 fn to_prompt_parts(&self) -> Vec<PromptPart> {
178 vec![PromptPart::Text(self.to_string())]
179 }
180
181 fn to_prompt(&self) -> String {
182 self.to_string()
183 }
184}
185
186macro_rules! impl_to_prompt_for_numbers {
187 ($($t:ty),*) => {
188 $(
189 impl ToPrompt for $t {
190 fn to_prompt_parts(&self) -> Vec<PromptPart> {
191 vec![PromptPart::Text(self.to_string())]
192 }
193
194 fn to_prompt(&self) -> String {
195 self.to_string()
196 }
197 }
198 )*
199 };
200}
201
202impl_to_prompt_for_numbers!(
203 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64
204);
205
206impl<T: ToPrompt> ToPrompt for Vec<T> {
208 fn to_prompt_parts(&self) -> Vec<PromptPart> {
209 vec![PromptPart::Text(self.to_prompt())]
210 }
211
212 fn to_prompt(&self) -> String {
213 format!(
214 "[{}]",
215 self.iter()
216 .map(|item| item.to_prompt())
217 .collect::<Vec<_>>()
218 .join(", ")
219 )
220 }
221}
222
223pub fn render_prompt<T: Serialize>(template: &str, context: T) -> Result<String, minijinja::Error> {
227 let mut env = Environment::new();
228 env.add_template("prompt", template)?;
229 let tmpl = env.get_template("prompt")?;
230 tmpl.render(context)
231}
232
233#[macro_export]
262macro_rules! prompt {
263 ($template:expr, $($key:ident = $value:expr),* $(,)?) => {
264 $crate::prompt::render_prompt($template, minijinja::context!($($key => $value),*))
265 };
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use serde::Serialize;
272 use std::fmt::Display;
273
274 enum TestEnum {
275 VariantA,
276 VariantB,
277 }
278
279 impl Display for TestEnum {
280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 match self {
282 TestEnum::VariantA => write!(f, "Variant A"),
283 TestEnum::VariantB => write!(f, "Variant B"),
284 }
285 }
286 }
287
288 impl ToPrompt for TestEnum {
289 fn to_prompt_parts(&self) -> Vec<PromptPart> {
290 vec![PromptPart::Text(self.to_string())]
291 }
292
293 fn to_prompt(&self) -> String {
294 self.to_string()
295 }
296 }
297
298 #[test]
299 fn test_to_prompt_for_enum() {
300 let variant = TestEnum::VariantA;
301 assert_eq!(variant.to_prompt(), "Variant A");
302 }
303
304 #[test]
305 fn test_to_prompt_for_enum_variant_b() {
306 let variant = TestEnum::VariantB;
307 assert_eq!(variant.to_prompt(), "Variant B");
308 }
309
310 #[test]
311 fn test_to_prompt_for_string() {
312 let s = "hello world";
313 assert_eq!(s.to_prompt(), "hello world");
314 }
315
316 #[test]
317 fn test_to_prompt_for_number() {
318 let n = 42;
319 assert_eq!(n.to_prompt(), "42");
320 }
321
322 #[derive(Serialize)]
323 struct SystemInfo {
324 version: &'static str,
325 os: &'static str,
326 }
327
328 #[test]
329 fn test_prompt_macro_simple() {
330 let user = "Yui";
331 let task = "implementation";
332 let prompt = prompt!(
333 "User {{user}} is working on the {{task}}.",
334 user = user,
335 task = task
336 )
337 .unwrap();
338 assert_eq!(prompt, "User Yui is working on the implementation.");
339 }
340
341 #[test]
342 fn test_prompt_macro_with_struct() {
343 let sys = SystemInfo {
344 version: "0.1.0",
345 os: "Rust",
346 };
347 let prompt = prompt!("System: {{sys.version}} on {{sys.os}}", sys = sys).unwrap();
348 assert_eq!(prompt, "System: 0.1.0 on Rust");
349 }
350
351 #[test]
352 fn test_prompt_macro_mixed() {
353 let user = "Mai";
354 let sys = SystemInfo {
355 version: "0.1.0",
356 os: "Rust",
357 };
358 let prompt = prompt!(
359 "User {{user}} is using {{sys.os}} v{{sys.version}}.",
360 user = user,
361 sys = sys
362 )
363 .unwrap();
364 assert_eq!(prompt, "User Mai is using Rust v0.1.0.");
365 }
366
367 #[test]
368 fn test_to_prompt_for_vec_of_strings() {
369 let items = vec!["apple", "banana", "cherry"];
370 assert_eq!(items.to_prompt(), "[apple, banana, cherry]");
371 }
372
373 #[test]
374 fn test_to_prompt_for_vec_of_numbers() {
375 let numbers = vec![1, 2, 3, 42];
376 assert_eq!(numbers.to_prompt(), "[1, 2, 3, 42]");
377 }
378
379 #[test]
380 fn test_to_prompt_for_empty_vec() {
381 let empty: Vec<String> = vec![];
382 assert_eq!(empty.to_prompt(), "[]");
383 }
384
385 #[test]
386 fn test_to_prompt_for_nested_vec() {
387 let nested = vec![vec![1, 2], vec![3, 4]];
388 assert_eq!(nested.to_prompt(), "[[1, 2], [3, 4]]");
389 }
390
391 #[test]
392 fn test_to_prompt_parts_for_vec() {
393 let items = vec!["a", "b", "c"];
394 let parts = items.to_prompt_parts();
395 assert_eq!(parts.len(), 1);
396 match &parts[0] {
397 PromptPart::Text(text) => assert_eq!(text, "[a, b, c]"),
398 _ => panic!("Expected Text variant"),
399 }
400 }
401
402 #[test]
403 fn test_prompt_macro_no_args() {
404 let prompt = prompt!("This is a static prompt.",).unwrap();
405 assert_eq!(prompt, "This is a static prompt.");
406 }
407
408 #[test]
409 fn test_render_prompt_with_json_value_dot_notation() {
410 use serde_json::json;
411
412 let context = json!({
413 "user": {
414 "name": "Alice",
415 "age": 30,
416 "profile": {
417 "role": "Developer"
418 }
419 }
420 });
421
422 let template =
423 "{{ user.name }} is {{ user.age }} years old and works as {{ user.profile.role }}";
424 let result = render_prompt(template, &context).unwrap();
425
426 assert_eq!(result, "Alice is 30 years old and works as Developer");
427 }
428
429 #[test]
430 fn test_render_prompt_with_hashmap_json_value() {
431 use serde_json::json;
432 use std::collections::HashMap;
433
434 let mut context = HashMap::new();
435 context.insert(
436 "step_1_output".to_string(),
437 json!({
438 "result": "success",
439 "data": {
440 "count": 42
441 }
442 }),
443 );
444 context.insert("task".to_string(), json!("analysis"));
445
446 let template = "Task: {{ task }}, Result: {{ step_1_output.result }}, Count: {{ step_1_output.data.count }}";
447 let result = render_prompt(template, &context).unwrap();
448
449 assert_eq!(result, "Task: analysis, Result: success, Count: 42");
450 }
451
452 #[test]
453 fn test_render_prompt_with_array_in_json_template() {
454 use serde_json::json;
455 use std::collections::HashMap;
456
457 let mut context = HashMap::new();
458 context.insert(
459 "user_request".to_string(),
460 json!({
461 "narrative_keywords": ["betrayal", "redemption", "sacrifice"]
462 }),
463 );
464
465 let template = r#"{"keywords": {{ user_request.narrative_keywords }}}"#;
467 let result = render_prompt(template, &context).unwrap();
468
469 let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
471 assert_eq!(parsed["keywords"][0], "betrayal");
472 assert_eq!(parsed["keywords"][1], "redemption");
473 assert_eq!(parsed["keywords"][2], "sacrifice");
474 }
475
476 #[test]
477 fn test_render_prompt_with_object_in_json_template() {
478 use serde_json::json;
479 use std::collections::HashMap;
480
481 let mut context = HashMap::new();
482 context.insert(
483 "user_request".to_string(),
484 json!({
485 "config": {
486 "theme": "dark_fantasy",
487 "complexity": 5
488 }
489 }),
490 );
491
492 let template = r#"{"settings": {{ user_request.config }}}"#;
494 let result = render_prompt(template, &context).unwrap();
495
496 let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
498 assert_eq!(parsed["settings"]["theme"], "dark_fantasy");
499 assert_eq!(parsed["settings"]["complexity"], 5);
500 }
501
502 #[test]
503 fn test_render_prompt_mixed_json_template() {
504 use serde_json::json;
505 use std::collections::HashMap;
506
507 let mut context = HashMap::new();
508 context.insert(
509 "world_concept".to_string(),
510 json!({
511 "concept": "A world where identity is volatile"
512 }),
513 );
514 context.insert(
515 "user_request".to_string(),
516 json!({
517 "narrative_keywords": ["betrayal", "redemption"],
518 "theme": "dark fantasy"
519 }),
520 );
521
522 let template = r#"{"concept": "{{ world_concept.concept }}", "keywords": {{ user_request.narrative_keywords }}, "theme": "{{ user_request.theme }}"}"#;
524 let result = render_prompt(template, &context).unwrap();
525
526 let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
528 assert_eq!(parsed["concept"], "A world where identity is volatile");
529 assert_eq!(parsed["keywords"][0], "betrayal");
530 assert_eq!(parsed["theme"], "dark fantasy");
531 }
532}
533
534#[derive(Debug, thiserror::Error)]
535pub enum PromptSetError {
536 #[error("Target '{target}' not found. Available targets: {available:?}")]
537 TargetNotFound {
538 target: String,
539 available: Vec<String>,
540 },
541 #[error("Failed to render prompt for target '{target}': {source}")]
542 RenderFailed {
543 target: String,
544 source: minijinja::Error,
545 },
546}
547
548pub trait ToPromptSet {
590 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<PromptPart>, PromptSetError>;
592
593 fn to_prompt_for(&self, target: &str) -> Result<String, PromptSetError> {
598 let parts = self.to_prompt_parts_for(target)?;
599 let text = parts
600 .iter()
601 .filter_map(|part| match part {
602 PromptPart::Text(text) => Some(text.as_str()),
603 _ => None,
604 })
605 .collect::<Vec<_>>()
606 .join("\n");
607 Ok(text)
608 }
609}
610
611pub trait ToPromptFor<T> {
616 fn to_prompt_for_with_mode(&self, target: &T, mode: &str) -> String;
618
619 fn to_prompt_for(&self, target: &T) -> String {
624 self.to_prompt_for_with_mode(target, "full")
625 }
626}