langextract_rust/
lib.rs

1//! # LangExtract
2//!
3//! A Rust library for extracting structured and grounded information from text using LLMs.
4//!
5//! This library provides a clean, async API for working with various language model providers
6//! to extract structured data from unstructured text.
7//!
8//! ## Features
9//!
10//! - Support for multiple LLM providers (Gemini, OpenAI, Ollama)
11//! - Async/await API for concurrent processing
12//! - Schema-driven extraction with validation
13//! - Text chunking and tokenization
14//! - Flexible output formats (JSON, YAML)
15//! - Built-in visualization and progress tracking
16//!
17//! ## Quick Start
18//!
19//! ```rust,no_run
20//! use langextract_rust::{extract, ExampleData, Extraction, FormatType};
21//!
22//! #[tokio::main]
23//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
24//!     let examples = vec![
25//!         ExampleData {
26//!             text: "John Doe is 30 years old".to_string(),
27//!             extractions: vec![
28//!                 Extraction::new("person".to_string(), "John Doe".to_string()),
29//!                 Extraction::new("age".to_string(), "30".to_string()),
30//!             ],
31//!         }
32//!     ];
33//!
34//!     let result = extract(
35//!         "Alice Smith is 25 years old and works as a doctor",
36//!         Some("Extract person names and ages from the text"),
37//!         &examples,
38//!         Default::default(),
39//!     ).await?;
40//!
41//!     println!("{:?}", result);
42//!     Ok(())
43//! }
44//! ```
45
46// Core modules
47pub mod config;
48pub mod data;
49pub mod exceptions;
50pub mod schema;
51
52// Processing modules
53pub mod alignment;
54pub mod annotation;
55pub mod chunking;
56pub mod inference;
57pub mod multipass;
58pub mod tokenizer;
59
60// Provider modules
61pub mod providers;
62pub mod factory;
63
64// Utility modules
65pub mod http_client;
66pub mod io;
67pub mod logging;
68pub mod pipeline;
69pub mod progress;
70pub mod prompting;
71pub mod resolver;
72pub mod templates;
73pub mod visualization;
74
75// Re-export key types for convenience
76pub use config::{
77    LangExtractConfig, ProcessingConfig, ValidationConfig as NewValidationConfig, 
78    ChunkingConfig, AlignmentConfig as NewAlignmentConfig, MultiPassConfig as NewMultiPassConfig, 
79    VisualizationConfig, InferenceConfig as NewInferenceConfig, ProgressConfig, 
80    ChunkingStrategy, ExportFormat as NewExportFormat
81};
82pub use data::{
83    AlignmentStatus, AnnotatedDocument, CharInterval, Document, ExampleData, Extraction,
84    FormatType,
85};
86pub use exceptions::{LangExtractError, LangExtractResult};
87pub use inference::{BaseLanguageModel, ScoredOutput};
88pub use logging::{ProgressHandler, ProgressEvent, ConsoleProgressHandler, SilentProgressHandler, LogProgressHandler};
89pub use providers::{ProviderConfig, ProviderType, UniversalProvider};
90pub use resolver::{ValidationConfig, ValidationResult, ValidationError, ValidationWarning, CoercionSummary, CoercionDetail, CoercionTargetType};
91pub use visualization::{ExportFormat, ExportConfig, export_document};
92pub use pipeline::{PipelineConfig, PipelineStep, PipelineResult, PipelineExecutor};
93
94use serde::{Deserialize, Serialize};
95use std::collections::HashMap;
96
97/// Configuration for the extract function
98#[derive(Clone, Serialize, Deserialize)]
99pub struct ExtractConfig {
100    /// The model ID to use (e.g., "gemini-2.5-flash", "gpt-4o")
101    pub model_id: String,
102    /// API key for the language model service
103    pub api_key: Option<String>,
104    /// Output format type
105    pub format_type: FormatType,
106    /// Maximum characters per chunk for processing
107    pub max_char_buffer: usize,
108    /// Sampling temperature (0.0 to 1.0)
109    pub temperature: f32,
110    /// Whether to wrap output in code fences
111    pub fence_output: Option<bool>,
112    /// Whether to use schema constraints
113    pub use_schema_constraints: bool,
114    /// Batch size for processing chunks
115    pub batch_length: usize,
116    /// Maximum number of concurrent workers
117    pub max_workers: usize,
118    /// Additional context for the prompt
119    pub additional_context: Option<String>,
120    /// Custom resolver parameters
121    pub resolver_params: HashMap<String, serde_json::Value>,
122    /// Custom language model parameters
123    pub language_model_params: HashMap<String, serde_json::Value>,
124    /// Enable debug mode
125    pub debug: bool,
126    /// Custom model URL for self-hosted models
127    pub model_url: Option<String>,
128    /// Number of extraction passes to improve recall
129    pub extraction_passes: usize,
130    /// Enable multi-pass extraction for improved recall
131    pub enable_multipass: bool,
132    /// Minimum extractions per chunk to avoid re-processing
133    pub multipass_min_extractions: usize,
134    /// Quality threshold for keeping extractions (0.0 to 1.0)
135    pub multipass_quality_threshold: f32,
136    /// Progress handler for reporting extraction progress (not serialized)
137    #[serde(skip)]
138    pub progress_handler: Option<std::sync::Arc<dyn ProgressHandler>>,
139}
140
141impl Default for ExtractConfig {
142    fn default() -> Self {
143        Self {
144            model_id: "gemini-2.5-flash".to_string(),
145            api_key: None,
146            format_type: FormatType::Json,
147            max_char_buffer: 1000,
148            temperature: 0.5,
149            fence_output: None,
150            use_schema_constraints: true,
151            batch_length: 10,
152            max_workers: 10,
153            additional_context: None,
154            resolver_params: HashMap::new(),
155            language_model_params: HashMap::new(),
156            debug: true,
157            model_url: None,
158            extraction_passes: 1,
159            enable_multipass: false,
160            multipass_min_extractions: 1,
161            multipass_quality_threshold: 0.3,
162            progress_handler: None,
163        }
164    }
165}
166
167impl std::fmt::Debug for ExtractConfig {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.debug_struct("ExtractConfig")
170            .field("model_id", &self.model_id)
171            .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
172            .field("format_type", &self.format_type)
173            .field("max_char_buffer", &self.max_char_buffer)
174            .field("temperature", &self.temperature)
175            .field("fence_output", &self.fence_output)
176            .field("use_schema_constraints", &self.use_schema_constraints)
177            .field("batch_length", &self.batch_length)
178            .field("max_workers", &self.max_workers)
179            .field("additional_context", &self.additional_context)
180            .field("resolver_params", &self.resolver_params)
181            .field("language_model_params", &self.language_model_params)
182            .field("debug", &self.debug)
183            .field("model_url", &self.model_url)
184            .field("extraction_passes", &self.extraction_passes)
185            .field("enable_multipass", &self.enable_multipass)
186            .field("multipass_min_extractions", &self.multipass_min_extractions)
187            .field("multipass_quality_threshold", &self.multipass_quality_threshold)
188            .field("progress_handler", &"<ProgressHandler>")
189            .finish()
190    }
191}
192
193impl ExtractConfig {
194    /// Set a progress handler for this configuration
195    pub fn with_progress_handler(mut self, handler: std::sync::Arc<dyn ProgressHandler>) -> Self {
196        self.progress_handler = Some(handler);
197        self
198    }
199
200    /// Enable console progress output with default settings
201    pub fn with_console_progress(mut self) -> Self {
202        self.progress_handler = Some(std::sync::Arc::new(ConsoleProgressHandler::new()));
203        self
204    }
205
206    /// Enable quiet mode (no progress output)
207    pub fn with_quiet_mode(mut self) -> Self {
208        self.progress_handler = Some(std::sync::Arc::new(SilentProgressHandler));
209        self
210    }
211
212    /// Enable verbose console output
213    pub fn with_verbose_progress(mut self) -> Self {
214        self.progress_handler = Some(std::sync::Arc::new(ConsoleProgressHandler::verbose()));
215        self
216    }
217}
218
219/// Convenient extraction function using the new unified configuration
220pub async fn extract_with_config(
221    text_or_documents: &str,
222    prompt_description: Option<&str>,
223    examples: &[ExampleData],
224    config: LangExtractConfig,
225) -> LangExtractResult<AnnotatedDocument> {
226    // Convert to legacy config for now
227    let legacy_config: ExtractConfig = config.into();
228    extract(text_or_documents, prompt_description, examples, legacy_config).await
229}
230
231/// Main extraction function that mirrors the Python API
232///
233/// Extracts structured information from text using a language model based on
234/// the provided examples and configuration.
235///
236/// # Arguments
237///
238/// * `text_or_documents` - The source text to extract information from, or a URL starting with http/https
239/// * `prompt_description` - Instructions for what to extract from the text
240/// * `examples` - Example data to guide the extraction
241/// * `config` - Configuration parameters for the extraction
242///
243/// # Returns
244///
245/// An `AnnotatedDocument` with the extracted information
246///
247/// # Errors
248///
249/// Returns an error if:
250/// * Examples are empty
251/// * No API key is provided
252/// * URL download fails
253/// * Language model inference fails
254pub async fn extract(
255    text_or_documents: &str,
256    prompt_description: Option<&str>,
257    examples: &[ExampleData],
258    config: ExtractConfig,
259) -> LangExtractResult<AnnotatedDocument> {
260    // Validate inputs
261    if examples.is_empty() {
262        return Err(LangExtractError::InvalidInput(
263            "Examples are required for reliable extraction. Please provide at least one ExampleData object with sample extractions.".to_string()
264        ));
265    }
266
267    if config.batch_length < config.max_workers {
268        log::warn!(
269            "batch_length ({}) < max_workers ({}). Only {} workers will be used. Set batch_length >= max_workers for optimal parallelization.",
270            config.batch_length,
271            config.max_workers,
272            config.batch_length
273        );
274    }
275
276    // Load environment variables
277    dotenvy::dotenv().ok();
278
279    // Initialize progress handler
280    if let Some(handler) = &config.progress_handler {
281        logging::init_progress_handler(handler.clone());
282    } else {
283        // Default to console handler if debug is enabled, otherwise silent
284        let default_handler: std::sync::Arc<dyn ProgressHandler> = if config.debug {
285            std::sync::Arc::new(ConsoleProgressHandler::new())
286        } else {
287            std::sync::Arc::new(SilentProgressHandler)
288        };
289        logging::init_progress_handler(default_handler);
290    }
291
292    // Handle URL input
293    let text = if io::is_url(text_or_documents) {
294        io::download_text_from_url(text_or_documents).await?
295    } else {
296        text_or_documents.to_string()
297    };
298
299    // Create prompt template
300    let mut prompt_template = prompting::PromptTemplateStructured::new(prompt_description);
301    prompt_template.examples.extend(examples.iter().cloned());
302
303    // Create language model
304    let language_model = factory::create_model(&config, Some(&prompt_template.examples)).await?;
305
306    // Create resolver
307    let resolver = resolver::Resolver::new(&config, language_model.requires_fence_output())?;
308
309    // Create annotator
310    let annotator = annotation::Annotator::new(
311        language_model,
312        prompt_template,
313        config.format_type,
314        resolver.fence_output(),
315    );
316
317    // Perform annotation - use multi-pass if enabled
318    if config.enable_multipass && config.extraction_passes > 1 {
319        // Use multi-pass extraction
320        let multipass_config = multipass::MultiPassConfig {
321            max_passes: config.extraction_passes,
322            min_extractions_per_chunk: config.multipass_min_extractions,
323            enable_targeted_reprocessing: true,
324            enable_refinement_passes: true,
325            quality_threshold: config.multipass_quality_threshold,
326            max_reprocess_chunks: 10,
327            temperature_decay: 0.9,
328        };
329
330        let processor = multipass::MultiPassProcessor::new(
331            multipass_config,
332            annotator,
333            resolver,
334        );
335
336        let (result, _stats) = processor.extract_multipass(
337            &text,
338            config.additional_context.as_deref(),
339            config.debug,
340        ).await?;
341
342        if config.debug {
343            println!("🎯 Multi-pass extraction completed with {} total extractions", 
344                result.extraction_count());
345        }
346
347        Ok(result)
348    } else {
349        // Use single-pass extraction
350        annotator
351            .annotate_text(
352                &text,
353                &resolver,
354                config.max_char_buffer,
355                config.batch_length,
356                config.additional_context.as_deref(),
357                config.debug,
358                config.extraction_passes,
359                config.max_workers,
360            )
361            .await
362    }
363}
364
365/// Visualize function that mirrors the Python API
366pub fn visualize(
367    annotated_document: &AnnotatedDocument,
368    show_char_intervals: bool,
369) -> LangExtractResult<String> {
370    visualization::visualize(annotated_document, show_char_intervals)
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_extract_config_default() {
379        let config = ExtractConfig::default();
380        assert_eq!(config.model_id, "gemini-2.5-flash");
381        assert_eq!(config.format_type, FormatType::Json);
382        assert_eq!(config.max_char_buffer, 1000);
383        assert_eq!(config.temperature, 0.5);
384    }
385
386    #[test]
387    fn test_extraction_validation() {
388        let examples: Vec<ExampleData> = vec![];
389        let config = ExtractConfig::default();
390
391        tokio_test::block_on(async {
392            let result = extract("test text", None, &examples, config).await;
393            assert!(result.is_err());
394            match result.err().unwrap() {
395                LangExtractError::InvalidInput(msg) => {
396                    assert!(msg.contains("Examples are required"));
397                }
398                _ => panic!("Expected InvalidInput error"),
399            }
400        });
401    }
402}