langextract_rust/
prompting.rs

1//! Advanced prompt template system with dynamic variables and provider adaptation.
2
3use crate::{
4    data::{ExampleData, FormatType},
5    exceptions::{LangExtractError, LangExtractResult},
6    providers::ProviderType,
7};
8use std::collections::HashMap;
9
10/// Error types for template operations
11#[derive(Debug, thiserror::Error)]
12pub enum TemplateError {
13    #[error("Missing required variable: {variable}")]
14    MissingVariable { variable: String },
15    #[error("Invalid template syntax: {message}")]
16    InvalidSyntax { message: String },
17    #[error("Example formatting error: {message}")]
18    ExampleError { message: String },
19}
20
21impl From<TemplateError> for LangExtractError {
22    fn from(err: TemplateError) -> Self {
23        LangExtractError::InvalidInput(err.to_string())
24    }
25}
26
27/// Context for rendering prompts
28#[derive(Debug, Clone)]
29pub struct PromptContext {
30    /// Task description for what to extract
31    pub task_description: String,
32    /// Example data to guide extraction
33    pub examples: Vec<ExampleData>,
34    /// Input text to process
35    pub input_text: String,
36    /// Additional context information
37    pub additional_context: Option<String>,
38    /// Schema hint for structured output
39    pub schema_hint: Option<String>,
40    /// Custom variables for template substitution
41    pub variables: HashMap<String, String>,
42}
43
44impl PromptContext {
45    /// Create a new prompt context
46    pub fn new(task_description: String, input_text: String) -> Self {
47        Self {
48            task_description,
49            input_text,
50            examples: Vec::new(),
51            additional_context: None,
52            schema_hint: None,
53            variables: HashMap::new(),
54        }
55    }
56
57    /// Add examples to the context
58    pub fn with_examples(mut self, examples: Vec<ExampleData>) -> Self {
59        self.examples = examples;
60        self
61    }
62
63    /// Add additional context
64    pub fn with_context(mut self, context: String) -> Self {
65        self.additional_context = Some(context);
66        self
67    }
68
69    /// Add a custom variable
70    pub fn with_variable(mut self, key: String, value: String) -> Self {
71        self.variables.insert(key, value);
72        self
73    }
74
75    /// Add schema hint
76    pub fn with_schema_hint(mut self, hint: String) -> Self {
77        self.schema_hint = Some(hint);
78        self
79    }
80}
81
82/// Trait for rendering prompt templates
83pub trait TemplateRenderer {
84    /// Render the template with the given context
85    fn render(&self, context: &PromptContext) -> LangExtractResult<String>;
86    
87    /// Validate the template structure
88    fn validate(&self) -> LangExtractResult<()>;
89    
90    /// Get required variables for this template
91    fn required_variables(&self) -> Vec<String>;
92}
93
94/// Advanced prompt template with dynamic variables and provider adaptation
95#[derive(Debug, Clone)]
96pub struct PromptTemplate {
97    /// Base template string with variable placeholders
98    pub base_template: String,
99    /// System message for providers that support it
100    pub system_message: Option<String>,
101    /// Template for formatting examples
102    pub example_template: String,
103    /// Output format type
104    pub format_type: FormatType,
105    /// Target provider type for optimization
106    pub provider_type: ProviderType,
107    /// Maximum number of examples to include
108    pub max_examples: Option<usize>,
109    /// Whether to include reasoning instructions
110    pub include_reasoning: bool,
111}
112
113impl PromptTemplate {
114    /// Create a new prompt template
115    pub fn new(format_type: FormatType, provider_type: ProviderType) -> Self {
116        let base_template = Self::default_base_template(format_type, provider_type);
117        let example_template = Self::default_example_template(format_type);
118        
119        Self {
120            base_template,
121            system_message: None,
122            example_template,
123            format_type,
124            provider_type,
125            max_examples: Some(5),
126            include_reasoning: false,
127        }
128    }
129
130    /// Create template optimized for specific provider
131    pub fn for_provider(provider_type: ProviderType, format_type: FormatType) -> Self {
132        let mut template = Self::new(format_type, provider_type);
133        
134        match provider_type {
135            ProviderType::OpenAI => {
136                template.system_message = Some(
137                    "You are an expert information extraction assistant. Extract structured information exactly as shown in the examples.".to_string()
138                );
139                template.include_reasoning = false; // OpenAI is good with direct instructions
140            }
141            ProviderType::Ollama => {
142                template.include_reasoning = true; // Local models benefit from reasoning steps
143                template.max_examples = Some(3); // Keep prompts shorter for local models
144            }
145            ProviderType::Custom => {
146                // Conservative defaults for unknown providers
147                template.max_examples = Some(3);
148                template.include_reasoning = true;
149            }
150        }
151        
152        template
153    }
154
155    /// Set maximum number of examples
156    pub fn with_max_examples(mut self, max: usize) -> Self {
157        self.max_examples = Some(max);
158        self
159    }
160
161    /// Set system message
162    pub fn with_system_message(mut self, message: String) -> Self {
163        self.system_message = Some(message);
164        self
165    }
166
167    /// Enable or disable reasoning instructions
168    pub fn with_reasoning(mut self, enable: bool) -> Self {
169        self.include_reasoning = enable;
170        self
171    }
172
173    /// Set custom base template
174    pub fn with_base_template(mut self, template: String) -> Self {
175        self.base_template = template;
176        self
177    }
178
179    /// Default base template for different formats and providers
180    fn default_base_template(format_type: FormatType, provider_type: ProviderType) -> String {
181        let format_name = match format_type {
182            FormatType::Json => "JSON",
183            FormatType::Yaml => "YAML",
184        };
185
186        let reasoning_instruction = match provider_type {
187            ProviderType::Ollama | ProviderType::Custom => 
188                "\n\nThink step by step:\n1. Read the text carefully\n2. Identify the requested information\n3. Extract it in the exact format shown in examples\n",
189            ProviderType::OpenAI => "",
190        };
191
192        format!(
193            "{{task_description}}{{additional_context}}{{examples}}{reasoning}\nNow extract information from this text:\n\nInput: {{input_text}}\n\nOutput ({format_name} format):",
194            reasoning = reasoning_instruction
195        )
196    }
197
198    /// Default example template for different formats
199    fn default_example_template(format_type: FormatType) -> String {
200        match format_type {
201            FormatType::Json => {
202                "Input: {input}\nOutput: {output_json}\n".to_string()
203            }
204            FormatType::Yaml => {
205                "Input: {input}\nOutput:\n{output_yaml}\n".to_string()
206            }
207        }
208    }
209
210    /// Format examples according to the template
211    fn format_examples(&self, examples: &[ExampleData]) -> LangExtractResult<String> {
212        if examples.is_empty() {
213            return Ok(String::new());
214        }
215
216        let mut formatted = String::from("\n\nExamples:\n\n");
217        
218        // Limit examples if max_examples is set
219        let examples_to_use = if let Some(max) = self.max_examples {
220            &examples[..examples.len().min(max)]
221        } else {
222            examples
223        };
224
225        for (i, example) in examples_to_use.iter().enumerate() {
226            formatted.push_str(&format!("Example {}:\n", i + 1));
227            
228            let output_formatted = match self.format_type {
229                FormatType::Json => self.format_example_as_json(example)?,
230                FormatType::Yaml => self.format_example_as_yaml(example)?,
231            };
232
233            let example_text = self.example_template
234                .replace("{input}", &example.text)
235                .replace("{output_json}", &output_formatted)
236                .replace("{output_yaml}", &output_formatted);
237
238            formatted.push_str(&example_text);
239            formatted.push('\n');
240        }
241
242        Ok(formatted)
243    }
244
245    /// Format example as JSON
246    fn format_example_as_json(&self, example: &ExampleData) -> LangExtractResult<String> {
247        let mut json_obj = serde_json::Map::new();
248        
249        for extraction in &example.extractions {
250            json_obj.insert(
251                extraction.extraction_class.clone(),
252                serde_json::Value::String(extraction.extraction_text.clone()),
253            );
254        }
255
256        let json_value = serde_json::Value::Object(json_obj);
257        serde_json::to_string_pretty(&json_value)
258            .map_err(|e| TemplateError::ExampleError { 
259                message: format!("Failed to format JSON: {}", e) 
260            }.into())
261    }
262
263    /// Format example as YAML
264    fn format_example_as_yaml(&self, example: &ExampleData) -> LangExtractResult<String> {
265        let mut yaml_map = std::collections::BTreeMap::new();
266        
267        for extraction in &example.extractions {
268            yaml_map.insert(
269                extraction.extraction_class.clone(),
270                extraction.extraction_text.clone(),
271            );
272        }
273
274        serde_yaml::to_string(&yaml_map)
275            .map_err(|e| TemplateError::ExampleError { 
276                message: format!("Failed to format YAML: {}", e) 
277            }.into())
278    }
279
280    /// Substitute variables in template
281    fn substitute_variables(&self, template: &str, context: &PromptContext) -> LangExtractResult<String> {
282        let mut result = template.to_string();
283        
284        // Built-in variables
285        result = result.replace("{task_description}", &context.task_description);
286        result = result.replace("{input_text}", &context.input_text);
287        
288        // Additional context
289        if let Some(context_text) = &context.additional_context {
290            result = result.replace("{additional_context}", &format!("\n\nAdditional Context: {}\n", context_text));
291        } else {
292            result = result.replace("{additional_context}", "");
293        }
294
295        // Examples
296        let examples_text = self.format_examples(&context.examples)?;
297        result = result.replace("{examples}", &examples_text);
298
299        // Reasoning section
300        if self.include_reasoning {
301            result = result.replace("{reasoning}", "\n\nPlease think through this step by step before providing your answer.");
302        } else {
303            result = result.replace("{reasoning}", "");
304        }
305
306        // Schema hint
307        if let Some(hint) = &context.schema_hint {
308            result = result.replace("{schema_hint}", &format!("\n\nSchema guidance: {}\n", hint));
309        } else {
310            result = result.replace("{schema_hint}", "");
311        }
312
313        // Custom variables
314        for (key, value) in &context.variables {
315            result = result.replace(&format!("{{{}}}", key), value);
316        }
317
318        // Note: We skip variable validation here because JSON/YAML examples 
319        // may contain braces that look like template variables
320
321        Ok(result)
322    }
323}
324
325impl TemplateRenderer for PromptTemplate {
326    fn render(&self, context: &PromptContext) -> LangExtractResult<String> {
327        self.substitute_variables(&self.base_template, context)
328    }
329
330    fn validate(&self) -> LangExtractResult<()> {
331        // Check if base template has required placeholders
332        if !self.base_template.contains("{task_description}") {
333            return Err(TemplateError::InvalidSyntax { 
334                message: "Base template must contain {task_description} placeholder".to_string() 
335            }.into());
336        }
337        
338        if !self.base_template.contains("{input_text}") {
339            return Err(TemplateError::InvalidSyntax { 
340                message: "Base template must contain {input_text} placeholder".to_string() 
341            }.into());
342        }
343
344        Ok(())
345    }
346
347    fn required_variables(&self) -> Vec<String> {
348        let mut vars = vec!["task_description".to_string(), "input_text".to_string()];
349        
350        // Extract custom variables from template
351        let mut i = 0;
352        while i < self.base_template.len() {
353            if let Some(start) = self.base_template[i..].find('{') {
354                let start = start + i;
355                if let Some(end) = self.base_template[start..].find('}') {
356                    let end = end + start;
357                    let var_name = &self.base_template[start+1..end];
358                    if !var_name.is_empty() && !vars.contains(&var_name.to_string()) {
359                        vars.push(var_name.to_string());
360                    }
361                    i = end + 1;
362                } else {
363                    break;
364                }
365            } else {
366                break;
367            }
368        }
369        
370        vars
371    }
372}
373
374/// Backward compatibility - simplified prompt template
375#[derive(Debug, Clone)]
376pub struct PromptTemplateStructured {
377    /// Description of what to extract
378    pub description: Option<String>,
379    /// Example data for guidance
380    pub examples: Vec<ExampleData>,
381    /// Advanced template for rendering
382    template: PromptTemplate,
383}
384
385impl PromptTemplateStructured {
386    /// Create a new structured prompt template
387    pub fn new(description: Option<&str>) -> Self {
388        Self {
389            description: description.map(|s| s.to_string()),
390            examples: Vec::new(),
391            template: PromptTemplate::new(FormatType::Json, ProviderType::Ollama),
392        }
393    }
394
395    /// Create with specific format and provider
396    pub fn with_format_and_provider(
397        description: Option<&str>,
398        format_type: FormatType,
399        provider_type: ProviderType,
400    ) -> Self {
401        Self {
402            description: description.map(|s| s.to_string()),
403            examples: Vec::new(),
404            template: PromptTemplate::for_provider(provider_type, format_type),
405        }
406    }
407
408    /// Render the prompt for given text
409    pub fn render(&self, input_text: &str, additional_context: Option<&str>) -> LangExtractResult<String> {
410        let mut context = PromptContext::new(
411            self.description.clone().unwrap_or_default(),
412            input_text.to_string(),
413        );
414        
415        context.examples = self.examples.clone();
416        
417        if let Some(ctx) = additional_context {
418            context.additional_context = Some(ctx.to_string());
419        }
420
421        self.template.render(&context)
422    }
423
424    /// Get the underlying template for advanced customization
425    pub fn template(&self) -> &PromptTemplate {
426        &self.template
427    }
428
429    /// Get mutable reference to the underlying template
430    pub fn template_mut(&mut self) -> &mut PromptTemplate {
431        &mut self.template
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::data::Extraction;
439
440    #[test]
441    fn test_prompt_context_creation() {
442        let context = PromptContext::new(
443            "Extract names".to_string(),
444            "John is here".to_string(),
445        )
446        .with_context("Additional info".to_string())
447        .with_variable("custom".to_string(), "value".to_string())
448        .with_schema_hint("Use proper format".to_string());
449
450        assert_eq!(context.task_description, "Extract names");
451        assert_eq!(context.input_text, "John is here");
452        assert_eq!(context.additional_context, Some("Additional info".to_string()));
453        assert_eq!(context.variables.get("custom"), Some(&"value".to_string()));
454        assert_eq!(context.schema_hint, Some("Use proper format".to_string()));
455    }
456
457    #[test]
458    fn test_template_validation() {
459        let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
460        assert!(template.validate().is_ok());
461
462        let mut invalid_template = template.clone();
463        invalid_template.base_template = "No required placeholders".to_string();
464        assert!(invalid_template.validate().is_err());
465    }
466
467    #[test]
468    fn test_required_variables() {
469        let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
470        let vars = template.required_variables();
471        
472        assert!(vars.contains(&"task_description".to_string()));
473        assert!(vars.contains(&"input_text".to_string()));
474        assert!(vars.contains(&"examples".to_string()));
475    }
476
477    #[test]
478    fn test_example_formatting_json() {
479        let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
480        let example = ExampleData::new(
481            "John is 30".to_string(),
482            vec![
483                Extraction::new("name".to_string(), "John".to_string()),
484                Extraction::new("age".to_string(), "30".to_string()),
485            ],
486        );
487
488        let formatted = template.format_example_as_json(&example).unwrap();
489        assert!(formatted.contains("\"name\": \"John\""));
490        assert!(formatted.contains("\"age\": \"30\""));
491    }
492
493    #[test]
494    fn test_template_rendering() {
495        let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
496        let context = PromptContext::new(
497            "Extract names and ages".to_string(),
498            "Alice is 25 years old".to_string(),
499        );
500
501        let rendered = template.render(&context).unwrap();
502        
503        assert!(rendered.contains("Extract names and ages"));
504        assert!(rendered.contains("Alice is 25 years old"));
505        assert!(rendered.contains("JSON format"));
506    }
507
508    #[test]
509    fn test_provider_specific_templates() {
510        let openai_template = PromptTemplate::for_provider(ProviderType::OpenAI, FormatType::Json);
511        let ollama_template = PromptTemplate::for_provider(ProviderType::Ollama, FormatType::Json);
512
513        assert!(openai_template.system_message.is_some());
514        assert!(!openai_template.include_reasoning);
515        
516        assert!(ollama_template.include_reasoning);
517        assert_eq!(ollama_template.max_examples, Some(3));
518    }
519
520    #[test]
521    fn test_backward_compatibility() {
522        let mut template = PromptTemplateStructured::new(Some("Extract info"));
523        template.examples.push(ExampleData::new(
524            "Test".to_string(),
525            vec![Extraction::new("test".to_string(), "value".to_string())],
526        ));
527
528        let rendered = template.render("Input text", None).unwrap();
529        assert!(rendered.contains("Extract info"));
530        assert!(rendered.contains("Input text"));
531    }
532}