1use crate::{
4 data::{ExampleData, FormatType},
5 exceptions::{LangExtractError, LangExtractResult},
6 providers::ProviderType,
7};
8use std::collections::HashMap;
9
10#[derive(Debug, thiserror::Error)]
12pub enum TemplateError {
13 #[error("Missing required variable: {variable}")]
14 MissingVariable { variable: String },
15 #[error("Invalid template syntax: {message}")]
16 InvalidSyntax { message: String },
17 #[error("Example formatting error: {message}")]
18 ExampleError { message: String },
19}
20
21impl From<TemplateError> for LangExtractError {
22 fn from(err: TemplateError) -> Self {
23 LangExtractError::InvalidInput(err.to_string())
24 }
25}
26
27#[derive(Debug, Clone)]
29pub struct PromptContext {
30 pub task_description: String,
32 pub examples: Vec<ExampleData>,
34 pub input_text: String,
36 pub additional_context: Option<String>,
38 pub schema_hint: Option<String>,
40 pub variables: HashMap<String, String>,
42}
43
44impl PromptContext {
45 pub fn new(task_description: String, input_text: String) -> Self {
47 Self {
48 task_description,
49 input_text,
50 examples: Vec::new(),
51 additional_context: None,
52 schema_hint: None,
53 variables: HashMap::new(),
54 }
55 }
56
57 pub fn with_examples(mut self, examples: Vec<ExampleData>) -> Self {
59 self.examples = examples;
60 self
61 }
62
63 pub fn with_context(mut self, context: String) -> Self {
65 self.additional_context = Some(context);
66 self
67 }
68
69 pub fn with_variable(mut self, key: String, value: String) -> Self {
71 self.variables.insert(key, value);
72 self
73 }
74
75 pub fn with_schema_hint(mut self, hint: String) -> Self {
77 self.schema_hint = Some(hint);
78 self
79 }
80}
81
82pub trait TemplateRenderer {
84 fn render(&self, context: &PromptContext) -> LangExtractResult<String>;
86
87 fn validate(&self) -> LangExtractResult<()>;
89
90 fn required_variables(&self) -> Vec<String>;
92}
93
94#[derive(Debug, Clone)]
96pub struct PromptTemplate {
97 pub base_template: String,
99 pub system_message: Option<String>,
101 pub example_template: String,
103 pub format_type: FormatType,
105 pub provider_type: ProviderType,
107 pub max_examples: Option<usize>,
109 pub include_reasoning: bool,
111}
112
113impl PromptTemplate {
114 pub fn new(format_type: FormatType, provider_type: ProviderType) -> Self {
116 let base_template = Self::default_base_template(format_type, provider_type);
117 let example_template = Self::default_example_template(format_type);
118
119 Self {
120 base_template,
121 system_message: None,
122 example_template,
123 format_type,
124 provider_type,
125 max_examples: Some(5),
126 include_reasoning: false,
127 }
128 }
129
130 pub fn for_provider(provider_type: ProviderType, format_type: FormatType) -> Self {
132 let mut template = Self::new(format_type, provider_type);
133
134 match provider_type {
135 ProviderType::OpenAI => {
136 template.system_message = Some(
137 "You are an expert information extraction assistant. Extract structured information exactly as shown in the examples.".to_string()
138 );
139 template.include_reasoning = false; }
141 ProviderType::Ollama => {
142 template.include_reasoning = true; template.max_examples = Some(3); }
145 ProviderType::Custom => {
146 template.max_examples = Some(3);
148 template.include_reasoning = true;
149 }
150 }
151
152 template
153 }
154
155 pub fn with_max_examples(mut self, max: usize) -> Self {
157 self.max_examples = Some(max);
158 self
159 }
160
161 pub fn with_system_message(mut self, message: String) -> Self {
163 self.system_message = Some(message);
164 self
165 }
166
167 pub fn with_reasoning(mut self, enable: bool) -> Self {
169 self.include_reasoning = enable;
170 self
171 }
172
173 pub fn with_base_template(mut self, template: String) -> Self {
175 self.base_template = template;
176 self
177 }
178
179 fn default_base_template(format_type: FormatType, provider_type: ProviderType) -> String {
181 let format_name = match format_type {
182 FormatType::Json => "JSON",
183 FormatType::Yaml => "YAML",
184 };
185
186 let reasoning_instruction = match provider_type {
187 ProviderType::Ollama | ProviderType::Custom =>
188 "\n\nThink step by step:\n1. Read the text carefully\n2. Identify the requested information\n3. Extract it in the exact format shown in examples\n",
189 ProviderType::OpenAI => "",
190 };
191
192 format!(
193 "{{task_description}}{{additional_context}}{{examples}}{reasoning}\nNow extract information from this text:\n\nInput: {{input_text}}\n\nOutput ({format_name} format):",
194 reasoning = reasoning_instruction
195 )
196 }
197
198 fn default_example_template(format_type: FormatType) -> String {
200 match format_type {
201 FormatType::Json => {
202 "Input: {input}\nOutput: {output_json}\n".to_string()
203 }
204 FormatType::Yaml => {
205 "Input: {input}\nOutput:\n{output_yaml}\n".to_string()
206 }
207 }
208 }
209
210 fn format_examples(&self, examples: &[ExampleData]) -> LangExtractResult<String> {
212 if examples.is_empty() {
213 return Ok(String::new());
214 }
215
216 let mut formatted = String::from("\n\nExamples:\n\n");
217
218 let examples_to_use = if let Some(max) = self.max_examples {
220 &examples[..examples.len().min(max)]
221 } else {
222 examples
223 };
224
225 for (i, example) in examples_to_use.iter().enumerate() {
226 formatted.push_str(&format!("Example {}:\n", i + 1));
227
228 let output_formatted = match self.format_type {
229 FormatType::Json => self.format_example_as_json(example)?,
230 FormatType::Yaml => self.format_example_as_yaml(example)?,
231 };
232
233 let example_text = self.example_template
234 .replace("{input}", &example.text)
235 .replace("{output_json}", &output_formatted)
236 .replace("{output_yaml}", &output_formatted);
237
238 formatted.push_str(&example_text);
239 formatted.push('\n');
240 }
241
242 Ok(formatted)
243 }
244
245 fn format_example_as_json(&self, example: &ExampleData) -> LangExtractResult<String> {
247 let mut json_obj = serde_json::Map::new();
248
249 for extraction in &example.extractions {
250 json_obj.insert(
251 extraction.extraction_class.clone(),
252 serde_json::Value::String(extraction.extraction_text.clone()),
253 );
254 }
255
256 let json_value = serde_json::Value::Object(json_obj);
257 serde_json::to_string_pretty(&json_value)
258 .map_err(|e| TemplateError::ExampleError {
259 message: format!("Failed to format JSON: {}", e)
260 }.into())
261 }
262
263 fn format_example_as_yaml(&self, example: &ExampleData) -> LangExtractResult<String> {
265 let mut yaml_map = std::collections::BTreeMap::new();
266
267 for extraction in &example.extractions {
268 yaml_map.insert(
269 extraction.extraction_class.clone(),
270 extraction.extraction_text.clone(),
271 );
272 }
273
274 serde_yaml::to_string(&yaml_map)
275 .map_err(|e| TemplateError::ExampleError {
276 message: format!("Failed to format YAML: {}", e)
277 }.into())
278 }
279
280 fn substitute_variables(&self, template: &str, context: &PromptContext) -> LangExtractResult<String> {
282 let mut result = template.to_string();
283
284 result = result.replace("{task_description}", &context.task_description);
286 result = result.replace("{input_text}", &context.input_text);
287
288 if let Some(context_text) = &context.additional_context {
290 result = result.replace("{additional_context}", &format!("\n\nAdditional Context: {}\n", context_text));
291 } else {
292 result = result.replace("{additional_context}", "");
293 }
294
295 let examples_text = self.format_examples(&context.examples)?;
297 result = result.replace("{examples}", &examples_text);
298
299 if self.include_reasoning {
301 result = result.replace("{reasoning}", "\n\nPlease think through this step by step before providing your answer.");
302 } else {
303 result = result.replace("{reasoning}", "");
304 }
305
306 if let Some(hint) = &context.schema_hint {
308 result = result.replace("{schema_hint}", &format!("\n\nSchema guidance: {}\n", hint));
309 } else {
310 result = result.replace("{schema_hint}", "");
311 }
312
313 for (key, value) in &context.variables {
315 result = result.replace(&format!("{{{}}}", key), value);
316 }
317
318 Ok(result)
322 }
323}
324
325impl TemplateRenderer for PromptTemplate {
326 fn render(&self, context: &PromptContext) -> LangExtractResult<String> {
327 self.substitute_variables(&self.base_template, context)
328 }
329
330 fn validate(&self) -> LangExtractResult<()> {
331 if !self.base_template.contains("{task_description}") {
333 return Err(TemplateError::InvalidSyntax {
334 message: "Base template must contain {task_description} placeholder".to_string()
335 }.into());
336 }
337
338 if !self.base_template.contains("{input_text}") {
339 return Err(TemplateError::InvalidSyntax {
340 message: "Base template must contain {input_text} placeholder".to_string()
341 }.into());
342 }
343
344 Ok(())
345 }
346
347 fn required_variables(&self) -> Vec<String> {
348 let mut vars = vec!["task_description".to_string(), "input_text".to_string()];
349
350 let mut i = 0;
352 while i < self.base_template.len() {
353 if let Some(start) = self.base_template[i..].find('{') {
354 let start = start + i;
355 if let Some(end) = self.base_template[start..].find('}') {
356 let end = end + start;
357 let var_name = &self.base_template[start+1..end];
358 if !var_name.is_empty() && !vars.contains(&var_name.to_string()) {
359 vars.push(var_name.to_string());
360 }
361 i = end + 1;
362 } else {
363 break;
364 }
365 } else {
366 break;
367 }
368 }
369
370 vars
371 }
372}
373
374#[derive(Debug, Clone)]
376pub struct PromptTemplateStructured {
377 pub description: Option<String>,
379 pub examples: Vec<ExampleData>,
381 template: PromptTemplate,
383}
384
385impl PromptTemplateStructured {
386 pub fn new(description: Option<&str>) -> Self {
388 Self {
389 description: description.map(|s| s.to_string()),
390 examples: Vec::new(),
391 template: PromptTemplate::new(FormatType::Json, ProviderType::Ollama),
392 }
393 }
394
395 pub fn with_format_and_provider(
397 description: Option<&str>,
398 format_type: FormatType,
399 provider_type: ProviderType,
400 ) -> Self {
401 Self {
402 description: description.map(|s| s.to_string()),
403 examples: Vec::new(),
404 template: PromptTemplate::for_provider(provider_type, format_type),
405 }
406 }
407
408 pub fn render(&self, input_text: &str, additional_context: Option<&str>) -> LangExtractResult<String> {
410 let mut context = PromptContext::new(
411 self.description.clone().unwrap_or_default(),
412 input_text.to_string(),
413 );
414
415 context.examples = self.examples.clone();
416
417 if let Some(ctx) = additional_context {
418 context.additional_context = Some(ctx.to_string());
419 }
420
421 self.template.render(&context)
422 }
423
424 pub fn template(&self) -> &PromptTemplate {
426 &self.template
427 }
428
429 pub fn template_mut(&mut self) -> &mut PromptTemplate {
431 &mut self.template
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::data::Extraction;
439
440 #[test]
441 fn test_prompt_context_creation() {
442 let context = PromptContext::new(
443 "Extract names".to_string(),
444 "John is here".to_string(),
445 )
446 .with_context("Additional info".to_string())
447 .with_variable("custom".to_string(), "value".to_string())
448 .with_schema_hint("Use proper format".to_string());
449
450 assert_eq!(context.task_description, "Extract names");
451 assert_eq!(context.input_text, "John is here");
452 assert_eq!(context.additional_context, Some("Additional info".to_string()));
453 assert_eq!(context.variables.get("custom"), Some(&"value".to_string()));
454 assert_eq!(context.schema_hint, Some("Use proper format".to_string()));
455 }
456
457 #[test]
458 fn test_template_validation() {
459 let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
460 assert!(template.validate().is_ok());
461
462 let mut invalid_template = template.clone();
463 invalid_template.base_template = "No required placeholders".to_string();
464 assert!(invalid_template.validate().is_err());
465 }
466
467 #[test]
468 fn test_required_variables() {
469 let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
470 let vars = template.required_variables();
471
472 assert!(vars.contains(&"task_description".to_string()));
473 assert!(vars.contains(&"input_text".to_string()));
474 assert!(vars.contains(&"examples".to_string()));
475 }
476
477 #[test]
478 fn test_example_formatting_json() {
479 let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
480 let example = ExampleData::new(
481 "John is 30".to_string(),
482 vec![
483 Extraction::new("name".to_string(), "John".to_string()),
484 Extraction::new("age".to_string(), "30".to_string()),
485 ],
486 );
487
488 let formatted = template.format_example_as_json(&example).unwrap();
489 assert!(formatted.contains("\"name\": \"John\""));
490 assert!(formatted.contains("\"age\": \"30\""));
491 }
492
493 #[test]
494 fn test_template_rendering() {
495 let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
496 let context = PromptContext::new(
497 "Extract names and ages".to_string(),
498 "Alice is 25 years old".to_string(),
499 );
500
501 let rendered = template.render(&context).unwrap();
502
503 assert!(rendered.contains("Extract names and ages"));
504 assert!(rendered.contains("Alice is 25 years old"));
505 assert!(rendered.contains("JSON format"));
506 }
507
508 #[test]
509 fn test_provider_specific_templates() {
510 let openai_template = PromptTemplate::for_provider(ProviderType::OpenAI, FormatType::Json);
511 let ollama_template = PromptTemplate::for_provider(ProviderType::Ollama, FormatType::Json);
512
513 assert!(openai_template.system_message.is_some());
514 assert!(!openai_template.include_reasoning);
515
516 assert!(ollama_template.include_reasoning);
517 assert_eq!(ollama_template.max_examples, Some(3));
518 }
519
520 #[test]
521 fn test_backward_compatibility() {
522 let mut template = PromptTemplateStructured::new(Some("Extract info"));
523 template.examples.push(ExampleData::new(
524 "Test".to_string(),
525 vec![Extraction::new("test".to_string(), "value".to_string())],
526 ));
527
528 let rendered = template.render("Input text", None).unwrap();
529 assert!(rendered.contains("Extract info"));
530 assert!(rendered.contains("Input text"));
531 }
532}