llm_toolkit/
prompt.rs

1//! A trait and macros for powerful, type-safe prompt generation.
2
3use minijinja::Environment;
4use serde::Serialize;
5
6/// Represents a part of a multimodal prompt.
7///
8/// This enum allows prompts to contain different types of content,
9/// such as text and images, enabling multimodal LLM interactions.
10#[derive(Debug, Clone)]
11pub enum PromptPart {
12    /// Text content in the prompt.
13    Text(String),
14    /// Image content with media type and binary data.
15    Image {
16        /// The MIME media type (e.g., "image/jpeg", "image/png").
17        media_type: String,
18        /// The raw image data.
19        data: Vec<u8>,
20    },
21    // Future variants like Audio or Video can be added here
22}
23
24/// A trait for converting any type into a string suitable for an LLM prompt.
25///
26/// This trait provides a standard interface for converting various types
27/// into strings that can be used as prompts for language models.
28///
29/// # Example
30///
31/// ```
32/// use llm_toolkit::prompt::ToPrompt;
33///
34/// // Common types have ToPrompt implementations
35/// let number = 42;
36/// assert_eq!(number.to_prompt(), "42");
37///
38/// let text = "Hello, LLM!";
39/// assert_eq!(text.to_prompt(), "Hello, LLM!");
40/// ```
41///
42/// # Custom Implementation
43///
44/// You can also implement `ToPrompt` directly for your own types:
45///
46/// ```
47/// use llm_toolkit::prompt::{ToPrompt, PromptPart};
48/// use std::fmt;
49///
50/// struct CustomType {
51///     value: String,
52/// }
53///
54/// impl fmt::Display for CustomType {
55///     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56///         write!(f, "{}", self.value)
57///     }
58/// }
59///
60/// // By implementing ToPrompt directly, you can control the conversion.
61/// impl ToPrompt for CustomType {
62///     fn to_prompt_parts(&self) -> Vec<PromptPart> {
63///         vec![PromptPart::Text(self.to_string())]
64///     }
65///
66///     fn to_prompt(&self) -> String {
67///         self.to_string()
68///     }
69/// }
70///
71/// let custom = CustomType { value: "custom".to_string() };
72/// assert_eq!(custom.to_prompt(), "custom");
73/// ```
74pub trait ToPrompt {
75    /// Converts the object into a vector of `PromptPart`s based on a mode.
76    ///
77    /// This is the core method that `derive(ToPrompt)` will implement.
78    /// The `mode` argument allows for different prompt representations, such as:
79    /// - "full": A comprehensive prompt with schema and examples.
80    /// - "schema_only": Just the data structure's schema.
81    /// - "example_only": Just a concrete example.
82    ///
83    /// The default implementation ignores the mode and calls `to_prompt_parts`
84    /// for backward compatibility with manual implementations.
85    fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<PromptPart> {
86        // Default implementation for backward compatibility
87        let _ = mode; // Unused in default impl
88        self.to_prompt_parts()
89    }
90
91    /// Converts the object into a prompt string based on a mode.
92    ///
93    /// This method extracts only the text portions from `to_prompt_parts_with_mode()`.
94    fn to_prompt_with_mode(&self, mode: &str) -> String {
95        self.to_prompt_parts_with_mode(mode)
96            .iter()
97            .filter_map(|part| match part {
98                PromptPart::Text(text) => Some(text.as_str()),
99                _ => None,
100            })
101            .collect::<Vec<_>>()
102            .join("")
103    }
104
105    /// Converts the object into a vector of `PromptPart`s using the default "full" mode.
106    ///
107    /// This method enables multimodal prompt generation by returning
108    /// a collection of prompt parts that can include text, images, and
109    /// other media types.
110    fn to_prompt_parts(&self) -> Vec<PromptPart> {
111        self.to_prompt_parts_with_mode("full")
112    }
113
114    /// Converts the object into a prompt string using the default "full" mode.
115    ///
116    /// This method provides backward compatibility by extracting only
117    /// the text portions from `to_prompt_parts()` and joining them.
118    fn to_prompt(&self) -> String {
119        self.to_prompt_with_mode("full")
120    }
121
122    /// Returns a schema-level prompt for the type itself.
123    ///
124    /// For enums, this returns all possible variants with their descriptions.
125    /// For structs, this returns the field schema.
126    ///
127    /// Unlike instance methods like `to_prompt()`, this is a type-level method
128    /// that doesn't require an instance.
129    ///
130    /// # Examples
131    ///
132    /// ```ignore
133    /// // Enum: get all variants
134    /// let schema = MyEnum::prompt_schema();
135    ///
136    /// // Struct: get field schema
137    /// let schema = MyStruct::prompt_schema();
138    /// ```
139    fn prompt_schema() -> String {
140        String::new() // Default implementation returns empty string
141    }
142}
143
144// Add implementations for common types
145
146impl ToPrompt for String {
147    fn to_prompt_parts(&self) -> Vec<PromptPart> {
148        vec![PromptPart::Text(self.clone())]
149    }
150
151    fn to_prompt(&self) -> String {
152        self.clone()
153    }
154}
155
156impl ToPrompt for &str {
157    fn to_prompt_parts(&self) -> Vec<PromptPart> {
158        vec![PromptPart::Text(self.to_string())]
159    }
160
161    fn to_prompt(&self) -> String {
162        self.to_string()
163    }
164}
165
166impl ToPrompt for bool {
167    fn to_prompt_parts(&self) -> Vec<PromptPart> {
168        vec![PromptPart::Text(self.to_string())]
169    }
170
171    fn to_prompt(&self) -> String {
172        self.to_string()
173    }
174}
175
176impl ToPrompt for char {
177    fn to_prompt_parts(&self) -> Vec<PromptPart> {
178        vec![PromptPart::Text(self.to_string())]
179    }
180
181    fn to_prompt(&self) -> String {
182        self.to_string()
183    }
184}
185
186macro_rules! impl_to_prompt_for_numbers {
187    ($($t:ty),*) => {
188        $(
189            impl ToPrompt for $t {
190                fn to_prompt_parts(&self) -> Vec<PromptPart> {
191                    vec![PromptPart::Text(self.to_string())]
192                }
193
194                fn to_prompt(&self) -> String {
195                    self.to_string()
196                }
197            }
198        )*
199    };
200}
201
202impl_to_prompt_for_numbers!(
203    i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64
204);
205
206// Implement ToPrompt for Vec<T> where T: ToPrompt
207impl<T: ToPrompt> ToPrompt for Vec<T> {
208    fn to_prompt_parts(&self) -> Vec<PromptPart> {
209        vec![PromptPart::Text(self.to_prompt())]
210    }
211
212    fn to_prompt(&self) -> String {
213        format!(
214            "[{}]",
215            self.iter()
216                .map(|item| item.to_prompt())
217                .collect::<Vec<_>>()
218                .join(", ")
219        )
220    }
221}
222
223// Implement ToPrompt for Option<T> where T: ToPrompt
224impl<T: ToPrompt> ToPrompt for Option<T> {
225    fn to_prompt_parts(&self) -> Vec<PromptPart> {
226        vec![PromptPart::Text(self.to_prompt())]
227    }
228
229    fn to_prompt(&self) -> String {
230        match self {
231            Some(value) => value.to_prompt(),
232            None => String::new(),
233        }
234    }
235}
236
237/// Renders a prompt from a template string and a serializable context.
238///
239/// This is the underlying function for the `prompt!` macro.
240pub fn render_prompt<T: Serialize>(template: &str, context: T) -> Result<String, minijinja::Error> {
241    let mut env = Environment::new();
242    env.add_template("prompt", template)?;
243    let tmpl = env.get_template("prompt")?;
244    tmpl.render(context)
245}
246
247/// Creates a prompt string from a template and key-value pairs.
248///
249/// This macro provides a `println!`-like experience for building prompts
250/// from various data sources. It leverages `minijinja` for templating.
251///
252/// # Example
253///
254/// ```
255/// use llm_toolkit::prompt;
256/// use serde::Serialize;
257///
258/// #[derive(Serialize)]
259/// struct User {
260///     name: &'static str,
261///     role: &'static str,
262/// }
263///
264/// let user = User { name: "Mai", role: "UX Engineer" };
265/// let task = "designing a new macro";
266///
267/// let p = prompt!(
268///     "User {{user.name}} ({{user.role}}) is currently {{task}}.",
269///     user = user,
270///     task = task
271/// ).unwrap();
272///
273/// assert_eq!(p, "User Mai (UX Engineer) is currently designing a new macro.");
274/// ```
275#[macro_export]
276macro_rules! prompt {
277    ($template:expr, $($key:ident = $value:expr),* $(,)?) => {
278        $crate::prompt::render_prompt($template, minijinja::context!($($key => $value),*))
279    };
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use serde::Serialize;
286    use std::fmt::Display;
287
288    enum TestEnum {
289        VariantA,
290        VariantB,
291    }
292
293    impl Display for TestEnum {
294        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295            match self {
296                TestEnum::VariantA => write!(f, "Variant A"),
297                TestEnum::VariantB => write!(f, "Variant B"),
298            }
299        }
300    }
301
302    impl ToPrompt for TestEnum {
303        fn to_prompt_parts(&self) -> Vec<PromptPart> {
304            vec![PromptPart::Text(self.to_string())]
305        }
306
307        fn to_prompt(&self) -> String {
308            self.to_string()
309        }
310    }
311
312    #[test]
313    fn test_to_prompt_for_enum() {
314        let variant = TestEnum::VariantA;
315        assert_eq!(variant.to_prompt(), "Variant A");
316    }
317
318    #[test]
319    fn test_to_prompt_for_enum_variant_b() {
320        let variant = TestEnum::VariantB;
321        assert_eq!(variant.to_prompt(), "Variant B");
322    }
323
324    #[test]
325    fn test_to_prompt_for_string() {
326        let s = "hello world";
327        assert_eq!(s.to_prompt(), "hello world");
328    }
329
330    #[test]
331    fn test_to_prompt_for_number() {
332        let n = 42;
333        assert_eq!(n.to_prompt(), "42");
334    }
335
336    #[test]
337    fn test_to_prompt_for_option_some() {
338        let opt: Option<String> = Some("hello".to_string());
339        assert_eq!(opt.to_prompt(), "hello");
340    }
341
342    #[test]
343    fn test_to_prompt_for_option_none() {
344        let opt: Option<String> = None;
345        assert_eq!(opt.to_prompt(), "");
346    }
347
348    #[test]
349    fn test_to_prompt_for_option_number() {
350        let opt_some: Option<i32> = Some(42);
351        assert_eq!(opt_some.to_prompt(), "42");
352
353        let opt_none: Option<i32> = None;
354        assert_eq!(opt_none.to_prompt(), "");
355    }
356
357    #[test]
358    fn test_to_prompt_parts_for_option() {
359        let opt: Option<String> = Some("test".to_string());
360        let parts = opt.to_prompt_parts();
361        assert_eq!(parts.len(), 1);
362        match &parts[0] {
363            PromptPart::Text(text) => assert_eq!(text, "test"),
364            _ => panic!("Expected PromptPart::Text"),
365        }
366    }
367
368    #[derive(Serialize)]
369    struct SystemInfo {
370        version: &'static str,
371        os: &'static str,
372    }
373
374    #[test]
375    fn test_prompt_macro_simple() {
376        let user = "Yui";
377        let task = "implementation";
378        let prompt = prompt!(
379            "User {{user}} is working on the {{task}}.",
380            user = user,
381            task = task
382        )
383        .unwrap();
384        assert_eq!(prompt, "User Yui is working on the implementation.");
385    }
386
387    #[test]
388    fn test_prompt_macro_with_struct() {
389        let sys = SystemInfo {
390            version: "0.1.0",
391            os: "Rust",
392        };
393        let prompt = prompt!("System: {{sys.version}} on {{sys.os}}", sys = sys).unwrap();
394        assert_eq!(prompt, "System: 0.1.0 on Rust");
395    }
396
397    #[test]
398    fn test_prompt_macro_mixed() {
399        let user = "Mai";
400        let sys = SystemInfo {
401            version: "0.1.0",
402            os: "Rust",
403        };
404        let prompt = prompt!(
405            "User {{user}} is using {{sys.os}} v{{sys.version}}.",
406            user = user,
407            sys = sys
408        )
409        .unwrap();
410        assert_eq!(prompt, "User Mai is using Rust v0.1.0.");
411    }
412
413    #[test]
414    fn test_to_prompt_for_vec_of_strings() {
415        let items = vec!["apple", "banana", "cherry"];
416        assert_eq!(items.to_prompt(), "[apple, banana, cherry]");
417    }
418
419    #[test]
420    fn test_to_prompt_for_vec_of_numbers() {
421        let numbers = vec![1, 2, 3, 42];
422        assert_eq!(numbers.to_prompt(), "[1, 2, 3, 42]");
423    }
424
425    #[test]
426    fn test_to_prompt_for_empty_vec() {
427        let empty: Vec<String> = vec![];
428        assert_eq!(empty.to_prompt(), "[]");
429    }
430
431    #[test]
432    fn test_to_prompt_for_nested_vec() {
433        let nested = vec![vec![1, 2], vec![3, 4]];
434        assert_eq!(nested.to_prompt(), "[[1, 2], [3, 4]]");
435    }
436
437    #[test]
438    fn test_to_prompt_parts_for_vec() {
439        let items = vec!["a", "b", "c"];
440        let parts = items.to_prompt_parts();
441        assert_eq!(parts.len(), 1);
442        match &parts[0] {
443            PromptPart::Text(text) => assert_eq!(text, "[a, b, c]"),
444            _ => panic!("Expected Text variant"),
445        }
446    }
447
448    #[test]
449    fn test_to_prompt_for_option_vec() {
450        // Option<Vec<T>>
451        let opt_vec_some: Option<Vec<String>> = Some(vec!["a".to_string(), "b".to_string()]);
452        assert_eq!(opt_vec_some.to_prompt(), "[a, b]");
453
454        let opt_vec_none: Option<Vec<String>> = None;
455        assert_eq!(opt_vec_none.to_prompt(), "");
456    }
457
458    #[test]
459    fn test_to_prompt_for_vec_option() {
460        // Vec<Option<T>>
461        let vec_opts = vec![Some("hello".to_string()), None, Some("world".to_string())];
462        // Each Option is converted: Some("hello") -> "hello", None -> ""
463        assert_eq!(vec_opts.to_prompt(), "[hello, , world]");
464    }
465
466    #[test]
467    fn test_to_prompt_for_option_none_with_parts() {
468        let opt: Option<String> = None;
469        let parts = opt.to_prompt_parts();
470        assert_eq!(parts.len(), 1);
471        match &parts[0] {
472            PromptPart::Text(text) => assert_eq!(text, ""),
473            _ => panic!("Expected PromptPart::Text"),
474        }
475    }
476
477    #[test]
478    fn test_prompt_macro_no_args() {
479        let prompt = prompt!("This is a static prompt.",).unwrap();
480        assert_eq!(prompt, "This is a static prompt.");
481    }
482
483    #[test]
484    fn test_render_prompt_with_json_value_dot_notation() {
485        use serde_json::json;
486
487        let context = json!({
488            "user": {
489                "name": "Alice",
490                "age": 30,
491                "profile": {
492                    "role": "Developer"
493                }
494            }
495        });
496
497        let template =
498            "{{ user.name }} is {{ user.age }} years old and works as {{ user.profile.role }}";
499        let result = render_prompt(template, &context).unwrap();
500
501        assert_eq!(result, "Alice is 30 years old and works as Developer");
502    }
503
504    #[test]
505    fn test_render_prompt_with_hashmap_json_value() {
506        use serde_json::json;
507        use std::collections::HashMap;
508
509        let mut context = HashMap::new();
510        context.insert(
511            "step_1_output".to_string(),
512            json!({
513                "result": "success",
514                "data": {
515                    "count": 42
516                }
517            }),
518        );
519        context.insert("task".to_string(), json!("analysis"));
520
521        let template = "Task: {{ task }}, Result: {{ step_1_output.result }}, Count: {{ step_1_output.data.count }}";
522        let result = render_prompt(template, &context).unwrap();
523
524        assert_eq!(result, "Task: analysis, Result: success, Count: 42");
525    }
526
527    #[test]
528    fn test_render_prompt_with_array_in_json_template() {
529        use serde_json::json;
530        use std::collections::HashMap;
531
532        let mut context = HashMap::new();
533        context.insert(
534            "user_request".to_string(),
535            json!({
536                "narrative_keywords": ["betrayal", "redemption", "sacrifice"]
537            }),
538        );
539
540        // Test: Embedding array directly in JSON template (common pattern in strategy generation)
541        let template = r#"{"keywords": {{ user_request.narrative_keywords }}}"#;
542        let result = render_prompt(template, &context).unwrap();
543
544        // Verify the result is valid JSON
545        let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
546        assert_eq!(parsed["keywords"][0], "betrayal");
547        assert_eq!(parsed["keywords"][1], "redemption");
548        assert_eq!(parsed["keywords"][2], "sacrifice");
549    }
550
551    #[test]
552    fn test_render_prompt_with_object_in_json_template() {
553        use serde_json::json;
554        use std::collections::HashMap;
555
556        let mut context = HashMap::new();
557        context.insert(
558            "user_request".to_string(),
559            json!({
560                "config": {
561                    "theme": "dark_fantasy",
562                    "complexity": 5
563                }
564            }),
565        );
566
567        // Test: Embedding object directly in JSON template
568        let template = r#"{"settings": {{ user_request.config }}}"#;
569        let result = render_prompt(template, &context).unwrap();
570
571        // Verify the result is valid JSON
572        let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
573        assert_eq!(parsed["settings"]["theme"], "dark_fantasy");
574        assert_eq!(parsed["settings"]["complexity"], 5);
575    }
576
577    #[test]
578    fn test_render_prompt_mixed_json_template() {
579        use serde_json::json;
580        use std::collections::HashMap;
581
582        let mut context = HashMap::new();
583        context.insert(
584            "world_concept".to_string(),
585            json!({
586                "concept": "A world where identity is volatile"
587            }),
588        );
589        context.insert(
590            "user_request".to_string(),
591            json!({
592                "narrative_keywords": ["betrayal", "redemption"],
593                "theme": "dark fantasy"
594            }),
595        );
596
597        // Test: Complex case with both array and quoted string (like the actual error case)
598        let template = r#"{"concept": "{{ world_concept.concept }}", "keywords": {{ user_request.narrative_keywords }}, "theme": "{{ user_request.theme }}"}"#;
599        let result = render_prompt(template, &context).unwrap();
600
601        // Verify the result is valid JSON
602        let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
603        assert_eq!(parsed["concept"], "A world where identity is volatile");
604        assert_eq!(parsed["keywords"][0], "betrayal");
605        assert_eq!(parsed["theme"], "dark fantasy");
606    }
607}
608
609#[derive(Debug, thiserror::Error)]
610pub enum PromptSetError {
611    #[error("Target '{target}' not found. Available targets: {available:?}")]
612    TargetNotFound {
613        target: String,
614        available: Vec<String>,
615    },
616    #[error("Failed to render prompt for target '{target}': {source}")]
617    RenderFailed {
618        target: String,
619        source: minijinja::Error,
620    },
621}
622
623/// A trait for types that can generate multiple named prompt targets.
624///
625/// This trait enables a single data structure to produce different prompt formats
626/// for various use cases (e.g., human-readable vs. machine-parsable formats).
627///
628/// # Example
629///
630/// ```ignore
631/// use llm_toolkit::prompt::{ToPromptSet, PromptPart};
632/// use serde::Serialize;
633///
634/// #[derive(ToPromptSet, Serialize)]
635/// #[prompt_for(name = "Visual", template = "## {{title}}\n\n> {{description}}")]
636/// struct Task {
637///     title: String,
638///     description: String,
639///
640///     #[prompt_for(name = "Agent")]
641///     priority: u8,
642///
643///     #[prompt_for(name = "Agent", rename = "internal_id")]
644///     id: u64,
645///
646///     #[prompt_for(skip)]
647///     is_dirty: bool,
648/// }
649///
650/// let task = Task {
651///     title: "Implement feature".to_string(),
652///     description: "Add new functionality".to_string(),
653///     priority: 1,
654///     id: 42,
655///     is_dirty: false,
656/// };
657///
658/// // Generate visual prompt
659/// let visual_prompt = task.to_prompt_for("Visual")?;
660///
661/// // Generate agent prompt
662/// let agent_prompt = task.to_prompt_for("Agent")?;
663/// ```
664pub trait ToPromptSet {
665    /// Generates multimodal prompt parts for the specified target.
666    fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<PromptPart>, PromptSetError>;
667
668    /// Generates a text prompt for the specified target.
669    ///
670    /// This method extracts only the text portions from `to_prompt_parts_for()`
671    /// and joins them together.
672    fn to_prompt_for(&self, target: &str) -> Result<String, PromptSetError> {
673        let parts = self.to_prompt_parts_for(target)?;
674        let text = parts
675            .iter()
676            .filter_map(|part| match part {
677                PromptPart::Text(text) => Some(text.as_str()),
678                _ => None,
679            })
680            .collect::<Vec<_>>()
681            .join("\n");
682        Ok(text)
683    }
684}
685
686/// A trait for generating a prompt for a specific target type.
687///
688/// This allows a type (e.g., a `Tool`) to define how it should be represented
689/// in a prompt when provided with a target context (e.g., an `Agent`).
690pub trait ToPromptFor<T> {
691    /// Generates a prompt for the given target, using a specific mode.
692    fn to_prompt_for_with_mode(&self, target: &T, mode: &str) -> String;
693
694    /// Generates a prompt for the given target using the default "full" mode.
695    ///
696    /// This method provides backward compatibility by calling the `_with_mode`
697    /// variant with a default mode.
698    fn to_prompt_for(&self, target: &T) -> String {
699        self.to_prompt_for_with_mode(target, "full")
700    }
701}