1use crate::{
4 data::{ExampleData, FormatType},
5 exceptions::{LangExtractError, LangExtractResult},
6 providers::ProviderType,
7};
8use std::collections::HashMap;
9
10#[derive(Debug, thiserror::Error)]
12pub enum TemplateError {
13 #[error("Missing required variable: {variable}")]
14 MissingVariable { variable: String },
15 #[error("Invalid template syntax: {message}")]
16 InvalidSyntax { message: String },
17 #[error("Example formatting error: {message}")]
18 ExampleError { message: String },
19}
20
21impl From<TemplateError> for LangExtractError {
22 fn from(err: TemplateError) -> Self {
23 LangExtractError::InvalidInput(err.to_string())
24 }
25}
26
27#[derive(Debug, Clone)]
29pub struct PromptContext {
30 pub task_description: String,
32 pub examples: Vec<ExampleData>,
34 pub input_text: String,
36 pub additional_context: Option<String>,
38 pub schema_hint: Option<String>,
40 pub variables: HashMap<String, String>,
42}
43
44impl PromptContext {
45 pub fn new(task_description: String, input_text: String) -> Self {
47 Self {
48 task_description,
49 input_text,
50 examples: Vec::new(),
51 additional_context: None,
52 schema_hint: None,
53 variables: HashMap::new(),
54 }
55 }
56
57 pub fn with_examples(mut self, examples: Vec<ExampleData>) -> Self {
59 self.examples = examples;
60 self
61 }
62
63 pub fn with_context(mut self, context: String) -> Self {
65 self.additional_context = Some(context);
66 self
67 }
68
69 pub fn with_variable(mut self, key: String, value: String) -> Self {
71 self.variables.insert(key, value);
72 self
73 }
74
75 pub fn with_schema_hint(mut self, hint: String) -> Self {
77 self.schema_hint = Some(hint);
78 self
79 }
80}
81
82pub trait TemplateRenderer {
84 fn render(&self, context: &PromptContext) -> LangExtractResult<String>;
86
87 fn validate(&self) -> LangExtractResult<()>;
89
90 fn required_variables(&self) -> Vec<String>;
92}
93
94#[derive(Debug, Clone)]
96pub struct PromptTemplate {
97 pub base_template: String,
99 pub system_message: Option<String>,
101 pub example_template: String,
103 pub format_type: FormatType,
105 pub provider_type: ProviderType,
107 pub max_examples: Option<usize>,
109 pub include_reasoning: bool,
111}
112
113impl PromptTemplate {
114 pub fn new(format_type: FormatType, provider_type: ProviderType) -> Self {
116 let base_template = Self::default_base_template(format_type, provider_type);
117 let example_template = Self::default_example_template(format_type);
118
119 Self {
120 base_template,
121 system_message: None,
122 example_template,
123 format_type,
124 provider_type,
125 max_examples: Some(5),
126 include_reasoning: false,
127 }
128 }
129
130 pub fn for_provider(provider_type: ProviderType, format_type: FormatType) -> Self {
132 let mut template = Self::new(format_type, provider_type);
133
134 match provider_type {
135 ProviderType::OpenAI => {
136 template.system_message = Some(
137 "You are an expert information extraction assistant. Extract structured information exactly as shown in the examples.".to_string()
138 );
139 template.include_reasoning = false; }
141 ProviderType::Ollama => {
142 template.include_reasoning = true; template.max_examples = Some(3); }
145 ProviderType::Custom => {
146 template.max_examples = Some(3);
148 template.include_reasoning = true;
149 }
150 }
151
152 template
153 }
154
155 pub fn with_max_examples(mut self, max: usize) -> Self {
157 self.max_examples = Some(max);
158 self
159 }
160
161 pub fn with_system_message(mut self, message: String) -> Self {
163 self.system_message = Some(message);
164 self
165 }
166
167 pub fn with_reasoning(mut self, enable: bool) -> Self {
169 self.include_reasoning = enable;
170 self
171 }
172
173 pub fn with_base_template(mut self, template: String) -> Self {
175 self.base_template = template;
176 self
177 }
178
179 fn default_base_template(format_type: FormatType, provider_type: ProviderType) -> String {
181 use crate::templates::TemplateBuilder;
182
183 let include_reasoning = matches!(provider_type, ProviderType::Ollama | ProviderType::Custom);
184
185 TemplateBuilder::new(format_type)
186 .with_reasoning(include_reasoning)
187 .build()
188 }
189
190 fn default_example_template(format_type: FormatType) -> String {
192 match format_type {
193 FormatType::Json => {
194 "Input: {input}\nOutput: {output_json}\n".to_string()
195 }
196 FormatType::Yaml => {
197 "Input: {input}\nOutput:\n{output_yaml}\n".to_string()
198 }
199 }
200 }
201
202 fn format_examples(&self, examples: &[ExampleData]) -> LangExtractResult<String> {
204 use crate::templates::ExampleFormatter;
205
206 let formatter = if let Some(max) = self.max_examples {
207 ExampleFormatter::new(self.format_type).with_max_examples(max)
208 } else {
209 ExampleFormatter::new(self.format_type)
210 };
211
212 formatter.format_examples(examples)
213 }
214
215 fn substitute_variables(&self, template: &str, context: &PromptContext) -> LangExtractResult<String> {
220 use crate::templates::TemplateEngine;
221 use std::collections::HashMap;
222
223 let mut variables = HashMap::new();
224
225 variables.insert("task_description".to_string(), context.task_description.clone());
227 variables.insert("input_text".to_string(), context.input_text.clone());
228
229 if let Some(context_text) = &context.additional_context {
231 variables.insert("additional_context".to_string(),
232 format!("\n\nAdditional Context: {}\n", context_text));
233 } else {
234 variables.insert("additional_context".to_string(), String::new());
235 }
236
237 let examples_text = self.format_examples(&context.examples)?;
239 variables.insert("examples".to_string(), examples_text);
240
241 if self.include_reasoning {
243 variables.insert("reasoning".to_string(),
244 "\n\nPlease think through this step by step before providing your answer.".to_string());
245 } else {
246 variables.insert("reasoning".to_string(), String::new());
247 }
248
249 if let Some(hint) = &context.schema_hint {
251 variables.insert("schema_hint".to_string(),
252 format!("\n\nSchema guidance: {}\n", hint));
253 } else {
254 variables.insert("schema_hint".to_string(), String::new());
255 }
256
257 for (key, value) in &context.variables {
259 variables.insert(key.clone(), value.clone());
260 }
261
262 let engine = TemplateEngine::lenient();
264 engine.render(template, &variables)
265 }
266}
267
268impl TemplateRenderer for PromptTemplate {
269 fn render(&self, context: &PromptContext) -> LangExtractResult<String> {
270 self.substitute_variables(&self.base_template, context)
271 }
272
273 fn validate(&self) -> LangExtractResult<()> {
274 if !self.base_template.contains("{task_description}") {
276 return Err(TemplateError::InvalidSyntax {
277 message: "Base template must contain {task_description} placeholder".to_string()
278 }.into());
279 }
280
281 if !self.base_template.contains("{input_text}") {
282 return Err(TemplateError::InvalidSyntax {
283 message: "Base template must contain {input_text} placeholder".to_string()
284 }.into());
285 }
286
287 Ok(())
288 }
289
290 fn required_variables(&self) -> Vec<String> {
291 let mut vars = vec!["task_description".to_string(), "input_text".to_string()];
292
293 let mut i = 0;
295 while i < self.base_template.len() {
296 if let Some(start) = self.base_template[i..].find('{') {
297 let start = start + i;
298 if let Some(end) = self.base_template[start..].find('}') {
299 let end = end + start;
300 let var_name = &self.base_template[start+1..end];
301 if !var_name.is_empty() && !vars.contains(&var_name.to_string()) {
302 vars.push(var_name.to_string());
303 }
304 i = end + 1;
305 } else {
306 break;
307 }
308 } else {
309 break;
310 }
311 }
312
313 vars
314 }
315}
316
317#[derive(Debug, Clone)]
319pub struct PromptTemplateStructured {
320 pub description: Option<String>,
322 pub examples: Vec<ExampleData>,
324 template: PromptTemplate,
326}
327
328impl PromptTemplateStructured {
329 pub fn new(description: Option<&str>) -> Self {
331 Self {
332 description: description.map(|s| s.to_string()),
333 examples: Vec::new(),
334 template: PromptTemplate::new(FormatType::Json, ProviderType::Ollama),
335 }
336 }
337
338 pub fn with_format_and_provider(
340 description: Option<&str>,
341 format_type: FormatType,
342 provider_type: ProviderType,
343 ) -> Self {
344 Self {
345 description: description.map(|s| s.to_string()),
346 examples: Vec::new(),
347 template: PromptTemplate::for_provider(provider_type, format_type),
348 }
349 }
350
351 pub fn render(&self, input_text: &str, additional_context: Option<&str>) -> LangExtractResult<String> {
353 let mut context = PromptContext::new(
354 self.description.clone().unwrap_or_default(),
355 input_text.to_string(),
356 );
357
358 context.examples = self.examples.clone();
359
360 if let Some(ctx) = additional_context {
361 context.additional_context = Some(ctx.to_string());
362 }
363
364 self.template.render(&context)
365 }
366
367 pub fn template(&self) -> &PromptTemplate {
369 &self.template
370 }
371
372 pub fn template_mut(&mut self) -> &mut PromptTemplate {
374 &mut self.template
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::data::Extraction;
382
383 #[test]
384 fn test_prompt_context_creation() {
385 let context = PromptContext::new(
386 "Extract names".to_string(),
387 "John is here".to_string(),
388 )
389 .with_context("Additional info".to_string())
390 .with_variable("custom".to_string(), "value".to_string())
391 .with_schema_hint("Use proper format".to_string());
392
393 assert_eq!(context.task_description, "Extract names");
394 assert_eq!(context.input_text, "John is here");
395 assert_eq!(context.additional_context, Some("Additional info".to_string()));
396 assert_eq!(context.variables.get("custom"), Some(&"value".to_string()));
397 assert_eq!(context.schema_hint, Some("Use proper format".to_string()));
398 }
399
400 #[test]
401 fn test_template_validation() {
402 let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
403 assert!(template.validate().is_ok());
404
405 let mut invalid_template = template.clone();
406 invalid_template.base_template = "No required placeholders".to_string();
407 assert!(invalid_template.validate().is_err());
408 }
409
410 #[test]
411 fn test_required_variables() {
412 let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
413 let vars = template.required_variables();
414
415 assert!(vars.contains(&"task_description".to_string()));
416 assert!(vars.contains(&"input_text".to_string()));
417 assert!(vars.contains(&"examples".to_string()));
418 }
419
420 #[test]
421 fn test_example_formatting_json() {
422 let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
423 let example = ExampleData::new(
424 "John is 30".to_string(),
425 vec![
426 Extraction::new("name".to_string(), "John".to_string()),
427 Extraction::new("age".to_string(), "30".to_string()),
428 ],
429 );
430
431 let context = PromptContext::new("Extract information".to_string(), "Test input".to_string())
434 .with_examples(vec![example]);
435 let rendered = template.render(&context).unwrap();
436 assert!(rendered.contains("Extract information"));
437 assert!(rendered.contains("Test input"));
438 }
439
440 #[test]
441 fn test_template_rendering() {
442 let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
443 let context = PromptContext::new(
444 "Extract names and ages".to_string(),
445 "Alice is 25 years old".to_string(),
446 );
447
448 let rendered = template.render(&context).unwrap();
449
450 assert!(rendered.contains("Extract names and ages"));
451 assert!(rendered.contains("Alice is 25 years old"));
452 assert!(rendered.contains("JSON format"));
453 }
454
455 #[test]
456 fn test_provider_specific_templates() {
457 let openai_template = PromptTemplate::for_provider(ProviderType::OpenAI, FormatType::Json);
458 let ollama_template = PromptTemplate::for_provider(ProviderType::Ollama, FormatType::Json);
459
460 assert!(openai_template.system_message.is_some());
461 assert!(!openai_template.include_reasoning);
462
463 assert!(ollama_template.include_reasoning);
464 assert_eq!(ollama_template.max_examples, Some(3));
465 }
466
467 #[test]
468 fn test_backward_compatibility() {
469 let mut template = PromptTemplateStructured::new(Some("Extract info"));
470 template.examples.push(ExampleData::new(
471 "Test".to_string(),
472 vec![Extraction::new("test".to_string(), "value".to_string())],
473 ));
474
475 let rendered = template.render("Input text", None).unwrap();
476 assert!(rendered.contains("Extract info"));
477 assert!(rendered.contains("Input text"));
478 }
479}