Skip to main content

langextract_rust/
templates.rs

1//! Template engine and utilities for LangExtract.
2//!
3//! This module provides a unified template system that eliminates duplication
4//! across different prompt templates and formats.
5
6use crate::{data::{ExampleData, FormatType}, exceptions::{LangExtractError, LangExtractResult}};
7use std::collections::HashMap;
8
9/// Template error types
10#[derive(Debug, thiserror::Error)]
11pub enum TemplateError {
12    #[error("Missing required variable: {variable}")]
13    MissingVariable { variable: String },
14    #[error("Invalid template syntax: {message}")]
15    InvalidSyntax { message: String },
16    #[error("Variable substitution failed: {message}")]
17    SubstitutionError { message: String },
18}
19
20impl From<TemplateError> for LangExtractError {
21    fn from(err: TemplateError) -> Self {
22        LangExtractError::InvalidInput(err.to_string())
23    }
24}
25
26/// A simple but flexible template engine
27#[derive(Debug, Clone)]
28pub struct TemplateEngine {
29    /// Variable delimiter start (default: "{")
30    pub var_start: String,
31    /// Variable delimiter end (default: "}")
32    pub var_end: String,
33    /// Whether to allow missing variables (replaces with empty string)
34    pub allow_missing: bool,
35}
36
37impl Default for TemplateEngine {
38    fn default() -> Self {
39        Self {
40            var_start: "{".to_string(),
41            var_end: "}".to_string(),
42            allow_missing: false,
43        }
44    }
45}
46
47impl TemplateEngine {
48    /// Create a new template engine
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    /// Create a lenient template engine that allows missing variables
54    pub fn lenient() -> Self {
55        Self {
56            allow_missing: true,
57            ..Default::default()
58        }
59    }
60
61    /// Render a template with variables
62    pub fn render(&self, template: &str, variables: &HashMap<String, String>) -> LangExtractResult<String> {
63        let mut result = template.to_string();
64        let mut pos = 0;
65
66        while pos < result.len() {
67            if let Some(start) = result[pos..].find(&self.var_start) {
68                let abs_start = pos + start;
69                let search_from = abs_start + self.var_start.len();
70                
71                if let Some(end) = result[search_from..].find(&self.var_end) {
72                    let abs_end = search_from + end;
73                    let var_name = &result[abs_start + self.var_start.len()..abs_end];
74                    
75                    if let Some(value) = variables.get(var_name) {
76                        result.replace_range(abs_start..abs_end + self.var_end.len(), value);
77                        pos = abs_start + value.len();
78                    } else if self.allow_missing {
79                        result.replace_range(abs_start..abs_end + self.var_end.len(), "");
80                        pos = abs_start;
81                    } else {
82                        return Err(TemplateError::MissingVariable {
83                            variable: var_name.to_string(),
84                        }.into());
85                    }
86                } else {
87                    return Err(TemplateError::InvalidSyntax {
88                        message: format!("Unclosed variable at position {}", abs_start),
89                    }.into());
90                }
91            } else {
92                break;
93            }
94        }
95
96        Ok(result)
97    }
98
99    /// Extract all variable names from a template
100    pub fn extract_variables(&self, template: &str) -> Vec<String> {
101        let mut variables = Vec::new();
102        let mut pos = 0;
103
104        while pos < template.len() {
105            if let Some(start) = template[pos..].find(&self.var_start) {
106                let abs_start = pos + start;
107                let search_from = abs_start + self.var_start.len();
108                
109                if let Some(end) = template[search_from..].find(&self.var_end) {
110                    let abs_end = search_from + end;
111                    let var_name = &template[abs_start + self.var_start.len()..abs_end];
112                    
113                    if !var_name.is_empty() && !variables.contains(&var_name.to_string()) {
114                        variables.push(var_name.to_string());
115                    }
116                    pos = abs_end + self.var_end.len();
117                } else {
118                    break;
119                }
120            } else {
121                break;
122            }
123        }
124
125        variables
126    }
127
128    /// Validate that all required variables are present
129    pub fn validate(&self, template: &str, variables: &HashMap<String, String>) -> LangExtractResult<()> {
130        if self.allow_missing {
131            return Ok(());
132        }
133
134        let required = self.extract_variables(template);
135        for var in required {
136            if !variables.contains_key(&var) {
137                return Err(TemplateError::MissingVariable { variable: var }.into());
138            }
139        }
140        Ok(())
141    }
142}
143
144/// Common template fragments for reuse
145pub struct TemplateFragments;
146
147impl TemplateFragments {
148    /// Standard instruction prefix
149    pub fn instruction_prefix() -> &'static str {
150        "You are an expert information extraction assistant. "
151    }
152
153    /// JSON format instruction
154    pub fn json_format_instruction() -> &'static str {
155        "Respond with valid JSON that matches the structure shown in the examples."
156    }
157
158    /// YAML format instruction
159    pub fn yaml_format_instruction() -> &'static str {
160        "Respond with valid YAML that matches the structure shown in the examples."
161    }
162
163    /// Reasoning instruction for local models
164    pub fn reasoning_instruction() -> &'static str {
165        "\n\nThink step by step:\n1. Read the text carefully\n2. Identify the requested information\n3. Extract it in the exact format shown in examples"
166    }
167
168    /// Example section header
169    pub fn examples_header() -> &'static str {
170        "\n\nExamples:\n"
171    }
172
173    /// Input section header  
174    pub fn input_header() -> &'static str {
175        "\n\nNow extract information from this text:\n\nInput: "
176    }
177
178    /// Output section header
179    pub fn output_header(format: FormatType) -> String {
180        match format {
181            FormatType::Json => "\n\nOutput (JSON format):".to_string(),
182            FormatType::Yaml => "\n\nOutput (YAML format):".to_string(),
183        }
184    }
185}
186
187/// Example formatter that handles different output formats consistently
188pub struct ExampleFormatter {
189    format_type: FormatType,
190    max_examples: Option<usize>,
191}
192
193impl ExampleFormatter {
194    pub fn new(format_type: FormatType) -> Self {
195        Self {
196            format_type,
197            max_examples: None,
198        }
199    }
200
201    pub fn with_max_examples(mut self, max: usize) -> Self {
202        self.max_examples = Some(max);
203        self
204    }
205
206    /// Format examples for inclusion in prompts
207    pub fn format_examples(&self, examples: &[ExampleData]) -> LangExtractResult<String> {
208        if examples.is_empty() {
209            return Ok(String::new());
210        }
211
212        let examples_to_use = if let Some(max) = self.max_examples {
213            &examples[..examples.len().min(max)]
214        } else {
215            examples
216        };
217
218        let mut result = String::new();
219        result.push_str(TemplateFragments::examples_header());
220
221        for (i, example) in examples_to_use.iter().enumerate() {
222            result.push_str(&format!("\nExample {}:\n", i + 1));
223            result.push_str(&format!("Input: {}\n", example.text));
224            result.push_str("Output: ");
225            result.push_str(&self.format_single_example(example)?);
226            result.push('\n');
227        }
228
229        Ok(result)
230    }
231
232    /// Format a single example in the specified format
233    fn format_single_example(&self, example: &ExampleData) -> LangExtractResult<String> {
234        match self.format_type {
235            FormatType::Json => self.format_as_json(example),
236            FormatType::Yaml => self.format_as_yaml(example),
237        }
238    }
239
240    fn format_as_json(&self, example: &ExampleData) -> LangExtractResult<String> {
241        let mut json_obj = serde_json::Map::new();
242        
243        for extraction in &example.extractions {
244            json_obj.insert(
245                extraction.extraction_class.clone(),
246                serde_json::Value::String(extraction.extraction_text.clone()),
247            );
248        }
249
250        serde_json::to_string_pretty(&json_obj)
251            .map_err(|e| TemplateError::SubstitutionError {
252                message: format!("Failed to format JSON: {}", e),
253            }.into())
254    }
255
256    fn format_as_yaml(&self, example: &ExampleData) -> LangExtractResult<String> {
257        let mut yaml_map = std::collections::BTreeMap::new();
258        
259        for extraction in &example.extractions {
260            yaml_map.insert(
261                extraction.extraction_class.clone(),
262                extraction.extraction_text.clone(),
263            );
264        }
265
266        serde_yaml::to_string(&yaml_map)
267            .map_err(|e| TemplateError::SubstitutionError {
268                message: format!("Failed to format YAML: {}", e),
269            }.into())
270    }
271}
272
273/// Template builder for creating common prompt templates
274pub struct TemplateBuilder {
275    instruction: String,
276    format_instruction: String,
277    reasoning: String,
278    examples_section: String,
279    context_section: String,
280    input_section: String,
281    _output_section: String,
282    engine: TemplateEngine,
283}
284
285impl TemplateBuilder {
286    pub fn new(format_type: FormatType) -> Self {
287        let format_instruction = match format_type {
288            FormatType::Json => TemplateFragments::json_format_instruction(),
289            FormatType::Yaml => TemplateFragments::yaml_format_instruction(),
290        };
291
292        Self {
293            instruction: TemplateFragments::instruction_prefix().to_string(),
294            format_instruction: format_instruction.to_string(),
295            reasoning: String::new(),
296            examples_section: "{examples}".to_string(),
297            context_section: "{additional_context}".to_string(),
298            input_section: format!("{}{}{}", 
299                TemplateFragments::input_header(),
300                "{input_text}",
301                TemplateFragments::output_header(format_type)
302            ),
303            _output_section: String::new(),
304            engine: TemplateEngine::lenient(),
305        }
306    }
307
308    pub fn with_instruction(mut self, instruction: &str) -> Self {
309        self.instruction = instruction.to_string();
310        self
311    }
312
313    pub fn with_reasoning(mut self, include: bool) -> Self {
314        if include {
315            self.reasoning = TemplateFragments::reasoning_instruction().to_string();
316        } else {
317            self.reasoning.clear();
318        }
319        self
320    }
321
322    pub fn with_custom_examples_section(mut self, section: &str) -> Self {
323        self.examples_section = section.to_string();
324        self
325    }
326
327    pub fn build(&self) -> String {
328        format!(
329            "{{task_description}}\n\n{}{}{}{}{}{}\n",
330            self.instruction,
331            self.format_instruction,
332            self.context_section,
333            self.examples_section,
334            self.reasoning,
335            self.input_section,
336        )
337    }
338
339    pub fn build_with_variables(self, variables: HashMap<String, String>) -> LangExtractResult<String> {
340        let template = self.build();
341        self.engine.render(&template, &variables)
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use crate::data::Extraction;
349
350    #[test]
351    fn test_template_engine_basic() {
352        let engine = TemplateEngine::new();
353        let template = "Hello {name}, welcome to {place}!";
354        
355        let mut vars = HashMap::new();
356        vars.insert("name".to_string(), "John".to_string());
357        vars.insert("place".to_string(), "LangExtract".to_string());
358
359        let result = engine.render(template, &vars).unwrap();
360        assert_eq!(result, "Hello John, welcome to LangExtract!");
361    }
362
363    #[test]
364    fn test_template_engine_missing_var() {
365        let engine = TemplateEngine::new();
366        let template = "Hello {name}, welcome to {place}!";
367        
368        let mut vars = HashMap::new();
369        vars.insert("name".to_string(), "John".to_string());
370        // Missing "place" variable
371
372        let result = engine.render(template, &vars);
373        assert!(result.is_err());
374    }
375
376    #[test]
377    fn test_template_engine_lenient() {
378        let engine = TemplateEngine::lenient();
379        let template = "Hello {name}, welcome to {place}!";
380        
381        let mut vars = HashMap::new();
382        vars.insert("name".to_string(), "John".to_string());
383        // Missing "place" variable
384
385        let result = engine.render(template, &vars).unwrap();
386        assert_eq!(result, "Hello John, welcome to !");
387    }
388
389    #[test]
390    fn test_variable_extraction() {
391        let engine = TemplateEngine::new();
392        let template = "Hello {name}, welcome to {place}! Your ID is {id}.";
393        
394        let vars = engine.extract_variables(template);
395        assert_eq!(vars, vec!["name", "place", "id"]);
396    }
397
398    #[test]
399    fn test_example_formatter_json() {
400        let formatter = ExampleFormatter::new(FormatType::Json);
401        
402        let example = ExampleData::new(
403            "John Doe is 30 years old".to_string(),
404            vec![
405                Extraction::new("person".to_string(), "John Doe".to_string()),
406                Extraction::new("age".to_string(), "30".to_string()),
407            ],
408        );
409
410        let result = formatter.format_examples(&[example]).unwrap();
411        assert!(result.contains("Examples:"));
412        assert!(result.contains("John Doe"));
413        assert!(result.contains("person"));
414        assert!(result.contains("age"));
415    }
416
417    #[test]
418    fn test_template_builder() {
419        let template = TemplateBuilder::new(FormatType::Json)
420            .with_reasoning(true)
421            .build();
422
423        assert!(template.contains("You are an expert"));
424        assert!(template.contains("JSON"));
425        assert!(template.contains("Think step by step"));
426        assert!(template.contains("{task_description}"));
427        assert!(template.contains("{examples}"));
428        assert!(template.contains("{input_text}"));
429    }
430
431    #[test]
432    fn test_template_builder_with_variables() {
433        let mut vars = HashMap::new();
434        vars.insert("task_description".to_string(), "Extract names".to_string());
435        vars.insert("examples".to_string(), "Example: John -> person: John".to_string());
436        vars.insert("input_text".to_string(), "Alice Smith".to_string());
437        vars.insert("additional_context".to_string(), "".to_string());
438
439        let result = TemplateBuilder::new(FormatType::Json)
440            .build_with_variables(vars)
441            .unwrap();
442
443        assert!(result.contains("Extract names"));
444        assert!(result.contains("Alice Smith"));
445        assert!(result.contains("Example: John"));
446    }
447}