1use crate::{
7 data::FormatType,
8 logging::ProgressHandler,
9 providers::ProviderConfig,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14
15#[derive(Clone, Serialize, Deserialize)]
17pub struct LangExtractConfig {
18 pub processing: ProcessingConfig,
20 pub provider: ProviderConfig,
22 pub validation: ValidationConfig,
24 pub chunking: ChunkingConfig,
26 pub alignment: AlignmentConfig,
28 pub multipass: MultiPassConfig,
30 pub visualization: VisualizationConfig,
32 pub inference: InferenceConfig,
34 #[serde(skip)]
36 pub progress: ProgressConfig,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ProcessingConfig {
42 pub format_type: FormatType,
44 pub max_char_buffer: usize,
46 pub batch_length: usize,
48 pub max_workers: usize,
50 pub additional_context: Option<String>,
52 pub debug: bool,
54 pub extraction_passes: usize,
56 pub fence_output: Option<bool>,
58 pub use_schema_constraints: bool,
60 pub custom_params: HashMap<String, serde_json::Value>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ValidationConfig {
67 pub enable_schema_validation: bool,
69 pub enable_type_coercion: bool,
71 pub require_all_fields: bool,
73 pub save_raw_outputs: bool,
75 pub raw_outputs_dir: String,
77 pub quality_threshold: f32,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ChunkingConfig {
84 pub strategy: ChunkingStrategy,
86 pub target_size: usize,
88 pub max_size: usize,
90 pub overlap: usize,
92 pub min_size: usize,
94 pub preserve_sentences: bool,
96 pub preserve_paragraphs: bool,
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
102#[serde(rename_all = "snake_case")]
103pub enum ChunkingStrategy {
104 Token,
106 Semantic,
108 Sentence,
110 Paragraph,
112 Fixed,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct AlignmentConfig {
119 pub enable_fuzzy_alignment: bool,
121 pub fuzzy_alignment_threshold: f32,
123 pub accept_match_lesser: bool,
125 pub case_sensitive: bool,
127 pub max_search_window: usize,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct MultiPassConfig {
134 pub enable_multipass: bool,
136 pub max_passes: usize,
138 pub min_extractions_per_chunk: usize,
140 pub enable_targeted_reprocessing: bool,
142 pub enable_refinement_passes: bool,
144 pub quality_threshold: f32,
146 pub max_reprocess_chunks: usize,
148 pub temperature_decay: f32,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct VisualizationConfig {
155 pub default_format: ExportFormat,
157 pub show_char_intervals: bool,
159 pub include_text: bool,
161 pub highlight_extractions: bool,
163 pub include_statistics: bool,
165 pub custom_css: Option<String>,
167 pub default_title: Option<String>,
169}
170
171#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
173#[serde(rename_all = "snake_case")]
174pub enum ExportFormat {
175 Text,
176 Html,
177 Markdown,
178 Json,
179 Csv,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct InferenceConfig {
185 pub temperature: f32,
187 pub max_tokens: Option<usize>,
189 pub num_candidates: usize,
191 pub stop_sequences: Vec<String>,
193 pub extra_params: HashMap<String, serde_json::Value>,
195}
196
197#[derive(Clone)]
199pub struct ProgressConfig {
200 pub handler: Option<Arc<dyn ProgressHandler>>,
202 pub show_progress: bool,
204 pub show_debug: bool,
206 pub use_styling: bool,
208}
209
210impl Default for LangExtractConfig {
211 fn default() -> Self {
212 Self {
213 processing: ProcessingConfig::default(),
214 provider: ProviderConfig::ollama("mistral", None), validation: ValidationConfig::default(),
216 chunking: ChunkingConfig::default(),
217 alignment: AlignmentConfig::default(),
218 multipass: MultiPassConfig::default(),
219 visualization: VisualizationConfig::default(),
220 inference: InferenceConfig::default(),
221 progress: ProgressConfig::default(),
222 }
223 }
224}
225
226impl Default for ProcessingConfig {
227 fn default() -> Self {
228 Self {
229 format_type: FormatType::Json,
230 max_char_buffer: 8000,
231 batch_length: 4,
232 max_workers: 6,
233 additional_context: None,
234 debug: false,
235 extraction_passes: 1,
236 fence_output: None,
237 use_schema_constraints: true,
238 custom_params: HashMap::new(),
239 }
240 }
241}
242
243impl Default for ValidationConfig {
244 fn default() -> Self {
245 Self {
246 enable_schema_validation: true,
247 enable_type_coercion: true,
248 require_all_fields: false,
249 save_raw_outputs: true,
250 raw_outputs_dir: "./raw_outputs".to_string(),
251 quality_threshold: 0.0,
252 }
253 }
254}
255
256impl Default for ChunkingConfig {
257 fn default() -> Self {
258 Self {
259 strategy: ChunkingStrategy::Token,
260 target_size: 8000,
261 max_size: 10000,
262 overlap: 200,
263 min_size: 500,
264 preserve_sentences: true,
265 preserve_paragraphs: true,
266 }
267 }
268}
269
270impl Default for AlignmentConfig {
271 fn default() -> Self {
272 Self {
273 enable_fuzzy_alignment: true,
274 fuzzy_alignment_threshold: 0.4,
275 accept_match_lesser: true,
276 case_sensitive: false,
277 max_search_window: 100,
278 }
279 }
280}
281
282impl Default for MultiPassConfig {
283 fn default() -> Self {
284 Self {
285 enable_multipass: false,
286 max_passes: 2,
287 min_extractions_per_chunk: 1,
288 enable_targeted_reprocessing: true,
289 enable_refinement_passes: true,
290 quality_threshold: 0.3,
291 max_reprocess_chunks: 10,
292 temperature_decay: 0.9,
293 }
294 }
295}
296
297impl Default for VisualizationConfig {
298 fn default() -> Self {
299 Self {
300 default_format: ExportFormat::Text,
301 show_char_intervals: false,
302 include_text: true,
303 highlight_extractions: true,
304 include_statistics: true,
305 custom_css: None,
306 default_title: None,
307 }
308 }
309}
310
311impl Default for InferenceConfig {
312 fn default() -> Self {
313 Self {
314 temperature: 0.3,
315 max_tokens: None,
316 num_candidates: 1,
317 stop_sequences: vec![],
318 extra_params: HashMap::new(),
319 }
320 }
321}
322
323impl Default for ProgressConfig {
324 fn default() -> Self {
325 Self {
326 handler: None,
327 show_progress: true,
328 show_debug: false,
329 use_styling: true,
330 }
331 }
332}
333
334impl LangExtractConfig {
336 pub fn new() -> Self {
338 Self::default()
339 }
340
341 pub fn with_provider(mut self, provider: ProviderConfig) -> Self {
343 self.provider = provider;
344 self
345 }
346
347 pub fn with_processing(mut self, processing: ProcessingConfig) -> Self {
349 self.processing = processing;
350 self
351 }
352
353 pub fn with_validation(mut self, validation: ValidationConfig) -> Self {
355 self.validation = validation;
356 self
357 }
358
359 pub fn with_chunking(mut self, chunking: ChunkingConfig) -> Self {
361 self.chunking = chunking;
362 self
363 }
364
365 pub fn with_alignment(mut self, alignment: AlignmentConfig) -> Self {
367 self.alignment = alignment;
368 self
369 }
370
371 pub fn with_multipass(mut self, multipass: MultiPassConfig) -> Self {
373 self.multipass = multipass;
374 self
375 }
376
377 pub fn with_visualization(mut self, visualization: VisualizationConfig) -> Self {
379 self.visualization = visualization;
380 self
381 }
382
383 pub fn with_inference(mut self, inference: InferenceConfig) -> Self {
385 self.inference = inference;
386 self
387 }
388
389 pub fn with_progress(mut self, progress: ProgressConfig) -> Self {
391 self.progress = progress;
392 self
393 }
394
395 pub fn with_debug(mut self, enabled: bool) -> Self {
397 self.processing.debug = enabled;
398 self.progress.show_debug = enabled;
399 self
400 }
401
402 pub fn with_max_char_buffer(mut self, size: usize) -> Self {
404 self.processing.max_char_buffer = size;
405 self.chunking.target_size = size;
406 self
407 }
408
409 pub fn with_workers(mut self, workers: usize) -> Self {
411 self.processing.max_workers = workers;
412 self
413 }
414
415 pub fn with_temperature(mut self, temperature: f32) -> Self {
417 self.inference.temperature = temperature.clamp(0.0, 1.0);
418 self
419 }
420
421 pub fn with_multipass_enabled(mut self, enabled: bool) -> Self {
423 self.multipass.enable_multipass = enabled;
424 self
425 }
426
427 pub fn with_progress_handler(mut self, handler: Arc<dyn ProgressHandler>) -> Self {
429 self.progress.handler = Some(handler);
430 self
431 }
432
433 pub fn with_quiet_mode(mut self) -> Self {
435 self.progress.show_progress = false;
436 self.progress.show_debug = false;
437 self
438 }
439
440 pub fn with_verbose_mode(mut self) -> Self {
442 self.progress.show_progress = true;
443 self.progress.show_debug = true;
444 self
445 }
446}
447
448impl LangExtractConfig {
450 pub fn for_openai(model: &str, api_key: Option<String>) -> Self {
452 Self::new()
453 .with_provider(ProviderConfig::openai(model, api_key))
454 .with_inference(InferenceConfig {
455 temperature: 0.2,
456 max_tokens: Some(2000),
457 ..Default::default()
458 })
459 }
460
461 pub fn for_ollama(model: &str, base_url: Option<String>) -> Self {
463 Self::new()
464 .with_provider(ProviderConfig::ollama(model, base_url))
465 .with_inference(InferenceConfig {
466 temperature: 0.3,
467 max_tokens: Some(1500),
468 ..Default::default()
469 })
470 .with_chunking(ChunkingConfig {
471 target_size: 6000, max_size: 8000,
473 ..Default::default()
474 })
475 }
476
477 pub fn for_high_performance() -> Self {
479 Self::new()
480 .with_processing(ProcessingConfig {
481 max_workers: 12,
482 batch_length: 8,
483 max_char_buffer: 10000,
484 ..Default::default()
485 })
486 .with_multipass(MultiPassConfig {
487 enable_multipass: true,
488 max_passes: 3,
489 ..Default::default()
490 })
491 }
492
493 pub fn for_memory_efficient() -> Self {
495 Self::new()
496 .with_processing(ProcessingConfig {
497 max_workers: 4,
498 batch_length: 2,
499 max_char_buffer: 6000,
500 ..Default::default()
501 })
502 .with_chunking(ChunkingConfig {
503 target_size: 4000,
504 max_size: 6000,
505 overlap: 100,
506 ..Default::default()
507 })
508 }
509}
510
511impl From<LangExtractConfig> for crate::ExtractConfig {
513 fn from(config: LangExtractConfig) -> Self {
514 let provider_config_value = serde_json::to_value(&config.provider).unwrap_or_default();
515
516 Self {
517 model_id: config.provider.model.clone(),
518 api_key: config.provider.api_key.clone(),
519 format_type: config.processing.format_type,
520 max_char_buffer: config.processing.max_char_buffer,
521 temperature: config.inference.temperature,
522 fence_output: config.processing.fence_output,
523 use_schema_constraints: config.processing.use_schema_constraints,
524 batch_length: config.processing.batch_length,
525 max_workers: config.processing.max_workers,
526 additional_context: config.processing.additional_context.clone(),
527 resolver_params: HashMap::new(), language_model_params: {
529 let mut params = HashMap::new();
530 params.insert("provider_config".to_string(), provider_config_value);
531 params
532 },
533 debug: config.processing.debug,
534 model_url: Some(config.provider.base_url.clone()),
535 extraction_passes: config.processing.extraction_passes,
536 enable_multipass: config.multipass.enable_multipass,
537 multipass_min_extractions: config.multipass.min_extractions_per_chunk,
538 multipass_quality_threshold: config.multipass.quality_threshold,
539 progress_handler: config.progress.handler,
540 }
541 }
542}
543
544impl std::fmt::Debug for LangExtractConfig {
545 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546 f.debug_struct("LangExtractConfig")
547 .field("processing", &self.processing)
548 .field("provider", &self.provider)
549 .field("validation", &self.validation)
550 .field("chunking", &self.chunking)
551 .field("alignment", &self.alignment)
552 .field("multipass", &self.multipass)
553 .field("visualization", &self.visualization)
554 .field("inference", &self.inference)
555 .field("progress", &"<ProgressConfig>")
556 .finish()
557 }
558}
559
560impl std::fmt::Debug for ProgressConfig {
561 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
562 f.debug_struct("ProgressConfig")
563 .field("handler", &"<ProgressHandler>")
564 .field("show_progress", &self.show_progress)
565 .field("show_debug", &self.show_debug)
566 .field("use_styling", &self.use_styling)
567 .finish()
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 #[test]
576 fn test_default_config() {
577 let config = LangExtractConfig::default();
578 assert_eq!(config.processing.format_type, FormatType::Json);
579 assert_eq!(config.processing.max_char_buffer, 8000);
580 assert_eq!(config.chunking.strategy, ChunkingStrategy::Token);
581 }
582
583 #[test]
584 fn test_builder_pattern() {
585 let config = LangExtractConfig::new()
586 .with_debug(true)
587 .with_max_char_buffer(10000)
588 .with_workers(8)
589 .with_temperature(0.5);
590
591 assert!(config.processing.debug);
592 assert_eq!(config.processing.max_char_buffer, 10000);
593 assert_eq!(config.processing.max_workers, 8);
594 assert_eq!(config.inference.temperature, 0.5);
595 }
596
597 #[test]
598 fn test_specialized_configs() {
599 use crate::providers::ProviderType;
600
601 let openai_config = LangExtractConfig::for_openai("gpt-4o", Some("test-key".to_string()));
602 assert_eq!(openai_config.provider.provider_type, ProviderType::OpenAI);
603 assert_eq!(openai_config.inference.temperature, 0.2);
604
605 let ollama_config = LangExtractConfig::for_ollama("mistral", None);
606 assert_eq!(ollama_config.provider.provider_type, ProviderType::Ollama);
607 assert_eq!(ollama_config.chunking.target_size, 6000);
608
609 let hp_config = LangExtractConfig::for_high_performance();
610 assert_eq!(hp_config.processing.max_workers, 12);
611 assert!(hp_config.multipass.enable_multipass);
612 }
613
614 #[test]
615 fn test_backward_compatibility() {
616 let new_config = LangExtractConfig::for_ollama("mistral", None)
617 .with_debug(true)
618 .with_temperature(0.4);
619
620 let old_config: crate::ExtractConfig = new_config.into();
621 assert_eq!(old_config.model_id, "mistral");
622 assert!(old_config.debug);
623 assert_eq!(old_config.temperature, 0.4);
624 }
625
626 #[test]
627 fn test_serialization() {
628 let config = LangExtractConfig::for_openai("gpt-4o", Some("test-key".to_string()));
629 let serialized = serde_json::to_string(&config).unwrap();
630 let deserialized: LangExtractConfig = serde_json::from_str(&serialized).unwrap();
631
632 assert_eq!(config.provider.model, deserialized.provider.model);
633 assert_eq!(config.processing.format_type, deserialized.processing.format_type);
634 }
635}