use anyhow::{Result, bail};
use kreuzberg::{
ChunkingConfig, ExecutionProviderType, ExtractionConfig, LanguageDetectionConfig, LlmConfig, OcrConfig,
};
use crate::ContentOutputFormatArg;
#[derive(Clone, Copy, Debug, PartialEq, Eq, clap::ValueEnum)]
pub enum AccelerationArg {
Auto,
Cpu,
#[value(name = "coreml")]
CoreMl,
Cuda,
#[value(name = "tensorrt")]
TensorRt,
}
impl From<AccelerationArg> for ExecutionProviderType {
fn from(arg: AccelerationArg) -> Self {
match arg {
AccelerationArg::Auto => ExecutionProviderType::Auto,
AccelerationArg::Cpu => ExecutionProviderType::Cpu,
AccelerationArg::CoreMl => ExecutionProviderType::CoreMl,
AccelerationArg::Cuda => ExecutionProviderType::Cuda,
AccelerationArg::TensorRt => ExecutionProviderType::TensorRt,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, clap::ValueEnum)]
pub enum ReductionLevelArg {
Off,
Light,
Moderate,
Aggressive,
Maximum,
}
impl ReductionLevelArg {
fn as_mode_str(self) -> &'static str {
match self {
Self::Off => "off",
Self::Light => "light",
Self::Moderate => "moderate",
Self::Aggressive => "aggressive",
Self::Maximum => "maximum",
}
}
}
#[derive(Debug, Default, clap::Args)]
pub struct ExtractionOverrides {
#[arg(long)]
pub ocr: Option<bool>,
#[arg(long)]
pub ocr_backend: Option<String>,
#[arg(long)]
pub ocr_language: Option<String>,
#[arg(long)]
pub force_ocr: Option<bool>,
#[arg(long)]
pub disable_ocr: Option<bool>,
#[arg(long)]
pub no_cache: Option<bool>,
#[arg(long)]
pub ocr_auto_rotate: Option<bool>,
#[arg(long)]
pub vlm_model: Option<String>,
#[arg(long)]
pub vlm_api_key: Option<String>,
#[arg(long)]
pub vlm_prompt: Option<String>,
#[arg(long)]
pub chunk: Option<bool>,
#[arg(long)]
pub chunk_size: Option<usize>,
#[arg(long)]
pub chunk_overlap: Option<usize>,
#[arg(long)]
pub chunking_tokenizer: Option<String>,
#[arg(long, value_enum)]
pub content_format: Option<ContentOutputFormatArg>,
#[arg(long, value_enum, hide = true)]
pub output_format: Option<ContentOutputFormatArg>,
#[arg(long)]
pub include_structure: Option<bool>,
#[arg(long)]
pub quality: Option<bool>,
#[arg(long)]
pub detect_language: Option<bool>,
#[cfg(feature = "layout-detection")]
#[arg(long, default_missing_value = "true", num_args = 0..=1)]
pub layout: Option<bool>,
#[cfg(feature = "layout-detection")]
#[arg(long)]
pub layout_confidence: Option<f32>,
#[cfg(feature = "layout-detection")]
#[arg(
long,
help = "Table structure model: tatr (default), slanet_wired, slanet_wireless, slanet_plus, slanet_auto, disabled"
)]
pub layout_table_model: Option<String>,
#[arg(long, value_enum)]
pub acceleration: Option<AccelerationArg>,
#[arg(long, help = "Limit parallel extractions in batch mode")]
pub max_concurrent: Option<usize>,
#[arg(long, help = "Limit total threads for constrained environments")]
pub max_threads: Option<usize>,
#[arg(long)]
pub extract_pages: Option<bool>,
#[arg(long)]
pub page_markers: Option<bool>,
#[arg(long)]
pub extract_images: Option<bool>,
#[arg(long)]
pub target_dpi: Option<i32>,
#[arg(long)]
pub pdf_password: Vec<String>,
#[cfg(any(feature = "bundled-pdfium", feature = "static-pdfium"))]
#[arg(long)]
pub pdf_extract_images: Option<bool>,
#[cfg(any(feature = "bundled-pdfium", feature = "static-pdfium"))]
#[arg(long)]
pub pdf_extract_metadata: Option<bool>,
#[arg(long, value_enum)]
pub token_reduction: Option<ReductionLevelArg>,
#[arg(long)]
pub msg_codepage: Option<u32>,
#[arg(long)]
pub cache_namespace: Option<String>,
#[arg(long)]
pub cache_ttl_secs: Option<u64>,
#[cfg(feature = "html")]
#[arg(long, value_name = "THEME")]
pub html_theme: Option<String>,
#[cfg(feature = "html")]
#[arg(long, value_name = "CSS")]
pub html_css: Option<String>,
#[cfg(feature = "html")]
#[arg(long, value_name = "PATH")]
pub html_css_file: Option<std::path::PathBuf>,
#[cfg(feature = "html")]
#[arg(long, value_name = "PREFIX")]
pub html_class_prefix: Option<String>,
#[cfg(feature = "html")]
#[arg(long)]
pub html_no_embed_css: bool,
}
impl ExtractionOverrides {
pub fn validate(&self) -> Result<()> {
if let Some(size) = self.chunk_size {
if size == 0 {
bail!("Invalid chunk size: {size}. Chunk size must be greater than 0.");
}
if size > 1_000_000 {
bail!(
"Invalid chunk size: {size}. Chunk size must be less than 1,000,000 characters to avoid excessive memory usage."
);
}
}
if let Some(overlap) = self.chunk_overlap
&& let Some(size) = self.chunk_size
&& overlap >= size
{
bail!("Invalid chunk overlap: {overlap}. Overlap ({overlap}) must be less than chunk size ({size}).");
}
if let Some(dpi) = self.target_dpi
&& (!(36..=2400).contains(&dpi))
{
bail!("Invalid target DPI: {dpi}. Value must be between 36 and 2400.");
}
#[cfg(feature = "layout-detection")]
{
if let Some(conf) = self.layout_confidence
&& !(0.0..=1.0).contains(&conf)
{
bail!("Invalid layout confidence: {conf}. Value must be between 0.0 and 1.0.");
}
if self.layout == Some(false) && (self.layout_confidence.is_some() || self.layout_table_model.is_some()) {
bail!("--layout false cannot be combined with --layout-confidence or --layout-table-model");
}
}
#[cfg(not(feature = "chunking-tokenizers"))]
if self.chunking_tokenizer.is_some() {
bail!(
"--chunking-tokenizer requires the chunking-tokenizers feature. \
Rebuild with --features chunking-tokenizers"
);
}
if self.force_ocr == Some(true) && self.disable_ocr == Some(true) {
bail!("--force-ocr and --disable-ocr cannot both be true");
}
if let Some(ref backend) = self.ocr_backend
&& !["tesseract", "paddle-ocr", "easyocr", "vlm"].contains(&backend.as_str())
{
bail!(
"Invalid OCR backend '{}'. Valid backends: tesseract, paddle-ocr, easyocr, vlm",
backend
);
}
if self.vlm_api_key.is_some() && self.vlm_model.is_none() {
bail!("--vlm-api-key requires --vlm-model to be specified");
}
if self.vlm_prompt.is_some() && self.vlm_model.is_none() {
bail!("--vlm-prompt requires --vlm-model to be specified");
}
if let Some(0) = self.max_concurrent {
bail!("--max-concurrent must be at least 1");
}
if let Some(0) = self.max_threads {
bail!("--max-threads must be at least 1");
}
Ok(())
}
pub fn apply(self, config: &mut ExtractionConfig) {
self.apply_ocr(config);
self.apply_vlm_ocr(config);
self.apply_chunking(config);
self.apply_quality_and_detection(config);
self.apply_output_format(config);
self.apply_include_structure(config);
self.apply_layout(config);
self.apply_acceleration(config);
self.apply_concurrency(config);
self.apply_pages(config);
self.apply_images(config);
self.apply_pdf(config);
self.apply_token_reduction(config);
self.apply_email(config);
self.apply_cache(config);
self.apply_html_styled(config);
}
fn apply_ocr(&self, config: &mut ExtractionConfig) {
if let Some(ocr_flag) = self.ocr {
if ocr_flag {
let backend = match self.ocr_backend.as_deref() {
Some("paddle-ocr") => "paddle-ocr",
Some("easyocr") => "easyocr",
_ => "tesseract",
};
let language = match &self.ocr_language {
Some(lang) => lang.clone(),
None => match backend {
"paddle-ocr" | "easyocr" => "en".to_string(),
_ => "eng".to_string(),
},
};
let existing_paddle_config = config.ocr.as_ref().and_then(|o| o.paddle_ocr_config.clone());
let existing_element_config = config.ocr.as_ref().and_then(|o| o.element_config.clone());
let auto_rotate = self.ocr_auto_rotate.unwrap_or(false);
config.ocr = Some(OcrConfig {
enabled: true,
backend: backend.to_string(),
language,
tesseract_config: None,
output_format: None,
paddle_ocr_config: existing_paddle_config,
element_config: existing_element_config,
quality_thresholds: None,
pipeline: None,
auto_rotate,
vlm_config: None,
vlm_prompt: None,
});
} else {
config.ocr = None;
}
}
if self.ocr.is_none()
&& let Some(ref lang) = self.ocr_language
&& let Some(ref mut existing_ocr) = config.ocr
{
existing_ocr.language = lang.clone();
}
if self.ocr.is_none()
&& let Some(rotate) = self.ocr_auto_rotate
&& let Some(ref mut existing_ocr) = config.ocr
{
existing_ocr.auto_rotate = rotate;
}
if let Some(force_ocr_flag) = self.force_ocr {
config.force_ocr = force_ocr_flag;
}
if let Some(disable_ocr_flag) = self.disable_ocr {
config.disable_ocr = disable_ocr_flag;
}
if let Some(no_cache_flag) = self.no_cache {
config.use_cache = !no_cache_flag;
}
}
fn apply_vlm_ocr(&self, config: &mut ExtractionConfig) {
if let Some(ref vlm_model) = self.vlm_model {
let vlm_llm_config = LlmConfig {
model: vlm_model.clone(),
api_key: self.vlm_api_key.clone(),
base_url: None,
timeout_secs: None,
max_retries: None,
temperature: None,
max_tokens: None,
};
let ocr = config.ocr.get_or_insert_with(|| OcrConfig {
enabled: true,
backend: "vlm".to_string(),
language: "eng".to_string(),
tesseract_config: None,
output_format: None,
paddle_ocr_config: None,
element_config: None,
quality_thresholds: None,
pipeline: None,
auto_rotate: false,
vlm_config: None,
vlm_prompt: None,
});
ocr.backend = "vlm".to_string();
ocr.vlm_config = Some(vlm_llm_config);
if let Some(ref prompt) = self.vlm_prompt {
ocr.vlm_prompt = Some(prompt.clone());
}
}
}
fn apply_chunking(&self, config: &mut ExtractionConfig) {
let chunk = if self.chunking_tokenizer.is_some() && self.chunk.is_none() {
Some(true)
} else {
self.chunk
};
if let Some(chunk_flag) = chunk {
if chunk_flag {
let max_characters = self.chunk_size.unwrap_or(1000);
let overlap = self.chunk_overlap.unwrap_or(200);
let mut chunking_config = ChunkingConfig {
max_characters,
overlap,
trim: true,
chunker_type: kreuzberg::ChunkerType::Text,
..Default::default()
};
#[cfg(feature = "chunking-tokenizers")]
if let Some(ref model) = self.chunking_tokenizer {
chunking_config.sizing = kreuzberg::ChunkSizing::Tokenizer {
model: model.clone(),
cache_dir: None,
};
}
config.chunking = Some(chunking_config);
} else {
config.chunking = None;
}
} else if let Some(ref mut chunking) = config.chunking {
if let Some(max_characters) = self.chunk_size {
chunking.max_characters = max_characters;
}
if let Some(overlap) = self.chunk_overlap {
chunking.overlap = overlap;
}
if chunking.overlap >= chunking.max_characters {
chunking.overlap = chunking.max_characters / 4;
}
#[cfg(feature = "chunking-tokenizers")]
if let Some(ref model) = self.chunking_tokenizer {
chunking.sizing = kreuzberg::ChunkSizing::Tokenizer {
model: model.clone(),
cache_dir: None,
};
}
}
}
fn apply_quality_and_detection(&self, config: &mut ExtractionConfig) {
if let Some(quality_flag) = self.quality {
config.enable_quality_processing = quality_flag;
}
if let Some(detect_language_flag) = self.detect_language {
if detect_language_flag {
config.language_detection = Some(LanguageDetectionConfig {
enabled: true,
min_confidence: 0.8,
detect_multiple: false,
});
} else {
config.language_detection = None;
}
}
}
fn apply_output_format(&self, config: &mut ExtractionConfig) {
let final_format = self.content_format.or_else(|| {
if self.output_format.is_some() {
eprintln!("warning: '--output-format' is deprecated, use '--content-format' instead");
}
self.output_format
});
if let Some(content_fmt) = final_format {
config.output_format = content_fmt.into();
}
}
fn apply_include_structure(&self, config: &mut ExtractionConfig) {
if let Some(flag) = self.include_structure {
config.include_document_structure = flag;
}
}
#[allow(unused_variables)]
fn apply_layout(&self, config: &mut ExtractionConfig) {
#[cfg(feature = "layout-detection")]
{
if self.layout == Some(false) {
config.layout = None;
return;
}
let has_layout_flag =
self.layout == Some(true) || self.layout_confidence.is_some() || self.layout_table_model.is_some();
if has_layout_flag {
let mut layout = config.layout.clone().unwrap_or_default();
if let Some(confidence) = self.layout_confidence {
layout.confidence_threshold = Some(confidence);
}
if let Some(ref table_model) = self.layout_table_model {
layout.table_model = table_model.parse().unwrap_or_default();
}
config.layout = Some(layout);
}
}
}
fn apply_acceleration(&self, config: &mut ExtractionConfig) {
if let Some(accel) = self.acceleration {
let mut accel_config = config.acceleration.clone().unwrap_or_default();
accel_config.provider = accel.into();
config.acceleration = Some(accel_config);
}
}
fn apply_concurrency(&self, config: &mut ExtractionConfig) {
if let Some(max_concurrent) = self.max_concurrent {
config.max_concurrent_extractions = Some(max_concurrent);
}
if let Some(max_threads) = self.max_threads {
let concurrency = config.concurrency.get_or_insert_with(Default::default);
concurrency.max_threads = Some(max_threads);
}
}
fn apply_pages(&self, config: &mut ExtractionConfig) {
let has_page_flag = self.extract_pages.is_some() || self.page_markers.is_some();
if has_page_flag {
let mut page_config = config.pages.clone().unwrap_or_default();
if let Some(extract) = self.extract_pages {
page_config.extract_pages = extract;
}
if let Some(markers) = self.page_markers {
page_config.insert_page_markers = markers;
}
config.pages = Some(page_config);
}
}
fn apply_images(&self, config: &mut ExtractionConfig) {
let has_image_flag = self.extract_images.is_some() || self.target_dpi.is_some();
if has_image_flag {
let mut img = config.images.clone().unwrap_or_default();
if let Some(extract) = self.extract_images {
img.extract_images = extract;
}
if let Some(dpi) = self.target_dpi {
img.target_dpi = dpi;
}
config.images = Some(img);
}
}
#[allow(unused_variables)]
fn apply_pdf(&self, config: &mut ExtractionConfig) {
#[cfg(any(feature = "bundled-pdfium", feature = "static-pdfium"))]
{
let has_pdf_flag = self.pdf_extract_images.is_some()
|| self.pdf_extract_metadata.is_some()
|| !self.pdf_password.is_empty();
if has_pdf_flag {
let pdf_opts = config.pdf_options.get_or_insert_with(Default::default);
if let Some(extract_img) = self.pdf_extract_images {
pdf_opts.extract_images = extract_img;
}
if let Some(extract_meta) = self.pdf_extract_metadata {
pdf_opts.extract_metadata = extract_meta;
}
if !self.pdf_password.is_empty() {
pdf_opts.passwords = Some(self.pdf_password.clone());
}
}
}
#[cfg(not(any(feature = "bundled-pdfium", feature = "static-pdfium")))]
if !self.pdf_password.is_empty() {
let pdf_opts = config.pdf_options.get_or_insert_with(Default::default);
pdf_opts.passwords = Some(self.pdf_password.clone());
}
}
fn apply_token_reduction(&self, config: &mut ExtractionConfig) {
if let Some(level) = self.token_reduction {
config.token_reduction = Some(kreuzberg::TokenReductionConfig {
mode: level.as_mode_str().to_string(),
preserve_important_words: true,
});
}
}
fn apply_email(&self, config: &mut ExtractionConfig) {
if let Some(codepage) = self.msg_codepage {
let email = config.email.get_or_insert_with(Default::default);
email.msg_fallback_codepage = Some(codepage);
}
}
fn apply_cache(&self, config: &mut ExtractionConfig) {
if let Some(ns) = &self.cache_namespace {
config.cache_namespace = Some(ns.clone());
}
if let Some(ttl) = self.cache_ttl_secs {
config.cache_ttl_secs = Some(ttl);
}
}
#[allow(unused_variables)]
fn apply_html_styled(&self, config: &mut ExtractionConfig) {
#[cfg(feature = "html")]
{
let has_flag = self.html_theme.is_some()
|| self.html_css.is_some()
|| self.html_css_file.is_some()
|| self.html_class_prefix.is_some()
|| self.html_no_embed_css;
if has_flag {
config.output_format = kreuzberg::OutputFormat::Html;
let mut html_cfg = config.html_output.clone().unwrap_or_default();
if let Some(ref theme_str) = self.html_theme {
html_cfg.theme = match theme_str.to_lowercase().as_str() {
"github" => kreuzberg::HtmlTheme::GitHub,
"dark" => kreuzberg::HtmlTheme::Dark,
"light" => kreuzberg::HtmlTheme::Light,
"unstyled" => kreuzberg::HtmlTheme::Unstyled,
_ => kreuzberg::HtmlTheme::Default,
};
}
if let Some(ref css) = self.html_css {
html_cfg.css = Some(css.clone());
}
if let Some(ref path) = self.html_css_file {
html_cfg.css_file = Some(path.clone());
}
if let Some(ref prefix) = self.html_class_prefix {
html_cfg.class_prefix = prefix.clone();
}
if self.html_no_embed_css {
html_cfg.embed_css = false;
}
config.html_output = Some(html_cfg);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use kreuzberg::ExtractionConfig;
fn default_overrides() -> ExtractionOverrides {
ExtractionOverrides::default()
}
#[test]
fn test_ocr_default_language_tesseract() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
ocr: Some(true),
..default_overrides()
};
overrides.apply(&mut config);
let ocr = config.ocr.unwrap();
assert_eq!(ocr.backend, "tesseract");
assert_eq!(ocr.language, "eng");
}
#[test]
fn test_ocr_default_language_paddleocr() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
ocr: Some(true),
ocr_backend: Some("paddle-ocr".to_string()),
..default_overrides()
};
overrides.apply(&mut config);
let ocr = config.ocr.unwrap();
assert_eq!(ocr.backend, "paddle-ocr");
assert_eq!(ocr.language, "en");
}
#[test]
fn test_ocr_default_language_easyocr() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
ocr: Some(true),
ocr_backend: Some("easyocr".to_string()),
..default_overrides()
};
overrides.apply(&mut config);
let ocr = config.ocr.unwrap();
assert_eq!(ocr.backend, "easyocr");
assert_eq!(ocr.language, "en");
}
#[test]
fn test_ocr_language_override_tesseract() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
ocr: Some(true),
ocr_language: Some("fra".to_string()),
..default_overrides()
};
overrides.apply(&mut config);
let ocr = config.ocr.unwrap();
assert_eq!(ocr.backend, "tesseract");
assert_eq!(ocr.language, "fra");
}
#[test]
fn test_ocr_language_override_paddleocr() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
ocr: Some(true),
ocr_backend: Some("paddle-ocr".to_string()),
ocr_language: Some("ch".to_string()),
..default_overrides()
};
overrides.apply(&mut config);
let ocr = config.ocr.unwrap();
assert_eq!(ocr.backend, "paddle-ocr");
assert_eq!(ocr.language, "ch");
}
#[test]
fn test_ocr_language_without_ocr_flag_no_existing_config() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
ocr_language: Some("deu".to_string()),
..default_overrides()
};
overrides.apply(&mut config);
assert!(config.ocr.is_none());
}
#[test]
fn test_ocr_language_without_ocr_flag_existing_config() {
let mut config = ExtractionConfig {
ocr: Some(OcrConfig {
enabled: true,
backend: "tesseract".to_string(),
language: "eng".to_string(),
tesseract_config: None,
output_format: None,
paddle_ocr_config: None,
element_config: None,
quality_thresholds: None,
pipeline: None,
auto_rotate: false,
vlm_config: None,
vlm_prompt: None,
}),
..Default::default()
};
let overrides = ExtractionOverrides {
ocr_language: Some("deu".to_string()),
..default_overrides()
};
overrides.apply(&mut config);
let ocr = config.ocr.unwrap();
assert_eq!(ocr.backend, "tesseract");
assert_eq!(ocr.language, "deu");
}
#[test]
fn test_ocr_disabled_ignores_language() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
ocr: Some(false),
ocr_language: Some("fra".to_string()),
..default_overrides()
};
overrides.apply(&mut config);
assert!(config.ocr.is_none());
}
#[test]
fn test_chunking_enabled_defaults() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
chunk: Some(true),
..default_overrides()
};
overrides.apply(&mut config);
let chunking = config.chunking.unwrap();
assert_eq!(chunking.max_characters, 1000);
assert_eq!(chunking.overlap, 200);
}
#[test]
fn test_chunking_custom_size() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
chunk: Some(true),
chunk_size: Some(500),
chunk_overlap: Some(50),
..default_overrides()
};
overrides.apply(&mut config);
let chunking = config.chunking.unwrap();
assert_eq!(chunking.max_characters, 500);
assert_eq!(chunking.overlap, 50);
}
#[test]
fn test_chunking_disabled() {
let mut config = ExtractionConfig {
chunking: Some(ChunkingConfig::default()),
..Default::default()
};
let overrides = ExtractionOverrides {
chunk: Some(false),
..default_overrides()
};
overrides.apply(&mut config);
assert!(config.chunking.is_none());
}
#[test]
fn test_validate_chunk_size_zero() {
let overrides = ExtractionOverrides {
chunk_size: Some(0),
..default_overrides()
};
assert!(overrides.validate().is_err());
}
#[test]
fn test_validate_chunk_size_too_large() {
let overrides = ExtractionOverrides {
chunk_size: Some(2_000_000),
..default_overrides()
};
assert!(overrides.validate().is_err());
}
#[test]
fn test_validate_overlap_exceeds_size() {
let overrides = ExtractionOverrides {
chunk_size: Some(100),
chunk_overlap: Some(200),
..default_overrides()
};
assert!(overrides.validate().is_err());
}
#[test]
fn test_validate_target_dpi_out_of_range() {
let overrides = ExtractionOverrides {
target_dpi: Some(5),
..default_overrides()
};
assert!(overrides.validate().is_err());
let overrides = ExtractionOverrides {
target_dpi: Some(5000),
..default_overrides()
};
assert!(overrides.validate().is_err());
}
#[test]
fn test_validate_target_dpi_valid() {
let overrides = ExtractionOverrides {
target_dpi: Some(300),
..default_overrides()
};
assert!(overrides.validate().is_ok());
}
#[cfg(feature = "layout-detection")]
#[test]
fn test_validate_layout_confidence_out_of_range() {
let overrides = ExtractionOverrides {
layout_confidence: Some(1.5),
..default_overrides()
};
assert!(overrides.validate().is_err());
let overrides = ExtractionOverrides {
layout_confidence: Some(-0.1),
..default_overrides()
};
assert!(overrides.validate().is_err());
}
#[cfg(feature = "layout-detection")]
#[test]
fn test_validate_layout_confidence_valid() {
let overrides = ExtractionOverrides {
layout_confidence: Some(0.5),
..default_overrides()
};
assert!(overrides.validate().is_ok());
}
#[cfg(feature = "layout-detection")]
#[test]
fn test_layout_table_model_applied() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
layout_table_model: Some("slanet_wired".to_string()),
..default_overrides()
};
overrides.apply(&mut config);
let layout = config.layout.unwrap();
assert_eq!(layout.table_model, kreuzberg::TableModel::SlanetWired);
}
#[cfg(feature = "layout-detection")]
#[test]
fn test_layout_confidence_applied() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
layout_confidence: Some(0.7),
..default_overrides()
};
overrides.apply(&mut config);
let layout = config.layout.unwrap();
assert_eq!(layout.confidence_threshold, Some(0.7));
}
#[test]
fn test_acceleration_applied() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
acceleration: Some(AccelerationArg::Cpu),
..default_overrides()
};
overrides.apply(&mut config);
let accel = config.acceleration.unwrap();
assert_eq!(accel.provider, ExecutionProviderType::Cpu);
}
#[test]
fn test_extract_pages_applied() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
extract_pages: Some(true),
page_markers: Some(true),
..default_overrides()
};
overrides.apply(&mut config);
let pages = config.pages.unwrap();
assert!(pages.extract_pages);
assert!(pages.insert_page_markers);
}
#[test]
fn test_extract_images_applied() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
extract_images: Some(true),
target_dpi: Some(150),
..default_overrides()
};
overrides.apply(&mut config);
let images = config.images.unwrap();
assert!(images.extract_images);
assert_eq!(images.target_dpi, 150);
}
#[test]
fn test_token_reduction_applied() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
token_reduction: Some(ReductionLevelArg::Aggressive),
..default_overrides()
};
overrides.apply(&mut config);
let reduction = config.token_reduction.unwrap();
assert_eq!(reduction.mode, "aggressive");
}
#[test]
fn test_msg_codepage_applied() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
msg_codepage: Some(1251),
..default_overrides()
};
overrides.apply(&mut config);
let email = config.email.unwrap();
assert_eq!(email.msg_fallback_codepage, Some(1251));
}
#[test]
fn test_max_concurrent_applied() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
max_concurrent: Some(4),
..default_overrides()
};
overrides.apply(&mut config);
assert_eq!(config.max_concurrent_extractions, Some(4));
}
#[test]
fn test_max_threads_applied() {
let mut config = ExtractionConfig::default();
let overrides = ExtractionOverrides {
max_threads: Some(2),
..default_overrides()
};
overrides.apply(&mut config);
let concurrency = config.concurrency.unwrap();
assert_eq!(concurrency.max_threads, Some(2));
}
#[test]
fn test_include_structure_applied() {
let mut config = ExtractionConfig::default();
assert!(!config.include_document_structure);
let overrides = ExtractionOverrides {
include_structure: Some(true),
..default_overrides()
};
overrides.apply(&mut config);
assert!(config.include_document_structure);
}
#[test]
fn test_validate_invalid_ocr_backend() {
let overrides = ExtractionOverrides {
ocr_backend: Some("invalid-backend".to_string()),
..default_overrides()
};
let err = overrides.validate().unwrap_err();
assert!(err.to_string().contains("Invalid OCR backend"));
}
#[test]
fn test_validate_max_concurrent_zero() {
let overrides = ExtractionOverrides {
max_concurrent: Some(0),
..default_overrides()
};
let err = overrides.validate().unwrap_err();
assert!(err.to_string().contains("--max-concurrent must be at least 1"));
}
#[test]
fn test_validate_max_threads_zero() {
let overrides = ExtractionOverrides {
max_threads: Some(0),
..default_overrides()
};
let err = overrides.validate().unwrap_err();
assert!(err.to_string().contains("--max-threads must be at least 1"));
}
#[test]
fn test_validate_valid_ocr_backends() {
for backend in &["tesseract", "paddle-ocr", "easyocr"] {
let overrides = ExtractionOverrides {
ocr_backend: Some(backend.to_string()),
..default_overrides()
};
assert!(overrides.validate().is_ok(), "Expected backend '{backend}' to be valid");
}
}
#[test]
fn test_chunk_overlap_clamped_on_existing_config() {
let mut config = ExtractionConfig {
chunking: Some(ChunkingConfig {
max_characters: 800,
overlap: 100,
..Default::default()
}),
..Default::default()
};
let overrides = ExtractionOverrides {
chunk_overlap: Some(1500),
..default_overrides()
};
overrides.apply(&mut config);
let chunking = config.chunking.unwrap();
assert_eq!(chunking.overlap, 800 / 4);
assert_eq!(chunking.max_characters, 800);
}
#[test]
fn test_chunk_overlap_valid_on_existing_config() {
let mut config = ExtractionConfig {
chunking: Some(ChunkingConfig {
max_characters: 800,
overlap: 100,
..Default::default()
}),
..Default::default()
};
let overrides = ExtractionOverrides {
chunk_overlap: Some(200),
..default_overrides()
};
overrides.apply(&mut config);
let chunking = config.chunking.unwrap();
assert_eq!(chunking.overlap, 200);
assert_eq!(chunking.max_characters, 800);
}
#[cfg(not(feature = "chunking-tokenizers"))]
#[test]
fn test_validate_chunking_tokenizer_requires_feature() {
let overrides = ExtractionOverrides {
chunking_tokenizer: Some("Xenova/gpt-4o".to_string()),
..default_overrides()
};
let err = overrides.validate().unwrap_err();
assert!(
err.to_string()
.contains("--chunking-tokenizer requires the chunking-tokenizers feature")
);
}
#[test]
fn test_no_overrides_leaves_config_unchanged() {
let original = ExtractionConfig::default();
let mut config = original.clone();
let overrides = default_overrides();
overrides.apply(&mut config);
assert!(config.ocr.is_none());
assert!(config.chunking.is_none());
assert!(config.use_cache);
assert!(config.enable_quality_processing);
assert!(!config.force_ocr);
assert!(config.language_detection.is_none());
assert!(config.pages.is_none());
assert!(config.images.is_none());
assert!(config.token_reduction.is_none());
assert!(config.email.is_none());
assert!(config.acceleration.is_none());
assert!(config.concurrency.is_none());
assert!(!config.include_document_structure);
}
}