use image::{DynamicImage, GenericImageView};
use ndarray::ArrayD;
use std::path::Path;
use crate::error::{OcrError, OcrResult};
use crate::mnn::{InferenceConfig, InferenceEngine};
use crate::postprocess::{extract_boxes_with_unclip, TextBox};
use crate::preprocess::{preprocess_for_det, NormalizeParams};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DetPrecisionMode {
#[default]
Fast,
}
#[derive(Debug, Clone)]
pub struct DetOptions {
pub max_side_len: u32,
pub box_threshold: f32,
pub unclip_ratio: f32,
pub score_threshold: f32,
pub min_area: u32,
pub box_border: u32,
pub merge_boxes: bool,
pub merge_threshold: i32,
pub precision_mode: DetPrecisionMode,
pub multi_scales: Vec<f32>,
pub block_size: u32,
pub block_overlap: u32,
pub nms_threshold: f32,
}
impl Default for DetOptions {
fn default() -> Self {
Self {
max_side_len: 960,
box_threshold: 0.5,
unclip_ratio: 1.5,
score_threshold: 0.3,
min_area: 16,
box_border: 5,
merge_boxes: false,
merge_threshold: 10,
precision_mode: DetPrecisionMode::Fast,
multi_scales: vec![0.5, 1.0, 1.5],
block_size: 640,
block_overlap: 100,
nms_threshold: 0.3,
}
}
}
impl DetOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_side_len(mut self, len: u32) -> Self {
self.max_side_len = len;
self
}
pub fn with_box_threshold(mut self, threshold: f32) -> Self {
self.box_threshold = threshold;
self
}
pub fn with_score_threshold(mut self, threshold: f32) -> Self {
self.score_threshold = threshold;
self
}
pub fn with_min_area(mut self, area: u32) -> Self {
self.min_area = area;
self
}
pub fn with_box_border(mut self, border: u32) -> Self {
self.box_border = border;
self
}
pub fn with_merge_boxes(mut self, merge: bool) -> Self {
self.merge_boxes = merge;
self
}
pub fn with_merge_threshold(mut self, threshold: i32) -> Self {
self.merge_threshold = threshold;
self
}
pub fn with_precision_mode(mut self, mode: DetPrecisionMode) -> Self {
self.precision_mode = mode;
self
}
pub fn with_multi_scales(mut self, scales: Vec<f32>) -> Self {
self.multi_scales = scales;
self
}
pub fn with_block_size(mut self, size: u32) -> Self {
self.block_size = size;
self
}
pub fn fast() -> Self {
Self {
max_side_len: 960,
precision_mode: DetPrecisionMode::Fast,
..Default::default()
}
}
}
pub struct DetModel {
engine: InferenceEngine,
options: DetOptions,
normalize_params: NormalizeParams,
}
impl DetModel {
pub fn from_file(
model_path: impl AsRef<Path>,
config: Option<InferenceConfig>,
) -> OcrResult<Self> {
let engine = InferenceEngine::from_file(model_path, config)?;
Ok(Self {
engine,
options: DetOptions::default(),
normalize_params: NormalizeParams::paddle_det(),
})
}
pub fn from_bytes(model_bytes: &[u8], config: Option<InferenceConfig>) -> OcrResult<Self> {
let engine = InferenceEngine::from_buffer(model_bytes, config)?;
Ok(Self {
engine,
options: DetOptions::default(),
normalize_params: NormalizeParams::paddle_det(),
})
}
pub fn with_options(mut self, options: DetOptions) -> Self {
self.options = options;
self
}
pub fn options(&self) -> &DetOptions {
&self.options
}
pub fn options_mut(&mut self) -> &mut DetOptions {
&mut self.options
}
pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
self.detect_fast(image)
}
pub fn detect_and_crop(&self, image: &DynamicImage) -> OcrResult<Vec<(DynamicImage, TextBox)>> {
let boxes = self.detect(image)?;
let (width, height) = image.dimensions();
let mut results = Vec::with_capacity(boxes.len());
for text_box in boxes {
let expanded = text_box.expand(self.options.box_border, width, height);
let cropped = image.crop_imm(
expanded.rect.left() as u32,
expanded.rect.top() as u32,
expanded.rect.width(),
expanded.rect.height(),
);
results.push((cropped, expanded));
}
Ok(results)
}
fn detect_fast(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
let (original_width, original_height) = image.dimensions();
let scaled = self.scale_image(image);
let (scaled_width, scaled_height) = scaled.dimensions();
let input = preprocess_for_det(&scaled, &self.normalize_params)?;
let output = self.engine.run_dynamic(input.view().into_dyn())?;
let output_shape = output.shape();
let out_w = output_shape[3] as u32;
let out_h = output_shape[2] as u32;
let boxes = self.postprocess_output(
&output,
out_w,
out_h,
scaled_width,
scaled_height,
original_width,
original_height,
)?;
Ok(boxes)
}
fn scale_image(&self, image: &DynamicImage) -> DynamicImage {
let (w, h) = image.dimensions();
let max_dim = w.max(h);
if max_dim <= self.options.max_side_len {
return image.clone();
}
let scale = self.options.max_side_len as f64 / max_dim as f64;
let new_w = (w as f64 * scale).round() as u32;
let new_h = (h as f64 * scale).round() as u32;
image.resize_exact(new_w, new_h, image::imageops::FilterType::Lanczos3)
}
fn postprocess_output(
&self,
output: &ArrayD<f32>,
out_w: u32,
out_h: u32,
scaled_width: u32,
scaled_height: u32,
original_width: u32,
original_height: u32,
) -> OcrResult<Vec<TextBox>> {
let output_shape = output.shape();
if output_shape.len() < 3 {
return Err(OcrError::PostprocessError(
"Detection model output shape invalid".to_string(),
));
}
let mask_data: Vec<f32> = output.iter().cloned().collect();
let binary_mask: Vec<u8> = mask_data
.iter()
.map(|&v| {
if v > self.options.score_threshold {
255u8
} else {
0u8
}
})
.collect();
let boxes = extract_boxes_with_unclip(
&binary_mask,
out_w,
out_h,
scaled_width,
scaled_height,
original_width,
original_height,
self.options.min_area,
self.options.unclip_ratio,
);
Ok(boxes)
}
}
impl DetModel {
pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
Ok(self.engine.run_dynamic(input)?)
}
pub fn input_shape(&self) -> &[usize] {
self.engine.input_shape()
}
pub fn output_shape(&self) -> &[usize] {
self.engine.output_shape()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_det_options_default() {
let opts = DetOptions::default();
assert_eq!(opts.max_side_len, 960);
assert_eq!(opts.box_threshold, 0.5);
assert_eq!(opts.unclip_ratio, 1.5);
assert_eq!(opts.score_threshold, 0.3);
assert_eq!(opts.min_area, 16);
assert_eq!(opts.box_border, 5);
assert!(!opts.merge_boxes);
assert_eq!(opts.merge_threshold, 10);
assert_eq!(opts.precision_mode, DetPrecisionMode::Fast);
assert_eq!(opts.nms_threshold, 0.3);
}
#[test]
fn test_det_options_fast() {
let opts = DetOptions::fast();
assert_eq!(opts.max_side_len, 960);
assert_eq!(opts.precision_mode, DetPrecisionMode::Fast);
}
#[test]
fn test_det_options_builder() {
let opts = DetOptions::new()
.with_max_side_len(1280)
.with_box_threshold(0.6)
.with_score_threshold(0.4)
.with_min_area(32)
.with_box_border(10)
.with_merge_boxes(true)
.with_merge_threshold(20)
.with_precision_mode(DetPrecisionMode::Fast)
.with_multi_scales(vec![0.5, 1.0, 1.5])
.with_block_size(800);
assert_eq!(opts.max_side_len, 1280);
assert_eq!(opts.box_threshold, 0.6);
assert_eq!(opts.score_threshold, 0.4);
assert_eq!(opts.min_area, 32);
assert_eq!(opts.box_border, 10);
assert!(opts.merge_boxes);
assert_eq!(opts.merge_threshold, 20);
assert_eq!(opts.precision_mode, DetPrecisionMode::Fast);
assert_eq!(opts.multi_scales, vec![0.5, 1.0, 1.5]);
assert_eq!(opts.block_size, 800);
}
#[test]
fn test_det_precision_mode_default() {
let mode = DetPrecisionMode::default();
assert_eq!(mode, DetPrecisionMode::Fast);
}
#[test]
fn test_det_precision_mode_equality() {
assert_eq!(DetPrecisionMode::Fast, DetPrecisionMode::Fast);
}
#[test]
fn test_det_options_chaining() {
let opts = DetOptions::new()
.with_max_side_len(1000)
.with_box_threshold(0.7);
assert_eq!(opts.max_side_len, 1000);
assert_eq!(opts.box_threshold, 0.7);
assert_eq!(opts.score_threshold, 0.3);
}
#[test]
fn test_det_options_presets_are_valid() {
let fast = DetOptions::fast();
assert!(fast.box_threshold >= 0.0 && fast.box_threshold <= 1.0);
assert!(fast.score_threshold >= 0.0 && fast.score_threshold <= 1.0);
assert!(fast.nms_threshold >= 0.0 && fast.nms_threshold <= 1.0);
}
}