1pub mod data;
48pub mod exceptions;
49pub mod schema;
50
51pub mod alignment;
53pub mod annotation;
54pub mod chunking;
55pub mod inference;
56pub mod multipass;
57pub mod tokenizer;
58
59pub mod providers;
61pub mod factory;
62
63pub mod io;
65pub mod progress;
66pub mod prompting;
67pub mod resolver;
68pub mod visualization;
69
70pub use data::{
72 AlignmentStatus, AnnotatedDocument, CharInterval, Document, ExampleData, Extraction,
73 FormatType,
74};
75pub use exceptions::{LangExtractError, LangExtractResult};
76pub use inference::{BaseLanguageModel, ScoredOutput};
77pub use providers::{ProviderConfig, ProviderType, UniversalProvider};
78pub use resolver::{ValidationConfig, ValidationResult, ValidationError, ValidationWarning, CoercionSummary, CoercionDetail, CoercionTargetType};
79pub use visualization::{ExportFormat, ExportConfig, export_document};
80
81use std::collections::HashMap;
82
83#[derive(Debug, Clone)]
85pub struct ExtractConfig {
86 pub model_id: String,
88 pub api_key: Option<String>,
90 pub format_type: FormatType,
92 pub max_char_buffer: usize,
94 pub temperature: f32,
96 pub fence_output: Option<bool>,
98 pub use_schema_constraints: bool,
100 pub batch_length: usize,
102 pub max_workers: usize,
104 pub additional_context: Option<String>,
106 pub resolver_params: HashMap<String, serde_json::Value>,
108 pub language_model_params: HashMap<String, serde_json::Value>,
110 pub debug: bool,
112 pub model_url: Option<String>,
114 pub extraction_passes: usize,
116 pub enable_multipass: bool,
118 pub multipass_min_extractions: usize,
120 pub multipass_quality_threshold: f32,
122}
123
124impl Default for ExtractConfig {
125 fn default() -> Self {
126 Self {
127 model_id: "gemini-2.5-flash".to_string(),
128 api_key: None,
129 format_type: FormatType::Json,
130 max_char_buffer: 1000,
131 temperature: 0.5,
132 fence_output: None,
133 use_schema_constraints: true,
134 batch_length: 10,
135 max_workers: 10,
136 additional_context: None,
137 resolver_params: HashMap::new(),
138 language_model_params: HashMap::new(),
139 debug: true,
140 model_url: None,
141 extraction_passes: 1,
142 enable_multipass: false,
143 multipass_min_extractions: 1,
144 multipass_quality_threshold: 0.3,
145 }
146 }
147}
148
149pub async fn extract(
173 text_or_documents: &str,
174 prompt_description: Option<&str>,
175 examples: &[ExampleData],
176 config: ExtractConfig,
177) -> LangExtractResult<AnnotatedDocument> {
178 if examples.is_empty() {
180 return Err(LangExtractError::InvalidInput(
181 "Examples are required for reliable extraction. Please provide at least one ExampleData object with sample extractions.".to_string()
182 ));
183 }
184
185 if config.batch_length < config.max_workers {
186 log::warn!(
187 "batch_length ({}) < max_workers ({}). Only {} workers will be used. Set batch_length >= max_workers for optimal parallelization.",
188 config.batch_length,
189 config.max_workers,
190 config.batch_length
191 );
192 }
193
194 dotenvy::dotenv().ok();
196
197 let text = if io::is_url(text_or_documents) {
199 io::download_text_from_url(text_or_documents).await?
200 } else {
201 text_or_documents.to_string()
202 };
203
204 let mut prompt_template = prompting::PromptTemplateStructured::new(prompt_description);
206 prompt_template.examples.extend(examples.iter().cloned());
207
208 let language_model = factory::create_model(&config, Some(&prompt_template.examples)).await?;
210
211 let resolver = resolver::Resolver::new(&config, language_model.requires_fence_output())?;
213
214 let annotator = annotation::Annotator::new(
216 language_model,
217 prompt_template,
218 config.format_type,
219 resolver.fence_output(),
220 );
221
222 if config.enable_multipass && config.extraction_passes > 1 {
224 let multipass_config = multipass::MultiPassConfig {
226 max_passes: config.extraction_passes,
227 min_extractions_per_chunk: config.multipass_min_extractions,
228 enable_targeted_reprocessing: true,
229 enable_refinement_passes: true,
230 quality_threshold: config.multipass_quality_threshold,
231 max_reprocess_chunks: 10,
232 temperature_decay: 0.9,
233 };
234
235 let processor = multipass::MultiPassProcessor::new(
236 multipass_config,
237 annotator,
238 resolver,
239 );
240
241 let (result, _stats) = processor.extract_multipass(
242 &text,
243 config.additional_context.as_deref(),
244 config.debug,
245 ).await?;
246
247 if config.debug {
248 println!("🎯 Multi-pass extraction completed with {} total extractions",
249 result.extraction_count());
250 }
251
252 Ok(result)
253 } else {
254 annotator
256 .annotate_text(
257 &text,
258 &resolver,
259 config.max_char_buffer,
260 config.batch_length,
261 config.additional_context.as_deref(),
262 config.debug,
263 config.extraction_passes,
264 config.max_workers,
265 )
266 .await
267 }
268}
269
270pub fn visualize(
272 annotated_document: &AnnotatedDocument,
273 show_char_intervals: bool,
274) -> LangExtractResult<String> {
275 visualization::visualize(annotated_document, show_char_intervals)
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn test_extract_config_default() {
284 let config = ExtractConfig::default();
285 assert_eq!(config.model_id, "gemini-2.5-flash");
286 assert_eq!(config.format_type, FormatType::Json);
287 assert_eq!(config.max_char_buffer, 1000);
288 assert_eq!(config.temperature, 0.5);
289 }
290
291 #[test]
292 fn test_extraction_validation() {
293 let examples: Vec<ExampleData> = vec![];
294 let config = ExtractConfig::default();
295
296 tokio_test::block_on(async {
297 let result = extract("test text", None, &examples, config).await;
298 assert!(result.is_err());
299 match result.err().unwrap() {
300 LangExtractError::InvalidInput(msg) => {
301 assert!(msg.contains("Examples are required"));
302 }
303 _ => panic!("Expected InvalidInput error"),
304 }
305 });
306 }
307}