langextract_rust/
lib.rs

1//! # LangExtract
2//!
3//! A Rust library for extracting structured and grounded information from text using LLMs.
4//!
5//! This library provides a clean, async API for working with various language model providers
6//! to extract structured data from unstructured text.
7//!
8//! ## Features
9//!
10//! - Support for multiple LLM providers (Gemini, OpenAI, Ollama)
11//! - Async/await API for concurrent processing
12//! - Schema-driven extraction with validation
13//! - Text chunking and tokenization
14//! - Flexible output formats (JSON, YAML)
15//! - Built-in visualization and progress tracking
16//!
17//! ## Quick Start
18//!
19//! ```rust,no_run
20//! use langextract_rust::{extract, ExampleData, Extraction, FormatType};
21//!
22//! #[tokio::main]
23//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
24//!     let examples = vec![
25//!         ExampleData {
26//!             text: "John Doe is 30 years old".to_string(),
27//!             extractions: vec![
28//!                 Extraction::new("person".to_string(), "John Doe".to_string()),
29//!                 Extraction::new("age".to_string(), "30".to_string()),
30//!             ],
31//!         }
32//!     ];
33//!
34//!     let result = extract(
35//!         "Alice Smith is 25 years old and works as a doctor",
36//!         Some("Extract person names and ages from the text"),
37//!         &examples,
38//!         Default::default(),
39//!     ).await?;
40//!
41//!     println!("{:?}", result);
42//!     Ok(())
43//! }
44//! ```
45
46// Core modules
47pub mod data;
48pub mod exceptions;
49pub mod schema;
50
51// Processing modules
52pub mod alignment;
53pub mod annotation;
54pub mod chunking;
55pub mod inference;
56pub mod multipass;
57pub mod tokenizer;
58
59// Provider modules
60pub mod providers;
61pub mod factory;
62
63// Utility modules
64pub mod io;
65pub mod progress;
66pub mod prompting;
67pub mod resolver;
68pub mod visualization;
69
70// Re-export key types for convenience
71pub use data::{
72    AlignmentStatus, AnnotatedDocument, CharInterval, Document, ExampleData, Extraction,
73    FormatType,
74};
75pub use exceptions::{LangExtractError, LangExtractResult};
76pub use inference::{BaseLanguageModel, ScoredOutput};
77pub use providers::{ProviderConfig, ProviderType, UniversalProvider};
78pub use resolver::{ValidationConfig, ValidationResult, ValidationError, ValidationWarning, CoercionSummary, CoercionDetail, CoercionTargetType};
79pub use visualization::{ExportFormat, ExportConfig, export_document};
80
81use std::collections::HashMap;
82
83/// Configuration for the extract function
84#[derive(Debug, Clone)]
85pub struct ExtractConfig {
86    /// The model ID to use (e.g., "gemini-2.5-flash", "gpt-4o")
87    pub model_id: String,
88    /// API key for the language model service
89    pub api_key: Option<String>,
90    /// Output format type
91    pub format_type: FormatType,
92    /// Maximum characters per chunk for processing
93    pub max_char_buffer: usize,
94    /// Sampling temperature (0.0 to 1.0)
95    pub temperature: f32,
96    /// Whether to wrap output in code fences
97    pub fence_output: Option<bool>,
98    /// Whether to use schema constraints
99    pub use_schema_constraints: bool,
100    /// Batch size for processing chunks
101    pub batch_length: usize,
102    /// Maximum number of concurrent workers
103    pub max_workers: usize,
104    /// Additional context for the prompt
105    pub additional_context: Option<String>,
106    /// Custom resolver parameters
107    pub resolver_params: HashMap<String, serde_json::Value>,
108    /// Custom language model parameters
109    pub language_model_params: HashMap<String, serde_json::Value>,
110    /// Enable debug mode
111    pub debug: bool,
112    /// Custom model URL for self-hosted models
113    pub model_url: Option<String>,
114    /// Number of extraction passes to improve recall
115    pub extraction_passes: usize,
116    /// Enable multi-pass extraction for improved recall
117    pub enable_multipass: bool,
118    /// Minimum extractions per chunk to avoid re-processing
119    pub multipass_min_extractions: usize,
120    /// Quality threshold for keeping extractions (0.0 to 1.0)
121    pub multipass_quality_threshold: f32,
122}
123
124impl Default for ExtractConfig {
125    fn default() -> Self {
126        Self {
127            model_id: "gemini-2.5-flash".to_string(),
128            api_key: None,
129            format_type: FormatType::Json,
130            max_char_buffer: 1000,
131            temperature: 0.5,
132            fence_output: None,
133            use_schema_constraints: true,
134            batch_length: 10,
135            max_workers: 10,
136            additional_context: None,
137            resolver_params: HashMap::new(),
138            language_model_params: HashMap::new(),
139            debug: true,
140            model_url: None,
141            extraction_passes: 1,
142            enable_multipass: false,
143            multipass_min_extractions: 1,
144            multipass_quality_threshold: 0.3,
145        }
146    }
147}
148
149/// Main extraction function that mirrors the Python API
150///
151/// Extracts structured information from text using a language model based on
152/// the provided examples and configuration.
153///
154/// # Arguments
155///
156/// * `text_or_documents` - The source text to extract information from, or a URL starting with http/https
157/// * `prompt_description` - Instructions for what to extract from the text
158/// * `examples` - Example data to guide the extraction
159/// * `config` - Configuration parameters for the extraction
160///
161/// # Returns
162///
163/// An `AnnotatedDocument` with the extracted information
164///
165/// # Errors
166///
167/// Returns an error if:
168/// * Examples are empty
169/// * No API key is provided
170/// * URL download fails
171/// * Language model inference fails
172pub async fn extract(
173    text_or_documents: &str,
174    prompt_description: Option<&str>,
175    examples: &[ExampleData],
176    config: ExtractConfig,
177) -> LangExtractResult<AnnotatedDocument> {
178    // Validate inputs
179    if examples.is_empty() {
180        return Err(LangExtractError::InvalidInput(
181            "Examples are required for reliable extraction. Please provide at least one ExampleData object with sample extractions.".to_string()
182        ));
183    }
184
185    if config.batch_length < config.max_workers {
186        log::warn!(
187            "batch_length ({}) < max_workers ({}). Only {} workers will be used. Set batch_length >= max_workers for optimal parallelization.",
188            config.batch_length,
189            config.max_workers,
190            config.batch_length
191        );
192    }
193
194    // Load environment variables
195    dotenvy::dotenv().ok();
196
197    // Handle URL input
198    let text = if io::is_url(text_or_documents) {
199        io::download_text_from_url(text_or_documents).await?
200    } else {
201        text_or_documents.to_string()
202    };
203
204    // Create prompt template
205    let mut prompt_template = prompting::PromptTemplateStructured::new(prompt_description);
206    prompt_template.examples.extend(examples.iter().cloned());
207
208    // Create language model
209    let language_model = factory::create_model(&config, Some(&prompt_template.examples)).await?;
210
211    // Create resolver
212    let resolver = resolver::Resolver::new(&config, language_model.requires_fence_output())?;
213
214    // Create annotator
215    let annotator = annotation::Annotator::new(
216        language_model,
217        prompt_template,
218        config.format_type,
219        resolver.fence_output(),
220    );
221
222    // Perform annotation - use multi-pass if enabled
223    if config.enable_multipass && config.extraction_passes > 1 {
224        // Use multi-pass extraction
225        let multipass_config = multipass::MultiPassConfig {
226            max_passes: config.extraction_passes,
227            min_extractions_per_chunk: config.multipass_min_extractions,
228            enable_targeted_reprocessing: true,
229            enable_refinement_passes: true,
230            quality_threshold: config.multipass_quality_threshold,
231            max_reprocess_chunks: 10,
232            temperature_decay: 0.9,
233        };
234
235        let processor = multipass::MultiPassProcessor::new(
236            multipass_config,
237            annotator,
238            resolver,
239        );
240
241        let (result, _stats) = processor.extract_multipass(
242            &text,
243            config.additional_context.as_deref(),
244            config.debug,
245        ).await?;
246
247        if config.debug {
248            println!("🎯 Multi-pass extraction completed with {} total extractions", 
249                result.extraction_count());
250        }
251
252        Ok(result)
253    } else {
254        // Use single-pass extraction
255        annotator
256            .annotate_text(
257                &text,
258                &resolver,
259                config.max_char_buffer,
260                config.batch_length,
261                config.additional_context.as_deref(),
262                config.debug,
263                config.extraction_passes,
264                config.max_workers,
265            )
266            .await
267    }
268}
269
270/// Visualize function that mirrors the Python API
271pub fn visualize(
272    annotated_document: &AnnotatedDocument,
273    show_char_intervals: bool,
274) -> LangExtractResult<String> {
275    visualization::visualize(annotated_document, show_char_intervals)
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn test_extract_config_default() {
284        let config = ExtractConfig::default();
285        assert_eq!(config.model_id, "gemini-2.5-flash");
286        assert_eq!(config.format_type, FormatType::Json);
287        assert_eq!(config.max_char_buffer, 1000);
288        assert_eq!(config.temperature, 0.5);
289    }
290
291    #[test]
292    fn test_extraction_validation() {
293        let examples: Vec<ExampleData> = vec![];
294        let config = ExtractConfig::default();
295
296        tokio_test::block_on(async {
297            let result = extract("test text", None, &examples, config).await;
298            assert!(result.is_err());
299            match result.err().unwrap() {
300                LangExtractError::InvalidInput(msg) => {
301                    assert!(msg.contains("Examples are required"));
302                }
303                _ => panic!("Expected InvalidInput error"),
304            }
305        });
306    }
307}