1use crate::{data::{ExampleData, FormatType}, exceptions::{LangExtractError, LangExtractResult}};
7use std::collections::HashMap;
8
9#[derive(Debug, thiserror::Error)]
11pub enum TemplateError {
12 #[error("Missing required variable: {variable}")]
13 MissingVariable { variable: String },
14 #[error("Invalid template syntax: {message}")]
15 InvalidSyntax { message: String },
16 #[error("Variable substitution failed: {message}")]
17 SubstitutionError { message: String },
18}
19
20impl From<TemplateError> for LangExtractError {
21 fn from(err: TemplateError) -> Self {
22 LangExtractError::InvalidInput(err.to_string())
23 }
24}
25
26#[derive(Debug, Clone)]
28pub struct TemplateEngine {
29 pub var_start: String,
31 pub var_end: String,
33 pub allow_missing: bool,
35}
36
37impl Default for TemplateEngine {
38 fn default() -> Self {
39 Self {
40 var_start: "{".to_string(),
41 var_end: "}".to_string(),
42 allow_missing: false,
43 }
44 }
45}
46
47impl TemplateEngine {
48 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn lenient() -> Self {
55 Self {
56 allow_missing: true,
57 ..Default::default()
58 }
59 }
60
61 pub fn render(&self, template: &str, variables: &HashMap<String, String>) -> LangExtractResult<String> {
63 let mut result = template.to_string();
64 let mut pos = 0;
65
66 while pos < result.len() {
67 if let Some(start) = result[pos..].find(&self.var_start) {
68 let abs_start = pos + start;
69 let search_from = abs_start + self.var_start.len();
70
71 if let Some(end) = result[search_from..].find(&self.var_end) {
72 let abs_end = search_from + end;
73 let var_name = &result[abs_start + self.var_start.len()..abs_end];
74
75 if let Some(value) = variables.get(var_name) {
76 result.replace_range(abs_start..abs_end + self.var_end.len(), value);
77 pos = abs_start + value.len();
78 } else if self.allow_missing {
79 result.replace_range(abs_start..abs_end + self.var_end.len(), "");
80 pos = abs_start;
81 } else {
82 return Err(TemplateError::MissingVariable {
83 variable: var_name.to_string(),
84 }.into());
85 }
86 } else {
87 return Err(TemplateError::InvalidSyntax {
88 message: format!("Unclosed variable at position {}", abs_start),
89 }.into());
90 }
91 } else {
92 break;
93 }
94 }
95
96 Ok(result)
97 }
98
99 pub fn extract_variables(&self, template: &str) -> Vec<String> {
101 let mut variables = Vec::new();
102 let mut pos = 0;
103
104 while pos < template.len() {
105 if let Some(start) = template[pos..].find(&self.var_start) {
106 let abs_start = pos + start;
107 let search_from = abs_start + self.var_start.len();
108
109 if let Some(end) = template[search_from..].find(&self.var_end) {
110 let abs_end = search_from + end;
111 let var_name = &template[abs_start + self.var_start.len()..abs_end];
112
113 if !var_name.is_empty() && !variables.contains(&var_name.to_string()) {
114 variables.push(var_name.to_string());
115 }
116 pos = abs_end + self.var_end.len();
117 } else {
118 break;
119 }
120 } else {
121 break;
122 }
123 }
124
125 variables
126 }
127
128 pub fn validate(&self, template: &str, variables: &HashMap<String, String>) -> LangExtractResult<()> {
130 if self.allow_missing {
131 return Ok(());
132 }
133
134 let required = self.extract_variables(template);
135 for var in required {
136 if !variables.contains_key(&var) {
137 return Err(TemplateError::MissingVariable { variable: var }.into());
138 }
139 }
140 Ok(())
141 }
142}
143
144pub struct TemplateFragments;
146
147impl TemplateFragments {
148 pub fn instruction_prefix() -> &'static str {
150 "You are an expert information extraction assistant. "
151 }
152
153 pub fn json_format_instruction() -> &'static str {
155 "Respond with valid JSON that matches the structure shown in the examples."
156 }
157
158 pub fn yaml_format_instruction() -> &'static str {
160 "Respond with valid YAML that matches the structure shown in the examples."
161 }
162
163 pub fn reasoning_instruction() -> &'static str {
165 "\n\nThink step by step:\n1. Read the text carefully\n2. Identify the requested information\n3. Extract it in the exact format shown in examples"
166 }
167
168 pub fn examples_header() -> &'static str {
170 "\n\nExamples:\n"
171 }
172
173 pub fn input_header() -> &'static str {
175 "\n\nNow extract information from this text:\n\nInput: "
176 }
177
178 pub fn output_header(format: FormatType) -> String {
180 match format {
181 FormatType::Json => "\n\nOutput (JSON format):".to_string(),
182 FormatType::Yaml => "\n\nOutput (YAML format):".to_string(),
183 }
184 }
185}
186
187pub struct ExampleFormatter {
189 format_type: FormatType,
190 max_examples: Option<usize>,
191}
192
193impl ExampleFormatter {
194 pub fn new(format_type: FormatType) -> Self {
195 Self {
196 format_type,
197 max_examples: None,
198 }
199 }
200
201 pub fn with_max_examples(mut self, max: usize) -> Self {
202 self.max_examples = Some(max);
203 self
204 }
205
206 pub fn format_examples(&self, examples: &[ExampleData]) -> LangExtractResult<String> {
208 if examples.is_empty() {
209 return Ok(String::new());
210 }
211
212 let examples_to_use = if let Some(max) = self.max_examples {
213 &examples[..examples.len().min(max)]
214 } else {
215 examples
216 };
217
218 let mut result = String::new();
219 result.push_str(TemplateFragments::examples_header());
220
221 for (i, example) in examples_to_use.iter().enumerate() {
222 result.push_str(&format!("\nExample {}:\n", i + 1));
223 result.push_str(&format!("Input: {}\n", example.text));
224 result.push_str("Output: ");
225 result.push_str(&self.format_single_example(example)?);
226 result.push('\n');
227 }
228
229 Ok(result)
230 }
231
232 fn format_single_example(&self, example: &ExampleData) -> LangExtractResult<String> {
234 match self.format_type {
235 FormatType::Json => self.format_as_json(example),
236 FormatType::Yaml => self.format_as_yaml(example),
237 }
238 }
239
240 fn format_as_json(&self, example: &ExampleData) -> LangExtractResult<String> {
241 let mut json_obj = serde_json::Map::new();
242
243 for extraction in &example.extractions {
244 json_obj.insert(
245 extraction.extraction_class.clone(),
246 serde_json::Value::String(extraction.extraction_text.clone()),
247 );
248 }
249
250 serde_json::to_string_pretty(&json_obj)
251 .map_err(|e| TemplateError::SubstitutionError {
252 message: format!("Failed to format JSON: {}", e),
253 }.into())
254 }
255
256 fn format_as_yaml(&self, example: &ExampleData) -> LangExtractResult<String> {
257 let mut yaml_map = std::collections::BTreeMap::new();
258
259 for extraction in &example.extractions {
260 yaml_map.insert(
261 extraction.extraction_class.clone(),
262 extraction.extraction_text.clone(),
263 );
264 }
265
266 serde_yaml::to_string(&yaml_map)
267 .map_err(|e| TemplateError::SubstitutionError {
268 message: format!("Failed to format YAML: {}", e),
269 }.into())
270 }
271}
272
273pub struct TemplateBuilder {
275 instruction: String,
276 format_instruction: String,
277 reasoning: String,
278 examples_section: String,
279 context_section: String,
280 input_section: String,
281 _output_section: String,
282 engine: TemplateEngine,
283}
284
285impl TemplateBuilder {
286 pub fn new(format_type: FormatType) -> Self {
287 let format_instruction = match format_type {
288 FormatType::Json => TemplateFragments::json_format_instruction(),
289 FormatType::Yaml => TemplateFragments::yaml_format_instruction(),
290 };
291
292 Self {
293 instruction: TemplateFragments::instruction_prefix().to_string(),
294 format_instruction: format_instruction.to_string(),
295 reasoning: String::new(),
296 examples_section: "{examples}".to_string(),
297 context_section: "{additional_context}".to_string(),
298 input_section: format!("{}{}{}",
299 TemplateFragments::input_header(),
300 "{input_text}",
301 TemplateFragments::output_header(format_type)
302 ),
303 _output_section: String::new(),
304 engine: TemplateEngine::lenient(),
305 }
306 }
307
308 pub fn with_instruction(mut self, instruction: &str) -> Self {
309 self.instruction = instruction.to_string();
310 self
311 }
312
313 pub fn with_reasoning(mut self, include: bool) -> Self {
314 if include {
315 self.reasoning = TemplateFragments::reasoning_instruction().to_string();
316 } else {
317 self.reasoning.clear();
318 }
319 self
320 }
321
322 pub fn with_custom_examples_section(mut self, section: &str) -> Self {
323 self.examples_section = section.to_string();
324 self
325 }
326
327 pub fn build(&self) -> String {
328 format!(
329 "{{task_description}}\n\n{}{}{}{}{}{}\n",
330 self.instruction,
331 self.format_instruction,
332 self.context_section,
333 self.examples_section,
334 self.reasoning,
335 self.input_section,
336 )
337 }
338
339 pub fn build_with_variables(self, variables: HashMap<String, String>) -> LangExtractResult<String> {
340 let template = self.build();
341 self.engine.render(&template, &variables)
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::data::Extraction;
349
350 #[test]
351 fn test_template_engine_basic() {
352 let engine = TemplateEngine::new();
353 let template = "Hello {name}, welcome to {place}!";
354
355 let mut vars = HashMap::new();
356 vars.insert("name".to_string(), "John".to_string());
357 vars.insert("place".to_string(), "LangExtract".to_string());
358
359 let result = engine.render(template, &vars).unwrap();
360 assert_eq!(result, "Hello John, welcome to LangExtract!");
361 }
362
363 #[test]
364 fn test_template_engine_missing_var() {
365 let engine = TemplateEngine::new();
366 let template = "Hello {name}, welcome to {place}!";
367
368 let mut vars = HashMap::new();
369 vars.insert("name".to_string(), "John".to_string());
370 let result = engine.render(template, &vars);
373 assert!(result.is_err());
374 }
375
376 #[test]
377 fn test_template_engine_lenient() {
378 let engine = TemplateEngine::lenient();
379 let template = "Hello {name}, welcome to {place}!";
380
381 let mut vars = HashMap::new();
382 vars.insert("name".to_string(), "John".to_string());
383 let result = engine.render(template, &vars).unwrap();
386 assert_eq!(result, "Hello John, welcome to !");
387 }
388
389 #[test]
390 fn test_variable_extraction() {
391 let engine = TemplateEngine::new();
392 let template = "Hello {name}, welcome to {place}! Your ID is {id}.";
393
394 let vars = engine.extract_variables(template);
395 assert_eq!(vars, vec!["name", "place", "id"]);
396 }
397
398 #[test]
399 fn test_example_formatter_json() {
400 let formatter = ExampleFormatter::new(FormatType::Json);
401
402 let example = ExampleData::new(
403 "John Doe is 30 years old".to_string(),
404 vec![
405 Extraction::new("person".to_string(), "John Doe".to_string()),
406 Extraction::new("age".to_string(), "30".to_string()),
407 ],
408 );
409
410 let result = formatter.format_examples(&[example]).unwrap();
411 assert!(result.contains("Examples:"));
412 assert!(result.contains("John Doe"));
413 assert!(result.contains("person"));
414 assert!(result.contains("age"));
415 }
416
417 #[test]
418 fn test_template_builder() {
419 let template = TemplateBuilder::new(FormatType::Json)
420 .with_reasoning(true)
421 .build();
422
423 assert!(template.contains("You are an expert"));
424 assert!(template.contains("JSON"));
425 assert!(template.contains("Think step by step"));
426 assert!(template.contains("{task_description}"));
427 assert!(template.contains("{examples}"));
428 assert!(template.contains("{input_text}"));
429 }
430
431 #[test]
432 fn test_template_builder_with_variables() {
433 let mut vars = HashMap::new();
434 vars.insert("task_description".to_string(), "Extract names".to_string());
435 vars.insert("examples".to_string(), "Example: John -> person: John".to_string());
436 vars.insert("input_text".to_string(), "Alice Smith".to_string());
437 vars.insert("additional_context".to_string(), "".to_string());
438
439 let result = TemplateBuilder::new(FormatType::Json)
440 .build_with_variables(vars)
441 .unwrap();
442
443 assert!(result.contains("Extract names"));
444 assert!(result.contains("Alice Smith"));
445 assert!(result.contains("Example: John"));
446 }
447}