langextract_rust/
prompting.rs

1//! Advanced prompt template system with dynamic variables and provider adaptation.
2
3use crate::{
4    data::{ExampleData, FormatType},
5    exceptions::{LangExtractError, LangExtractResult},
6    providers::ProviderType,
7};
8use std::collections::HashMap;
9
10/// Error types for template operations
11#[derive(Debug, thiserror::Error)]
12pub enum TemplateError {
13    #[error("Missing required variable: {variable}")]
14    MissingVariable { variable: String },
15    #[error("Invalid template syntax: {message}")]
16    InvalidSyntax { message: String },
17    #[error("Example formatting error: {message}")]
18    ExampleError { message: String },
19}
20
21impl From<TemplateError> for LangExtractError {
22    fn from(err: TemplateError) -> Self {
23        LangExtractError::InvalidInput(err.to_string())
24    }
25}
26
27/// Context for rendering prompts
28#[derive(Debug, Clone)]
29pub struct PromptContext {
30    /// Task description for what to extract
31    pub task_description: String,
32    /// Example data to guide extraction
33    pub examples: Vec<ExampleData>,
34    /// Input text to process
35    pub input_text: String,
36    /// Additional context information
37    pub additional_context: Option<String>,
38    /// Schema hint for structured output
39    pub schema_hint: Option<String>,
40    /// Custom variables for template substitution
41    pub variables: HashMap<String, String>,
42}
43
44impl PromptContext {
45    /// Create a new prompt context
46    pub fn new(task_description: String, input_text: String) -> Self {
47        Self {
48            task_description,
49            input_text,
50            examples: Vec::new(),
51            additional_context: None,
52            schema_hint: None,
53            variables: HashMap::new(),
54        }
55    }
56
57    /// Add examples to the context
58    pub fn with_examples(mut self, examples: Vec<ExampleData>) -> Self {
59        self.examples = examples;
60        self
61    }
62
63    /// Add additional context
64    pub fn with_context(mut self, context: String) -> Self {
65        self.additional_context = Some(context);
66        self
67    }
68
69    /// Add a custom variable
70    pub fn with_variable(mut self, key: String, value: String) -> Self {
71        self.variables.insert(key, value);
72        self
73    }
74
75    /// Add schema hint
76    pub fn with_schema_hint(mut self, hint: String) -> Self {
77        self.schema_hint = Some(hint);
78        self
79    }
80}
81
82/// Trait for rendering prompt templates
83pub trait TemplateRenderer {
84    /// Render the template with the given context
85    fn render(&self, context: &PromptContext) -> LangExtractResult<String>;
86    
87    /// Validate the template structure
88    fn validate(&self) -> LangExtractResult<()>;
89    
90    /// Get required variables for this template
91    fn required_variables(&self) -> Vec<String>;
92}
93
94/// Advanced prompt template with dynamic variables and provider adaptation
95#[derive(Debug, Clone)]
96pub struct PromptTemplate {
97    /// Base template string with variable placeholders
98    pub base_template: String,
99    /// System message for providers that support it
100    pub system_message: Option<String>,
101    /// Template for formatting examples
102    pub example_template: String,
103    /// Output format type
104    pub format_type: FormatType,
105    /// Target provider type for optimization
106    pub provider_type: ProviderType,
107    /// Maximum number of examples to include
108    pub max_examples: Option<usize>,
109    /// Whether to include reasoning instructions
110    pub include_reasoning: bool,
111}
112
113impl PromptTemplate {
114    /// Create a new prompt template
115    pub fn new(format_type: FormatType, provider_type: ProviderType) -> Self {
116        let base_template = Self::default_base_template(format_type, provider_type);
117        let example_template = Self::default_example_template(format_type);
118        
119        Self {
120            base_template,
121            system_message: None,
122            example_template,
123            format_type,
124            provider_type,
125            max_examples: Some(5),
126            include_reasoning: false,
127        }
128    }
129
130    /// Create template optimized for specific provider
131    pub fn for_provider(provider_type: ProviderType, format_type: FormatType) -> Self {
132        let mut template = Self::new(format_type, provider_type);
133        
134        match provider_type {
135            ProviderType::OpenAI => {
136                template.system_message = Some(
137                    "You are an expert information extraction assistant. Extract structured information exactly as shown in the examples.".to_string()
138                );
139                template.include_reasoning = false; // OpenAI is good with direct instructions
140            }
141            ProviderType::Ollama => {
142                template.include_reasoning = true; // Local models benefit from reasoning steps
143                template.max_examples = Some(3); // Keep prompts shorter for local models
144            }
145            ProviderType::Custom => {
146                // Conservative defaults for unknown providers
147                template.max_examples = Some(3);
148                template.include_reasoning = true;
149            }
150        }
151        
152        template
153    }
154
155    /// Set maximum number of examples
156    pub fn with_max_examples(mut self, max: usize) -> Self {
157        self.max_examples = Some(max);
158        self
159    }
160
161    /// Set system message
162    pub fn with_system_message(mut self, message: String) -> Self {
163        self.system_message = Some(message);
164        self
165    }
166
167    /// Enable or disable reasoning instructions
168    pub fn with_reasoning(mut self, enable: bool) -> Self {
169        self.include_reasoning = enable;
170        self
171    }
172
173    /// Set custom base template
174    pub fn with_base_template(mut self, template: String) -> Self {
175        self.base_template = template;
176        self
177    }
178
179    /// Default base template for different formats and providers
180    fn default_base_template(format_type: FormatType, provider_type: ProviderType) -> String {
181        use crate::templates::TemplateBuilder;
182        
183        let include_reasoning = matches!(provider_type, ProviderType::Ollama | ProviderType::Custom);
184        
185        TemplateBuilder::new(format_type)
186            .with_reasoning(include_reasoning)
187            .build()
188    }
189
190    /// Default example template for different formats
191    fn default_example_template(format_type: FormatType) -> String {
192        match format_type {
193            FormatType::Json => {
194                "Input: {input}\nOutput: {output_json}\n".to_string()
195            }
196            FormatType::Yaml => {
197                "Input: {input}\nOutput:\n{output_yaml}\n".to_string()
198            }
199        }
200    }
201
202    /// Format examples according to the template
203    fn format_examples(&self, examples: &[ExampleData]) -> LangExtractResult<String> {
204        use crate::templates::ExampleFormatter;
205        
206        let formatter = if let Some(max) = self.max_examples {
207            ExampleFormatter::new(self.format_type).with_max_examples(max)
208        } else {
209            ExampleFormatter::new(self.format_type)
210        };
211        
212        formatter.format_examples(examples)
213    }
214
215    // Note: format_example_as_json and format_example_as_yaml methods have been moved
216    // to the templates::ExampleFormatter to eliminate duplication
217
218    /// Substitute variables in template
219    fn substitute_variables(&self, template: &str, context: &PromptContext) -> LangExtractResult<String> {
220        use crate::templates::TemplateEngine;
221        use std::collections::HashMap;
222        
223        let mut variables = HashMap::new();
224        
225        // Built-in variables
226        variables.insert("task_description".to_string(), context.task_description.clone());
227        variables.insert("input_text".to_string(), context.input_text.clone());
228        
229        // Additional context
230        if let Some(context_text) = &context.additional_context {
231            variables.insert("additional_context".to_string(), 
232                format!("\n\nAdditional Context: {}\n", context_text));
233        } else {
234            variables.insert("additional_context".to_string(), String::new());
235        }
236
237        // Examples
238        let examples_text = self.format_examples(&context.examples)?;
239        variables.insert("examples".to_string(), examples_text);
240
241        // Reasoning section
242        if self.include_reasoning {
243            variables.insert("reasoning".to_string(), 
244                "\n\nPlease think through this step by step before providing your answer.".to_string());
245        } else {
246            variables.insert("reasoning".to_string(), String::new());
247        }
248
249        // Schema hint
250        if let Some(hint) = &context.schema_hint {
251            variables.insert("schema_hint".to_string(), 
252                format!("\n\nSchema guidance: {}\n", hint));
253        } else {
254            variables.insert("schema_hint".to_string(), String::new());
255        }
256
257        // Custom variables
258        for (key, value) in &context.variables {
259            variables.insert(key.clone(), value.clone());
260        }
261
262        // Use lenient template engine to avoid issues with JSON/YAML examples
263        let engine = TemplateEngine::lenient();
264        engine.render(template, &variables)
265    }
266}
267
268impl TemplateRenderer for PromptTemplate {
269    fn render(&self, context: &PromptContext) -> LangExtractResult<String> {
270        self.substitute_variables(&self.base_template, context)
271    }
272
273    fn validate(&self) -> LangExtractResult<()> {
274        // Check if base template has required placeholders
275        if !self.base_template.contains("{task_description}") {
276            return Err(TemplateError::InvalidSyntax { 
277                message: "Base template must contain {task_description} placeholder".to_string() 
278            }.into());
279        }
280        
281        if !self.base_template.contains("{input_text}") {
282            return Err(TemplateError::InvalidSyntax { 
283                message: "Base template must contain {input_text} placeholder".to_string() 
284            }.into());
285        }
286
287        Ok(())
288    }
289
290    fn required_variables(&self) -> Vec<String> {
291        let mut vars = vec!["task_description".to_string(), "input_text".to_string()];
292        
293        // Extract custom variables from template
294        let mut i = 0;
295        while i < self.base_template.len() {
296            if let Some(start) = self.base_template[i..].find('{') {
297                let start = start + i;
298                if let Some(end) = self.base_template[start..].find('}') {
299                    let end = end + start;
300                    let var_name = &self.base_template[start+1..end];
301                    if !var_name.is_empty() && !vars.contains(&var_name.to_string()) {
302                        vars.push(var_name.to_string());
303                    }
304                    i = end + 1;
305                } else {
306                    break;
307                }
308            } else {
309                break;
310            }
311        }
312        
313        vars
314    }
315}
316
317/// Backward compatibility - simplified prompt template
318#[derive(Debug, Clone)]
319pub struct PromptTemplateStructured {
320    /// Description of what to extract
321    pub description: Option<String>,
322    /// Example data for guidance
323    pub examples: Vec<ExampleData>,
324    /// Advanced template for rendering
325    template: PromptTemplate,
326}
327
328impl PromptTemplateStructured {
329    /// Create a new structured prompt template
330    pub fn new(description: Option<&str>) -> Self {
331        Self {
332            description: description.map(|s| s.to_string()),
333            examples: Vec::new(),
334            template: PromptTemplate::new(FormatType::Json, ProviderType::Ollama),
335        }
336    }
337
338    /// Create with specific format and provider
339    pub fn with_format_and_provider(
340        description: Option<&str>,
341        format_type: FormatType,
342        provider_type: ProviderType,
343    ) -> Self {
344        Self {
345            description: description.map(|s| s.to_string()),
346            examples: Vec::new(),
347            template: PromptTemplate::for_provider(provider_type, format_type),
348        }
349    }
350
351    /// Render the prompt for given text
352    pub fn render(&self, input_text: &str, additional_context: Option<&str>) -> LangExtractResult<String> {
353        let mut context = PromptContext::new(
354            self.description.clone().unwrap_or_default(),
355            input_text.to_string(),
356        );
357        
358        context.examples = self.examples.clone();
359        
360        if let Some(ctx) = additional_context {
361            context.additional_context = Some(ctx.to_string());
362        }
363
364        self.template.render(&context)
365    }
366
367    /// Get the underlying template for advanced customization
368    pub fn template(&self) -> &PromptTemplate {
369        &self.template
370    }
371
372    /// Get mutable reference to the underlying template
373    pub fn template_mut(&mut self) -> &mut PromptTemplate {
374        &mut self.template
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::data::Extraction;
382
383    #[test]
384    fn test_prompt_context_creation() {
385        let context = PromptContext::new(
386            "Extract names".to_string(),
387            "John is here".to_string(),
388        )
389        .with_context("Additional info".to_string())
390        .with_variable("custom".to_string(), "value".to_string())
391        .with_schema_hint("Use proper format".to_string());
392
393        assert_eq!(context.task_description, "Extract names");
394        assert_eq!(context.input_text, "John is here");
395        assert_eq!(context.additional_context, Some("Additional info".to_string()));
396        assert_eq!(context.variables.get("custom"), Some(&"value".to_string()));
397        assert_eq!(context.schema_hint, Some("Use proper format".to_string()));
398    }
399
400    #[test]
401    fn test_template_validation() {
402        let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
403        assert!(template.validate().is_ok());
404
405        let mut invalid_template = template.clone();
406        invalid_template.base_template = "No required placeholders".to_string();
407        assert!(invalid_template.validate().is_err());
408    }
409
410    #[test]
411    fn test_required_variables() {
412        let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
413        let vars = template.required_variables();
414        
415        assert!(vars.contains(&"task_description".to_string()));
416        assert!(vars.contains(&"input_text".to_string()));
417        assert!(vars.contains(&"examples".to_string()));
418    }
419
420    #[test]
421    fn test_example_formatting_json() {
422        let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
423        let example = ExampleData::new(
424            "John is 30".to_string(),
425            vec![
426                Extraction::new("name".to_string(), "John".to_string()),
427                Extraction::new("age".to_string(), "30".to_string()),
428            ],
429        );
430
431        // Test is now handled by the templates::ExampleFormatter tests
432        // Let's test the template rendering instead
433        let context = PromptContext::new("Extract information".to_string(), "Test input".to_string())
434            .with_examples(vec![example]);
435        let rendered = template.render(&context).unwrap();
436        assert!(rendered.contains("Extract information"));
437        assert!(rendered.contains("Test input"));
438    }
439
440    #[test]
441    fn test_template_rendering() {
442        let template = PromptTemplate::new(FormatType::Json, ProviderType::OpenAI);
443        let context = PromptContext::new(
444            "Extract names and ages".to_string(),
445            "Alice is 25 years old".to_string(),
446        );
447
448        let rendered = template.render(&context).unwrap();
449        
450        assert!(rendered.contains("Extract names and ages"));
451        assert!(rendered.contains("Alice is 25 years old"));
452        assert!(rendered.contains("JSON format"));
453    }
454
455    #[test]
456    fn test_provider_specific_templates() {
457        let openai_template = PromptTemplate::for_provider(ProviderType::OpenAI, FormatType::Json);
458        let ollama_template = PromptTemplate::for_provider(ProviderType::Ollama, FormatType::Json);
459
460        assert!(openai_template.system_message.is_some());
461        assert!(!openai_template.include_reasoning);
462        
463        assert!(ollama_template.include_reasoning);
464        assert_eq!(ollama_template.max_examples, Some(3));
465    }
466
467    #[test]
468    fn test_backward_compatibility() {
469        let mut template = PromptTemplateStructured::new(Some("Extract info"));
470        template.examples.push(ExampleData::new(
471            "Test".to_string(),
472            vec![Extraction::new("test".to_string(), "value".to_string())],
473        ));
474
475        let rendered = template.render("Input text", None).unwrap();
476        assert!(rendered.contains("Extract info"));
477        assert!(rendered.contains("Input text"));
478    }
479}