1use minijinja::Environment;
4use serde::Serialize;
5
6#[derive(Debug, Clone)]
11pub enum PromptPart {
12    Text(String),
14    Image {
16        media_type: String,
18        data: Vec<u8>,
20    },
21    }
23
24pub trait ToPrompt {
75    fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<PromptPart> {
86        let _ = mode; self.to_prompt_parts()
89    }
90
91    fn to_prompt_with_mode(&self, mode: &str) -> String {
95        self.to_prompt_parts_with_mode(mode)
96            .iter()
97            .filter_map(|part| match part {
98                PromptPart::Text(text) => Some(text.as_str()),
99                _ => None,
100            })
101            .collect::<Vec<_>>()
102            .join("")
103    }
104
105    fn to_prompt_parts(&self) -> Vec<PromptPart> {
111        self.to_prompt_parts_with_mode("full")
112    }
113
114    fn to_prompt(&self) -> String {
119        self.to_prompt_with_mode("full")
120    }
121
122    fn prompt_schema() -> String {
140        String::new() }
142}
143
144impl ToPrompt for String {
147    fn to_prompt_parts(&self) -> Vec<PromptPart> {
148        vec![PromptPart::Text(self.clone())]
149    }
150
151    fn to_prompt(&self) -> String {
152        self.clone()
153    }
154}
155
156impl ToPrompt for &str {
157    fn to_prompt_parts(&self) -> Vec<PromptPart> {
158        vec![PromptPart::Text(self.to_string())]
159    }
160
161    fn to_prompt(&self) -> String {
162        self.to_string()
163    }
164}
165
166impl ToPrompt for bool {
167    fn to_prompt_parts(&self) -> Vec<PromptPart> {
168        vec![PromptPart::Text(self.to_string())]
169    }
170
171    fn to_prompt(&self) -> String {
172        self.to_string()
173    }
174}
175
176impl ToPrompt for char {
177    fn to_prompt_parts(&self) -> Vec<PromptPart> {
178        vec![PromptPart::Text(self.to_string())]
179    }
180
181    fn to_prompt(&self) -> String {
182        self.to_string()
183    }
184}
185
186macro_rules! impl_to_prompt_for_numbers {
187    ($($t:ty),*) => {
188        $(
189            impl ToPrompt for $t {
190                fn to_prompt_parts(&self) -> Vec<PromptPart> {
191                    vec![PromptPart::Text(self.to_string())]
192                }
193
194                fn to_prompt(&self) -> String {
195                    self.to_string()
196                }
197            }
198        )*
199    };
200}
201
202impl_to_prompt_for_numbers!(
203    i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64
204);
205
206impl<T: ToPrompt> ToPrompt for Vec<T> {
208    fn to_prompt_parts(&self) -> Vec<PromptPart> {
209        vec![PromptPart::Text(self.to_prompt())]
210    }
211
212    fn to_prompt(&self) -> String {
213        format!(
214            "[{}]",
215            self.iter()
216                .map(|item| item.to_prompt())
217                .collect::<Vec<_>>()
218                .join(", ")
219        )
220    }
221}
222
223pub fn render_prompt<T: Serialize>(template: &str, context: T) -> Result<String, minijinja::Error> {
227    let mut env = Environment::new();
228    env.add_template("prompt", template)?;
229    let tmpl = env.get_template("prompt")?;
230    tmpl.render(context)
231}
232
233#[macro_export]
262macro_rules! prompt {
263    ($template:expr, $($key:ident = $value:expr),* $(,)?) => {
264        $crate::prompt::render_prompt($template, minijinja::context!($($key => $value),*))
265    };
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use serde::Serialize;
272    use std::fmt::Display;
273
274    enum TestEnum {
275        VariantA,
276        VariantB,
277    }
278
279    impl Display for TestEnum {
280        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281            match self {
282                TestEnum::VariantA => write!(f, "Variant A"),
283                TestEnum::VariantB => write!(f, "Variant B"),
284            }
285        }
286    }
287
288    impl ToPrompt for TestEnum {
289        fn to_prompt_parts(&self) -> Vec<PromptPart> {
290            vec![PromptPart::Text(self.to_string())]
291        }
292
293        fn to_prompt(&self) -> String {
294            self.to_string()
295        }
296    }
297
298    #[test]
299    fn test_to_prompt_for_enum() {
300        let variant = TestEnum::VariantA;
301        assert_eq!(variant.to_prompt(), "Variant A");
302    }
303
304    #[test]
305    fn test_to_prompt_for_enum_variant_b() {
306        let variant = TestEnum::VariantB;
307        assert_eq!(variant.to_prompt(), "Variant B");
308    }
309
310    #[test]
311    fn test_to_prompt_for_string() {
312        let s = "hello world";
313        assert_eq!(s.to_prompt(), "hello world");
314    }
315
316    #[test]
317    fn test_to_prompt_for_number() {
318        let n = 42;
319        assert_eq!(n.to_prompt(), "42");
320    }
321
322    #[derive(Serialize)]
323    struct SystemInfo {
324        version: &'static str,
325        os: &'static str,
326    }
327
328    #[test]
329    fn test_prompt_macro_simple() {
330        let user = "Yui";
331        let task = "implementation";
332        let prompt = prompt!(
333            "User {{user}} is working on the {{task}}.",
334            user = user,
335            task = task
336        )
337        .unwrap();
338        assert_eq!(prompt, "User Yui is working on the implementation.");
339    }
340
341    #[test]
342    fn test_prompt_macro_with_struct() {
343        let sys = SystemInfo {
344            version: "0.1.0",
345            os: "Rust",
346        };
347        let prompt = prompt!("System: {{sys.version}} on {{sys.os}}", sys = sys).unwrap();
348        assert_eq!(prompt, "System: 0.1.0 on Rust");
349    }
350
351    #[test]
352    fn test_prompt_macro_mixed() {
353        let user = "Mai";
354        let sys = SystemInfo {
355            version: "0.1.0",
356            os: "Rust",
357        };
358        let prompt = prompt!(
359            "User {{user}} is using {{sys.os}} v{{sys.version}}.",
360            user = user,
361            sys = sys
362        )
363        .unwrap();
364        assert_eq!(prompt, "User Mai is using Rust v0.1.0.");
365    }
366
367    #[test]
368    fn test_to_prompt_for_vec_of_strings() {
369        let items = vec!["apple", "banana", "cherry"];
370        assert_eq!(items.to_prompt(), "[apple, banana, cherry]");
371    }
372
373    #[test]
374    fn test_to_prompt_for_vec_of_numbers() {
375        let numbers = vec![1, 2, 3, 42];
376        assert_eq!(numbers.to_prompt(), "[1, 2, 3, 42]");
377    }
378
379    #[test]
380    fn test_to_prompt_for_empty_vec() {
381        let empty: Vec<String> = vec![];
382        assert_eq!(empty.to_prompt(), "[]");
383    }
384
385    #[test]
386    fn test_to_prompt_for_nested_vec() {
387        let nested = vec![vec![1, 2], vec![3, 4]];
388        assert_eq!(nested.to_prompt(), "[[1, 2], [3, 4]]");
389    }
390
391    #[test]
392    fn test_to_prompt_parts_for_vec() {
393        let items = vec!["a", "b", "c"];
394        let parts = items.to_prompt_parts();
395        assert_eq!(parts.len(), 1);
396        match &parts[0] {
397            PromptPart::Text(text) => assert_eq!(text, "[a, b, c]"),
398            _ => panic!("Expected Text variant"),
399        }
400    }
401
402    #[test]
403    fn test_prompt_macro_no_args() {
404        let prompt = prompt!("This is a static prompt.",).unwrap();
405        assert_eq!(prompt, "This is a static prompt.");
406    }
407
408    #[test]
409    fn test_render_prompt_with_json_value_dot_notation() {
410        use serde_json::json;
411
412        let context = json!({
413            "user": {
414                "name": "Alice",
415                "age": 30,
416                "profile": {
417                    "role": "Developer"
418                }
419            }
420        });
421
422        let template =
423            "{{ user.name }} is {{ user.age }} years old and works as {{ user.profile.role }}";
424        let result = render_prompt(template, &context).unwrap();
425
426        assert_eq!(result, "Alice is 30 years old and works as Developer");
427    }
428
429    #[test]
430    fn test_render_prompt_with_hashmap_json_value() {
431        use serde_json::json;
432        use std::collections::HashMap;
433
434        let mut context = HashMap::new();
435        context.insert(
436            "step_1_output".to_string(),
437            json!({
438                "result": "success",
439                "data": {
440                    "count": 42
441                }
442            }),
443        );
444        context.insert("task".to_string(), json!("analysis"));
445
446        let template = "Task: {{ task }}, Result: {{ step_1_output.result }}, Count: {{ step_1_output.data.count }}";
447        let result = render_prompt(template, &context).unwrap();
448
449        assert_eq!(result, "Task: analysis, Result: success, Count: 42");
450    }
451
452    #[test]
453    fn test_render_prompt_with_array_in_json_template() {
454        use serde_json::json;
455        use std::collections::HashMap;
456
457        let mut context = HashMap::new();
458        context.insert(
459            "user_request".to_string(),
460            json!({
461                "narrative_keywords": ["betrayal", "redemption", "sacrifice"]
462            }),
463        );
464
465        let template = r#"{"keywords": {{ user_request.narrative_keywords }}}"#;
467        let result = render_prompt(template, &context).unwrap();
468
469        let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
471        assert_eq!(parsed["keywords"][0], "betrayal");
472        assert_eq!(parsed["keywords"][1], "redemption");
473        assert_eq!(parsed["keywords"][2], "sacrifice");
474    }
475
476    #[test]
477    fn test_render_prompt_with_object_in_json_template() {
478        use serde_json::json;
479        use std::collections::HashMap;
480
481        let mut context = HashMap::new();
482        context.insert(
483            "user_request".to_string(),
484            json!({
485                "config": {
486                    "theme": "dark_fantasy",
487                    "complexity": 5
488                }
489            }),
490        );
491
492        let template = r#"{"settings": {{ user_request.config }}}"#;
494        let result = render_prompt(template, &context).unwrap();
495
496        let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
498        assert_eq!(parsed["settings"]["theme"], "dark_fantasy");
499        assert_eq!(parsed["settings"]["complexity"], 5);
500    }
501
502    #[test]
503    fn test_render_prompt_mixed_json_template() {
504        use serde_json::json;
505        use std::collections::HashMap;
506
507        let mut context = HashMap::new();
508        context.insert(
509            "world_concept".to_string(),
510            json!({
511                "concept": "A world where identity is volatile"
512            }),
513        );
514        context.insert(
515            "user_request".to_string(),
516            json!({
517                "narrative_keywords": ["betrayal", "redemption"],
518                "theme": "dark fantasy"
519            }),
520        );
521
522        let template = r#"{"concept": "{{ world_concept.concept }}", "keywords": {{ user_request.narrative_keywords }}, "theme": "{{ user_request.theme }}"}"#;
524        let result = render_prompt(template, &context).unwrap();
525
526        let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
528        assert_eq!(parsed["concept"], "A world where identity is volatile");
529        assert_eq!(parsed["keywords"][0], "betrayal");
530        assert_eq!(parsed["theme"], "dark fantasy");
531    }
532}
533
534#[derive(Debug, thiserror::Error)]
535pub enum PromptSetError {
536    #[error("Target '{target}' not found. Available targets: {available:?}")]
537    TargetNotFound {
538        target: String,
539        available: Vec<String>,
540    },
541    #[error("Failed to render prompt for target '{target}': {source}")]
542    RenderFailed {
543        target: String,
544        source: minijinja::Error,
545    },
546}
547
548pub trait ToPromptSet {
590    fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<PromptPart>, PromptSetError>;
592
593    fn to_prompt_for(&self, target: &str) -> Result<String, PromptSetError> {
598        let parts = self.to_prompt_parts_for(target)?;
599        let text = parts
600            .iter()
601            .filter_map(|part| match part {
602                PromptPart::Text(text) => Some(text.as_str()),
603                _ => None,
604            })
605            .collect::<Vec<_>>()
606            .join("\n");
607        Ok(text)
608    }
609}
610
611pub trait ToPromptFor<T> {
616    fn to_prompt_for_with_mode(&self, target: &T, mode: &str) -> String;
618
619    fn to_prompt_for(&self, target: &T) -> String {
624        self.to_prompt_for_with_mode(target, "full")
625    }
626}