1use minijinja::Environment;
4use serde::Serialize;
5
6pub trait ToPrompt {
53 fn to_prompt(&self) -> String;
55}
56
57impl ToPrompt for String {
60 fn to_prompt(&self) -> String {
61 self.clone()
62 }
63}
64
65impl ToPrompt for &str {
66 fn to_prompt(&self) -> String {
67 self.to_string()
68 }
69}
70
71impl ToPrompt for bool {
72 fn to_prompt(&self) -> String {
73 self.to_string()
74 }
75}
76
77impl ToPrompt for char {
78 fn to_prompt(&self) -> String {
79 self.to_string()
80 }
81}
82
83macro_rules! impl_to_prompt_for_numbers {
84 ($($t:ty),*) => {
85 $(
86 impl ToPrompt for $t {
87 fn to_prompt(&self) -> String {
88 self.to_string()
89 }
90 }
91 )*
92 };
93}
94
95impl_to_prompt_for_numbers!(
96 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64
97);
98
99pub fn render_prompt<T: Serialize>(template: &str, context: T) -> Result<String, minijinja::Error> {
103 let mut env = Environment::new();
104 env.add_template("prompt", template)?;
105 let tmpl = env.get_template("prompt")?;
106 tmpl.render(context)
107}
108
109#[macro_export]
138macro_rules! prompt {
139 ($template:expr, $($key:ident = $value:expr),* $(,)?) => {
140 $crate::prompt::render_prompt($template, minijinja::context!($($key => $value),*))
141 };
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use serde::Serialize;
148 use std::fmt::Display;
149
150 enum TestEnum {
151 VariantA,
152 VariantB,
153 }
154
155 impl Display for TestEnum {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 match self {
158 TestEnum::VariantA => write!(f, "Variant A"),
159 TestEnum::VariantB => write!(f, "Variant B"),
160 }
161 }
162 }
163
164 impl ToPrompt for TestEnum {
165 fn to_prompt(&self) -> String {
166 self.to_string()
167 }
168 }
169
170 #[test]
171 fn test_to_prompt_for_enum() {
172 let variant = TestEnum::VariantA;
173 assert_eq!(variant.to_prompt(), "Variant A");
174 }
175
176 #[test]
177 fn test_to_prompt_for_enum_variant_b() {
178 let variant = TestEnum::VariantB;
179 assert_eq!(variant.to_prompt(), "Variant B");
180 }
181
182 #[test]
183 fn test_to_prompt_for_string() {
184 let s = "hello world";
185 assert_eq!(s.to_prompt(), "hello world");
186 }
187
188 #[test]
189 fn test_to_prompt_for_number() {
190 let n = 42;
191 assert_eq!(n.to_prompt(), "42");
192 }
193
194 #[derive(Serialize)]
195 struct SystemInfo {
196 version: &'static str,
197 os: &'static str,
198 }
199
200 #[test]
201 fn test_prompt_macro_simple() {
202 let user = "Yui";
203 let task = "implementation";
204 let prompt = prompt!(
205 "User {{user}} is working on the {{task}}.",
206 user = user,
207 task = task
208 )
209 .unwrap();
210 assert_eq!(prompt, "User Yui is working on the implementation.");
211 }
212
213 #[test]
214 fn test_prompt_macro_with_struct() {
215 let sys = SystemInfo {
216 version: "0.1.0",
217 os: "Rust",
218 };
219 let prompt = prompt!("System: {{sys.version}} on {{sys.os}}", sys = sys).unwrap();
220 assert_eq!(prompt, "System: 0.1.0 on Rust");
221 }
222
223 #[test]
224 fn test_prompt_macro_mixed() {
225 let user = "Mai";
226 let sys = SystemInfo {
227 version: "0.1.0",
228 os: "Rust",
229 };
230 let prompt = prompt!(
231 "User {{user}} is using {{sys.os}} v{{sys.version}}.",
232 user = user,
233 sys = sys
234 )
235 .unwrap();
236 assert_eq!(prompt, "User Mai is using Rust v0.1.0.");
237 }
238
239 #[test]
240 fn test_prompt_macro_no_args() {
241 let prompt = prompt!("This is a static prompt.",).unwrap();
242 assert_eq!(prompt, "This is a static prompt.");
243 }
244}