1use include_dir::{include_dir, Dir};
2use minijinja::{Environment, Error as MiniJinjaError, Value as MJValue};
3use once_cell::sync::Lazy;
4use serde::Serialize;
5use std::path::PathBuf;
6use std::sync::{Arc, RwLock};
7
8static CORE_PROMPTS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/prompts");
11
12static GLOBAL_ENV: Lazy<Arc<RwLock<Environment<'static>>>> = Lazy::new(|| {
18 let mut env = Environment::new();
19 env.set_trim_blocks(true);
20 env.set_lstrip_blocks(true);
21
22 for file in CORE_PROMPTS_DIR.files() {
24 let name = file.path().to_string_lossy().to_string();
25 let source = String::from_utf8_lossy(file.contents()).to_string();
26
27 let static_name: &'static str = Box::leak(name.into_boxed_str());
31 let static_source: &'static str = Box::leak(source.into_boxed_str());
32
33 if let Err(e) = env.add_template(static_name, static_source) {
34 tracing::error!("Failed to add template {}: {}", static_name, e);
35 }
36 }
37
38 Arc::new(RwLock::new(env))
39});
40
41pub fn render_global_template<T: Serialize>(
47 template_name: &str,
48 context_data: &T,
49) -> Result<String, MiniJinjaError> {
50 let env = GLOBAL_ENV.read().expect("GLOBAL_ENV lock poisoned");
51 let tmpl = env.get_template(template_name)?;
52 let ctx = MJValue::from_serialize(context_data);
53 let rendered = tmpl.render(ctx)?;
54 Ok(rendered.trim().to_string())
55}
56
57pub fn render_global_file<T: Serialize>(
66 template_file: impl Into<PathBuf>,
67 context_data: &T,
68) -> Result<String, MiniJinjaError> {
69 let file_path = template_file.into();
70 let template_name = file_path.to_string_lossy().to_string();
71
72 render_global_template(&template_name, context_data)
73}
74
75pub fn render_global_from_file<T: Serialize>(
77 template_file: impl Into<PathBuf>,
78 context_data: &T,
79) -> Result<String, MiniJinjaError> {
80 render_global_file(template_file, context_data)
81}
82
83pub fn render_inline_once<T: Serialize>(
92 template_str: &str,
93 context_data: &T,
94) -> Result<String, MiniJinjaError> {
95 let mut env = Environment::new();
96 env.add_template("inline_ephemeral", template_str)?;
97 let tmpl = env.get_template("inline_ephemeral")?;
98 let ctx = MJValue::from_serialize(context_data);
99 let rendered = tmpl.render(ctx)?;
100 Ok(rendered.trim().to_string())
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use serde_json::json;
107 use std::collections::HashMap;
108
109 #[derive(Serialize)]
111 struct TestContext {
112 name: String,
113 age: u32,
114 }
115
116 fn build_context(name: Option<&str>, age: Option<u32>) -> HashMap<String, serde_json::Value> {
118 let mut ctx = HashMap::new();
119 if let Some(n) = name {
120 ctx.insert("name".to_string(), json!(n));
121 }
122 if let Some(a) = age {
123 ctx.insert("age".to_string(), json!(a));
124 }
125 ctx
126 }
127
128 #[test]
129 fn test_render_inline_once_basic() {
130 let template_str = "Hello, {{ name }}! You are {{ age }} years old.";
131 let context = TestContext {
132 name: "Alice".to_string(),
133 age: 30,
134 };
135
136 let result = render_inline_once(template_str, &context).unwrap();
137 assert_eq!(result, "Hello, Alice! You are 30 years old.");
138 }
139
140 #[test]
141 fn test_render_inline_missing_variable() {
142 let template_str = "Hello, {{ name }}! You are {{ age }} years old.";
143 let context = build_context(Some("Alice"), None);
144 let result = render_inline_once(template_str, &context).unwrap();
147 assert!(result.contains("Hello, Alice! You are years old."));
148 }
149
150 #[test]
151 fn test_global_file_render() {
152 let context = TestContext {
155 name: "Alice".to_string(),
156 age: 30,
157 };
158
159 let result = render_global_file("mock.md", &context).unwrap();
160 assert_eq!(
163 result,
164 "This prompt is only used for testing.\n\nHello, Alice! You are 30 years old."
165 );
166 }
167
168 #[test]
169 fn test_global_file_not_found() {
170 let context = TestContext {
171 name: "Unused".to_string(),
172 age: 99,
173 };
174
175 let result = render_global_file("non_existent.md", &context);
176 assert!(result.is_err(), "Should fail because file is missing");
177 }
178
179 #[test]
180 fn test_inline_complex_object() {
181 #[derive(Serialize)]
183 struct Tool {
184 name: String,
185 description: String,
186 }
187
188 #[derive(Serialize)]
189 struct ToolsContext {
190 tools: Vec<Tool>,
191 }
192
193 let template_str = "\
194### Tool Descriptions
195{% for tool in tools %}
196- {{ tool.name }}: {{ tool.description }}
197{% endfor %}";
198
199 let context = ToolsContext {
200 tools: vec![
201 Tool {
202 name: "calculator".to_string(),
203 description: "Performs basic math operations".to_string(),
204 },
205 Tool {
206 name: "weather".to_string(),
207 description: "Gets weather information".to_string(),
208 },
209 ],
210 };
211
212 let rendered = render_inline_once(template_str, &context).unwrap();
213 let expected = "\
214### Tool Descriptions
215
216- calculator: Performs basic math operations
217
218- weather: Gets weather information";
219 assert_eq!(rendered, expected);
220 }
221
222 #[test]
223 fn test_inline_with_empty_list() {
224 let template_str = "\
225### Tool Descriptions
226{% for tool in tools %}
227- {{ tool.name }}: {{ tool.description }}
228{% endfor %}";
229
230 #[derive(Serialize)]
231 struct ToolsContext {
232 tools: Vec<String>, }
234
235 let context = ToolsContext { tools: vec![] };
236 let rendered = render_inline_once(template_str, &context).unwrap();
237 let expected = "### Tool Descriptions";
238 assert_eq!(rendered, expected);
239 }
240}