use std::path::PathBuf;
#[derive(Debug, Clone, PartialEq)]
pub enum DetResizeStrategy {
MaxSide {
max_side: u32,
},
MinSide {
min_side: u32,
max_side_limit: u32,
},
}
impl Default for DetResizeStrategy {
fn default() -> Self {
DetResizeStrategy::MaxSide { max_side: 960 }
}
}
#[derive(Debug, Clone)]
pub struct OcrConfig {
pub det_threshold: f32,
pub box_threshold: f32,
pub rec_threshold: f32,
pub det_max_side: u32,
pub det_resize_strategy: DetResizeStrategy,
pub rec_target_height: u32,
pub num_threads: usize,
pub unclip_ratio: f32,
pub max_candidates: usize,
pub detect_styles: bool,
pub det_model_path: Option<PathBuf>,
pub rec_model_path: Option<PathBuf>,
pub dict_path: Option<PathBuf>,
}
impl Default for OcrConfig {
fn default() -> Self {
Self {
det_threshold: 0.3,
box_threshold: 0.6,
rec_threshold: 0.5,
det_max_side: 960,
det_resize_strategy: DetResizeStrategy::default(),
rec_target_height: 48,
num_threads: 4,
unclip_ratio: 1.5,
max_candidates: 1000,
detect_styles: true,
det_model_path: None,
rec_model_path: None,
dict_path: None,
}
}
}
impl OcrConfig {
pub fn new() -> Self {
Self::default()
}
pub fn v5() -> Self {
Self {
det_resize_strategy: DetResizeStrategy::MinSide {
min_side: 64,
max_side_limit: 4000,
},
box_threshold: 0.6,
..Default::default()
}
}
pub fn builder() -> OcrConfigBuilder {
OcrConfigBuilder::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct OcrConfigBuilder {
config: OcrConfig,
}
impl OcrConfigBuilder {
pub fn new() -> Self {
Self {
config: OcrConfig::default(),
}
}
pub fn det_threshold(mut self, threshold: f32) -> Self {
self.config.det_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn box_threshold(mut self, threshold: f32) -> Self {
self.config.box_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn rec_threshold(mut self, threshold: f32) -> Self {
self.config.rec_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn det_max_side(mut self, max_side: u32) -> Self {
let max_side = max_side.max(32);
self.config.det_max_side = max_side;
self.config.det_resize_strategy = DetResizeStrategy::MaxSide { max_side };
self
}
pub fn det_resize_strategy(mut self, strategy: DetResizeStrategy) -> Self {
self.config.det_resize_strategy = strategy;
self
}
pub fn rec_target_height(mut self, height: u32) -> Self {
self.config.rec_target_height = height.max(16);
self
}
pub fn num_threads(mut self, threads: usize) -> Self {
self.config.num_threads = threads.max(1);
self
}
pub fn unclip_ratio(mut self, ratio: f32) -> Self {
self.config.unclip_ratio = ratio.max(1.0);
self
}
pub fn max_candidates(mut self, max: usize) -> Self {
self.config.max_candidates = max.max(1);
self
}
pub fn detect_styles(mut self, detect: bool) -> Self {
self.config.detect_styles = detect;
self
}
pub fn det_model_path(mut self, path: impl Into<PathBuf>) -> Self {
self.config.det_model_path = Some(path.into());
self
}
pub fn rec_model_path(mut self, path: impl Into<PathBuf>) -> Self {
self.config.rec_model_path = Some(path.into());
self
}
pub fn dict_path(mut self, path: impl Into<PathBuf>) -> Self {
self.config.dict_path = Some(path.into());
self
}
pub fn build(self) -> OcrConfig {
self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = OcrConfig::default();
assert!((config.det_threshold - 0.3).abs() < f32::EPSILON);
assert!((config.box_threshold - 0.6).abs() < f32::EPSILON);
assert_eq!(config.num_threads, 4);
assert!(config.detect_styles);
assert_eq!(config.det_resize_strategy, DetResizeStrategy::MaxSide { max_side: 960 });
}
#[test]
fn test_v5_config() {
let config = OcrConfig::v5();
assert!(matches!(
config.det_resize_strategy,
DetResizeStrategy::MinSide {
min_side: 64,
max_side_limit: 4000
}
));
}
#[test]
fn test_builder() {
let config = OcrConfig::builder()
.det_threshold(0.5)
.num_threads(8)
.detect_styles(false)
.build();
assert!((config.det_threshold - 0.5).abs() < f32::EPSILON);
assert_eq!(config.num_threads, 8);
assert!(!config.detect_styles);
}
#[test]
fn test_builder_clamping() {
let config = OcrConfig::builder()
.det_threshold(1.5) .num_threads(0) .build();
assert!((config.det_threshold - 1.0).abs() < f32::EPSILON);
assert_eq!(config.num_threads, 1);
}
}