langextract_rust/
config.rs

1//! Unified configuration system for LangExtract.
2//!
3//! This module provides a centralized configuration system that unifies all the
4//! various configuration structures used throughout the library.
5
6use crate::{
7    data::FormatType,
8    logging::ProgressHandler,
9    providers::ProviderConfig,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14
15/// The main unified configuration for LangExtract operations
16#[derive(Clone, Serialize, Deserialize)]
17pub struct LangExtractConfig {
18    /// Core processing configuration
19    pub processing: ProcessingConfig,
20    /// Provider configuration
21    pub provider: ProviderConfig,
22    /// Validation and output processing
23    pub validation: ValidationConfig,
24    /// Text chunking configuration  
25    pub chunking: ChunkingConfig,
26    /// Alignment configuration
27    pub alignment: AlignmentConfig,
28    /// Multi-pass extraction configuration
29    pub multipass: MultiPassConfig,
30    /// Visualization and export configuration
31    pub visualization: VisualizationConfig,
32    /// Inference-specific parameters
33    pub inference: InferenceConfig,
34    /// Progress reporting configuration (not serialized)
35    #[serde(skip)]
36    pub progress: ProgressConfig,
37}
38
39/// Core processing configuration
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ProcessingConfig {
42    /// Output format type
43    pub format_type: FormatType,
44    /// Maximum characters per chunk for processing
45    pub max_char_buffer: usize,
46    /// Batch size for processing chunks
47    pub batch_length: usize,
48    /// Maximum number of concurrent workers
49    pub max_workers: usize,
50    /// Additional context for the prompt
51    pub additional_context: Option<String>,
52    /// Enable debug mode
53    pub debug: bool,
54    /// Number of extraction passes to improve recall
55    pub extraction_passes: usize,
56    /// Whether to wrap output in code fences
57    pub fence_output: Option<bool>,
58    /// Whether to use schema constraints
59    pub use_schema_constraints: bool,
60    /// Custom parameters for extensibility
61    pub custom_params: HashMap<String, serde_json::Value>,
62}
63
64/// Configuration for validation and output processing
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ValidationConfig {
67    /// Whether to enable schema validation
68    pub enable_schema_validation: bool,
69    /// Whether to enable type coercion (e.g., string "25" -> number 25)
70    pub enable_type_coercion: bool,
71    /// Whether to require all expected fields to be present
72    pub require_all_fields: bool,
73    /// Whether to save raw model outputs to files
74    pub save_raw_outputs: bool,
75    /// Directory to save raw outputs (defaults to "./raw_outputs")
76    pub raw_outputs_dir: String,
77    /// Quality threshold for extractions (0.0 to 1.0)
78    pub quality_threshold: f32,
79}
80
81/// Configuration for text chunking
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ChunkingConfig {
84    /// Chunking strategy to use
85    pub strategy: ChunkingStrategy,
86    /// Target chunk size in characters
87    pub target_size: usize,
88    /// Maximum chunk size in characters
89    pub max_size: usize,
90    /// Overlap between chunks in characters
91    pub overlap: usize,
92    /// Minimum chunk size in characters
93    pub min_size: usize,
94    /// Whether to preserve sentence boundaries
95    pub preserve_sentences: bool,
96    /// Whether to preserve paragraph boundaries
97    pub preserve_paragraphs: bool,
98}
99
100/// Chunking strategy enumeration
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
102#[serde(rename_all = "snake_case")]
103pub enum ChunkingStrategy {
104    /// Token-based chunking (recommended)
105    Token,
106    /// Semantic chunking using embeddings
107    Semantic,
108    /// Sentence-based chunking
109    Sentence,
110    /// Paragraph-based chunking
111    Paragraph,
112    /// Fixed character-based chunking
113    Fixed,
114}
115
116/// Configuration for text alignment
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct AlignmentConfig {
119    /// Enable fuzzy alignment when exact matching fails
120    pub enable_fuzzy_alignment: bool,
121    /// Minimum overlap ratio for fuzzy alignment (0.0 to 1.0)
122    pub fuzzy_alignment_threshold: f32,
123    /// Accept partial exact matches (MATCH_LESSER status)
124    pub accept_match_lesser: bool,
125    /// Case-sensitive matching
126    pub case_sensitive: bool,
127    /// Maximum search window size for fuzzy matching
128    pub max_search_window: usize,
129}
130
131/// Configuration for multi-pass extraction
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct MultiPassConfig {
134    /// Enable multi-pass extraction for improved recall
135    pub enable_multipass: bool,
136    /// Number of extraction passes to perform
137    pub max_passes: usize,
138    /// Minimum extraction count per chunk to avoid re-processing
139    pub min_extractions_per_chunk: usize,
140    /// Enable targeted re-processing of low-yield chunks
141    pub enable_targeted_reprocessing: bool,
142    /// Enable refinement passes using previous results
143    pub enable_refinement_passes: bool,
144    /// Minimum quality score to keep extractions (0.0 to 1.0)
145    pub quality_threshold: f32,
146    /// Maximum number of chunks to re-process per pass
147    pub max_reprocess_chunks: usize,
148    /// Temperature adjustment for subsequent passes
149    pub temperature_decay: f32,
150}
151
152/// Configuration for visualization and export
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct VisualizationConfig {
155    /// Default export format
156    pub default_format: ExportFormat,
157    /// Show character intervals in output
158    pub show_char_intervals: bool,
159    /// Include original text in export
160    pub include_text: bool,
161    /// Highlight extractions in text (for HTML/Markdown)
162    pub highlight_extractions: bool,
163    /// Include extraction statistics
164    pub include_statistics: bool,
165    /// Custom CSS for HTML export
166    pub custom_css: Option<String>,
167    /// Default title for exports
168    pub default_title: Option<String>,
169}
170
171/// Export format enumeration
172#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
173#[serde(rename_all = "snake_case")]
174pub enum ExportFormat {
175    Text,
176    Html,
177    Markdown,
178    Json,
179    Csv,
180}
181
182/// Configuration for language model inference
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct InferenceConfig {
185    /// Sampling temperature (0.0 to 1.0)
186    pub temperature: f32,
187    /// Maximum number of tokens to generate
188    pub max_tokens: Option<usize>,
189    /// Number of candidate outputs to generate
190    pub num_candidates: usize,
191    /// Stop sequences to halt generation
192    pub stop_sequences: Vec<String>,
193    /// Additional provider-specific parameters
194    pub extra_params: HashMap<String, serde_json::Value>,
195}
196
197/// Configuration for progress reporting
198#[derive(Clone)]
199pub struct ProgressConfig {
200    /// Progress handler for reporting extraction progress
201    pub handler: Option<Arc<dyn ProgressHandler>>,
202    /// Whether to show progress messages
203    pub show_progress: bool,
204    /// Whether to show debug information
205    pub show_debug: bool,
206    /// Whether to use emoji and colors in output
207    pub use_styling: bool,
208}
209
210impl Default for LangExtractConfig {
211    fn default() -> Self {
212        Self {
213            processing: ProcessingConfig::default(),
214            provider: ProviderConfig::ollama("mistral", None), // Safe default
215            validation: ValidationConfig::default(),
216            chunking: ChunkingConfig::default(),
217            alignment: AlignmentConfig::default(),
218            multipass: MultiPassConfig::default(),
219            visualization: VisualizationConfig::default(),
220            inference: InferenceConfig::default(),
221            progress: ProgressConfig::default(),
222        }
223    }
224}
225
226impl Default for ProcessingConfig {
227    fn default() -> Self {
228        Self {
229            format_type: FormatType::Json,
230            max_char_buffer: 8000,
231            batch_length: 4,
232            max_workers: 6,
233            additional_context: None,
234            debug: false,
235            extraction_passes: 1,
236            fence_output: None,
237            use_schema_constraints: true,
238            custom_params: HashMap::new(),
239        }
240    }
241}
242
243impl Default for ValidationConfig {
244    fn default() -> Self {
245        Self {
246            enable_schema_validation: true,
247            enable_type_coercion: true,
248            require_all_fields: false,
249            save_raw_outputs: true,
250            raw_outputs_dir: "./raw_outputs".to_string(),
251            quality_threshold: 0.0,
252        }
253    }
254}
255
256impl Default for ChunkingConfig {
257    fn default() -> Self {
258        Self {
259            strategy: ChunkingStrategy::Token,
260            target_size: 8000,
261            max_size: 10000,
262            overlap: 200,
263            min_size: 500,
264            preserve_sentences: true,
265            preserve_paragraphs: true,
266        }
267    }
268}
269
270impl Default for AlignmentConfig {
271    fn default() -> Self {
272        Self {
273            enable_fuzzy_alignment: true,
274            fuzzy_alignment_threshold: 0.4,
275            accept_match_lesser: true,
276            case_sensitive: false,
277            max_search_window: 100,
278        }
279    }
280}
281
282impl Default for MultiPassConfig {
283    fn default() -> Self {
284        Self {
285            enable_multipass: false,
286            max_passes: 2,
287            min_extractions_per_chunk: 1,
288            enable_targeted_reprocessing: true,
289            enable_refinement_passes: true,
290            quality_threshold: 0.3,
291            max_reprocess_chunks: 10,
292            temperature_decay: 0.9,
293        }
294    }
295}
296
297impl Default for VisualizationConfig {
298    fn default() -> Self {
299        Self {
300            default_format: ExportFormat::Text,
301            show_char_intervals: false,
302            include_text: true,
303            highlight_extractions: true,
304            include_statistics: true,
305            custom_css: None,
306            default_title: None,
307        }
308    }
309}
310
311impl Default for InferenceConfig {
312    fn default() -> Self {
313        Self {
314            temperature: 0.3,
315            max_tokens: None,
316            num_candidates: 1,
317            stop_sequences: vec![],
318            extra_params: HashMap::new(),
319        }
320    }
321}
322
323impl Default for ProgressConfig {
324    fn default() -> Self {
325        Self {
326            handler: None,
327            show_progress: true,
328            show_debug: false,
329            use_styling: true,
330        }
331    }
332}
333
334// Builder pattern implementation for easier configuration
335impl LangExtractConfig {
336    /// Create a new configuration with default values
337    pub fn new() -> Self {
338        Self::default()
339    }
340
341    /// Set the provider configuration
342    pub fn with_provider(mut self, provider: ProviderConfig) -> Self {
343        self.provider = provider;
344        self
345    }
346
347    /// Set the processing configuration
348    pub fn with_processing(mut self, processing: ProcessingConfig) -> Self {
349        self.processing = processing;
350        self
351    }
352
353    /// Set validation configuration
354    pub fn with_validation(mut self, validation: ValidationConfig) -> Self {
355        self.validation = validation;
356        self
357    }
358
359    /// Set chunking configuration
360    pub fn with_chunking(mut self, chunking: ChunkingConfig) -> Self {
361        self.chunking = chunking;
362        self
363    }
364
365    /// Set alignment configuration
366    pub fn with_alignment(mut self, alignment: AlignmentConfig) -> Self {
367        self.alignment = alignment;
368        self
369    }
370
371    /// Set multi-pass configuration
372    pub fn with_multipass(mut self, multipass: MultiPassConfig) -> Self {
373        self.multipass = multipass;
374        self
375    }
376
377    /// Set visualization configuration
378    pub fn with_visualization(mut self, visualization: VisualizationConfig) -> Self {
379        self.visualization = visualization;
380        self
381    }
382
383    /// Set inference configuration
384    pub fn with_inference(mut self, inference: InferenceConfig) -> Self {
385        self.inference = inference;
386        self
387    }
388
389    /// Set progress configuration
390    pub fn with_progress(mut self, progress: ProgressConfig) -> Self {
391        self.progress = progress;
392        self
393    }
394
395    /// Enable debug mode
396    pub fn with_debug(mut self, enabled: bool) -> Self {
397        self.processing.debug = enabled;
398        self.progress.show_debug = enabled;
399        self
400    }
401
402    /// Set maximum characters per chunk
403    pub fn with_max_char_buffer(mut self, size: usize) -> Self {
404        self.processing.max_char_buffer = size;
405        self.chunking.target_size = size;
406        self
407    }
408
409    /// Set the number of workers
410    pub fn with_workers(mut self, workers: usize) -> Self {
411        self.processing.max_workers = workers;
412        self
413    }
414
415    /// Set temperature for inference
416    pub fn with_temperature(mut self, temperature: f32) -> Self {
417        self.inference.temperature = temperature.clamp(0.0, 1.0);
418        self
419    }
420
421    /// Enable multi-pass extraction
422    pub fn with_multipass_enabled(mut self, enabled: bool) -> Self {
423        self.multipass.enable_multipass = enabled;
424        self
425    }
426
427    /// Set progress handler
428    pub fn with_progress_handler(mut self, handler: Arc<dyn ProgressHandler>) -> Self {
429        self.progress.handler = Some(handler);
430        self
431    }
432
433    /// Enable quiet mode (no progress output)
434    pub fn with_quiet_mode(mut self) -> Self {
435        self.progress.show_progress = false;
436        self.progress.show_debug = false;
437        self
438    }
439
440    /// Enable verbose mode (show all output)
441    pub fn with_verbose_mode(mut self) -> Self {
442        self.progress.show_progress = true;
443        self.progress.show_debug = true;
444        self
445    }
446}
447
448// Specialized builder methods for common configurations
449impl LangExtractConfig {
450    /// Create a configuration optimized for OpenAI
451    pub fn for_openai(model: &str, api_key: Option<String>) -> Self {
452        Self::new()
453            .with_provider(ProviderConfig::openai(model, api_key))
454            .with_inference(InferenceConfig {
455                temperature: 0.2,
456                max_tokens: Some(2000),
457                ..Default::default()
458            })
459    }
460
461    /// Create a configuration optimized for Ollama
462    pub fn for_ollama(model: &str, base_url: Option<String>) -> Self {
463        Self::new()
464            .with_provider(ProviderConfig::ollama(model, base_url))
465            .with_inference(InferenceConfig {
466                temperature: 0.3,
467                max_tokens: Some(1500),
468                ..Default::default()
469            })
470            .with_chunking(ChunkingConfig {
471                target_size: 6000, // Smaller chunks for local models
472                max_size: 8000,
473                ..Default::default()
474            })
475    }
476
477    /// Create a configuration for high-performance processing
478    pub fn for_high_performance() -> Self {
479        Self::new()
480            .with_processing(ProcessingConfig {
481                max_workers: 12,
482                batch_length: 8,
483                max_char_buffer: 10000,
484                ..Default::default()
485            })
486            .with_multipass(MultiPassConfig {
487                enable_multipass: true,
488                max_passes: 3,
489                ..Default::default()
490            })
491    }
492
493    /// Create a configuration for memory-efficient processing
494    pub fn for_memory_efficient() -> Self {
495        Self::new()
496            .with_processing(ProcessingConfig {
497                max_workers: 4,
498                batch_length: 2,
499                max_char_buffer: 6000,
500                ..Default::default()
501            })
502            .with_chunking(ChunkingConfig {
503                target_size: 4000,
504                max_size: 6000,
505                overlap: 100,
506                ..Default::default()
507            })
508    }
509}
510
511// Conversion traits for backward compatibility
512impl From<LangExtractConfig> for crate::ExtractConfig {
513    fn from(config: LangExtractConfig) -> Self {
514        let provider_config_value = serde_json::to_value(&config.provider).unwrap_or_default();
515        
516        Self {
517            model_id: config.provider.model.clone(),
518            api_key: config.provider.api_key.clone(),
519            format_type: config.processing.format_type,
520            max_char_buffer: config.processing.max_char_buffer,
521            temperature: config.inference.temperature,
522            fence_output: config.processing.fence_output,
523            use_schema_constraints: config.processing.use_schema_constraints,
524            batch_length: config.processing.batch_length,
525            max_workers: config.processing.max_workers,
526            additional_context: config.processing.additional_context.clone(),
527            resolver_params: HashMap::new(), // Legacy field
528            language_model_params: {
529                let mut params = HashMap::new();
530                params.insert("provider_config".to_string(), provider_config_value);
531                params
532            },
533            debug: config.processing.debug,
534            model_url: Some(config.provider.base_url.clone()),
535            extraction_passes: config.processing.extraction_passes,
536            enable_multipass: config.multipass.enable_multipass,
537            multipass_min_extractions: config.multipass.min_extractions_per_chunk,
538            multipass_quality_threshold: config.multipass.quality_threshold,
539            progress_handler: config.progress.handler,
540        }
541    }
542}
543
544impl std::fmt::Debug for LangExtractConfig {
545    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546        f.debug_struct("LangExtractConfig")
547            .field("processing", &self.processing)
548            .field("provider", &self.provider)
549            .field("validation", &self.validation)
550            .field("chunking", &self.chunking)
551            .field("alignment", &self.alignment)
552            .field("multipass", &self.multipass)
553            .field("visualization", &self.visualization)
554            .field("inference", &self.inference)
555            .field("progress", &"<ProgressConfig>")
556            .finish()
557    }
558}
559
560impl std::fmt::Debug for ProgressConfig {
561    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562        f.debug_struct("ProgressConfig")
563            .field("handler", &"<ProgressHandler>")
564            .field("show_progress", &self.show_progress)
565            .field("show_debug", &self.show_debug)
566            .field("use_styling", &self.use_styling)
567            .finish()
568    }
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    #[test]
576    fn test_default_config() {
577        let config = LangExtractConfig::default();
578        assert_eq!(config.processing.format_type, FormatType::Json);
579        assert_eq!(config.processing.max_char_buffer, 8000);
580        assert_eq!(config.chunking.strategy, ChunkingStrategy::Token);
581    }
582
583    #[test]
584    fn test_builder_pattern() {
585        let config = LangExtractConfig::new()
586            .with_debug(true)
587            .with_max_char_buffer(10000)
588            .with_workers(8)
589            .with_temperature(0.5);
590
591        assert!(config.processing.debug);
592        assert_eq!(config.processing.max_char_buffer, 10000);
593        assert_eq!(config.processing.max_workers, 8);
594        assert_eq!(config.inference.temperature, 0.5);
595    }
596
597    #[test]
598    fn test_specialized_configs() {
599        use crate::providers::ProviderType;
600        
601        let openai_config = LangExtractConfig::for_openai("gpt-4o", Some("test-key".to_string()));
602        assert_eq!(openai_config.provider.provider_type, ProviderType::OpenAI);
603        assert_eq!(openai_config.inference.temperature, 0.2);
604
605        let ollama_config = LangExtractConfig::for_ollama("mistral", None);
606        assert_eq!(ollama_config.provider.provider_type, ProviderType::Ollama);
607        assert_eq!(ollama_config.chunking.target_size, 6000);
608
609        let hp_config = LangExtractConfig::for_high_performance();
610        assert_eq!(hp_config.processing.max_workers, 12);
611        assert!(hp_config.multipass.enable_multipass);
612    }
613
614    #[test]
615    fn test_backward_compatibility() {
616        let new_config = LangExtractConfig::for_ollama("mistral", None)
617            .with_debug(true)
618            .with_temperature(0.4);
619
620        let old_config: crate::ExtractConfig = new_config.into();
621        assert_eq!(old_config.model_id, "mistral");
622        assert!(old_config.debug);
623        assert_eq!(old_config.temperature, 0.4);
624    }
625
626    #[test]
627    fn test_serialization() {
628        let config = LangExtractConfig::for_openai("gpt-4o", Some("test-key".to_string()));
629        let serialized = serde_json::to_string(&config).unwrap();
630        let deserialized: LangExtractConfig = serde_json::from_str(&serialized).unwrap();
631        
632        assert_eq!(config.provider.model, deserialized.provider.model);
633        assert_eq!(config.processing.format_type, deserialized.processing.format_type);
634    }
635}