ai_agents_context/
render.rs1use std::collections::HashMap;
2
3use minijinja::{Environment, Value as MJValue};
4use serde_json::Value;
5
6use ai_agents_core::{AgentError, Result};
7
8pub struct TemplateRenderer {
9 env: Environment<'static>,
10}
11
12impl Default for TemplateRenderer {
13 fn default() -> Self {
14 Self::new()
15 }
16}
17
18impl TemplateRenderer {
19 pub fn new() -> Self {
20 let mut env = Environment::new();
21 env.set_auto_escape_callback(|_| minijinja::AutoEscape::None);
22 Self { env }
23 }
24
25 pub fn render(&self, template: &str, context: &HashMap<String, Value>) -> Result<String> {
26 let mut ctx = HashMap::new();
27
28 let mut context_map = serde_json::Map::new();
30 for (key, value) in context {
31 context_map.insert(key.clone(), value.clone());
32 }
33 ctx.insert("context", json_to_minijinja(&Value::Object(context_map)));
34
35 if let Some(env_vars) = context.get("env") {
36 ctx.insert("env", json_to_minijinja(env_vars));
37 }
38
39 if let Some(state) = context.get("state") {
40 ctx.insert("state", json_to_minijinja(state));
41 }
42
43 let tmpl = self
44 .env
45 .template_from_str(template)
46 .map_err(|e| AgentError::TemplateError(e.to_string()))?;
47
48 tmpl.render(&ctx)
49 .map_err(|e| AgentError::TemplateError(e.to_string()))
50 }
51
52 pub fn render_path(
53 &self,
54 path_template: &str,
55 context: &HashMap<String, Value>,
56 ) -> Result<String> {
57 self.render(path_template, context)
58 }
59
60 pub fn render_with_state(
61 &self,
62 template: &str,
63 context: &HashMap<String, Value>,
64 state_name: &str,
65 previous_state: Option<&str>,
66 turn_count: u32,
67 max_turns: Option<u32>,
68 ) -> Result<String> {
69 let mut full_context = context.clone();
70
71 let mut state_ctx = serde_json::Map::new();
72 state_ctx.insert("name".into(), Value::String(state_name.to_string()));
73 state_ctx.insert(
74 "previous".into(),
75 Value::String(previous_state.unwrap_or("none").to_string()),
76 );
77 state_ctx.insert("turn_count".into(), Value::Number(turn_count.into()));
78 if let Some(max) = max_turns {
79 state_ctx.insert("max_turns".into(), Value::Number(max.into()));
80 }
81 full_context.insert("state".into(), Value::Object(state_ctx));
82
83 self.render(template, &full_context)
84 }
85}
86
87fn json_to_minijinja(value: &Value) -> MJValue {
88 match value {
89 Value::Null => MJValue::from(()),
90 Value::Bool(b) => MJValue::from(*b),
91 Value::Number(n) => {
92 if let Some(i) = n.as_i64() {
93 MJValue::from(i)
94 } else if let Some(u) = n.as_u64() {
95 MJValue::from(u)
96 } else if let Some(f) = n.as_f64() {
97 MJValue::from(f)
98 } else {
99 MJValue::from(())
100 }
101 }
102 Value::String(s) => MJValue::from(s.as_str()),
103 Value::Array(arr) => {
104 let items: Vec<MJValue> = arr.iter().map(json_to_minijinja).collect();
105 MJValue::from(items)
106 }
107 Value::Object(obj) => {
108 let map: std::collections::BTreeMap<String, MJValue> = obj
109 .iter()
110 .map(|(k, v)| (k.clone(), json_to_minijinja(v)))
111 .collect();
112 MJValue::from_iter(map)
113 }
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use serde_json::json;
121
122 #[test]
123 fn test_simple_variable() {
124 let renderer = TemplateRenderer::new();
125 let mut context = HashMap::new();
126 context.insert("user".into(), json!({"name": "Alice", "tier": "premium"}));
127
128 let template = "Hello, {{ context.user.name }}!";
129 let result = renderer.render(template, &context).unwrap();
130 assert_eq!(result, "Hello, Alice!");
131 }
132
133 #[test]
134 fn test_nested_variable() {
135 let renderer = TemplateRenderer::new();
136 let mut context = HashMap::new();
137 context.insert(
138 "user".into(),
139 json!({"preferences": {"theme": "dark", "language": "ko"}}),
140 );
141
142 let template = "Theme: {{ context.user.preferences.theme }}";
143 let result = renderer.render(template, &context).unwrap();
144 assert_eq!(result, "Theme: dark");
145 }
146
147 #[test]
148 fn test_conditional() {
149 let renderer = TemplateRenderer::new();
150 let mut context = HashMap::new();
151 context.insert("user".into(), json!({"tier": "premium"}));
152
153 let template = r#"{% if context.user.tier == "premium" %}Premium user{% else %}Regular user{% endif %}"#;
154 let result = renderer.render(template, &context).unwrap();
155 assert_eq!(result, "Premium user");
156 }
157
158 #[test]
159 fn test_loop() {
160 let renderer = TemplateRenderer::new();
161 let mut context = HashMap::new();
162 context.insert("items".into(), json!([{"name": "A"}, {"name": "B"}]));
163
164 let template = "{% for item in context.items %}{{ item.name }}{% endfor %}";
165 let result = renderer.render(template, &context).unwrap();
166 assert_eq!(result, "AB");
167 }
168
169 #[test]
170 fn test_state_variables() {
171 let renderer = TemplateRenderer::new();
172 let context = HashMap::new();
173
174 let template = "State: {{ state.name }}, Turn: {{ state.turn_count }}";
175 let result = renderer
176 .render_with_state(template, &context, "support", Some("greeting"), 2, Some(5))
177 .unwrap();
178 assert_eq!(result, "State: support, Turn: 2");
179 }
180
181 #[test]
182 fn test_korean_content() {
183 let renderer = TemplateRenderer::new();
184 let mut context = HashMap::new();
185 context.insert("user".into(), json!({"name": "김철수", "language": "ko"}));
186
187 let template = "안녕하세요, {{ context.user.name }}님! 언어: {{ context.user.language }}";
188 let result = renderer.render(template, &context).unwrap();
189 assert_eq!(result, "안녕하세요, 김철수님! 언어: ko");
190 }
191
192 #[test]
193 fn test_path_rendering() {
194 let renderer = TemplateRenderer::new();
195 let mut context = HashMap::new();
196 context.insert("user".into(), json!({"language": "ja"}));
197
198 let path = "./rules/{{ context.user.language }}/support.txt";
199 let result = renderer.render_path(path, &context).unwrap();
200 assert_eq!(result, "./rules/ja/support.txt");
201 }
202
203 #[test]
204 fn test_default_filter() {
205 let renderer = TemplateRenderer::new();
206 let context = HashMap::new();
207
208 let template = "{{ context.missing | default('N/A') }}";
209 let result = renderer.render(template, &context).unwrap();
210 assert_eq!(result, "N/A");
211 }
212}