Skip to main content

aster/
prompt_template.rs

1use include_dir::{include_dir, Dir};
2use minijinja::{Environment, Error as MiniJinjaError, Value as MJValue};
3use once_cell::sync::Lazy;
4use serde::Serialize;
5use std::path::PathBuf;
6use std::sync::{Arc, RwLock};
7
8/// This directory will be embedded into the final binary.
9/// Typically used to store "core" or "system" prompts.
10static CORE_PROMPTS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/prompts");
11
12/// A global MiniJinja environment storing the "core" prompts.
13///
14/// - Loaded at startup from the `CORE_PROMPTS_DIR`.
15/// - Ideal for "system" templates that don't change often.
16/// - *Not* used for extension prompts (which are ephemeral).
17static GLOBAL_ENV: Lazy<Arc<RwLock<Environment<'static>>>> = Lazy::new(|| {
18    let mut env = Environment::new();
19    env.set_trim_blocks(true);
20    env.set_lstrip_blocks(true);
21
22    // Pre-load all core templates from the embedded dir.
23    for file in CORE_PROMPTS_DIR.files() {
24        let name = file.path().to_string_lossy().to_string();
25        let source = String::from_utf8_lossy(file.contents()).to_string();
26
27        // Since we're using 'static lifetime for the Environment, we need to ensure
28        // the strings we add as templates live for the entire program duration.
29        // We can achieve this by leaking the strings (acceptable for initialization).
30        let static_name: &'static str = Box::leak(name.into_boxed_str());
31        let static_source: &'static str = Box::leak(source.into_boxed_str());
32
33        if let Err(e) = env.add_template(static_name, static_source) {
34            tracing::error!("Failed to add template {}: {}", static_name, e);
35        }
36    }
37
38    Arc::new(RwLock::new(env))
39});
40
41/// Renders a prompt from the global environment by name.
42///
43/// # Arguments
44/// * `template_name` - The name of the template (usually the file path or a custom ID).
45/// * `context_data`  - Data to be inserted into the template (must be `Serialize`).
46pub fn render_global_template<T: Serialize>(
47    template_name: &str,
48    context_data: &T,
49) -> Result<String, MiniJinjaError> {
50    let env = GLOBAL_ENV.read().expect("GLOBAL_ENV lock poisoned");
51    let tmpl = env.get_template(template_name)?;
52    let ctx = MJValue::from_serialize(context_data);
53    let rendered = tmpl.render(ctx)?;
54    Ok(rendered.trim().to_string())
55}
56
57/// Renders a file from `CORE_PROMPTS_DIR` within the global environment.
58///
59/// # Arguments
60/// * `template_file` - The file path within the embedded directory (e.g. "system.md").
61/// * `context_data`  - Data to be inserted into the template (must be `Serialize`).
62///
63/// This function **assumes** the file is already in `CORE_PROMPTS_DIR`. If it wasn't
64/// added to the global environment at startup (due to parse errors, etc.), this will error out.
65pub fn render_global_file<T: Serialize>(
66    template_file: impl Into<PathBuf>,
67    context_data: &T,
68) -> Result<String, MiniJinjaError> {
69    let file_path = template_file.into();
70    let template_name = file_path.to_string_lossy().to_string();
71
72    render_global_template(&template_name, context_data)
73}
74
75/// Alias for render_global_file for backward compatibility
76pub fn render_global_from_file<T: Serialize>(
77    template_file: impl Into<PathBuf>,
78    context_data: &T,
79) -> Result<String, MiniJinjaError> {
80    render_global_file(template_file, context_data)
81}
82
83/// Renders a **one-off ephemeral** template (inline string).
84///
85/// This does *not* store anything in the global environment and is best for
86/// extension prompts or user-supplied templates that are used infrequently.
87///
88/// # Arguments
89/// * `template_str`  - The raw template string.
90/// * `context_data`  - Data to be inserted into the template (must be `Serialize`).
91pub fn render_inline_once<T: Serialize>(
92    template_str: &str,
93    context_data: &T,
94) -> Result<String, MiniJinjaError> {
95    let mut env = Environment::new();
96    env.add_template("inline_ephemeral", template_str)?;
97    let tmpl = env.get_template("inline_ephemeral")?;
98    let ctx = MJValue::from_serialize(context_data);
99    let rendered = tmpl.render(ctx)?;
100    Ok(rendered.trim().to_string())
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use serde_json::json;
107    use std::collections::HashMap;
108
109    /// For convenience in tests, define a small struct or use a HashMap to provide context.
110    #[derive(Serialize)]
111    struct TestContext {
112        name: String,
113        age: u32,
114    }
115
116    // A simple function to help us test missing or partial data
117    fn build_context(name: Option<&str>, age: Option<u32>) -> HashMap<String, serde_json::Value> {
118        let mut ctx = HashMap::new();
119        if let Some(n) = name {
120            ctx.insert("name".to_string(), json!(n));
121        }
122        if let Some(a) = age {
123            ctx.insert("age".to_string(), json!(a));
124        }
125        ctx
126    }
127
128    #[test]
129    fn test_render_inline_once_basic() {
130        let template_str = "Hello, {{ name }}! You are {{ age }} years old.";
131        let context = TestContext {
132            name: "Alice".to_string(),
133            age: 30,
134        };
135
136        let result = render_inline_once(template_str, &context).unwrap();
137        assert_eq!(result, "Hello, Alice! You are 30 years old.");
138    }
139
140    #[test]
141    fn test_render_inline_missing_variable() {
142        let template_str = "Hello, {{ name }}! You are {{ age }} years old.";
143        let context = build_context(Some("Alice"), None);
144        // MiniJinja doesn't fail on missing variables, it renders them as empty strings
145        // So we should check that it renders successfully but with missing data
146        let result = render_inline_once(template_str, &context).unwrap();
147        assert!(result.contains("Hello, Alice! You are  years old."));
148    }
149
150    #[test]
151    fn test_global_file_render() {
152        // "mock.md" should exist in the embedded CORE_PROMPTS_DIR
153        // and have placeholders for `name` and `age`.
154        let context = TestContext {
155            name: "Alice".to_string(),
156            age: 30,
157        };
158
159        let result = render_global_file("mock.md", &context).unwrap();
160        // Assume mock.md content is something like:
161        //  "This prompt is only used for testing.\n\nHello, {{ name }}! You are {{ age }} years old."
162        assert_eq!(
163            result,
164            "This prompt is only used for testing.\n\nHello, Alice! You are 30 years old."
165        );
166    }
167
168    #[test]
169    fn test_global_file_not_found() {
170        let context = TestContext {
171            name: "Unused".to_string(),
172            age: 99,
173        };
174
175        let result = render_global_file("non_existent.md", &context);
176        assert!(result.is_err(), "Should fail because file is missing");
177    }
178
179    #[test]
180    fn test_inline_complex_object() {
181        // Example with more complex data.
182        #[derive(Serialize)]
183        struct Tool {
184            name: String,
185            description: String,
186        }
187
188        #[derive(Serialize)]
189        struct ToolsContext {
190            tools: Vec<Tool>,
191        }
192
193        let template_str = "\
194### Tool Descriptions
195{% for tool in tools %}
196- {{ tool.name }}: {{ tool.description }}
197{% endfor %}";
198
199        let context = ToolsContext {
200            tools: vec![
201                Tool {
202                    name: "calculator".to_string(),
203                    description: "Performs basic math operations".to_string(),
204                },
205                Tool {
206                    name: "weather".to_string(),
207                    description: "Gets weather information".to_string(),
208                },
209            ],
210        };
211
212        let rendered = render_inline_once(template_str, &context).unwrap();
213        let expected = "\
214### Tool Descriptions
215
216- calculator: Performs basic math operations
217
218- weather: Gets weather information";
219        assert_eq!(rendered, expected);
220    }
221
222    #[test]
223    fn test_inline_with_empty_list() {
224        let template_str = "\
225### Tool Descriptions
226{% for tool in tools %}
227- {{ tool.name }}: {{ tool.description }}
228{% endfor %}";
229
230        #[derive(Serialize)]
231        struct ToolsContext {
232            tools: Vec<String>, // or a struct if needed
233        }
234
235        let context = ToolsContext { tools: vec![] };
236        let rendered = render_inline_once(template_str, &context).unwrap();
237        let expected = "### Tool Descriptions";
238        assert_eq!(rendered, expected);
239    }
240}