1pub mod config;
48pub mod data;
49pub mod exceptions;
50pub mod schema;
51
52pub mod alignment;
54pub mod annotation;
55pub mod chunking;
56pub mod inference;
57pub mod multipass;
58pub mod tokenizer;
59
60pub mod providers;
62pub mod factory;
63
64pub mod http_client;
66pub mod io;
67pub mod logging;
68pub mod pipeline;
69pub mod progress;
70pub mod prompting;
71pub mod resolver;
72pub mod templates;
73pub mod visualization;
74
75pub use config::{
77 LangExtractConfig, ProcessingConfig, ValidationConfig as NewValidationConfig,
78 ChunkingConfig, AlignmentConfig as NewAlignmentConfig, MultiPassConfig as NewMultiPassConfig,
79 VisualizationConfig, InferenceConfig as NewInferenceConfig, ProgressConfig,
80 ChunkingStrategy, ExportFormat as NewExportFormat
81};
82pub use data::{
83 AlignmentStatus, AnnotatedDocument, CharInterval, Document, ExampleData, Extraction,
84 FormatType,
85};
86pub use exceptions::{LangExtractError, LangExtractResult};
87pub use inference::{BaseLanguageModel, ScoredOutput};
88pub use logging::{ProgressHandler, ProgressEvent, ConsoleProgressHandler, SilentProgressHandler, LogProgressHandler};
89pub use providers::{ProviderConfig, ProviderType, UniversalProvider};
90pub use resolver::{ValidationConfig, ValidationResult, ValidationError, ValidationWarning, CoercionSummary, CoercionDetail, CoercionTargetType};
91pub use visualization::{ExportFormat, ExportConfig, export_document};
92pub use pipeline::{PipelineConfig, PipelineStep, PipelineResult, PipelineExecutor};
93
94use serde::{Deserialize, Serialize};
95use std::collections::HashMap;
96
97#[derive(Clone, Serialize, Deserialize)]
99pub struct ExtractConfig {
100 pub model_id: String,
102 pub api_key: Option<String>,
104 pub format_type: FormatType,
106 pub max_char_buffer: usize,
108 pub temperature: f32,
110 pub fence_output: Option<bool>,
112 pub use_schema_constraints: bool,
114 pub batch_length: usize,
116 pub max_workers: usize,
118 pub additional_context: Option<String>,
120 pub resolver_params: HashMap<String, serde_json::Value>,
122 pub language_model_params: HashMap<String, serde_json::Value>,
124 pub debug: bool,
126 pub model_url: Option<String>,
128 pub extraction_passes: usize,
130 pub enable_multipass: bool,
132 pub multipass_min_extractions: usize,
134 pub multipass_quality_threshold: f32,
136 #[serde(skip)]
138 pub progress_handler: Option<std::sync::Arc<dyn ProgressHandler>>,
139}
140
141impl Default for ExtractConfig {
142 fn default() -> Self {
143 Self {
144 model_id: "gemini-2.5-flash".to_string(),
145 api_key: None,
146 format_type: FormatType::Json,
147 max_char_buffer: 1000,
148 temperature: 0.5,
149 fence_output: None,
150 use_schema_constraints: true,
151 batch_length: 10,
152 max_workers: 10,
153 additional_context: None,
154 resolver_params: HashMap::new(),
155 language_model_params: HashMap::new(),
156 debug: true,
157 model_url: None,
158 extraction_passes: 1,
159 enable_multipass: false,
160 multipass_min_extractions: 1,
161 multipass_quality_threshold: 0.3,
162 progress_handler: None,
163 }
164 }
165}
166
167impl std::fmt::Debug for ExtractConfig {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 f.debug_struct("ExtractConfig")
170 .field("model_id", &self.model_id)
171 .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
172 .field("format_type", &self.format_type)
173 .field("max_char_buffer", &self.max_char_buffer)
174 .field("temperature", &self.temperature)
175 .field("fence_output", &self.fence_output)
176 .field("use_schema_constraints", &self.use_schema_constraints)
177 .field("batch_length", &self.batch_length)
178 .field("max_workers", &self.max_workers)
179 .field("additional_context", &self.additional_context)
180 .field("resolver_params", &self.resolver_params)
181 .field("language_model_params", &self.language_model_params)
182 .field("debug", &self.debug)
183 .field("model_url", &self.model_url)
184 .field("extraction_passes", &self.extraction_passes)
185 .field("enable_multipass", &self.enable_multipass)
186 .field("multipass_min_extractions", &self.multipass_min_extractions)
187 .field("multipass_quality_threshold", &self.multipass_quality_threshold)
188 .field("progress_handler", &"<ProgressHandler>")
189 .finish()
190 }
191}
192
193impl ExtractConfig {
194 pub fn with_progress_handler(mut self, handler: std::sync::Arc<dyn ProgressHandler>) -> Self {
196 self.progress_handler = Some(handler);
197 self
198 }
199
200 pub fn with_console_progress(mut self) -> Self {
202 self.progress_handler = Some(std::sync::Arc::new(ConsoleProgressHandler::new()));
203 self
204 }
205
206 pub fn with_quiet_mode(mut self) -> Self {
208 self.progress_handler = Some(std::sync::Arc::new(SilentProgressHandler));
209 self
210 }
211
212 pub fn with_verbose_progress(mut self) -> Self {
214 self.progress_handler = Some(std::sync::Arc::new(ConsoleProgressHandler::verbose()));
215 self
216 }
217}
218
219pub async fn extract_with_config(
221 text_or_documents: &str,
222 prompt_description: Option<&str>,
223 examples: &[ExampleData],
224 config: LangExtractConfig,
225) -> LangExtractResult<AnnotatedDocument> {
226 let legacy_config: ExtractConfig = config.into();
228 extract(text_or_documents, prompt_description, examples, legacy_config).await
229}
230
231pub async fn extract(
255 text_or_documents: &str,
256 prompt_description: Option<&str>,
257 examples: &[ExampleData],
258 config: ExtractConfig,
259) -> LangExtractResult<AnnotatedDocument> {
260 if examples.is_empty() {
262 return Err(LangExtractError::InvalidInput(
263 "Examples are required for reliable extraction. Please provide at least one ExampleData object with sample extractions.".to_string()
264 ));
265 }
266
267 if config.batch_length < config.max_workers {
268 log::warn!(
269 "batch_length ({}) < max_workers ({}). Only {} workers will be used. Set batch_length >= max_workers for optimal parallelization.",
270 config.batch_length,
271 config.max_workers,
272 config.batch_length
273 );
274 }
275
276 dotenvy::dotenv().ok();
278
279 if let Some(handler) = &config.progress_handler {
281 logging::init_progress_handler(handler.clone());
282 } else {
283 let default_handler: std::sync::Arc<dyn ProgressHandler> = if config.debug {
285 std::sync::Arc::new(ConsoleProgressHandler::new())
286 } else {
287 std::sync::Arc::new(SilentProgressHandler)
288 };
289 logging::init_progress_handler(default_handler);
290 }
291
292 let text = if io::is_url(text_or_documents) {
294 io::download_text_from_url(text_or_documents).await?
295 } else {
296 text_or_documents.to_string()
297 };
298
299 let mut prompt_template = prompting::PromptTemplateStructured::new(prompt_description);
301 prompt_template.examples.extend(examples.iter().cloned());
302
303 let language_model = factory::create_model(&config, Some(&prompt_template.examples)).await?;
305
306 let resolver = resolver::Resolver::new(&config, language_model.requires_fence_output())?;
308
309 let annotator = annotation::Annotator::new(
311 language_model,
312 prompt_template,
313 config.format_type,
314 resolver.fence_output(),
315 );
316
317 if config.enable_multipass && config.extraction_passes > 1 {
319 let multipass_config = multipass::MultiPassConfig {
321 max_passes: config.extraction_passes,
322 min_extractions_per_chunk: config.multipass_min_extractions,
323 enable_targeted_reprocessing: true,
324 enable_refinement_passes: true,
325 quality_threshold: config.multipass_quality_threshold,
326 max_reprocess_chunks: 10,
327 temperature_decay: 0.9,
328 };
329
330 let processor = multipass::MultiPassProcessor::new(
331 multipass_config,
332 annotator,
333 resolver,
334 );
335
336 let (result, _stats) = processor.extract_multipass(
337 &text,
338 config.additional_context.as_deref(),
339 config.debug,
340 ).await?;
341
342 if config.debug {
343 println!("🎯 Multi-pass extraction completed with {} total extractions",
344 result.extraction_count());
345 }
346
347 Ok(result)
348 } else {
349 annotator
351 .annotate_text(
352 &text,
353 &resolver,
354 config.max_char_buffer,
355 config.batch_length,
356 config.additional_context.as_deref(),
357 config.debug,
358 config.extraction_passes,
359 config.max_workers,
360 )
361 .await
362 }
363}
364
365pub fn visualize(
367 annotated_document: &AnnotatedDocument,
368 show_char_intervals: bool,
369) -> LangExtractResult<String> {
370 visualization::visualize(annotated_document, show_char_intervals)
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_extract_config_default() {
379 let config = ExtractConfig::default();
380 assert_eq!(config.model_id, "gemini-2.5-flash");
381 assert_eq!(config.format_type, FormatType::Json);
382 assert_eq!(config.max_char_buffer, 1000);
383 assert_eq!(config.temperature, 0.5);
384 }
385
386 #[test]
387 fn test_extraction_validation() {
388 let examples: Vec<ExampleData> = vec![];
389 let config = ExtractConfig::default();
390
391 tokio_test::block_on(async {
392 let result = extract("test text", None, &examples, config).await;
393 assert!(result.is_err());
394 match result.err().unwrap() {
395 LangExtractError::InvalidInput(msg) => {
396 assert!(msg.contains("Examples are required"));
397 }
398 _ => panic!("Expected InvalidInput error"),
399 }
400 });
401 }
402}