Skip to main content

ocr_rs/
det.rs

1//! Text Detection Model
2//!
3//! Provides text region detection functionality based on PaddleOCR detection models
4
5use image::{DynamicImage, GenericImageView};
6use ndarray::ArrayD;
7use std::path::Path;
8
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{InferenceConfig, InferenceEngine};
11use crate::postprocess::{extract_boxes_with_unclip, TextBox};
12use crate::preprocess::{preprocess_for_det, NormalizeParams};
13
14/// Detection precision mode
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum DetPrecisionMode {
17    /// Fast mode - single detection
18    #[default]
19    Fast,
20}
21
22/// Detection options
23#[derive(Debug, Clone)]
24pub struct DetOptions {
25    /// Maximum image side length limit (will be scaled if exceeded)
26    pub max_side_len: u32,
27    /// Bounding box binarization threshold (0.0 - 1.0)
28    pub box_threshold: f32,
29    /// Text box expansion ratio
30    pub unclip_ratio: f32,
31    /// Pixel-level segmentation threshold
32    pub score_threshold: f32,
33    /// Minimum bounding box area
34    pub min_area: u32,
35    /// Bounding box border expansion
36    pub box_border: u32,
37    /// Whether to merge adjacent text boxes
38    pub merge_boxes: bool,
39    /// Merge distance threshold
40    pub merge_threshold: i32,
41    /// Precision mode
42    pub precision_mode: DetPrecisionMode,
43    /// Scale ratios for multi-scale detection (high precision mode only)
44    pub multi_scales: Vec<f32>,
45    /// Block size for block detection (high precision mode only)
46    pub block_size: u32,
47    /// Overlap area for block detection
48    pub block_overlap: u32,
49    /// NMS IoU threshold
50    pub nms_threshold: f32,
51}
52
53impl Default for DetOptions {
54    fn default() -> Self {
55        Self {
56            max_side_len: 960,
57            box_threshold: 0.5,
58            unclip_ratio: 1.5,
59            score_threshold: 0.3,
60            min_area: 16,
61            box_border: 5,
62            merge_boxes: false,
63            merge_threshold: 10,
64            precision_mode: DetPrecisionMode::Fast,
65            multi_scales: vec![0.5, 1.0, 1.5],
66            block_size: 640,
67            block_overlap: 100,
68            nms_threshold: 0.3,
69        }
70    }
71}
72
73impl DetOptions {
74    /// Create new detection options
75    pub fn new() -> Self {
76        Self::default()
77    }
78
79    /// Set maximum side length
80    pub fn with_max_side_len(mut self, len: u32) -> Self {
81        self.max_side_len = len;
82        self
83    }
84
85    /// Set bounding box threshold
86    pub fn with_box_threshold(mut self, threshold: f32) -> Self {
87        self.box_threshold = threshold;
88        self
89    }
90
91    /// Set segmentation threshold
92    pub fn with_score_threshold(mut self, threshold: f32) -> Self {
93        self.score_threshold = threshold;
94        self
95    }
96
97    /// Set minimum area
98    pub fn with_min_area(mut self, area: u32) -> Self {
99        self.min_area = area;
100        self
101    }
102
103    /// Set box border expansion
104    pub fn with_box_border(mut self, border: u32) -> Self {
105        self.box_border = border;
106        self
107    }
108
109    /// Enable box merging
110    pub fn with_merge_boxes(mut self, merge: bool) -> Self {
111        self.merge_boxes = merge;
112        self
113    }
114
115    /// Set merge threshold
116    pub fn with_merge_threshold(mut self, threshold: i32) -> Self {
117        self.merge_threshold = threshold;
118        self
119    }
120
121    /// Set precision mode
122    pub fn with_precision_mode(mut self, mode: DetPrecisionMode) -> Self {
123        self.precision_mode = mode;
124        self
125    }
126
127    /// Set multi-scale ratios
128    pub fn with_multi_scales(mut self, scales: Vec<f32>) -> Self {
129        self.multi_scales = scales;
130        self
131    }
132
133    /// Set block size
134    pub fn with_block_size(mut self, size: u32) -> Self {
135        self.block_size = size;
136        self
137    }
138
139    /// Fast mode preset
140    pub fn fast() -> Self {
141        Self {
142            max_side_len: 960,
143            precision_mode: DetPrecisionMode::Fast,
144            ..Default::default()
145        }
146    }
147}
148
149/// Text detection model
150pub struct DetModel {
151    engine: InferenceEngine,
152    options: DetOptions,
153    normalize_params: NormalizeParams,
154}
155
156impl DetModel {
157    /// Create detector from model file
158    ///
159    /// # Parameters
160    /// - `model_path`: Model file path (.mnn format)
161    /// - `config`: Optional inference config
162    pub fn from_file(
163        model_path: impl AsRef<Path>,
164        config: Option<InferenceConfig>,
165    ) -> OcrResult<Self> {
166        let engine = InferenceEngine::from_file(model_path, config)?;
167        Ok(Self {
168            engine,
169            options: DetOptions::default(),
170            normalize_params: NormalizeParams::paddle_det(),
171        })
172    }
173
174    /// Create detector from model bytes
175    pub fn from_bytes(model_bytes: &[u8], config: Option<InferenceConfig>) -> OcrResult<Self> {
176        let engine = InferenceEngine::from_buffer(model_bytes, config)?;
177        Ok(Self {
178            engine,
179            options: DetOptions::default(),
180            normalize_params: NormalizeParams::paddle_det(),
181        })
182    }
183
184    /// Set detection options
185    pub fn with_options(mut self, options: DetOptions) -> Self {
186        self.options = options;
187        self
188    }
189
190    /// Get current detection options
191    pub fn options(&self) -> &DetOptions {
192        &self.options
193    }
194
195    /// Modify detection options
196    pub fn options_mut(&mut self) -> &mut DetOptions {
197        &mut self.options
198    }
199
200    /// Detect text regions in image
201    ///
202    /// # Parameters
203    /// - `image`: Input image
204    ///
205    /// # Returns
206    /// List of detected text bounding boxes
207    pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
208        self.detect_fast(image)
209    }
210
211    /// Detect and return cropped text images
212    ///
213    /// # Parameters
214    /// - `image`: Input image
215    ///
216    /// # Returns
217    /// List of (text image, corresponding bounding box)
218    pub fn detect_and_crop(&self, image: &DynamicImage) -> OcrResult<Vec<(DynamicImage, TextBox)>> {
219        let boxes = self.detect(image)?;
220        let (width, height) = image.dimensions();
221
222        let mut results = Vec::with_capacity(boxes.len());
223
224        for text_box in boxes {
225            // Expand bounding box
226            let expanded = text_box.expand(self.options.box_border, width, height);
227
228            // Crop image
229            let cropped = image.crop_imm(
230                expanded.rect.left() as u32,
231                expanded.rect.top() as u32,
232                expanded.rect.width(),
233                expanded.rect.height(),
234            );
235
236            results.push((cropped, expanded));
237        }
238
239        Ok(results)
240    }
241
242    /// Fast detection (single inference)
243    fn detect_fast(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
244        let (original_width, original_height) = image.dimensions();
245
246        // Scale image
247        let scaled = self.scale_image(image);
248        let (scaled_width, scaled_height) = scaled.dimensions();
249
250        // Preprocess
251        let input = preprocess_for_det(&scaled, &self.normalize_params)?;
252
253        // Inference (using dynamic shape)
254        let output = self.engine.run_dynamic(input.view().into_dyn())?;
255
256        // Post-processing - output shape matches input (including padding)
257        let output_shape = output.shape();
258        let out_w = output_shape[3] as u32;
259        let out_h = output_shape[2] as u32;
260
261        let boxes = self.postprocess_output(
262            &output,
263            out_w,
264            out_h,
265            scaled_width,
266            scaled_height,
267            original_width,
268            original_height,
269        )?;
270
271        Ok(boxes)
272    }
273
274    /// Balanced mode detection (multi-scale)
275    /// Scale image to maximum side length limit
276    fn scale_image(&self, image: &DynamicImage) -> DynamicImage {
277        let (w, h) = image.dimensions();
278        let max_dim = w.max(h);
279
280        if max_dim <= self.options.max_side_len {
281            return image.clone();
282        }
283
284        let scale = self.options.max_side_len as f64 / max_dim as f64;
285        let new_w = (w as f64 * scale).round() as u32;
286        let new_h = (h as f64 * scale).round() as u32;
287
288        image.resize_exact(new_w, new_h, image::imageops::FilterType::Lanczos3)
289    }
290
291    /// Post-process inference output
292    fn postprocess_output(
293        &self,
294        output: &ArrayD<f32>,
295        out_w: u32,
296        out_h: u32,
297        scaled_width: u32,
298        scaled_height: u32,
299        original_width: u32,
300        original_height: u32,
301    ) -> OcrResult<Vec<TextBox>> {
302        // Retrieve output data
303        let output_shape = output.shape();
304        if output_shape.len() < 3 {
305            return Err(OcrError::PostprocessError(
306                "Detection model output shape invalid".to_string(),
307            ));
308        }
309
310        // Extract segmentation mask (only valid region, remove padding)
311        let mask_data: Vec<f32> = output.iter().cloned().collect();
312
313        // Binarization
314        let binary_mask: Vec<u8> = mask_data
315            .iter()
316            .map(|&v| {
317                if v > self.options.score_threshold {
318                    255u8
319                } else {
320                    0u8
321                }
322            })
323            .collect();
324
325        // Extract bounding boxes (with unclip expansion)
326        // DB algorithm needs to expand detected contours because model output segmentation mask is usually smaller than actual text region
327        let boxes = extract_boxes_with_unclip(
328            &binary_mask,
329            out_w,
330            out_h,
331            scaled_width,
332            scaled_height,
333            original_width,
334            original_height,
335            self.options.min_area,
336            self.options.unclip_ratio,
337        );
338
339        Ok(boxes)
340    }
341}
342
343/// Low-level detection API
344impl DetModel {
345    /// Raw inference interface
346    ///
347    /// Execute model inference directly without preprocessing and postprocessing
348    ///
349    /// # Parameters
350    /// - `input`: Preprocessed input tensor [1, 3, H, W]
351    ///
352    /// # Returns
353    /// Model raw output
354    pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
355        Ok(self.engine.run_dynamic(input)?)
356    }
357
358    /// Get model input shape
359    pub fn input_shape(&self) -> &[usize] {
360        self.engine.input_shape()
361    }
362
363    /// Get model output shape
364    pub fn output_shape(&self) -> &[usize] {
365        self.engine.output_shape()
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_det_options_default() {
375        let opts = DetOptions::default();
376        assert_eq!(opts.max_side_len, 960);
377        assert_eq!(opts.box_threshold, 0.5);
378        assert_eq!(opts.unclip_ratio, 1.5);
379        assert_eq!(opts.score_threshold, 0.3);
380        assert_eq!(opts.min_area, 16);
381        assert_eq!(opts.box_border, 5);
382        assert!(!opts.merge_boxes);
383        assert_eq!(opts.merge_threshold, 10);
384        assert_eq!(opts.precision_mode, DetPrecisionMode::Fast);
385        assert_eq!(opts.nms_threshold, 0.3);
386    }
387
388    #[test]
389    fn test_det_options_fast() {
390        let opts = DetOptions::fast();
391        assert_eq!(opts.max_side_len, 960);
392        assert_eq!(opts.precision_mode, DetPrecisionMode::Fast);
393    }
394
395    #[test]
396    fn test_det_options_builder() {
397        let opts = DetOptions::new()
398            .with_max_side_len(1280)
399            .with_box_threshold(0.6)
400            .with_score_threshold(0.4)
401            .with_min_area(32)
402            .with_box_border(10)
403            .with_merge_boxes(true)
404            .with_merge_threshold(20)
405            .with_precision_mode(DetPrecisionMode::Fast)
406            .with_multi_scales(vec![0.5, 1.0, 1.5])
407            .with_block_size(800);
408
409        assert_eq!(opts.max_side_len, 1280);
410        assert_eq!(opts.box_threshold, 0.6);
411        assert_eq!(opts.score_threshold, 0.4);
412        assert_eq!(opts.min_area, 32);
413        assert_eq!(opts.box_border, 10);
414        assert!(opts.merge_boxes);
415        assert_eq!(opts.merge_threshold, 20);
416        assert_eq!(opts.precision_mode, DetPrecisionMode::Fast);
417        assert_eq!(opts.multi_scales, vec![0.5, 1.0, 1.5]);
418        assert_eq!(opts.block_size, 800);
419    }
420
421    #[test]
422    fn test_det_precision_mode_default() {
423        let mode = DetPrecisionMode::default();
424        assert_eq!(mode, DetPrecisionMode::Fast);
425    }
426
427    #[test]
428    fn test_det_precision_mode_equality() {
429        assert_eq!(DetPrecisionMode::Fast, DetPrecisionMode::Fast);
430    }
431
432    #[test]
433    fn test_det_options_chaining() {
434        // Test that chaining calls do not lose previous settings
435        let opts = DetOptions::new()
436            .with_max_side_len(1000)
437            .with_box_threshold(0.7);
438
439        assert_eq!(opts.max_side_len, 1000);
440        assert_eq!(opts.box_threshold, 0.7);
441        // Other values should be default values
442        assert_eq!(opts.score_threshold, 0.3);
443    }
444
445    #[test]
446    fn test_det_options_presets_are_valid() {
447        // Ensure preset parameter values are within valid ranges
448        let fast = DetOptions::fast();
449        assert!(fast.box_threshold >= 0.0 && fast.box_threshold <= 1.0);
450        assert!(fast.score_threshold >= 0.0 && fast.score_threshold <= 1.0);
451        assert!(fast.nms_threshold >= 0.0 && fast.nms_threshold <= 1.0);
452    }
453}