use std::collections::HashMap;
use std::path::Path;
use minijinja::Environment;
use crate::error::{Error, Result};
#[derive(Clone)]
pub struct JinjaTemplate {
env: Environment<'static>,
name: String,
}
impl JinjaTemplate {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let content = std::fs::read_to_string(path).map_err(|e| {
Error::Other(format!(
"Failed to read template file '{}': {}",
path.display(),
e
))
})?;
let name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("template")
.to_string();
Self::from_str(&name, &content)
}
pub fn from_str(name: &str, content: &str) -> Result<Self> {
let mut env = Environment::new();
env.add_template_owned(name.to_string(), content.to_string())
.map_err(|e| Error::Other(format!("Failed to parse template '{}': {}", name, e)))?;
Ok(Self {
env,
name: name.to_string(),
})
}
pub fn name(&self) -> &str {
&self.name
}
pub fn render(&self, vars: &HashMap<String, minijinja::Value>) -> Result<String> {
let tmpl = self
.env
.get_template(&self.name)
.map_err(|e| Error::Other(format!("Template not found: {}", e)))?;
tmpl.render(vars)
.map_err(|e| Error::Other(format!("Failed to render template: {}", e)))
}
pub fn render_strings(&self, vars: &[(&str, &str)]) -> Result<String> {
let map: HashMap<String, minijinja::Value> = vars
.iter()
.map(|(k, v)| (k.to_string(), minijinja::Value::from(*v)))
.collect();
self.render(&map)
}
pub fn render_context(&self, ctx: minijinja::Value) -> Result<String> {
let tmpl = self
.env
.get_template(&self.name)
.map_err(|e| Error::Other(format!("Template not found: {}", e)))?;
tmpl.render(ctx)
.map_err(|e| Error::Other(format!("Failed to render template: {}", e)))
}
}
#[derive(Clone)]
pub struct JinjaFormatter {
template: JinjaTemplate,
}
impl JinjaFormatter {
pub fn new(template: JinjaTemplate) -> Self {
Self { template }
}
}
impl crate::recursive::formatter::PromptFormatter for JinjaFormatter {
fn format<'a>(
&'a self,
prompt: &'a str,
feedback: Option<&str>,
iteration: u32,
) -> std::borrow::Cow<'a, str> {
let iteration_str = iteration.to_string();
let rendered = self.template.render_strings(&[
("task", prompt),
("feedback", feedback.unwrap_or("")),
("iteration", &iteration_str),
]);
match rendered {
Ok(s) => std::borrow::Cow::Owned(s),
Err(_) => std::borrow::Cow::Borrowed(prompt),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jinja_template_from_str() {
let template = JinjaTemplate::from_str(
"test",
r#"
## Task
{{ question }}
## Answer
{{ answer }}
"#,
)
.unwrap();
assert_eq!(template.name(), "test");
let output = template
.render_strings(&[("question", "What is 2+2?"), ("answer", "4")])
.unwrap();
assert!(output.contains("What is 2+2?"));
assert!(output.contains("4"));
}
#[test]
fn test_jinja_template_with_loop() {
let template = JinjaTemplate::from_str(
"loop_test",
r#"
## Errors
{% for error in errors %}
- {{ error }}
{% endfor %}
"#,
)
.unwrap();
let mut vars = HashMap::new();
vars.insert(
"errors".to_string(),
minijinja::Value::from(vec!["Error 1", "Error 2", "Error 3"]),
);
let output = template.render(&vars).unwrap();
assert!(output.contains("- Error 1"));
assert!(output.contains("- Error 2"));
assert!(output.contains("- Error 3"));
}
#[test]
fn test_jinja_template_with_conditional() {
let template = JinjaTemplate::from_str(
"conditional_test",
r#"
{% if notes %}
## Notes
{{ notes }}
{% endif %}
Done.
"#,
)
.unwrap();
let output = template
.render_strings(&[("notes", "Some important notes")])
.unwrap();
assert!(output.contains("## Notes"));
assert!(output.contains("Some important notes"));
let mut vars = HashMap::new();
vars.insert("notes".to_string(), minijinja::Value::from(""));
let output = template.render(&vars).unwrap();
assert!(output.contains("Done."));
}
#[test]
fn test_jinja_template_with_default() {
let template = JinjaTemplate::from_str(
"default_test",
r#"
Language: {{ language | default("yaml") }}
"#,
)
.unwrap();
let vars: HashMap<String, minijinja::Value> = HashMap::new();
let output = template.render(&vars).unwrap();
assert!(output.contains("yaml"));
let output = template.render_strings(&[("language", "rust")]).unwrap();
assert!(output.contains("rust"));
}
#[test]
fn test_jinja_template_code_block() {
let template = JinjaTemplate::from_str(
"code_test",
r#"
```{{ language }}
{{ code }}
```
"#,
)
.unwrap();
let output = template
.render_strings(&[
("language", "yaml"),
(
"code",
"resources:\n bucket:\n type: gcp:storage:Bucket",
),
])
.unwrap();
assert!(output.contains("```yaml"));
assert!(output.contains("gcp:storage:Bucket"));
assert!(output.contains("```\n") || output.ends_with("```"));
}
#[test]
fn test_jinja_formatter_no_feedback() {
use crate::recursive::formatter::PromptFormatter;
let template = JinjaTemplate::from_str(
"fmt_test",
r#"## Task
{{ task }}
## Rules
- Be concise"#,
)
.unwrap();
let fmt = JinjaFormatter::new(template);
let result = fmt.format("Write hello world", None, 0);
assert!(result.contains("Write hello world"));
assert!(result.contains("Be concise"));
}
#[test]
fn test_jinja_formatter_with_feedback() {
use crate::recursive::formatter::PromptFormatter;
let template = JinjaTemplate::from_str(
"fmt_fb",
r#"{{ task }}
{% if feedback %}
Feedback: {{ feedback }}
{% endif %}
Iteration: {{ iteration }}"#,
)
.unwrap();
let fmt = JinjaFormatter::new(template);
let result = fmt.format("task", Some("improve it"), 2);
assert!(result.contains("task"));
assert!(result.contains("Feedback: improve it"));
assert!(result.contains("Iteration: 2"));
}
#[test]
fn test_jinja_formatter_uses_task_variable() {
use crate::recursive::formatter::PromptFormatter;
let template = JinjaTemplate::from_str("task_test", "Task: {{ task }}").unwrap();
let fmt = JinjaFormatter::new(template);
let result = fmt.format("original prompt", None, 0);
assert_eq!(result.as_ref(), "Task: original prompt");
}
}