pub mod config;
pub mod data;
pub mod exceptions;
pub mod schema;
pub mod alignment;
pub mod annotation;
pub mod chunking;
pub mod inference;
pub mod multipass;
pub mod tokenizer;
pub mod providers;
pub mod factory;
pub mod io;
pub mod logging;
pub mod pipeline;
pub mod progress;
pub mod prompting;
pub mod resolver;
pub mod templates;
pub mod visualization;
pub use config::{
LangExtractConfig, ProcessingConfig, ValidationConfig as NewValidationConfig,
ChunkingConfig, AlignmentConfig as NewAlignmentConfig, MultiPassConfig as NewMultiPassConfig,
VisualizationConfig, InferenceConfig as NewInferenceConfig, ProgressConfig,
ChunkingStrategy, ExportFormat as NewExportFormat
};
pub use data::{
AlignmentStatus, AnnotatedDocument, CharInterval, Document, ExampleData, Extraction,
FormatType,
};
pub use exceptions::{LangExtractError, LangExtractResult};
pub use inference::{BaseLanguageModel, ScoredOutput};
pub use logging::{ProgressHandler, ProgressEvent, ConsoleProgressHandler, SilentProgressHandler, LogProgressHandler};
pub use providers::{ProviderConfig, ProviderType, UniversalProvider};
pub use resolver::{ValidationConfig, ValidationResult, ValidationError, ValidationWarning, CoercionSummary, CoercionDetail, CoercionTargetType};
pub use visualization::{ExportFormat, ExportConfig, export_document};
pub use pipeline::{PipelineConfig, PipelineStep, PipelineResult, PipelineExecutor};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Clone, Serialize, Deserialize)]
pub struct ExtractConfig {
pub model_id: String,
pub api_key: Option<String>,
pub format_type: FormatType,
pub max_char_buffer: usize,
pub temperature: f32,
pub fence_output: Option<bool>,
pub use_schema_constraints: bool,
pub batch_length: usize,
pub max_workers: usize,
pub additional_context: Option<String>,
pub resolver_params: HashMap<String, serde_json::Value>,
pub language_model_params: HashMap<String, serde_json::Value>,
pub debug: bool,
pub model_url: Option<String>,
pub enable_multipass: bool,
pub multipass_max_passes: usize,
pub multipass_min_extractions: usize,
pub multipass_quality_threshold: f32,
#[serde(skip)]
pub progress_handler: Option<std::sync::Arc<dyn ProgressHandler>>,
}
impl Default for ExtractConfig {
fn default() -> Self {
Self {
model_id: "gpt-4o-mini".to_string(),
api_key: None,
format_type: FormatType::Json,
max_char_buffer: 1000,
temperature: 0.5,
fence_output: None,
use_schema_constraints: true,
batch_length: 10,
max_workers: 10,
additional_context: None,
resolver_params: HashMap::new(),
language_model_params: HashMap::new(),
debug: false,
model_url: None,
enable_multipass: false,
multipass_max_passes: 2,
multipass_min_extractions: 1,
multipass_quality_threshold: 0.3,
progress_handler: None,
}
}
}
impl std::fmt::Debug for ExtractConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExtractConfig")
.field("model_id", &self.model_id)
.field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
.field("format_type", &self.format_type)
.field("max_char_buffer", &self.max_char_buffer)
.field("temperature", &self.temperature)
.field("fence_output", &self.fence_output)
.field("use_schema_constraints", &self.use_schema_constraints)
.field("batch_length", &self.batch_length)
.field("max_workers", &self.max_workers)
.field("additional_context", &self.additional_context)
.field("resolver_params", &self.resolver_params)
.field("language_model_params", &self.language_model_params)
.field("debug", &self.debug)
.field("model_url", &self.model_url)
.field("enable_multipass", &self.enable_multipass)
.field("multipass_max_passes", &self.multipass_max_passes)
.field("multipass_min_extractions", &self.multipass_min_extractions)
.field("multipass_quality_threshold", &self.multipass_quality_threshold)
.field("progress_handler", &"<ProgressHandler>")
.finish()
}
}
impl ExtractConfig {
pub fn with_progress_handler(mut self, handler: std::sync::Arc<dyn ProgressHandler>) -> Self {
self.progress_handler = Some(handler);
self
}
pub fn with_console_progress(mut self) -> Self {
self.progress_handler = Some(std::sync::Arc::new(ConsoleProgressHandler::new()));
self
}
pub fn with_quiet_mode(mut self) -> Self {
self.progress_handler = Some(std::sync::Arc::new(SilentProgressHandler));
self
}
pub fn with_verbose_progress(mut self) -> Self {
self.progress_handler = Some(std::sync::Arc::new(ConsoleProgressHandler::verbose()));
self
}
}
pub async fn extract_with_config(
text_or_documents: &str,
prompt_description: Option<&str>,
examples: &[ExampleData],
config: LangExtractConfig,
) -> LangExtractResult<AnnotatedDocument> {
let legacy_config: ExtractConfig = config.into();
extract(text_or_documents, prompt_description, examples, legacy_config).await
}
#[tracing::instrument(skip_all, fields(text_len = text_or_documents.len(), num_examples = examples.len(), model = %config.model_id, multipass = config.enable_multipass))]
pub async fn extract(
text_or_documents: &str,
prompt_description: Option<&str>,
examples: &[ExampleData],
config: ExtractConfig,
) -> LangExtractResult<AnnotatedDocument> {
if examples.is_empty() {
return Err(LangExtractError::InvalidInput(
"Examples are required for reliable extraction. Please provide at least one ExampleData object with sample extractions.".to_string()
));
}
if config.batch_length < config.max_workers {
log::warn!(
"batch_length ({}) < max_workers ({}). Only {} workers will be used. Set batch_length >= max_workers for optimal parallelization.",
config.batch_length,
config.max_workers,
config.batch_length
);
}
dotenvy::dotenv().ok();
if let Some(handler) = &config.progress_handler {
logging::init_progress_handler(handler.clone());
} else {
let default_handler: std::sync::Arc<dyn ProgressHandler> = if config.debug {
std::sync::Arc::new(ConsoleProgressHandler::new())
} else {
std::sync::Arc::new(SilentProgressHandler)
};
logging::init_progress_handler(default_handler);
}
let text = if io::is_url(text_or_documents) {
io::download_text_from_url(text_or_documents).await?
} else {
text_or_documents.to_string()
};
let mut prompt_template = prompting::PromptTemplateStructured::new(prompt_description);
prompt_template.examples.extend(examples.iter().cloned());
let language_model = factory::create_model(&config, Some(&prompt_template.examples)).await?;
let resolver = resolver::Resolver::new(&config, language_model.requires_fence_output())?;
let annotator = annotation::Annotator::with_config(
language_model,
prompt_template,
config.temperature,
config.language_model_params.get("max_output_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
);
if config.enable_multipass {
let multipass_config = multipass::MultiPassConfig {
max_passes: config.multipass_max_passes,
min_extractions_per_chunk: config.multipass_min_extractions,
enable_targeted_reprocessing: true,
enable_refinement_passes: true,
quality_threshold: config.multipass_quality_threshold,
max_reprocess_chunks: 10,
temperature_decay: 0.9,
max_char_buffer: config.max_char_buffer,
batch_length: config.batch_length,
max_workers: config.max_workers,
};
let processor = multipass::MultiPassProcessor::new(
multipass_config,
annotator,
resolver,
);
let (result, _stats) = processor.extract_multipass(
&text,
config.additional_context.as_deref(),
config.debug,
).await?;
if config.debug {
log::info!("Multi-pass extraction completed with {} total extractions",
result.extraction_count());
}
Ok(result)
} else {
annotator
.annotate_text(
&text,
&resolver,
config.max_char_buffer,
config.batch_length,
config.additional_context.as_deref(),
config.debug,
config.max_workers,
)
.await
}
}
pub fn visualize(
annotated_document: &AnnotatedDocument,
show_char_intervals: bool,
) -> LangExtractResult<String> {
visualization::visualize(annotated_document, show_char_intervals)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_config_default() {
let config = ExtractConfig::default();
assert_eq!(config.model_id, "gpt-4o-mini");
assert_eq!(config.format_type, FormatType::Json);
assert_eq!(config.max_char_buffer, 1000);
assert_eq!(config.temperature, 0.5);
}
#[test]
fn test_extraction_validation() {
let examples: Vec<ExampleData> = vec![];
let config = ExtractConfig::default();
tokio_test::block_on(async {
let result = extract("test text", None, &examples, config).await;
assert!(result.is_err());
match result.err().unwrap() {
LangExtractError::InvalidInput(msg) => {
assert!(msg.contains("Examples are required"));
}
_ => panic!("Expected InvalidInput error"),
}
});
}
}