1use minijinja::Environment;
4use serde::Serialize;
5
6#[derive(Debug, Clone)]
11pub enum PromptPart {
12 Text(String),
14 Image {
16 media_type: String,
18 data: Vec<u8>,
20 },
21 }
23
24pub trait ToPrompt {
75 fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<PromptPart> {
86 let _ = mode; self.to_prompt_parts()
89 }
90
91 fn to_prompt_with_mode(&self, mode: &str) -> String {
95 self.to_prompt_parts_with_mode(mode)
96 .iter()
97 .filter_map(|part| match part {
98 PromptPart::Text(text) => Some(text.as_str()),
99 _ => None,
100 })
101 .collect::<Vec<_>>()
102 .join("")
103 }
104
105 fn to_prompt_parts(&self) -> Vec<PromptPart> {
111 self.to_prompt_parts_with_mode("full")
112 }
113
114 fn to_prompt(&self) -> String {
119 self.to_prompt_with_mode("full")
120 }
121
122 fn prompt_schema() -> String {
140 String::new() }
142}
143
144impl ToPrompt for String {
147 fn to_prompt_parts(&self) -> Vec<PromptPart> {
148 vec![PromptPart::Text(self.clone())]
149 }
150
151 fn to_prompt(&self) -> String {
152 self.clone()
153 }
154}
155
156impl ToPrompt for &str {
157 fn to_prompt_parts(&self) -> Vec<PromptPart> {
158 vec![PromptPart::Text(self.to_string())]
159 }
160
161 fn to_prompt(&self) -> String {
162 self.to_string()
163 }
164}
165
166impl ToPrompt for bool {
167 fn to_prompt_parts(&self) -> Vec<PromptPart> {
168 vec![PromptPart::Text(self.to_string())]
169 }
170
171 fn to_prompt(&self) -> String {
172 self.to_string()
173 }
174}
175
176impl ToPrompt for char {
177 fn to_prompt_parts(&self) -> Vec<PromptPart> {
178 vec![PromptPart::Text(self.to_string())]
179 }
180
181 fn to_prompt(&self) -> String {
182 self.to_string()
183 }
184}
185
186macro_rules! impl_to_prompt_for_numbers {
187 ($($t:ty),*) => {
188 $(
189 impl ToPrompt for $t {
190 fn to_prompt_parts(&self) -> Vec<PromptPart> {
191 vec![PromptPart::Text(self.to_string())]
192 }
193
194 fn to_prompt(&self) -> String {
195 self.to_string()
196 }
197 }
198 )*
199 };
200}
201
202impl_to_prompt_for_numbers!(
203 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64
204);
205
206impl<T: ToPrompt> ToPrompt for Vec<T> {
208 fn to_prompt_parts(&self) -> Vec<PromptPart> {
209 vec![PromptPart::Text(self.to_prompt())]
210 }
211
212 fn to_prompt(&self) -> String {
213 format!(
214 "[{}]",
215 self.iter()
216 .map(|item| item.to_prompt())
217 .collect::<Vec<_>>()
218 .join(", ")
219 )
220 }
221}
222
223impl<T: ToPrompt> ToPrompt for Option<T> {
225 fn to_prompt_parts(&self) -> Vec<PromptPart> {
226 vec![PromptPart::Text(self.to_prompt())]
227 }
228
229 fn to_prompt(&self) -> String {
230 match self {
231 Some(value) => value.to_prompt(),
232 None => String::new(),
233 }
234 }
235}
236
237pub fn render_prompt<T: Serialize>(template: &str, context: T) -> Result<String, minijinja::Error> {
241 let mut env = Environment::new();
242 env.add_template("prompt", template)?;
243 let tmpl = env.get_template("prompt")?;
244 tmpl.render(context)
245}
246
247#[macro_export]
276macro_rules! prompt {
277 ($template:expr, $($key:ident = $value:expr),* $(,)?) => {
278 $crate::prompt::render_prompt($template, minijinja::context!($($key => $value),*))
279 };
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use serde::Serialize;
286 use std::fmt::Display;
287
288 enum TestEnum {
289 VariantA,
290 VariantB,
291 }
292
293 impl Display for TestEnum {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 match self {
296 TestEnum::VariantA => write!(f, "Variant A"),
297 TestEnum::VariantB => write!(f, "Variant B"),
298 }
299 }
300 }
301
302 impl ToPrompt for TestEnum {
303 fn to_prompt_parts(&self) -> Vec<PromptPart> {
304 vec![PromptPart::Text(self.to_string())]
305 }
306
307 fn to_prompt(&self) -> String {
308 self.to_string()
309 }
310 }
311
312 #[test]
313 fn test_to_prompt_for_enum() {
314 let variant = TestEnum::VariantA;
315 assert_eq!(variant.to_prompt(), "Variant A");
316 }
317
318 #[test]
319 fn test_to_prompt_for_enum_variant_b() {
320 let variant = TestEnum::VariantB;
321 assert_eq!(variant.to_prompt(), "Variant B");
322 }
323
324 #[test]
325 fn test_to_prompt_for_string() {
326 let s = "hello world";
327 assert_eq!(s.to_prompt(), "hello world");
328 }
329
330 #[test]
331 fn test_to_prompt_for_number() {
332 let n = 42;
333 assert_eq!(n.to_prompt(), "42");
334 }
335
336 #[test]
337 fn test_to_prompt_for_option_some() {
338 let opt: Option<String> = Some("hello".to_string());
339 assert_eq!(opt.to_prompt(), "hello");
340 }
341
342 #[test]
343 fn test_to_prompt_for_option_none() {
344 let opt: Option<String> = None;
345 assert_eq!(opt.to_prompt(), "");
346 }
347
348 #[test]
349 fn test_to_prompt_for_option_number() {
350 let opt_some: Option<i32> = Some(42);
351 assert_eq!(opt_some.to_prompt(), "42");
352
353 let opt_none: Option<i32> = None;
354 assert_eq!(opt_none.to_prompt(), "");
355 }
356
357 #[test]
358 fn test_to_prompt_parts_for_option() {
359 let opt: Option<String> = Some("test".to_string());
360 let parts = opt.to_prompt_parts();
361 assert_eq!(parts.len(), 1);
362 match &parts[0] {
363 PromptPart::Text(text) => assert_eq!(text, "test"),
364 _ => panic!("Expected PromptPart::Text"),
365 }
366 }
367
368 #[derive(Serialize)]
369 struct SystemInfo {
370 version: &'static str,
371 os: &'static str,
372 }
373
374 #[test]
375 fn test_prompt_macro_simple() {
376 let user = "Yui";
377 let task = "implementation";
378 let prompt = prompt!(
379 "User {{user}} is working on the {{task}}.",
380 user = user,
381 task = task
382 )
383 .unwrap();
384 assert_eq!(prompt, "User Yui is working on the implementation.");
385 }
386
387 #[test]
388 fn test_prompt_macro_with_struct() {
389 let sys = SystemInfo {
390 version: "0.1.0",
391 os: "Rust",
392 };
393 let prompt = prompt!("System: {{sys.version}} on {{sys.os}}", sys = sys).unwrap();
394 assert_eq!(prompt, "System: 0.1.0 on Rust");
395 }
396
397 #[test]
398 fn test_prompt_macro_mixed() {
399 let user = "Mai";
400 let sys = SystemInfo {
401 version: "0.1.0",
402 os: "Rust",
403 };
404 let prompt = prompt!(
405 "User {{user}} is using {{sys.os}} v{{sys.version}}.",
406 user = user,
407 sys = sys
408 )
409 .unwrap();
410 assert_eq!(prompt, "User Mai is using Rust v0.1.0.");
411 }
412
413 #[test]
414 fn test_to_prompt_for_vec_of_strings() {
415 let items = vec!["apple", "banana", "cherry"];
416 assert_eq!(items.to_prompt(), "[apple, banana, cherry]");
417 }
418
419 #[test]
420 fn test_to_prompt_for_vec_of_numbers() {
421 let numbers = vec![1, 2, 3, 42];
422 assert_eq!(numbers.to_prompt(), "[1, 2, 3, 42]");
423 }
424
425 #[test]
426 fn test_to_prompt_for_empty_vec() {
427 let empty: Vec<String> = vec![];
428 assert_eq!(empty.to_prompt(), "[]");
429 }
430
431 #[test]
432 fn test_to_prompt_for_nested_vec() {
433 let nested = vec![vec![1, 2], vec![3, 4]];
434 assert_eq!(nested.to_prompt(), "[[1, 2], [3, 4]]");
435 }
436
437 #[test]
438 fn test_to_prompt_parts_for_vec() {
439 let items = vec!["a", "b", "c"];
440 let parts = items.to_prompt_parts();
441 assert_eq!(parts.len(), 1);
442 match &parts[0] {
443 PromptPart::Text(text) => assert_eq!(text, "[a, b, c]"),
444 _ => panic!("Expected Text variant"),
445 }
446 }
447
448 #[test]
449 fn test_to_prompt_for_option_vec() {
450 let opt_vec_some: Option<Vec<String>> = Some(vec!["a".to_string(), "b".to_string()]);
452 assert_eq!(opt_vec_some.to_prompt(), "[a, b]");
453
454 let opt_vec_none: Option<Vec<String>> = None;
455 assert_eq!(opt_vec_none.to_prompt(), "");
456 }
457
458 #[test]
459 fn test_to_prompt_for_vec_option() {
460 let vec_opts = vec![Some("hello".to_string()), None, Some("world".to_string())];
462 assert_eq!(vec_opts.to_prompt(), "[hello, , world]");
464 }
465
466 #[test]
467 fn test_to_prompt_for_option_none_with_parts() {
468 let opt: Option<String> = None;
469 let parts = opt.to_prompt_parts();
470 assert_eq!(parts.len(), 1);
471 match &parts[0] {
472 PromptPart::Text(text) => assert_eq!(text, ""),
473 _ => panic!("Expected PromptPart::Text"),
474 }
475 }
476
477 #[test]
478 fn test_prompt_macro_no_args() {
479 let prompt = prompt!("This is a static prompt.",).unwrap();
480 assert_eq!(prompt, "This is a static prompt.");
481 }
482
483 #[test]
484 fn test_render_prompt_with_json_value_dot_notation() {
485 use serde_json::json;
486
487 let context = json!({
488 "user": {
489 "name": "Alice",
490 "age": 30,
491 "profile": {
492 "role": "Developer"
493 }
494 }
495 });
496
497 let template =
498 "{{ user.name }} is {{ user.age }} years old and works as {{ user.profile.role }}";
499 let result = render_prompt(template, &context).unwrap();
500
501 assert_eq!(result, "Alice is 30 years old and works as Developer");
502 }
503
504 #[test]
505 fn test_render_prompt_with_hashmap_json_value() {
506 use serde_json::json;
507 use std::collections::HashMap;
508
509 let mut context = HashMap::new();
510 context.insert(
511 "step_1_output".to_string(),
512 json!({
513 "result": "success",
514 "data": {
515 "count": 42
516 }
517 }),
518 );
519 context.insert("task".to_string(), json!("analysis"));
520
521 let template = "Task: {{ task }}, Result: {{ step_1_output.result }}, Count: {{ step_1_output.data.count }}";
522 let result = render_prompt(template, &context).unwrap();
523
524 assert_eq!(result, "Task: analysis, Result: success, Count: 42");
525 }
526
527 #[test]
528 fn test_render_prompt_with_array_in_json_template() {
529 use serde_json::json;
530 use std::collections::HashMap;
531
532 let mut context = HashMap::new();
533 context.insert(
534 "user_request".to_string(),
535 json!({
536 "narrative_keywords": ["betrayal", "redemption", "sacrifice"]
537 }),
538 );
539
540 let template = r#"{"keywords": {{ user_request.narrative_keywords }}}"#;
542 let result = render_prompt(template, &context).unwrap();
543
544 let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
546 assert_eq!(parsed["keywords"][0], "betrayal");
547 assert_eq!(parsed["keywords"][1], "redemption");
548 assert_eq!(parsed["keywords"][2], "sacrifice");
549 }
550
551 #[test]
552 fn test_render_prompt_with_object_in_json_template() {
553 use serde_json::json;
554 use std::collections::HashMap;
555
556 let mut context = HashMap::new();
557 context.insert(
558 "user_request".to_string(),
559 json!({
560 "config": {
561 "theme": "dark_fantasy",
562 "complexity": 5
563 }
564 }),
565 );
566
567 let template = r#"{"settings": {{ user_request.config }}}"#;
569 let result = render_prompt(template, &context).unwrap();
570
571 let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
573 assert_eq!(parsed["settings"]["theme"], "dark_fantasy");
574 assert_eq!(parsed["settings"]["complexity"], 5);
575 }
576
577 #[test]
578 fn test_render_prompt_mixed_json_template() {
579 use serde_json::json;
580 use std::collections::HashMap;
581
582 let mut context = HashMap::new();
583 context.insert(
584 "world_concept".to_string(),
585 json!({
586 "concept": "A world where identity is volatile"
587 }),
588 );
589 context.insert(
590 "user_request".to_string(),
591 json!({
592 "narrative_keywords": ["betrayal", "redemption"],
593 "theme": "dark fantasy"
594 }),
595 );
596
597 let template = r#"{"concept": "{{ world_concept.concept }}", "keywords": {{ user_request.narrative_keywords }}, "theme": "{{ user_request.theme }}"}"#;
599 let result = render_prompt(template, &context).unwrap();
600
601 let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
603 assert_eq!(parsed["concept"], "A world where identity is volatile");
604 assert_eq!(parsed["keywords"][0], "betrayal");
605 assert_eq!(parsed["theme"], "dark fantasy");
606 }
607}
608
609#[derive(Debug, thiserror::Error)]
610pub enum PromptSetError {
611 #[error("Target '{target}' not found. Available targets: {available:?}")]
612 TargetNotFound {
613 target: String,
614 available: Vec<String>,
615 },
616 #[error("Failed to render prompt for target '{target}': {source}")]
617 RenderFailed {
618 target: String,
619 source: minijinja::Error,
620 },
621}
622
623pub trait ToPromptSet {
665 fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<PromptPart>, PromptSetError>;
667
668 fn to_prompt_for(&self, target: &str) -> Result<String, PromptSetError> {
673 let parts = self.to_prompt_parts_for(target)?;
674 let text = parts
675 .iter()
676 .filter_map(|part| match part {
677 PromptPart::Text(text) => Some(text.as_str()),
678 _ => None,
679 })
680 .collect::<Vec<_>>()
681 .join("\n");
682 Ok(text)
683 }
684}
685
686pub trait ToPromptFor<T> {
691 fn to_prompt_for_with_mode(&self, target: &T, mode: &str) -> String;
693
694 fn to_prompt_for(&self, target: &T) -> String {
699 self.to_prompt_for_with_mode(target, "full")
700 }
701}