Skip to main content

langextract_rust/
config.rs

1//! Unified configuration system for LangExtract.
2//!
3//! This module provides a centralized configuration system that unifies all the
4//! various configuration structures used throughout the library.
5
6use crate::{
7    data::FormatType,
8    logging::ProgressHandler,
9    providers::ProviderConfig,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14
15/// The main unified configuration for LangExtract operations
16#[derive(Clone, Serialize, Deserialize)]
17pub struct LangExtractConfig {
18    /// Core processing configuration
19    pub processing: ProcessingConfig,
20    /// Provider configuration
21    pub provider: ProviderConfig,
22    /// Validation and output processing
23    pub validation: ValidationConfig,
24    /// Text chunking configuration  
25    pub chunking: ChunkingConfig,
26    /// Alignment configuration
27    pub alignment: AlignmentConfig,
28    /// Multi-pass extraction configuration
29    pub multipass: MultiPassConfig,
30    /// Visualization and export configuration
31    pub visualization: VisualizationConfig,
32    /// Inference-specific parameters
33    pub inference: InferenceConfig,
34    /// Progress reporting configuration (not serialized)
35    #[serde(skip)]
36    pub progress: ProgressConfig,
37}
38
39/// Core processing configuration
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ProcessingConfig {
42    /// Output format type
43    pub format_type: FormatType,
44    /// Maximum characters per chunk for processing
45    pub max_char_buffer: usize,
46    /// Batch size for processing chunks
47    pub batch_length: usize,
48    /// Maximum number of concurrent workers
49    pub max_workers: usize,
50    /// Additional context for the prompt
51    pub additional_context: Option<String>,
52    /// Enable debug mode
53    pub debug: bool,
54    /// Whether to wrap output in code fences
55    pub fence_output: Option<bool>,
56    /// Whether to use schema constraints
57    pub use_schema_constraints: bool,
58    /// Custom parameters for extensibility
59    pub custom_params: HashMap<String, serde_json::Value>,
60}
61
62/// Configuration for validation and output processing
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ValidationConfig {
65    /// Whether to enable schema validation
66    pub enable_schema_validation: bool,
67    /// Whether to enable type coercion (e.g., string "25" -> number 25)
68    pub enable_type_coercion: bool,
69    /// Whether to require all expected fields to be present
70    pub require_all_fields: bool,
71    /// Whether to save raw model outputs to files
72    pub save_raw_outputs: bool,
73    /// Directory to save raw outputs (defaults to "./raw_outputs")
74    pub raw_outputs_dir: String,
75    /// Quality threshold for extractions (0.0 to 1.0)
76    pub quality_threshold: f32,
77}
78
79/// Configuration for text chunking
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ChunkingConfig {
82    /// Chunking strategy to use
83    pub strategy: ChunkingStrategy,
84    /// Target chunk size in characters
85    pub target_size: usize,
86    /// Maximum chunk size in characters
87    pub max_size: usize,
88    /// Overlap between chunks in characters
89    pub overlap: usize,
90    /// Minimum chunk size in characters
91    pub min_size: usize,
92    /// Whether to preserve sentence boundaries
93    pub preserve_sentences: bool,
94    /// Whether to preserve paragraph boundaries
95    pub preserve_paragraphs: bool,
96}
97
98/// Chunking strategy enumeration
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
100#[serde(rename_all = "snake_case")]
101pub enum ChunkingStrategy {
102    /// Token-based chunking (recommended)
103    Token,
104    /// Semantic chunking using embeddings
105    Semantic,
106    /// Sentence-based chunking
107    Sentence,
108    /// Paragraph-based chunking
109    Paragraph,
110    /// Fixed character-based chunking
111    Fixed,
112}
113
114/// Configuration for text alignment
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct AlignmentConfig {
117    /// Enable fuzzy alignment when exact matching fails
118    pub enable_fuzzy_alignment: bool,
119    /// Minimum overlap ratio for fuzzy alignment (0.0 to 1.0)
120    pub fuzzy_alignment_threshold: f32,
121    /// Accept partial exact matches (MATCH_LESSER status)
122    pub accept_match_lesser: bool,
123    /// Case-sensitive matching
124    pub case_sensitive: bool,
125    /// Maximum search window size for fuzzy matching
126    pub max_search_window: usize,
127}
128
129/// Configuration for multi-pass extraction
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct MultiPassConfig {
132    /// Enable multi-pass extraction for improved recall
133    pub enable_multipass: bool,
134    /// Number of extraction passes to perform
135    pub max_passes: usize,
136    /// Minimum extraction count per chunk to avoid re-processing
137    pub min_extractions_per_chunk: usize,
138    /// Enable targeted re-processing of low-yield chunks
139    pub enable_targeted_reprocessing: bool,
140    /// Enable refinement passes using previous results
141    pub enable_refinement_passes: bool,
142    /// Minimum quality score to keep extractions (0.0 to 1.0)
143    pub quality_threshold: f32,
144    /// Maximum number of chunks to re-process per pass
145    pub max_reprocess_chunks: usize,
146    /// Temperature adjustment for subsequent passes
147    pub temperature_decay: f32,
148}
149
150/// Configuration for visualization and export
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct VisualizationConfig {
153    /// Default export format
154    pub default_format: ExportFormat,
155    /// Show character intervals in output
156    pub show_char_intervals: bool,
157    /// Include original text in export
158    pub include_text: bool,
159    /// Highlight extractions in text (for HTML/Markdown)
160    pub highlight_extractions: bool,
161    /// Include extraction statistics
162    pub include_statistics: bool,
163    /// Custom CSS for HTML export
164    pub custom_css: Option<String>,
165    /// Default title for exports
166    pub default_title: Option<String>,
167}
168
169/// Export format enumeration
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "snake_case")]
172pub enum ExportFormat {
173    Text,
174    Html,
175    Markdown,
176    Json,
177    Csv,
178}
179
180/// Configuration for language model inference
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct InferenceConfig {
183    /// Sampling temperature (0.0 to 1.0)
184    pub temperature: f32,
185    /// Maximum number of tokens to generate
186    pub max_tokens: Option<usize>,
187    /// Number of candidate outputs to generate
188    pub num_candidates: usize,
189    /// Stop sequences to halt generation
190    pub stop_sequences: Vec<String>,
191    /// Additional provider-specific parameters
192    pub extra_params: HashMap<String, serde_json::Value>,
193}
194
195/// Configuration for progress reporting
196#[derive(Clone)]
197pub struct ProgressConfig {
198    /// Progress handler for reporting extraction progress
199    pub handler: Option<Arc<dyn ProgressHandler>>,
200    /// Whether to show progress messages
201    pub show_progress: bool,
202    /// Whether to show debug information
203    pub show_debug: bool,
204    /// Whether to use emoji and colors in output
205    pub use_styling: bool,
206}
207
208impl Default for LangExtractConfig {
209    fn default() -> Self {
210        Self {
211            processing: ProcessingConfig::default(),
212            provider: ProviderConfig::ollama("mistral", None), // Safe default
213            validation: ValidationConfig::default(),
214            chunking: ChunkingConfig::default(),
215            alignment: AlignmentConfig::default(),
216            multipass: MultiPassConfig::default(),
217            visualization: VisualizationConfig::default(),
218            inference: InferenceConfig::default(),
219            progress: ProgressConfig::default(),
220        }
221    }
222}
223
224impl Default for ProcessingConfig {
225    fn default() -> Self {
226        Self {
227            format_type: FormatType::Json,
228            max_char_buffer: 8000,
229            batch_length: 4,
230            max_workers: 6,
231            additional_context: None,
232            debug: false,
233            fence_output: None,
234            use_schema_constraints: true,
235            custom_params: HashMap::new(),
236        }
237    }
238}
239
240impl Default for ValidationConfig {
241    fn default() -> Self {
242        Self {
243            enable_schema_validation: true,
244            enable_type_coercion: true,
245            require_all_fields: false,
246            save_raw_outputs: true,
247            raw_outputs_dir: "./raw_outputs".to_string(),
248            quality_threshold: 0.0,
249        }
250    }
251}
252
253impl Default for ChunkingConfig {
254    fn default() -> Self {
255        Self {
256            strategy: ChunkingStrategy::Token,
257            target_size: 8000,
258            max_size: 10000,
259            overlap: 200,
260            min_size: 500,
261            preserve_sentences: true,
262            preserve_paragraphs: true,
263        }
264    }
265}
266
267impl Default for AlignmentConfig {
268    fn default() -> Self {
269        Self {
270            enable_fuzzy_alignment: true,
271            fuzzy_alignment_threshold: 0.4,
272            accept_match_lesser: true,
273            case_sensitive: false,
274            max_search_window: 100,
275        }
276    }
277}
278
279impl Default for MultiPassConfig {
280    fn default() -> Self {
281        Self {
282            enable_multipass: false,
283            max_passes: 2,
284            min_extractions_per_chunk: 1,
285            enable_targeted_reprocessing: true,
286            enable_refinement_passes: true,
287            quality_threshold: 0.3,
288            max_reprocess_chunks: 10,
289            temperature_decay: 0.9,
290        }
291    }
292}
293
294impl Default for VisualizationConfig {
295    fn default() -> Self {
296        Self {
297            default_format: ExportFormat::Text,
298            show_char_intervals: false,
299            include_text: true,
300            highlight_extractions: true,
301            include_statistics: true,
302            custom_css: None,
303            default_title: None,
304        }
305    }
306}
307
308impl Default for InferenceConfig {
309    fn default() -> Self {
310        Self {
311            temperature: 0.3,
312            max_tokens: None,
313            num_candidates: 1,
314            stop_sequences: vec![],
315            extra_params: HashMap::new(),
316        }
317    }
318}
319
320impl Default for ProgressConfig {
321    fn default() -> Self {
322        Self {
323            handler: None,
324            show_progress: true,
325            show_debug: false,
326            use_styling: true,
327        }
328    }
329}
330
331// Builder pattern implementation for easier configuration
332impl LangExtractConfig {
333    /// Create a new configuration with default values
334    pub fn new() -> Self {
335        Self::default()
336    }
337
338    /// Set the provider configuration
339    pub fn with_provider(mut self, provider: ProviderConfig) -> Self {
340        self.provider = provider;
341        self
342    }
343
344    /// Set the processing configuration
345    pub fn with_processing(mut self, processing: ProcessingConfig) -> Self {
346        self.processing = processing;
347        self
348    }
349
350    /// Set validation configuration
351    pub fn with_validation(mut self, validation: ValidationConfig) -> Self {
352        self.validation = validation;
353        self
354    }
355
356    /// Set chunking configuration
357    pub fn with_chunking(mut self, chunking: ChunkingConfig) -> Self {
358        self.chunking = chunking;
359        self
360    }
361
362    /// Set alignment configuration
363    pub fn with_alignment(mut self, alignment: AlignmentConfig) -> Self {
364        self.alignment = alignment;
365        self
366    }
367
368    /// Set multi-pass configuration
369    pub fn with_multipass(mut self, multipass: MultiPassConfig) -> Self {
370        self.multipass = multipass;
371        self
372    }
373
374    /// Set visualization configuration
375    pub fn with_visualization(mut self, visualization: VisualizationConfig) -> Self {
376        self.visualization = visualization;
377        self
378    }
379
380    /// Set inference configuration
381    pub fn with_inference(mut self, inference: InferenceConfig) -> Self {
382        self.inference = inference;
383        self
384    }
385
386    /// Set progress configuration
387    pub fn with_progress(mut self, progress: ProgressConfig) -> Self {
388        self.progress = progress;
389        self
390    }
391
392    /// Enable debug mode
393    pub fn with_debug(mut self, enabled: bool) -> Self {
394        self.processing.debug = enabled;
395        self.progress.show_debug = enabled;
396        self
397    }
398
399    /// Set maximum characters per chunk
400    pub fn with_max_char_buffer(mut self, size: usize) -> Self {
401        self.processing.max_char_buffer = size;
402        self.chunking.target_size = size;
403        self
404    }
405
406    /// Set the number of workers
407    pub fn with_workers(mut self, workers: usize) -> Self {
408        self.processing.max_workers = workers;
409        self
410    }
411
412    /// Set temperature for inference
413    pub fn with_temperature(mut self, temperature: f32) -> Self {
414        self.inference.temperature = temperature.clamp(0.0, 1.0);
415        self
416    }
417
418    /// Enable multi-pass extraction
419    pub fn with_multipass_enabled(mut self, enabled: bool) -> Self {
420        self.multipass.enable_multipass = enabled;
421        self
422    }
423
424    /// Set progress handler
425    pub fn with_progress_handler(mut self, handler: Arc<dyn ProgressHandler>) -> Self {
426        self.progress.handler = Some(handler);
427        self
428    }
429
430    /// Enable quiet mode (no progress output)
431    pub fn with_quiet_mode(mut self) -> Self {
432        self.progress.show_progress = false;
433        self.progress.show_debug = false;
434        self
435    }
436
437    /// Enable verbose mode (show all output)
438    pub fn with_verbose_mode(mut self) -> Self {
439        self.progress.show_progress = true;
440        self.progress.show_debug = true;
441        self
442    }
443}
444
445// Specialized builder methods for common configurations
446impl LangExtractConfig {
447    /// Create a configuration optimized for OpenAI
448    pub fn for_openai(model: &str, api_key: Option<String>) -> Self {
449        Self::new()
450            .with_provider(ProviderConfig::openai(model, api_key))
451            .with_inference(InferenceConfig {
452                temperature: 0.2,
453                max_tokens: Some(2000),
454                ..Default::default()
455            })
456    }
457
458    /// Create a configuration optimized for Ollama
459    pub fn for_ollama(model: &str, base_url: Option<String>) -> Self {
460        Self::new()
461            .with_provider(ProviderConfig::ollama(model, base_url))
462            .with_inference(InferenceConfig {
463                temperature: 0.3,
464                max_tokens: Some(1500),
465                ..Default::default()
466            })
467            .with_chunking(ChunkingConfig {
468                target_size: 6000, // Smaller chunks for local models
469                max_size: 8000,
470                ..Default::default()
471            })
472    }
473
474    /// Create a configuration for high-performance processing
475    pub fn for_high_performance() -> Self {
476        Self::new()
477            .with_processing(ProcessingConfig {
478                max_workers: 12,
479                batch_length: 8,
480                max_char_buffer: 10000,
481                ..Default::default()
482            })
483            .with_multipass(MultiPassConfig {
484                enable_multipass: true,
485                max_passes: 3,
486                ..Default::default()
487            })
488    }
489
490    /// Create a configuration for memory-efficient processing
491    pub fn for_memory_efficient() -> Self {
492        Self::new()
493            .with_processing(ProcessingConfig {
494                max_workers: 4,
495                batch_length: 2,
496                max_char_buffer: 6000,
497                ..Default::default()
498            })
499            .with_chunking(ChunkingConfig {
500                target_size: 4000,
501                max_size: 6000,
502                overlap: 100,
503                ..Default::default()
504            })
505    }
506}
507
508// Conversion traits for backward compatibility
509impl From<LangExtractConfig> for crate::ExtractConfig {
510    fn from(config: LangExtractConfig) -> Self {
511        let provider_config_value = serde_json::to_value(&config.provider).unwrap_or_default();
512        
513        Self {
514            model_id: config.provider.model.clone(),
515            api_key: config.provider.api_key.clone(),
516            format_type: config.processing.format_type,
517            max_char_buffer: config.processing.max_char_buffer,
518            temperature: config.inference.temperature,
519            fence_output: config.processing.fence_output,
520            use_schema_constraints: config.processing.use_schema_constraints,
521            batch_length: config.processing.batch_length,
522            max_workers: config.processing.max_workers,
523            additional_context: config.processing.additional_context.clone(),
524            resolver_params: HashMap::new(), // Legacy field
525            language_model_params: {
526                let mut params = HashMap::new();
527                params.insert("provider_config".to_string(), provider_config_value);
528                params
529            },
530            debug: config.processing.debug,
531            model_url: Some(config.provider.base_url.clone()),
532            enable_multipass: config.multipass.enable_multipass,
533            multipass_max_passes: config.multipass.max_passes,
534            multipass_min_extractions: config.multipass.min_extractions_per_chunk,
535            multipass_quality_threshold: config.multipass.quality_threshold,
536            progress_handler: config.progress.handler,
537        }
538    }
539}
540
541impl std::fmt::Debug for LangExtractConfig {
542    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
543        f.debug_struct("LangExtractConfig")
544            .field("processing", &self.processing)
545            .field("provider", &self.provider)
546            .field("validation", &self.validation)
547            .field("chunking", &self.chunking)
548            .field("alignment", &self.alignment)
549            .field("multipass", &self.multipass)
550            .field("visualization", &self.visualization)
551            .field("inference", &self.inference)
552            .field("progress", &"<ProgressConfig>")
553            .finish()
554    }
555}
556
557impl std::fmt::Debug for ProgressConfig {
558    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
559        f.debug_struct("ProgressConfig")
560            .field("handler", &"<ProgressHandler>")
561            .field("show_progress", &self.show_progress)
562            .field("show_debug", &self.show_debug)
563            .field("use_styling", &self.use_styling)
564            .finish()
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    #[test]
573    fn test_default_config() {
574        let config = LangExtractConfig::default();
575        assert_eq!(config.processing.format_type, FormatType::Json);
576        assert_eq!(config.processing.max_char_buffer, 8000);
577        assert_eq!(config.chunking.strategy, ChunkingStrategy::Token);
578    }
579
580    #[test]
581    fn test_builder_pattern() {
582        let config = LangExtractConfig::new()
583            .with_debug(true)
584            .with_max_char_buffer(10000)
585            .with_workers(8)
586            .with_temperature(0.5);
587
588        assert!(config.processing.debug);
589        assert_eq!(config.processing.max_char_buffer, 10000);
590        assert_eq!(config.processing.max_workers, 8);
591        assert_eq!(config.inference.temperature, 0.5);
592    }
593
594    #[test]
595    fn test_specialized_configs() {
596        use crate::providers::ProviderType;
597        
598        let openai_config = LangExtractConfig::for_openai("gpt-4o", Some("test-key".to_string()));
599        assert_eq!(openai_config.provider.provider_type, ProviderType::OpenAI);
600        assert_eq!(openai_config.inference.temperature, 0.2);
601
602        let ollama_config = LangExtractConfig::for_ollama("mistral", None);
603        assert_eq!(ollama_config.provider.provider_type, ProviderType::Ollama);
604        assert_eq!(ollama_config.chunking.target_size, 6000);
605
606        let hp_config = LangExtractConfig::for_high_performance();
607        assert_eq!(hp_config.processing.max_workers, 12);
608        assert!(hp_config.multipass.enable_multipass);
609    }
610
611    #[test]
612    fn test_backward_compatibility() {
613        let new_config = LangExtractConfig::for_ollama("mistral", None)
614            .with_debug(true)
615            .with_temperature(0.4);
616
617        let old_config: crate::ExtractConfig = new_config.into();
618        assert_eq!(old_config.model_id, "mistral");
619        assert!(old_config.debug);
620        assert_eq!(old_config.temperature, 0.4);
621    }
622
623    #[test]
624    fn test_serialization() {
625        let config = LangExtractConfig::for_openai("gpt-4o", Some("test-key".to_string()));
626        let serialized = serde_json::to_string(&config).unwrap();
627        let deserialized: LangExtractConfig = serde_json::from_str(&serialized).unwrap();
628        
629        assert_eq!(config.provider.model, deserialized.provider.model);
630        assert_eq!(config.processing.format_type, deserialized.processing.format_type);
631    }
632}