Skip to main content

ai_agents_context/
render.rs

1use std::collections::HashMap;
2
3use minijinja::{Environment, Value as MJValue};
4use serde_json::Value;
5
6use ai_agents_core::{AgentError, Result};
7
8pub struct TemplateRenderer {
9    env: Environment<'static>,
10}
11
12impl Default for TemplateRenderer {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl TemplateRenderer {
19    pub fn new() -> Self {
20        let mut env = Environment::new();
21        env.set_auto_escape_callback(|_| minijinja::AutoEscape::None);
22        Self { env }
23    }
24
25    pub fn render(&self, template: &str, context: &HashMap<String, Value>) -> Result<String> {
26        let mut ctx = HashMap::new();
27
28        // Build context map
29        let mut context_map = serde_json::Map::new();
30        for (key, value) in context {
31            context_map.insert(key.clone(), value.clone());
32        }
33        ctx.insert("context", json_to_minijinja(&Value::Object(context_map)));
34
35        if let Some(env_vars) = context.get("env") {
36            ctx.insert("env", json_to_minijinja(env_vars));
37        }
38
39        if let Some(state) = context.get("state") {
40            ctx.insert("state", json_to_minijinja(state));
41        }
42
43        let tmpl = self
44            .env
45            .template_from_str(template)
46            .map_err(|e| AgentError::TemplateError(e.to_string()))?;
47
48        tmpl.render(&ctx)
49            .map_err(|e| AgentError::TemplateError(e.to_string()))
50    }
51
52    pub fn render_path(
53        &self,
54        path_template: &str,
55        context: &HashMap<String, Value>,
56    ) -> Result<String> {
57        self.render(path_template, context)
58    }
59
60    pub fn render_with_state(
61        &self,
62        template: &str,
63        context: &HashMap<String, Value>,
64        state_name: &str,
65        previous_state: Option<&str>,
66        turn_count: u32,
67        max_turns: Option<u32>,
68    ) -> Result<String> {
69        let mut full_context = context.clone();
70
71        let mut state_ctx = serde_json::Map::new();
72        state_ctx.insert("name".into(), Value::String(state_name.to_string()));
73        state_ctx.insert(
74            "previous".into(),
75            Value::String(previous_state.unwrap_or("none").to_string()),
76        );
77        state_ctx.insert("turn_count".into(), Value::Number(turn_count.into()));
78        if let Some(max) = max_turns {
79            state_ctx.insert("max_turns".into(), Value::Number(max.into()));
80        }
81        full_context.insert("state".into(), Value::Object(state_ctx));
82
83        self.render(template, &full_context)
84    }
85}
86
87fn json_to_minijinja(value: &Value) -> MJValue {
88    match value {
89        Value::Null => MJValue::from(()),
90        Value::Bool(b) => MJValue::from(*b),
91        Value::Number(n) => {
92            if let Some(i) = n.as_i64() {
93                MJValue::from(i)
94            } else if let Some(u) = n.as_u64() {
95                MJValue::from(u)
96            } else if let Some(f) = n.as_f64() {
97                MJValue::from(f)
98            } else {
99                MJValue::from(())
100            }
101        }
102        Value::String(s) => MJValue::from(s.as_str()),
103        Value::Array(arr) => {
104            let items: Vec<MJValue> = arr.iter().map(json_to_minijinja).collect();
105            MJValue::from(items)
106        }
107        Value::Object(obj) => {
108            let map: std::collections::BTreeMap<String, MJValue> = obj
109                .iter()
110                .map(|(k, v)| (k.clone(), json_to_minijinja(v)))
111                .collect();
112            MJValue::from_iter(map)
113        }
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use serde_json::json;
121
122    #[test]
123    fn test_simple_variable() {
124        let renderer = TemplateRenderer::new();
125        let mut context = HashMap::new();
126        context.insert("user".into(), json!({"name": "Alice", "tier": "premium"}));
127
128        let template = "Hello, {{ context.user.name }}!";
129        let result = renderer.render(template, &context).unwrap();
130        assert_eq!(result, "Hello, Alice!");
131    }
132
133    #[test]
134    fn test_nested_variable() {
135        let renderer = TemplateRenderer::new();
136        let mut context = HashMap::new();
137        context.insert(
138            "user".into(),
139            json!({"preferences": {"theme": "dark", "language": "ko"}}),
140        );
141
142        let template = "Theme: {{ context.user.preferences.theme }}";
143        let result = renderer.render(template, &context).unwrap();
144        assert_eq!(result, "Theme: dark");
145    }
146
147    #[test]
148    fn test_conditional() {
149        let renderer = TemplateRenderer::new();
150        let mut context = HashMap::new();
151        context.insert("user".into(), json!({"tier": "premium"}));
152
153        let template = r#"{% if context.user.tier == "premium" %}Premium user{% else %}Regular user{% endif %}"#;
154        let result = renderer.render(template, &context).unwrap();
155        assert_eq!(result, "Premium user");
156    }
157
158    #[test]
159    fn test_loop() {
160        let renderer = TemplateRenderer::new();
161        let mut context = HashMap::new();
162        context.insert("items".into(), json!([{"name": "A"}, {"name": "B"}]));
163
164        let template = "{% for item in context.items %}{{ item.name }}{% endfor %}";
165        let result = renderer.render(template, &context).unwrap();
166        assert_eq!(result, "AB");
167    }
168
169    #[test]
170    fn test_state_variables() {
171        let renderer = TemplateRenderer::new();
172        let context = HashMap::new();
173
174        let template = "State: {{ state.name }}, Turn: {{ state.turn_count }}";
175        let result = renderer
176            .render_with_state(template, &context, "support", Some("greeting"), 2, Some(5))
177            .unwrap();
178        assert_eq!(result, "State: support, Turn: 2");
179    }
180
181    #[test]
182    fn test_korean_content() {
183        let renderer = TemplateRenderer::new();
184        let mut context = HashMap::new();
185        context.insert("user".into(), json!({"name": "김철수", "language": "ko"}));
186
187        let template = "안녕하세요, {{ context.user.name }}님! 언어: {{ context.user.language }}";
188        let result = renderer.render(template, &context).unwrap();
189        assert_eq!(result, "안녕하세요, 김철수님! 언어: ko");
190    }
191
192    #[test]
193    fn test_path_rendering() {
194        let renderer = TemplateRenderer::new();
195        let mut context = HashMap::new();
196        context.insert("user".into(), json!({"language": "ja"}));
197
198        let path = "./rules/{{ context.user.language }}/support.txt";
199        let result = renderer.render_path(path, &context).unwrap();
200        assert_eq!(result, "./rules/ja/support.txt");
201    }
202
203    #[test]
204    fn test_default_filter() {
205        let renderer = TemplateRenderer::new();
206        let context = HashMap::new();
207
208        let template = "{{ context.missing | default('N/A') }}";
209        let result = renderer.render(template, &context).unwrap();
210        assert_eq!(result, "N/A");
211    }
212}