use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct PaddleOcrConfig {
pub language: String,
pub cache_dir: Option<PathBuf>,
pub use_angle_cls: bool,
pub enable_table_detection: bool,
pub det_db_thresh: f32,
pub det_db_box_thresh: f32,
pub det_db_unclip_ratio: f32,
pub det_limit_side_len: u32,
pub rec_batch_num: u32,
}
impl PaddleOcrConfig {
pub fn new(language: impl Into<String>) -> Self {
Self {
language: language.into(),
cache_dir: None,
use_angle_cls: true,
enable_table_detection: false,
det_db_thresh: 0.3,
det_db_box_thresh: 0.5,
det_db_unclip_ratio: 1.6,
det_limit_side_len: 960,
rec_batch_num: 6,
}
}
pub fn with_cache_dir(mut self, path: PathBuf) -> Self {
self.cache_dir = Some(path);
self
}
pub fn with_table_detection(mut self, enable: bool) -> Self {
self.enable_table_detection = enable;
self
}
pub fn with_angle_cls(mut self, enable: bool) -> Self {
self.use_angle_cls = enable;
self
}
pub fn with_det_db_thresh(mut self, threshold: f32) -> Self {
self.det_db_thresh = threshold.clamp(0.0, 1.0);
self
}
pub fn with_det_db_box_thresh(mut self, threshold: f32) -> Self {
self.det_db_box_thresh = threshold.clamp(0.0, 1.0);
self
}
pub fn with_det_db_unclip_ratio(mut self, ratio: f32) -> Self {
self.det_db_unclip_ratio = ratio.clamp(1.0, 3.0);
self
}
pub fn with_det_limit_side_len(mut self, length: u32) -> Self {
self.det_limit_side_len = length.clamp(64, 4096);
self
}
pub fn with_rec_batch_num(mut self, batch_size: u32) -> Self {
self.rec_batch_num = batch_size.clamp(1, 64);
self
}
pub fn resolve_cache_dir(&self) -> PathBuf {
if let Some(path) = &self.cache_dir {
return path.clone();
}
if let Ok(env_path) = std::env::var("KREUZBERG_CACHE_DIR") {
return PathBuf::from(env_path).join("paddle-ocr");
}
std::env::current_dir()
.unwrap_or_else(|_| PathBuf::from("."))
.join(".kreuzberg")
.join("paddle-ocr")
}
}
impl Default for PaddleOcrConfig {
fn default() -> Self {
Self::new("en")
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PaddleLanguage {
English,
Chinese,
Japanese,
Korean,
German,
French,
}
impl PaddleLanguage {
pub fn code(&self) -> &'static str {
match self {
Self::English => "en",
Self::Chinese => "ch",
Self::Japanese => "jpn",
Self::Korean => "kor",
Self::German => "deu",
Self::French => "fra",
}
}
pub fn from_code(code: &str) -> Option<Self> {
match code {
"en" => Some(Self::English),
"ch" => Some(Self::Chinese),
"jpn" => Some(Self::Japanese),
"kor" => Some(Self::Korean),
"deu" => Some(Self::German),
"fra" => Some(Self::French),
_ => None,
}
}
}
impl std::fmt::Display for PaddleLanguage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::English => write!(f, "English"),
Self::Chinese => write!(f, "Chinese"),
Self::Japanese => write!(f, "Japanese"),
Self::Korean => write!(f, "Korean"),
Self::German => write!(f, "German"),
Self::French => write!(f, "French"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_config() {
let config = PaddleOcrConfig::new("en");
assert_eq!(config.language, "en");
assert!(config.use_angle_cls);
assert!(!config.enable_table_detection);
}
#[test]
fn test_default_config() {
let config = PaddleOcrConfig::default();
assert_eq!(config.language, "en");
assert_eq!(config.det_db_thresh, 0.3);
assert_eq!(config.det_db_box_thresh, 0.5);
assert_eq!(config.det_db_unclip_ratio, 1.6);
assert_eq!(config.det_limit_side_len, 960);
assert_eq!(config.rec_batch_num, 6);
}
#[test]
fn test_builder_pattern() {
let config = PaddleOcrConfig::new("ch")
.with_angle_cls(false)
.with_table_detection(true)
.with_det_db_thresh(0.4)
.with_rec_batch_num(12);
assert_eq!(config.language, "ch");
assert!(!config.use_angle_cls);
assert!(config.enable_table_detection);
assert_eq!(config.det_db_thresh, 0.4);
assert_eq!(config.rec_batch_num, 12);
}
#[test]
fn test_with_cache_dir() {
let cache_path = PathBuf::from("/tmp/cache");
let config = PaddleOcrConfig::new("en").with_cache_dir(cache_path.clone());
assert_eq!(config.cache_dir, Some(cache_path));
}
#[test]
fn test_resolve_cache_dir_explicit() {
let cache_path = PathBuf::from("/tmp/explicit");
let config = PaddleOcrConfig::new("en").with_cache_dir(cache_path.clone());
assert_eq!(config.resolve_cache_dir(), cache_path);
}
#[test]
fn test_resolve_cache_dir_default() {
let config = PaddleOcrConfig::new("en");
let cache_dir = config.resolve_cache_dir();
assert!(cache_dir.to_string_lossy().contains(".kreuzberg"));
assert!(cache_dir.to_string_lossy().contains("paddle-ocr"));
}
#[test]
fn test_paddle_language_code() {
assert_eq!(PaddleLanguage::English.code(), "en");
assert_eq!(PaddleLanguage::Chinese.code(), "ch");
assert_eq!(PaddleLanguage::Japanese.code(), "jpn");
assert_eq!(PaddleLanguage::Korean.code(), "kor");
assert_eq!(PaddleLanguage::German.code(), "deu");
assert_eq!(PaddleLanguage::French.code(), "fra");
}
#[test]
fn test_paddle_language_from_code() {
assert_eq!(PaddleLanguage::from_code("en"), Some(PaddleLanguage::English));
assert_eq!(PaddleLanguage::from_code("ch"), Some(PaddleLanguage::Chinese));
assert_eq!(PaddleLanguage::from_code("jpn"), Some(PaddleLanguage::Japanese));
assert_eq!(PaddleLanguage::from_code("kor"), Some(PaddleLanguage::Korean));
assert_eq!(PaddleLanguage::from_code("deu"), Some(PaddleLanguage::German));
assert_eq!(PaddleLanguage::from_code("fra"), Some(PaddleLanguage::French));
assert_eq!(PaddleLanguage::from_code("unknown"), None);
}
#[test]
fn test_paddle_language_display() {
assert_eq!(PaddleLanguage::English.to_string(), "English");
assert_eq!(PaddleLanguage::Chinese.to_string(), "Chinese");
assert_eq!(PaddleLanguage::Japanese.to_string(), "Japanese");
}
#[test]
fn test_threshold_values() {
let config = PaddleOcrConfig::new("en")
.with_det_db_thresh(0.25)
.with_det_db_box_thresh(0.6)
.with_det_db_unclip_ratio(1.8);
assert_eq!(config.det_db_thresh, 0.25);
assert_eq!(config.det_db_box_thresh, 0.6);
assert_eq!(config.det_db_unclip_ratio, 1.8);
}
#[test]
fn test_side_length_and_batch() {
let config = PaddleOcrConfig::new("en")
.with_det_limit_side_len(1280)
.with_rec_batch_num(8);
assert_eq!(config.det_limit_side_len, 1280);
assert_eq!(config.rec_batch_num, 8);
}
#[test]
fn test_serialization() {
let config = PaddleOcrConfig::new("ch")
.with_table_detection(true)
.with_angle_cls(false);
let json = serde_json::to_string(&config).expect("Failed to serialize");
let deserialized: PaddleOcrConfig = serde_json::from_str(&json).expect("Failed to deserialize");
assert_eq!(deserialized.language, config.language);
assert_eq!(deserialized.enable_table_detection, config.enable_table_detection);
assert_eq!(deserialized.use_angle_cls, config.use_angle_cls);
}
}