1use crate::{
7 data::FormatType,
8 logging::ProgressHandler,
9 providers::ProviderConfig,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14
15#[derive(Clone, Serialize, Deserialize)]
17pub struct LangExtractConfig {
18 pub processing: ProcessingConfig,
20 pub provider: ProviderConfig,
22 pub validation: ValidationConfig,
24 pub chunking: ChunkingConfig,
26 pub alignment: AlignmentConfig,
28 pub multipass: MultiPassConfig,
30 pub visualization: VisualizationConfig,
32 pub inference: InferenceConfig,
34 #[serde(skip)]
36 pub progress: ProgressConfig,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ProcessingConfig {
42 pub format_type: FormatType,
44 pub max_char_buffer: usize,
46 pub batch_length: usize,
48 pub max_workers: usize,
50 pub additional_context: Option<String>,
52 pub debug: bool,
54 pub fence_output: Option<bool>,
56 pub use_schema_constraints: bool,
58 pub custom_params: HashMap<String, serde_json::Value>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ValidationConfig {
65 pub enable_schema_validation: bool,
67 pub enable_type_coercion: bool,
69 pub require_all_fields: bool,
71 pub save_raw_outputs: bool,
73 pub raw_outputs_dir: String,
75 pub quality_threshold: f32,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ChunkingConfig {
82 pub strategy: ChunkingStrategy,
84 pub target_size: usize,
86 pub max_size: usize,
88 pub overlap: usize,
90 pub min_size: usize,
92 pub preserve_sentences: bool,
94 pub preserve_paragraphs: bool,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
100#[serde(rename_all = "snake_case")]
101pub enum ChunkingStrategy {
102 Token,
104 Semantic,
106 Sentence,
108 Paragraph,
110 Fixed,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct AlignmentConfig {
117 pub enable_fuzzy_alignment: bool,
119 pub fuzzy_alignment_threshold: f32,
121 pub accept_match_lesser: bool,
123 pub case_sensitive: bool,
125 pub max_search_window: usize,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct MultiPassConfig {
132 pub enable_multipass: bool,
134 pub max_passes: usize,
136 pub min_extractions_per_chunk: usize,
138 pub enable_targeted_reprocessing: bool,
140 pub enable_refinement_passes: bool,
142 pub quality_threshold: f32,
144 pub max_reprocess_chunks: usize,
146 pub temperature_decay: f32,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct VisualizationConfig {
153 pub default_format: ExportFormat,
155 pub show_char_intervals: bool,
157 pub include_text: bool,
159 pub highlight_extractions: bool,
161 pub include_statistics: bool,
163 pub custom_css: Option<String>,
165 pub default_title: Option<String>,
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "snake_case")]
172pub enum ExportFormat {
173 Text,
174 Html,
175 Markdown,
176 Json,
177 Csv,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct InferenceConfig {
183 pub temperature: f32,
185 pub max_tokens: Option<usize>,
187 pub num_candidates: usize,
189 pub stop_sequences: Vec<String>,
191 pub extra_params: HashMap<String, serde_json::Value>,
193}
194
195#[derive(Clone)]
197pub struct ProgressConfig {
198 pub handler: Option<Arc<dyn ProgressHandler>>,
200 pub show_progress: bool,
202 pub show_debug: bool,
204 pub use_styling: bool,
206}
207
208impl Default for LangExtractConfig {
209 fn default() -> Self {
210 Self {
211 processing: ProcessingConfig::default(),
212 provider: ProviderConfig::ollama("mistral", None), validation: ValidationConfig::default(),
214 chunking: ChunkingConfig::default(),
215 alignment: AlignmentConfig::default(),
216 multipass: MultiPassConfig::default(),
217 visualization: VisualizationConfig::default(),
218 inference: InferenceConfig::default(),
219 progress: ProgressConfig::default(),
220 }
221 }
222}
223
224impl Default for ProcessingConfig {
225 fn default() -> Self {
226 Self {
227 format_type: FormatType::Json,
228 max_char_buffer: 8000,
229 batch_length: 4,
230 max_workers: 6,
231 additional_context: None,
232 debug: false,
233 fence_output: None,
234 use_schema_constraints: true,
235 custom_params: HashMap::new(),
236 }
237 }
238}
239
240impl Default for ValidationConfig {
241 fn default() -> Self {
242 Self {
243 enable_schema_validation: true,
244 enable_type_coercion: true,
245 require_all_fields: false,
246 save_raw_outputs: true,
247 raw_outputs_dir: "./raw_outputs".to_string(),
248 quality_threshold: 0.0,
249 }
250 }
251}
252
253impl Default for ChunkingConfig {
254 fn default() -> Self {
255 Self {
256 strategy: ChunkingStrategy::Token,
257 target_size: 8000,
258 max_size: 10000,
259 overlap: 200,
260 min_size: 500,
261 preserve_sentences: true,
262 preserve_paragraphs: true,
263 }
264 }
265}
266
267impl Default for AlignmentConfig {
268 fn default() -> Self {
269 Self {
270 enable_fuzzy_alignment: true,
271 fuzzy_alignment_threshold: 0.4,
272 accept_match_lesser: true,
273 case_sensitive: false,
274 max_search_window: 100,
275 }
276 }
277}
278
279impl Default for MultiPassConfig {
280 fn default() -> Self {
281 Self {
282 enable_multipass: false,
283 max_passes: 2,
284 min_extractions_per_chunk: 1,
285 enable_targeted_reprocessing: true,
286 enable_refinement_passes: true,
287 quality_threshold: 0.3,
288 max_reprocess_chunks: 10,
289 temperature_decay: 0.9,
290 }
291 }
292}
293
294impl Default for VisualizationConfig {
295 fn default() -> Self {
296 Self {
297 default_format: ExportFormat::Text,
298 show_char_intervals: false,
299 include_text: true,
300 highlight_extractions: true,
301 include_statistics: true,
302 custom_css: None,
303 default_title: None,
304 }
305 }
306}
307
308impl Default for InferenceConfig {
309 fn default() -> Self {
310 Self {
311 temperature: 0.3,
312 max_tokens: None,
313 num_candidates: 1,
314 stop_sequences: vec![],
315 extra_params: HashMap::new(),
316 }
317 }
318}
319
320impl Default for ProgressConfig {
321 fn default() -> Self {
322 Self {
323 handler: None,
324 show_progress: true,
325 show_debug: false,
326 use_styling: true,
327 }
328 }
329}
330
331impl LangExtractConfig {
333 pub fn new() -> Self {
335 Self::default()
336 }
337
338 pub fn with_provider(mut self, provider: ProviderConfig) -> Self {
340 self.provider = provider;
341 self
342 }
343
344 pub fn with_processing(mut self, processing: ProcessingConfig) -> Self {
346 self.processing = processing;
347 self
348 }
349
350 pub fn with_validation(mut self, validation: ValidationConfig) -> Self {
352 self.validation = validation;
353 self
354 }
355
356 pub fn with_chunking(mut self, chunking: ChunkingConfig) -> Self {
358 self.chunking = chunking;
359 self
360 }
361
362 pub fn with_alignment(mut self, alignment: AlignmentConfig) -> Self {
364 self.alignment = alignment;
365 self
366 }
367
368 pub fn with_multipass(mut self, multipass: MultiPassConfig) -> Self {
370 self.multipass = multipass;
371 self
372 }
373
374 pub fn with_visualization(mut self, visualization: VisualizationConfig) -> Self {
376 self.visualization = visualization;
377 self
378 }
379
380 pub fn with_inference(mut self, inference: InferenceConfig) -> Self {
382 self.inference = inference;
383 self
384 }
385
386 pub fn with_progress(mut self, progress: ProgressConfig) -> Self {
388 self.progress = progress;
389 self
390 }
391
392 pub fn with_debug(mut self, enabled: bool) -> Self {
394 self.processing.debug = enabled;
395 self.progress.show_debug = enabled;
396 self
397 }
398
399 pub fn with_max_char_buffer(mut self, size: usize) -> Self {
401 self.processing.max_char_buffer = size;
402 self.chunking.target_size = size;
403 self
404 }
405
406 pub fn with_workers(mut self, workers: usize) -> Self {
408 self.processing.max_workers = workers;
409 self
410 }
411
412 pub fn with_temperature(mut self, temperature: f32) -> Self {
414 self.inference.temperature = temperature.clamp(0.0, 1.0);
415 self
416 }
417
418 pub fn with_multipass_enabled(mut self, enabled: bool) -> Self {
420 self.multipass.enable_multipass = enabled;
421 self
422 }
423
424 pub fn with_progress_handler(mut self, handler: Arc<dyn ProgressHandler>) -> Self {
426 self.progress.handler = Some(handler);
427 self
428 }
429
430 pub fn with_quiet_mode(mut self) -> Self {
432 self.progress.show_progress = false;
433 self.progress.show_debug = false;
434 self
435 }
436
437 pub fn with_verbose_mode(mut self) -> Self {
439 self.progress.show_progress = true;
440 self.progress.show_debug = true;
441 self
442 }
443}
444
445impl LangExtractConfig {
447 pub fn for_openai(model: &str, api_key: Option<String>) -> Self {
449 Self::new()
450 .with_provider(ProviderConfig::openai(model, api_key))
451 .with_inference(InferenceConfig {
452 temperature: 0.2,
453 max_tokens: Some(2000),
454 ..Default::default()
455 })
456 }
457
458 pub fn for_ollama(model: &str, base_url: Option<String>) -> Self {
460 Self::new()
461 .with_provider(ProviderConfig::ollama(model, base_url))
462 .with_inference(InferenceConfig {
463 temperature: 0.3,
464 max_tokens: Some(1500),
465 ..Default::default()
466 })
467 .with_chunking(ChunkingConfig {
468 target_size: 6000, max_size: 8000,
470 ..Default::default()
471 })
472 }
473
474 pub fn for_high_performance() -> Self {
476 Self::new()
477 .with_processing(ProcessingConfig {
478 max_workers: 12,
479 batch_length: 8,
480 max_char_buffer: 10000,
481 ..Default::default()
482 })
483 .with_multipass(MultiPassConfig {
484 enable_multipass: true,
485 max_passes: 3,
486 ..Default::default()
487 })
488 }
489
490 pub fn for_memory_efficient() -> Self {
492 Self::new()
493 .with_processing(ProcessingConfig {
494 max_workers: 4,
495 batch_length: 2,
496 max_char_buffer: 6000,
497 ..Default::default()
498 })
499 .with_chunking(ChunkingConfig {
500 target_size: 4000,
501 max_size: 6000,
502 overlap: 100,
503 ..Default::default()
504 })
505 }
506}
507
508impl From<LangExtractConfig> for crate::ExtractConfig {
510 fn from(config: LangExtractConfig) -> Self {
511 let provider_config_value = serde_json::to_value(&config.provider).unwrap_or_default();
512
513 Self {
514 model_id: config.provider.model.clone(),
515 api_key: config.provider.api_key.clone(),
516 format_type: config.processing.format_type,
517 max_char_buffer: config.processing.max_char_buffer,
518 temperature: config.inference.temperature,
519 fence_output: config.processing.fence_output,
520 use_schema_constraints: config.processing.use_schema_constraints,
521 batch_length: config.processing.batch_length,
522 max_workers: config.processing.max_workers,
523 additional_context: config.processing.additional_context.clone(),
524 resolver_params: HashMap::new(), language_model_params: {
526 let mut params = HashMap::new();
527 params.insert("provider_config".to_string(), provider_config_value);
528 params
529 },
530 debug: config.processing.debug,
531 model_url: Some(config.provider.base_url.clone()),
532 enable_multipass: config.multipass.enable_multipass,
533 multipass_max_passes: config.multipass.max_passes,
534 multipass_min_extractions: config.multipass.min_extractions_per_chunk,
535 multipass_quality_threshold: config.multipass.quality_threshold,
536 progress_handler: config.progress.handler,
537 }
538 }
539}
540
541impl std::fmt::Debug for LangExtractConfig {
542 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
543 f.debug_struct("LangExtractConfig")
544 .field("processing", &self.processing)
545 .field("provider", &self.provider)
546 .field("validation", &self.validation)
547 .field("chunking", &self.chunking)
548 .field("alignment", &self.alignment)
549 .field("multipass", &self.multipass)
550 .field("visualization", &self.visualization)
551 .field("inference", &self.inference)
552 .field("progress", &"<ProgressConfig>")
553 .finish()
554 }
555}
556
557impl std::fmt::Debug for ProgressConfig {
558 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
559 f.debug_struct("ProgressConfig")
560 .field("handler", &"<ProgressHandler>")
561 .field("show_progress", &self.show_progress)
562 .field("show_debug", &self.show_debug)
563 .field("use_styling", &self.use_styling)
564 .finish()
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn test_default_config() {
574 let config = LangExtractConfig::default();
575 assert_eq!(config.processing.format_type, FormatType::Json);
576 assert_eq!(config.processing.max_char_buffer, 8000);
577 assert_eq!(config.chunking.strategy, ChunkingStrategy::Token);
578 }
579
580 #[test]
581 fn test_builder_pattern() {
582 let config = LangExtractConfig::new()
583 .with_debug(true)
584 .with_max_char_buffer(10000)
585 .with_workers(8)
586 .with_temperature(0.5);
587
588 assert!(config.processing.debug);
589 assert_eq!(config.processing.max_char_buffer, 10000);
590 assert_eq!(config.processing.max_workers, 8);
591 assert_eq!(config.inference.temperature, 0.5);
592 }
593
594 #[test]
595 fn test_specialized_configs() {
596 use crate::providers::ProviderType;
597
598 let openai_config = LangExtractConfig::for_openai("gpt-4o", Some("test-key".to_string()));
599 assert_eq!(openai_config.provider.provider_type, ProviderType::OpenAI);
600 assert_eq!(openai_config.inference.temperature, 0.2);
601
602 let ollama_config = LangExtractConfig::for_ollama("mistral", None);
603 assert_eq!(ollama_config.provider.provider_type, ProviderType::Ollama);
604 assert_eq!(ollama_config.chunking.target_size, 6000);
605
606 let hp_config = LangExtractConfig::for_high_performance();
607 assert_eq!(hp_config.processing.max_workers, 12);
608 assert!(hp_config.multipass.enable_multipass);
609 }
610
611 #[test]
612 fn test_backward_compatibility() {
613 let new_config = LangExtractConfig::for_ollama("mistral", None)
614 .with_debug(true)
615 .with_temperature(0.4);
616
617 let old_config: crate::ExtractConfig = new_config.into();
618 assert_eq!(old_config.model_id, "mistral");
619 assert!(old_config.debug);
620 assert_eq!(old_config.temperature, 0.4);
621 }
622
623 #[test]
624 fn test_serialization() {
625 let config = LangExtractConfig::for_openai("gpt-4o", Some("test-key".to_string()));
626 let serialized = serde_json::to_string(&config).unwrap();
627 let deserialized: LangExtractConfig = serde_json::from_str(&serialized).unwrap();
628
629 assert_eq!(config.provider.model, deserialized.provider.model);
630 assert_eq!(config.processing.format_type, deserialized.processing.format_type);
631 }
632}