use serde::{Deserialize, Serialize};
use super::formats::OutputFormat;
use crate::core::config_validation::validate_ocr_backend;
use crate::error::KreuzbergError;
use crate::types::OcrElementConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrQualityThresholds {
#[serde(default = "default_min_total_non_whitespace")]
pub min_total_non_whitespace: usize,
#[serde(default = "default_min_non_whitespace_per_page")]
pub min_non_whitespace_per_page: f64,
#[serde(default = "default_min_meaningful_word_len")]
pub min_meaningful_word_len: usize,
#[serde(default = "default_min_meaningful_words")]
pub min_meaningful_words: usize,
#[serde(default = "default_min_alnum_ratio")]
pub min_alnum_ratio: f64,
#[serde(default = "default_min_garbage_chars")]
pub min_garbage_chars: usize,
#[serde(default = "default_max_fragmented_word_ratio")]
pub max_fragmented_word_ratio: f64,
#[serde(default = "default_critical_fragmented_word_ratio")]
pub critical_fragmented_word_ratio: f64,
#[serde(default = "default_min_avg_word_length")]
pub min_avg_word_length: f64,
#[serde(default = "default_min_words_for_avg_length_check")]
pub min_words_for_avg_length_check: usize,
#[serde(default = "default_min_consecutive_repeat_ratio")]
pub min_consecutive_repeat_ratio: f64,
#[serde(default = "default_min_words_for_repeat_check")]
pub min_words_for_repeat_check: usize,
#[serde(default = "default_substantive_min_chars")]
pub substantive_min_chars: usize,
#[serde(default = "default_non_text_min_chars")]
pub non_text_min_chars: usize,
#[serde(default = "default_alnum_ws_ratio_threshold")]
pub alnum_ws_ratio_threshold: f64,
#[serde(default = "default_pipeline_min_quality")]
pub pipeline_min_quality: f64,
}
impl Default for OcrQualityThresholds {
fn default() -> Self {
Self {
min_total_non_whitespace: 64,
min_non_whitespace_per_page: 32.0,
min_meaningful_word_len: 4,
min_meaningful_words: 3,
min_alnum_ratio: 0.3,
min_garbage_chars: 5,
max_fragmented_word_ratio: 0.6,
critical_fragmented_word_ratio: 0.80,
min_avg_word_length: 2.0,
min_words_for_avg_length_check: 50,
min_consecutive_repeat_ratio: 0.08,
min_words_for_repeat_check: 50,
substantive_min_chars: 100,
non_text_min_chars: 20,
alnum_ws_ratio_threshold: 0.4,
pipeline_min_quality: 0.5,
}
}
}
fn default_min_total_non_whitespace() -> usize {
64
}
fn default_min_non_whitespace_per_page() -> f64 {
32.0
}
fn default_min_meaningful_word_len() -> usize {
4
}
fn default_min_meaningful_words() -> usize {
3
}
fn default_min_alnum_ratio() -> f64 {
0.3
}
fn default_min_garbage_chars() -> usize {
5
}
fn default_max_fragmented_word_ratio() -> f64 {
0.6
}
fn default_critical_fragmented_word_ratio() -> f64 {
0.80
}
fn default_min_avg_word_length() -> f64 {
2.0
}
fn default_min_words_for_avg_length_check() -> usize {
50
}
fn default_min_consecutive_repeat_ratio() -> f64 {
0.08
}
fn default_min_words_for_repeat_check() -> usize {
50
}
fn default_substantive_min_chars() -> usize {
100
}
fn default_non_text_min_chars() -> usize {
20
}
fn default_alnum_ws_ratio_threshold() -> f64 {
0.4
}
fn default_pipeline_min_quality() -> f64 {
0.5
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrPipelineStage {
pub backend: String,
#[serde(default = "default_priority")]
pub priority: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tesseract_config: Option<crate::types::TesseractConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub paddle_ocr_config: Option<serde_json::Value>,
}
fn default_priority() -> u32 {
100
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrPipelineConfig {
pub stages: Vec<OcrPipelineStage>,
#[serde(default)]
pub quality_thresholds: OcrQualityThresholds,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrConfig {
#[serde(default = "default_tesseract_backend")]
pub backend: String,
#[serde(default = "default_eng")]
pub language: String,
#[serde(default)]
pub tesseract_config: Option<crate::types::TesseractConfig>,
#[serde(default)]
pub output_format: Option<OutputFormat>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub paddle_ocr_config: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub element_config: Option<OcrElementConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub quality_thresholds: Option<OcrQualityThresholds>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pipeline: Option<OcrPipelineConfig>,
#[serde(default)]
pub auto_rotate: bool,
}
impl Default for OcrConfig {
fn default() -> Self {
Self {
backend: default_tesseract_backend(),
language: default_eng(),
tesseract_config: None,
output_format: None,
paddle_ocr_config: None,
element_config: None,
quality_thresholds: None,
pipeline: None,
auto_rotate: false,
}
}
}
impl OcrConfig {
pub fn validate(&self) -> Result<(), KreuzbergError> {
validate_ocr_backend(&self.backend)?;
if let Some(ref pipeline) = self.pipeline {
for stage in &pipeline.stages {
validate_ocr_backend(&stage.backend)?;
}
}
Ok(())
}
pub fn effective_thresholds(&self) -> OcrQualityThresholds {
self.quality_thresholds.clone().unwrap_or_default()
}
pub fn effective_pipeline(&self) -> Option<OcrPipelineConfig> {
if self.pipeline.is_some() {
return self.pipeline.clone();
}
#[cfg(feature = "paddle-ocr")]
{
let mut stages = vec![OcrPipelineStage {
backend: self.backend.clone(),
priority: 100,
language: None,
tesseract_config: self.tesseract_config.clone(),
paddle_ocr_config: None,
}];
if self.backend != "paddleocr" {
stages.push(OcrPipelineStage {
backend: "paddleocr".to_string(),
priority: 50,
language: None,
tesseract_config: None,
paddle_ocr_config: self.paddle_ocr_config.clone(),
});
}
Some(OcrPipelineConfig {
stages,
quality_thresholds: self.effective_thresholds(),
})
}
#[cfg(not(feature = "paddle-ocr"))]
{
None
}
}
}
fn default_tesseract_backend() -> String {
"tesseract".to_string()
}
fn default_eng() -> String {
"eng".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ocr_config_default() {
let config = OcrConfig::default();
assert_eq!(config.backend, "tesseract");
assert_eq!(config.language, "eng");
assert!(config.tesseract_config.is_none());
assert!(config.output_format.is_none());
}
#[test]
fn test_ocr_config_with_tesseract() {
let config = OcrConfig {
backend: "tesseract".to_string(),
language: "fra".to_string(),
..Default::default()
};
assert_eq!(config.backend, "tesseract");
assert_eq!(config.language, "fra");
}
#[test]
fn test_validate_tesseract_backend() {
let config = OcrConfig {
backend: "tesseract".to_string(),
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_easyocr_backend() {
let config = OcrConfig {
backend: "easyocr".to_string(),
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_paddleocr_backend() {
let config = OcrConfig {
backend: "paddleocr".to_string(),
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_invalid_backend_typo() {
let config = OcrConfig {
backend: "tesseract_typo".to_string(),
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Invalid OCR backend"));
}
#[test]
fn test_validate_invalid_backend_completely_wrong() {
let config = OcrConfig {
backend: "ocr_lib".to_string(),
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Invalid OCR backend") || err_msg.contains("Valid options are"));
}
#[test]
fn test_validate_default_backend() {
let config = OcrConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_effective_pipeline_explicit_pipeline_returned_unchanged() {
let explicit_pipeline = OcrPipelineConfig {
stages: vec![OcrPipelineStage {
backend: "easyocr".to_string(),
priority: 200,
language: Some("fra".to_string()),
tesseract_config: None,
paddle_ocr_config: None,
}],
quality_thresholds: OcrQualityThresholds::default(),
};
let config = OcrConfig {
pipeline: Some(explicit_pipeline.clone()),
..Default::default()
};
let result = config.effective_pipeline().unwrap();
assert_eq!(result.stages.len(), 1);
assert_eq!(result.stages[0].backend, "easyocr");
assert_eq!(result.stages[0].priority, 200);
assert_eq!(result.stages[0].language, Some("fra".to_string()));
}
#[test]
fn test_effective_pipeline_paddleocr_backend_no_duplicate() {
let config = OcrConfig {
backend: "paddleocr".to_string(),
..Default::default()
};
let result = config.effective_pipeline();
#[cfg(feature = "paddle-ocr")]
{
let pipeline = result.unwrap();
let paddle_count = pipeline.stages.iter().filter(|s| s.backend == "paddleocr").count();
assert_eq!(
paddle_count, 1,
"Should not have duplicate paddleocr stages, found {paddle_count}"
);
}
#[cfg(not(feature = "paddle-ocr"))]
{
assert!(result.is_none());
}
}
#[test]
fn test_effective_pipeline_default_tesseract_backend() {
let config = OcrConfig::default();
let result = config.effective_pipeline();
#[cfg(feature = "paddle-ocr")]
{
let pipeline = result.unwrap();
assert_eq!(pipeline.stages.len(), 2);
assert_eq!(pipeline.stages[0].backend, "tesseract");
assert_eq!(pipeline.stages[0].priority, 100);
assert_eq!(pipeline.stages[1].backend, "paddleocr");
assert_eq!(pipeline.stages[1].priority, 50);
}
#[cfg(not(feature = "paddle-ocr"))]
{
assert!(result.is_none());
}
}
#[test]
fn test_effective_thresholds_custom_vs_default() {
let custom = OcrQualityThresholds {
min_total_non_whitespace: 128,
min_meaningful_words: 10,
..Default::default()
};
let config_custom = OcrConfig {
quality_thresholds: Some(custom.clone()),
..Default::default()
};
let eff = config_custom.effective_thresholds();
assert_eq!(eff.min_total_non_whitespace, 128);
assert_eq!(eff.min_meaningful_words, 10);
let config_default = OcrConfig::default();
let eff_default = config_default.effective_thresholds();
assert_eq!(eff_default.min_total_non_whitespace, 64);
assert_eq!(eff_default.min_meaningful_words, 3);
}
#[test]
fn test_pipeline_config_serde_roundtrip() {
let pipeline = OcrPipelineConfig {
stages: vec![
OcrPipelineStage {
backend: "tesseract".to_string(),
priority: 100,
language: Some("eng".to_string()),
tesseract_config: None,
paddle_ocr_config: None,
},
OcrPipelineStage {
backend: "paddleocr".to_string(),
priority: 50,
language: None,
tesseract_config: None,
paddle_ocr_config: Some(serde_json::json!({"use_gpu": false})),
},
],
quality_thresholds: OcrQualityThresholds::default(),
};
let json = serde_json::to_string(&pipeline).unwrap();
let deserialized: OcrPipelineConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.stages.len(), 2);
assert_eq!(deserialized.stages[0].backend, "tesseract");
assert_eq!(deserialized.stages[0].priority, 100);
assert_eq!(deserialized.stages[1].backend, "paddleocr");
assert_eq!(deserialized.stages[1].priority, 50);
assert!(deserialized.stages[1].paddle_ocr_config.is_some());
}
#[test]
fn test_pipeline_stage_deserialization_missing_optional_fields() {
let json = r#"{"backend": "tesseract"}"#;
let stage: OcrPipelineStage = serde_json::from_str(json).unwrap();
assert_eq!(stage.backend, "tesseract");
assert_eq!(stage.priority, 100); assert!(stage.language.is_none());
assert!(stage.tesseract_config.is_none());
assert!(stage.paddle_ocr_config.is_none());
}
#[test]
fn test_pipeline_stage_default_priority_is_100() {
let json = r#"{"backend": "easyocr"}"#;
let stage: OcrPipelineStage = serde_json::from_str(json).unwrap();
assert_eq!(stage.priority, 100);
}
#[test]
fn test_ocr_config_deserialization_missing_optional_fields() {
let json = r#"{}"#;
let config: OcrConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.backend, "tesseract");
assert_eq!(config.language, "eng");
assert!(config.pipeline.is_none());
assert!(config.quality_thresholds.is_none());
assert!(config.element_config.is_none());
}
#[test]
fn test_quality_thresholds_deserialization_partial() {
let json = r#"{"min_total_non_whitespace": 256}"#;
let thresholds: OcrQualityThresholds = serde_json::from_str(json).unwrap();
assert_eq!(thresholds.min_total_non_whitespace, 256);
assert_eq!(thresholds.min_meaningful_words, 3);
assert_eq!(thresholds.min_garbage_chars, 5);
assert!((thresholds.pipeline_min_quality - 0.5).abs() < f64::EPSILON);
}
#[test]
fn test_validate_catches_invalid_pipeline_stage_backend() {
let config = OcrConfig {
pipeline: Some(OcrPipelineConfig {
stages: vec![
OcrPipelineStage {
backend: "tesseract".to_string(),
priority: 100,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
},
OcrPipelineStage {
backend: "invalid_backend".to_string(),
priority: 50,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
},
],
quality_thresholds: OcrQualityThresholds::default(),
}),
..Default::default()
};
let result = config.validate();
assert!(result.is_err(), "Should catch invalid backend in pipeline stages");
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Invalid OCR backend") || err_msg.contains("invalid_backend"));
}
#[test]
fn test_validate_passes_with_valid_pipeline_stages() {
let config = OcrConfig {
pipeline: Some(OcrPipelineConfig {
stages: vec![
OcrPipelineStage {
backend: "tesseract".to_string(),
priority: 100,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
},
OcrPipelineStage {
backend: "paddleocr".to_string(),
priority: 50,
language: None,
tesseract_config: None,
paddle_ocr_config: None,
},
],
quality_thresholds: OcrQualityThresholds::default(),
}),
..Default::default()
};
assert!(config.validate().is_ok());
}
}