promptforge/
template.rs

1use handlebars::Handlebars;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5use crate::formatting::{Formattable, Templatable};
6use crate::placeholder::extract_variables;
7use crate::template_format::{
8    detect_template, merge_vars, validate_template, TemplateError, TemplateFormat,
9};
10
11#[derive(Debug, Serialize, Deserialize, Clone)]
12pub struct Template {
13    template: String,
14    template_format: TemplateFormat,
15    input_variables: Vec<String>,
16    #[serde(skip, default)]
17    handlebars: Option<Handlebars<'static>>,
18    #[serde(skip)]
19    partials: HashMap<String, String>,
20}
21
22impl Template {
23    pub const MUSTACHE_TEMPLATE: &'static str = "mustache_template";
24
25    pub fn new(tmpl: &str) -> Result<Self, TemplateError> {
26        Self::new_with_config(tmpl, None, None)
27    }
28
29    pub fn new_with_config(
30        tmpl: &str,
31        template_format: Option<TemplateFormat>,
32        input_variables: Option<Vec<String>>,
33    ) -> Result<Self, TemplateError> {
34        validate_template(tmpl)?;
35
36        let template_format = template_format
37            .or_else(|| detect_template(tmpl).ok())
38            .ok_or_else(|| {
39                TemplateError::UnsupportedFormat("Unable to detect template format".into())
40            })?;
41        let input_variables = input_variables.unwrap_or_else(|| {
42            extract_variables(tmpl)
43                .into_iter()
44                .map(|var| var.to_string())
45                .collect()
46        });
47
48        let handlebars = if template_format == TemplateFormat::Mustache {
49            let handle = Self::initialize_handlebars(tmpl)?;
50            Some(handle)
51        } else {
52            None
53        };
54
55        Ok(Template {
56            template: tmpl.to_string(),
57            template_format,
58            input_variables,
59            handlebars,
60            partials: HashMap::new(),
61        })
62    }
63
64    pub fn from_template(tmpl: &str) -> Result<Self, TemplateError> {
65        Self::new(tmpl)
66    }
67
68    pub fn partial(&mut self, var: &str, value: &str) -> &mut Self {
69        self.partials.insert(var.to_string(), value.to_string());
70        self
71    }
72
73    pub fn clear_partials(&mut self) -> &mut Self {
74        self.partials.clear();
75        self
76    }
77
78    pub fn partial_vars(&self) -> &HashMap<String, String> {
79        &self.partials
80    }
81
82    fn initialize_handlebars(tmpl: &str) -> Result<Handlebars<'static>, TemplateError> {
83        let mut handlebars = Handlebars::new();
84        handlebars
85            .register_template_string(Self::MUSTACHE_TEMPLATE, tmpl)
86            .map_err(|e| {
87                TemplateError::MalformedTemplate(format!("Failed to register template: {}", e))
88            })?;
89        Ok(handlebars)
90    }
91
92    fn validate_variables(
93        &self,
94        variables: &std::collections::HashMap<&str, &str>,
95    ) -> Result<(), TemplateError> {
96        for var in &self.input_variables {
97            let has_key = variables.contains_key(var.as_str());
98            if !has_key {
99                return Err(TemplateError::MissingVariable(format!(
100                    "Variable '{}' is missing. Expected: {:?}, but received: {:?}",
101                    var,
102                    self.input_variables,
103                    variables.keys().collect::<Vec<_>>()
104                )));
105            }
106        }
107        Ok(())
108    }
109
110    fn format_fmtstring(&self, variables: &HashMap<&str, &str>) -> Result<String, TemplateError> {
111        let mut result = self.template.clone();
112
113        for var in &self.input_variables {
114            let placeholder = format!("{{{}}}", var);
115
116            if let Some(value) = variables.get(var.as_str()) {
117                result = result.replace(&placeholder, value);
118            } else {
119                return Err(TemplateError::MissingVariable(var.clone()));
120            }
121        }
122
123        Ok(result)
124    }
125
126    fn format_mustache(&self, variables: &HashMap<&str, &str>) -> Result<String, TemplateError> {
127        match &self.handlebars {
128            None => Err(TemplateError::UnsupportedFormat(
129                "Handlebars not initialized".to_string(),
130            )),
131            Some(handlebars) => handlebars
132                .render(Self::MUSTACHE_TEMPLATE, variables)
133                .map_err(TemplateError::RuntimeError),
134        }
135    }
136}
137
138impl Formattable for Template {
139    fn format(&self, variables: &HashMap<&str, &str>) -> Result<String, TemplateError> {
140        let merged_variables = merge_vars(&self.partials, variables);
141        self.validate_variables(&merged_variables)?;
142
143        match self.template_format {
144            TemplateFormat::FmtString => self.format_fmtstring(&merged_variables),
145            TemplateFormat::Mustache => self.format_mustache(&merged_variables),
146            TemplateFormat::PlainText => Ok(self.template.clone()),
147        }
148    }
149}
150
151impl Templatable for Template {
152    fn template(&self) -> &str {
153        &self.template
154    }
155
156    fn template_format(&self) -> TemplateFormat {
157        self.template_format.clone()
158    }
159
160    fn input_variables(&self) -> Vec<String> {
161        self.input_variables.clone()
162    }
163}
164
165impl TryFrom<String> for Template {
166    type Error = TemplateError;
167
168    fn try_from(value: String) -> Result<Self, Self::Error> {
169        Template::new(&value)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::vars;
177
178    #[test]
179    fn test_prompt_template_new_success() {
180        let valid_template = "Tell me a {adjective} joke about {content}.";
181        let tmpl = Template::new(valid_template);
182        assert!(tmpl.is_ok());
183        let tmpl = tmpl.unwrap();
184        assert_eq!(tmpl.template, valid_template);
185        assert_eq!(tmpl.template_format, TemplateFormat::FmtString);
186        assert_eq!(tmpl.input_variables, vec!["adjective", "content"]);
187
188        let valid_mustache_template = "Tell me a {{adjective}} joke about {{content}}.";
189        let tmpl = Template::new(valid_mustache_template);
190        assert!(tmpl.is_ok());
191        let tmpl = tmpl.unwrap();
192        assert_eq!(tmpl.template, valid_mustache_template);
193        assert_eq!(tmpl.template_format, TemplateFormat::Mustache);
194        assert_eq!(tmpl.input_variables, vec!["adjective", "content"]);
195
196        let no_placeholder_template = "Tell me a joke.";
197        let tmpl = Template::new(no_placeholder_template);
198        assert!(tmpl.is_ok());
199        let tmpl = tmpl.unwrap();
200        assert_eq!(tmpl.template, no_placeholder_template);
201        assert_eq!(tmpl.template_format, TemplateFormat::PlainText);
202        assert_eq!(tmpl.input_variables.len(), 0);
203    }
204
205    #[test]
206    fn test_prompt_template_new_error() {
207        let mixed_template = "Tell me a {adjective} joke about {{content}}.";
208        let tmpl_err = Template::new(mixed_template).unwrap_err();
209        assert!(matches!(tmpl_err, TemplateError::MalformedTemplate(_)));
210
211        let malformed_fmtstring = "Tell me a {adjective joke about {content}.";
212        let tmpl_err = Template::new(malformed_fmtstring).unwrap_err();
213        assert!(matches!(tmpl_err, TemplateError::MalformedTemplate(_)));
214
215        let malformed_mustache = "Tell me a {{adjective joke about {{content}}.";
216        let tmpl_err = Template::new(malformed_mustache).unwrap_err();
217        assert!(matches!(tmpl_err, TemplateError::MalformedTemplate(_)));
218    }
219
220    #[test]
221    fn test_fmtstring_formatting() {
222        let tmpl = Template::new("Hello, {name}!").unwrap();
223        let variables = &vars!(name = "John");
224        let formatted = tmpl.format(variables).unwrap();
225        assert_eq!(formatted, "Hello, John!");
226
227        let tmpl = Template::new("Hi {name}, you are {age} years old!").unwrap();
228        let variables = &vars!(name = "Alice", age = "30");
229        let formatted = tmpl.format(variables).unwrap();
230        assert_eq!(formatted, "Hi Alice, you are 30 years old!");
231
232        let tmpl = Template::new("Hello World!").unwrap();
233        let variables = &vars!();
234        let formatted = tmpl.format(variables).unwrap();
235        assert_eq!(formatted, "Hello World!");
236
237        let tmpl = Template::new("Goodbye, {name}!").unwrap();
238        let variables = &vars!(name = "John", extra = "data");
239        let formatted = tmpl.format(variables).unwrap();
240        assert_eq!(formatted, "Goodbye, John!");
241
242        let tmpl = Template::new("Goodbye, {name}!").unwrap();
243        let variables = &vars!(wrong_name = "John");
244        let result = tmpl.format(variables);
245        assert!(result.is_err());
246
247        let tmpl = Template::new("Hi {name}, you are {age} years old!").unwrap();
248        let variables = &vars!(name = "Alice");
249        let result = tmpl.format(variables).unwrap_err();
250        assert!(matches!(result, TemplateError::MissingVariable(_)));
251    }
252
253    #[test]
254    fn test_format_mustache_success() {
255        let tmpl = Template::new("Hello, {{name}}!").unwrap();
256        let variables = &vars!(name = "John");
257        let result = tmpl.format(variables).unwrap();
258        assert_eq!(result, "Hello, John!");
259
260        let variables = &vars!(name = "John", extra = "data");
261        let result = tmpl.format(variables).unwrap();
262        assert_eq!(result, "Hello, John!");
263
264        let tmpl_multiple_vars = Template::new("Hello, {{name}}! You are {{adjective}}.").unwrap();
265        let variables = &vars!(name = "John", adjective = "awesome");
266        let result = tmpl_multiple_vars.format(variables).unwrap();
267        assert_eq!(result, "Hello, John! You are awesome.");
268
269        let tmpl_multiple_instances =
270            Template::new("{{greeting}}, {{name}}! {{greeting}}, again!").unwrap();
271        let variables = &vars!(greeting = "Hello", name = "John");
272        let result = tmpl_multiple_instances.format(variables).unwrap();
273        assert_eq!(result, "Hello, John! Hello, again!");
274    }
275
276    #[test]
277    fn test_format_mustache_error() {
278        let tmpl_missing_var = Template::new("Hello, {{name}}!").unwrap();
279        let variables = &vars!(adjective = "cool");
280        let err = tmpl_missing_var.format(variables).unwrap_err();
281        assert!(matches!(err, TemplateError::MissingVariable(_)));
282    }
283
284    #[test]
285    fn test_format_plaintext() {
286        let tmpl = Template::new("Hello, world!").unwrap();
287        let variables = &vars!();
288        let result = tmpl.format(variables).unwrap();
289        assert_eq!(result, "Hello, world!");
290
291        let tmpl = Template::new("Welcome to the Rust world!").unwrap();
292        let variables = &vars!(name = "John", adjective = "awesome");
293        let result = tmpl.format(variables).unwrap();
294        assert_eq!(result, "Welcome to the Rust world!");
295
296        let tmpl_no_placeholders = Template::new("No placeholders here").unwrap();
297        let variables = &vars!(name = "ignored");
298        let result = tmpl_no_placeholders.format(variables).unwrap();
299        assert_eq!(result, "No placeholders here");
300
301        let tmpl_extra_spaces = Template::new("  Just some text   ").unwrap();
302        let variables = &vars!();
303        let result = tmpl_extra_spaces.format(variables).unwrap();
304        assert_eq!(result, "  Just some text   ");
305
306        let tmpl_with_newlines = Template::new("Text with\nmultiple lines\n").unwrap();
307        let result = tmpl_with_newlines.format(&vars!()).unwrap();
308        assert_eq!(result, "Text with\nmultiple lines\n");
309    }
310
311    #[test]
312    fn test_partial_adds_variables() {
313        let mut template = Template::new("Hello, {name}").unwrap();
314
315        template.partial("name", "Jill");
316
317        let partial_vars = template.partial_vars();
318        assert_eq!(partial_vars.get("name"), Some(&"Jill".to_string()));
319
320        let variables = &vars!();
321        let formatted = template.format(variables).unwrap();
322        assert_eq!(formatted, "Hello, Jill");
323
324        let variables = &vars!(name = "Alice");
325        let formatted = template.format(variables).unwrap();
326        assert_eq!(formatted, "Hello, Alice");
327    }
328
329    #[test]
330    fn test_multiple_partials() {
331        let mut template = Template::new("Hello, {name}. You are feeling {mood}.").unwrap();
332
333        template.partial("name", "Jill").partial("mood", "happy");
334
335        let partial_vars = template.partial_vars();
336        assert_eq!(partial_vars.get("name"), Some(&"Jill".to_string()));
337        assert_eq!(partial_vars.get("mood"), Some(&"happy".to_string()));
338
339        let variables = &vars!();
340        let formatted = template.format(variables).unwrap();
341        assert_eq!(formatted, "Hello, Jill. You are feeling happy.");
342
343        let variables = &vars!(mood = "excited");
344        let formatted = template.format(variables).unwrap();
345        assert_eq!(formatted, "Hello, Jill. You are feeling excited.");
346    }
347
348    #[test]
349    fn test_clear_partials() {
350        let mut template = Template::new("Hello, {name}.").unwrap();
351
352        template.partial("name", "Jill").clear_partials();
353
354        let partial_vars = template.partial_vars();
355        assert!(partial_vars.is_empty());
356
357        let variables = &vars!(name = "John");
358        let formatted = template.format(variables).unwrap();
359        assert_eq!(formatted, "Hello, John.");
360
361        let variables = &vars!();
362        let result = template.format(variables);
363        assert!(result.is_err());
364    }
365
366    #[test]
367    fn test_partial_vars() {
368        let mut template = Template::new("Hello, {name}!").unwrap();
369        template.partial("name", "Alice");
370
371        assert_eq!(
372            template.partial_vars().get("name"),
373            Some(&"Alice".to_string())
374        );
375
376        template.partial("name", "Bob");
377        assert_eq!(
378            template.partial_vars().get("name"),
379            Some(&"Bob".to_string())
380        );
381
382        template.clear_partials();
383        assert!(template.partial_vars().is_empty());
384
385        let variables = &vars!(name = "Charlie");
386        let formatted = template.format(variables).unwrap();
387        assert_eq!(formatted, "Hello, Charlie!");
388
389        let variables = &vars!();
390        let result = template.format(variables);
391        assert!(result.is_err());
392    }
393
394    #[test]
395    fn test_format_with_partials_and_runtime_vars() {
396        let mut template = Template::new("Hello, {name}. You are feeling {mood}.").unwrap();
397
398        template.partial("name", "Alice").partial("mood", "calm");
399
400        let variables = &vars!();
401        let formatted = template.format(variables).unwrap();
402        assert_eq!(formatted, "Hello, Alice. You are feeling calm.");
403
404        let variables = &vars!(mood = "excited");
405        let formatted = template.format(variables).unwrap();
406        assert_eq!(formatted, "Hello, Alice. You are feeling excited.");
407
408        let variables = &vars!(name = "Bob");
409        let formatted = template.format(variables).unwrap();
410        assert_eq!(formatted, "Hello, Bob. You are feeling calm.");
411
412        let variables = &vars!(name = "Charlie", mood = "joyful");
413        let formatted = template.format(variables).unwrap();
414        assert_eq!(formatted, "Hello, Charlie. You are feeling joyful.");
415    }
416
417    #[test]
418    fn test_format_with_missing_variables_in_partials() {
419        let mut template = Template::new("Hello, {name}. You are feeling {mood}.").unwrap();
420
421        template.partial("name", "Alice");
422
423        let variables = &vars!();
424        let result = template.format(variables);
425        assert!(result.is_err());
426
427        let variables = &vars!(mood = "happy");
428        let formatted = template.format(variables).unwrap();
429        assert_eq!(formatted, "Hello, Alice. You are feeling happy.");
430    }
431
432    #[test]
433    fn test_format_with_conflicting_partial_and_runtime_vars() {
434        let mut template = Template::new("Hello, {name}. You are feeling {mood}.").unwrap();
435
436        template.partial("name", "Alice").partial("mood", "calm");
437
438        let variables = &vars!(name = "Bob", mood = "excited");
439        let formatted = template.format(variables).unwrap();
440        assert_eq!(formatted, "Hello, Bob. You are feeling excited.");
441    }
442
443    #[test]
444    fn test_try_from_string_valid_template() {
445        let valid_template = "Hello, {name}! Your order number is {order_id}.".to_string();
446
447        let template = Template::try_from(valid_template.clone());
448        assert!(template.is_ok());
449        let template = template.unwrap();
450
451        assert_eq!(template.template, valid_template);
452        assert_eq!(template.template_format, TemplateFormat::FmtString);
453        assert_eq!(
454            template.input_variables,
455            vec!["name".to_string(), "order_id".to_string()]
456        );
457    }
458
459    #[test]
460    fn test_try_from_string_valid_mustache_template() {
461        let valid_mustache_template =
462            "Hello, {{name}}! Your favorite color is {{color}}.".to_string();
463
464        let template = Template::try_from(valid_mustache_template.clone());
465        assert!(template.is_ok());
466        let template = template.unwrap();
467
468        assert_eq!(template.template, valid_mustache_template);
469        assert_eq!(template.template_format, TemplateFormat::Mustache);
470        assert_eq!(
471            template.input_variables,
472            vec!["name".to_string(), "color".to_string()]
473        );
474    }
475
476    #[test]
477    fn test_try_from_string_plaintext_template() {
478        let plaintext_template = "Hello, world!".to_string();
479
480        let template = Template::try_from(plaintext_template.clone());
481        assert!(template.is_ok());
482        let template = template.unwrap();
483
484        assert_eq!(template.template, plaintext_template);
485        assert_eq!(template.template_format, TemplateFormat::PlainText);
486        assert!(template.input_variables.is_empty());
487    }
488
489    #[test]
490    fn test_try_from_string_malformed_template() {
491        let invalid_template = "Hello, {name!".to_string();
492
493        let template = Template::try_from(invalid_template.clone());
494        assert!(template.is_err());
495        if let Err(TemplateError::MalformedTemplate(msg)) = template {
496            println!("{}", msg);
497        } else {
498            panic!("Expected TemplateError::MalformedTemplate");
499        }
500    }
501
502    #[test]
503    fn test_try_from_string_mixed_format_template() {
504        let mixed_format_template = "Hello, {name} and {{color}}.".to_string();
505
506        let template = Template::try_from(mixed_format_template.clone());
507        assert!(template.is_err());
508        if let Err(TemplateError::MalformedTemplate(msg)) = template {
509            println!("{}", msg);
510        } else {
511            panic!("Expected TemplateError::MalformedTemplate");
512        }
513    }
514}