oar_ocr/predictor/
db_detector.rs

1//! DB (Differentiable Binarization) Text Detector
2//!
3//! This module implements a text detection predictor using the DB model,
4//! which is designed for detecting text regions in images. The DB model
5//! uses a differentiable binarization technique to improve the accuracy
6//! of text detection.
7//!
8//! The main components are:
9//! - `TextDetPredictor`: The main predictor that performs text detection
10//! - `TextDetPredictorConfig`: Configuration for the predictor
11//! - `TextDetResult`: Results from text detection
12//! - `TextDetPredictorBuilder`: Builder for creating predictor instances
13
14use crate::processors::{BoundingBox, DBPostProcess, DetResizeForTest, LimitType, NormalizeImage};
15use image::{DynamicImage, RgbImage};
16use std::fmt;
17use std::path::Path;
18use std::sync::Arc;
19
20use crate::impl_config_new_and_with_common;
21
22use crate::impl_common_builder_methods;
23
24use crate::core::ImageReader as CoreImageReader;
25use crate::core::{
26    BatchData, CommonBuilderConfig, OCRError, Tensor4D,
27    config::{ConfigValidator, ConfigValidatorExt},
28    constants::{DEFAULT_BATCH_SIZE, DEFAULT_MAX_SIDE_LIMIT},
29};
30use crate::core::{DefaultImageReader, OrtInfer};
31use crate::core::{
32    GranularImageReader as GIReader, InferenceEngine as GInferenceEngine, ModularPredictor,
33    Postprocessor as GPostprocessor, Preprocessor as GPreprocessor,
34};
35
36const DEFAULT_THRESH: f32 = 0.3;
37
38const DEFAULT_BOX_THRESH: f32 = 0.6;
39
40const DEFAULT_UNCLIP_RATIO: f32 = 1.5;
41
42/// Configuration for text detection
43///
44/// This struct holds configuration parameters for text detection.
45#[derive(Debug, Clone, Default)]
46pub struct TextDetConfig {
47    /// Limit for the side length of the image
48    pub limit_side_len: Option<u32>,
49    /// Type of limit to apply (Max or Min)
50    pub limit_type: Option<LimitType>,
51    /// Threshold for binarization
52    pub thresh: Option<f32>,
53    /// Threshold for filtering text boxes
54    pub box_thresh: Option<f32>,
55    /// Ratio for unclipping text boxes
56    pub unclip_ratio: Option<f32>,
57    /// Maximum side limit for the image
58    pub max_side_limit: Option<u32>,
59}
60
61/// Configuration for the text detection predictor
62///
63/// This struct holds configuration parameters for the text detection predictor.
64#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
65pub struct TextDetPredictorConfig {
66    /// Common configuration parameters
67    pub common: CommonBuilderConfig,
68    /// Limit for the side length of the image
69    pub limit_side_len: Option<u32>,
70    /// Type of limit to apply (Max or Min)
71    pub limit_type: Option<LimitType>,
72    /// Threshold for binarization
73    pub thresh: Option<f32>,
74    /// Threshold for filtering text boxes
75    pub box_thresh: Option<f32>,
76    /// Ratio for unclipping text boxes
77    pub unclip_ratio: Option<f32>,
78    /// Input shape for the model (channels, height, width)
79    pub input_shape: Option<(u32, u32, u32)>,
80    /// Maximum side limit for the image
81    pub max_side_limit: Option<u32>,
82}
83
84impl_config_new_and_with_common!(
85    TextDetPredictorConfig,
86    common_defaults: (None, Some(DEFAULT_BATCH_SIZE)),
87    fields: {
88        limit_side_len: None,
89        limit_type: None,
90        thresh: None,
91        box_thresh: None,
92        unclip_ratio: None,
93        input_shape: None,
94        max_side_limit: Some(DEFAULT_MAX_SIDE_LIMIT)
95    }
96);
97
98impl TextDetPredictorConfig {
99    /// Validates the configuration
100    ///
101    /// This function validates the configuration parameters to ensure they are within
102    /// acceptable ranges and formats.
103    pub fn validate(&self) -> Result<(), crate::core::ConfigError> {
104        ConfigValidator::validate(self)
105    }
106}
107
108impl ConfigValidator for TextDetPredictorConfig {
109    fn validate(&self) -> Result<(), crate::core::ConfigError> {
110        self.common.validate()?;
111
112        if let Some(thresh) = self.thresh {
113            self.validate_f32_range(thresh, 0.0, 1.0, "threshold")?;
114        }
115
116        if let Some(box_thresh) = self.box_thresh {
117            self.validate_f32_range(box_thresh, 0.0, 1.0, "box threshold")?;
118        }
119
120        if let Some(unclip_ratio) = self.unclip_ratio {
121            self.validate_positive_f32(unclip_ratio, "unclip ratio")?;
122        }
123
124        if let Some(max_side_limit) = self.max_side_limit {
125            self.validate_positive_usize(max_side_limit as usize, "max side limit")?;
126        }
127
128        if let Some(limit_side_len) = self.limit_side_len {
129            self.validate_positive_usize(limit_side_len as usize, "limit side length")?;
130        }
131
132        if let Some((c, h, w)) = self.input_shape
133            && (c == 0 || h == 0 || w == 0)
134        {
135            return Err(crate::core::ConfigError::InvalidConfig {
136                message: format!(
137                    "Input shape dimensions must be greater than 0, got ({c}, {h}, {w})"
138                ),
139            });
140        }
141
142        Ok(())
143    }
144
145    fn get_defaults() -> Self {
146        Self {
147            common: CommonBuilderConfig::get_defaults(),
148            limit_side_len: Some(960),
149            limit_type: Some(LimitType::Max),
150            thresh: Some(DEFAULT_THRESH),
151            box_thresh: Some(DEFAULT_BOX_THRESH),
152            unclip_ratio: Some(DEFAULT_UNCLIP_RATIO),
153            input_shape: Some((3, 640, 640)),
154            max_side_limit: Some(DEFAULT_MAX_SIDE_LIMIT),
155        }
156    }
157}
158
159/// Results from text detection
160///
161/// This struct holds the results of text detection operations.
162#[derive(Debug, Clone)]
163pub struct TextDetResult {
164    /// Paths to the input images
165    pub input_path: Vec<Arc<str>>,
166    /// Indexes of the input images
167    pub index: Vec<usize>,
168    /// Input images
169    pub input_img: Vec<Arc<RgbImage>>,
170    /// Detected polygons
171    pub dt_polys: Vec<Vec<BoundingBox>>,
172    /// Detection scores
173    pub dt_scores: Vec<Vec<f32>>,
174}
175
176impl TextDetResult {
177    /// Creates a new, empty `TextDetResult`
178    ///
179    /// This function initializes a new text detection result with empty vectors
180    /// for all fields.
181    pub fn new() -> Self {
182        Self {
183            input_path: Vec::new(),
184            index: Vec::new(),
185            input_img: Vec::new(),
186            dt_polys: Vec::new(),
187            dt_scores: Vec::new(),
188        }
189    }
190}
191
192impl fmt::Display for TextDetResult {
193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194        for (i, ((path, polys), scores)) in self
195            .input_path
196            .iter()
197            .zip(self.dt_polys.iter())
198            .zip(self.dt_scores.iter())
199            .enumerate()
200        {
201            writeln!(f, "Image {} of {}: {}", i + 1, self.input_path.len(), path)?;
202            writeln!(f, "  Total regions: {}", polys.len())?;
203
204            if !polys.is_empty() {
205                writeln!(f, "  Detection polygons:")?;
206                for (j, (bbox, &score)) in polys.iter().zip(scores.iter()).enumerate() {
207                    if bbox.points.is_empty() {
208                        writeln!(f, "    Region {j}: [] (empty, score: {score:.3})")?;
209                        continue;
210                    }
211
212                    write!(f, "    Region {j}: [")?;
213                    for (k, point) in bbox.points.iter().enumerate() {
214                        if k == 0 {
215                            write!(f, "[{:.0}, {:.0}]", point.x, point.y)?;
216                        } else {
217                            write!(f, ", [{:.0}, {:.0}]", point.x, point.y)?;
218                        }
219                    }
220                    writeln!(f, "] (score: {score:.3})")?;
221                }
222            }
223
224            if i < self.input_path.len() - 1 {
225                writeln!(f)?;
226            }
227        }
228
229        Ok(())
230    }
231}
232
233impl Default for TextDetResult {
234    fn default() -> Self {
235        Self::new()
236    }
237}
238
239/// Text detection predictor built from modular components
240///
241/// This is a type alias over `ModularPredictor` with concrete, composable components
242/// to eliminate duplicated StandardPredictor implementations across predictors.
243pub type TextDetPredictor =
244    ModularPredictor<TDImageReader, TDPreprocessor, TDOrtInfer, TDPostprocessor>;
245
246#[derive(Debug)]
247pub struct TDImageReader {
248    inner: DefaultImageReader,
249}
250impl TDImageReader {
251    pub fn new() -> Self {
252        Self {
253            inner: DefaultImageReader::new(),
254        }
255    }
256}
257impl Default for TDImageReader {
258    fn default() -> Self {
259        Self::new()
260    }
261}
262impl GIReader for TDImageReader {
263    fn read_images<'a>(
264        &self,
265        paths: impl Iterator<Item = &'a str>,
266    ) -> Result<Vec<RgbImage>, OCRError> {
267        self.inner.apply(paths)
268    }
269}
270
271#[derive(Debug)]
272pub struct TDPreprocessor {
273    pub resize: DetResizeForTest,
274    pub normalize: NormalizeImage,
275    // Store default configuration values for merging with runtime config
276    pub default_config: TextDetConfig,
277}
278#[derive(Debug)]
279pub struct TextDetPreprocessOutput {
280    pub tensor: Tensor4D,
281    pub shapes: Vec<[f32; 4]>,
282}
283impl GPreprocessor for TDPreprocessor {
284    type Config = TextDetConfig;
285    type Output = TextDetPreprocessOutput;
286    fn preprocess(
287        &self,
288        images: Vec<RgbImage>,
289        config: Option<&Self::Config>,
290    ) -> crate::core::OcrResult<Self::Output> {
291        // Merge runtime config with stored defaults
292        let merged = match config {
293            Some(runtime_config) => TextDetConfig {
294                limit_side_len: runtime_config
295                    .limit_side_len
296                    .or(self.default_config.limit_side_len),
297                limit_type: runtime_config
298                    .limit_type
299                    .clone()
300                    .or(self.default_config.limit_type.clone()),
301                thresh: runtime_config.thresh.or(self.default_config.thresh),
302                box_thresh: runtime_config.box_thresh.or(self.default_config.box_thresh),
303                unclip_ratio: runtime_config
304                    .unclip_ratio
305                    .or(self.default_config.unclip_ratio),
306                max_side_limit: runtime_config
307                    .max_side_limit
308                    .or(self.default_config.max_side_limit),
309            },
310            None => self.default_config.clone(),
311        };
312
313        let limit_side_len = merged
314            .limit_side_len
315            .unwrap_or(self.resize.limit_side_len.unwrap_or(960));
316        let limit_type = merged
317            .limit_type
318            .unwrap_or(self.resize.limit_type.clone().unwrap_or(LimitType::Min));
319        let max_side_limit = merged.max_side_limit.unwrap_or(self.resize.max_side_limit);
320        let batch_imgs: Vec<DynamicImage> =
321            images.into_iter().map(DynamicImage::ImageRgb8).collect();
322        let (resized_imgs, shapes) = self.resize.apply(
323            batch_imgs,
324            Some(limit_side_len),
325            Some(limit_type.clone()),
326            Some(max_side_limit),
327        );
328        let tensor = self
329            .normalize
330            .normalize_batch_to(resized_imgs)
331            .map_err(|e| {
332                OCRError::model_inference_error(
333                    "TextDetection",
334                    "preprocessing_normalization",
335                    0,
336                    &[shapes.len()],
337                    "Normalization failed in TDPreprocessor",
338                    e,
339                )
340            })?;
341        Ok(TextDetPreprocessOutput { tensor, shapes })
342    }
343}
344
345#[derive(Debug)]
346pub struct TDOrtInfer(pub OrtInfer);
347impl GInferenceEngine for TDOrtInfer {
348    type Input = TextDetPreprocessOutput;
349    type Output = Tensor4D;
350    fn infer(&self, input: &Self::Input) -> Result<Self::Output, OCRError> {
351        // Performance improvement: Pass reference instead of cloning the tensor
352        self.0.infer_4d(&input.tensor)
353    }
354    fn engine_info(&self) -> String {
355        "ONNXRuntime-4D".to_string()
356    }
357}
358
359#[derive(Debug)]
360pub struct TDPostprocessor {
361    pub op: DBPostProcess,
362    // Store default configuration values for merging with runtime config
363    pub default_config: TextDetConfig,
364}
365impl GPostprocessor for TDPostprocessor {
366    type Config = TextDetConfig;
367    type InferenceOutput = Tensor4D;
368    type PreprocessOutput = TextDetPreprocessOutput;
369    type Result = TextDetResult;
370    fn postprocess(
371        &self,
372        output: Self::InferenceOutput,
373        pre: Option<&Self::PreprocessOutput>,
374        batch_data: &BatchData,
375        raw_images: Vec<RgbImage>,
376        config: Option<&Self::Config>,
377    ) -> crate::core::OcrResult<Self::Result> {
378        // Merge runtime config with stored defaults
379        let merged = match config {
380            Some(runtime_config) => TextDetConfig {
381                limit_side_len: runtime_config
382                    .limit_side_len
383                    .or(self.default_config.limit_side_len),
384                limit_type: runtime_config
385                    .limit_type
386                    .clone()
387                    .or(self.default_config.limit_type.clone()),
388                thresh: runtime_config.thresh.or(self.default_config.thresh),
389                box_thresh: runtime_config.box_thresh.or(self.default_config.box_thresh),
390                unclip_ratio: runtime_config
391                    .unclip_ratio
392                    .or(self.default_config.unclip_ratio),
393                max_side_limit: runtime_config
394                    .max_side_limit
395                    .or(self.default_config.max_side_limit),
396            },
397            None => self.default_config.clone(),
398        };
399
400        let thresh = merged.thresh.unwrap_or(DEFAULT_THRESH);
401        let box_thresh = merged.box_thresh.unwrap_or(DEFAULT_BOX_THRESH);
402        let unclip_ratio = merged.unclip_ratio.unwrap_or(DEFAULT_UNCLIP_RATIO);
403        let shapes = pre.map(|p| p.shapes.clone()).unwrap_or_default();
404        let (polys, scores) = self.op.apply(
405            &output,
406            shapes,
407            Some(thresh),
408            Some(box_thresh),
409            Some(unclip_ratio),
410        );
411        Ok(TextDetResult {
412            input_path: batch_data.input_paths.clone(),
413            index: batch_data.indexes.clone(),
414            input_img: raw_images.into_iter().map(Arc::new).collect(),
415            dt_polys: polys,
416            dt_scores: scores,
417        })
418    }
419    fn empty_result(&self) -> crate::core::OcrResult<Self::Result> {
420        Ok(TextDetResult::new())
421    }
422}
423
424/// Builder for `TextDetPredictor`
425///
426/// This struct is used to build a `TextDetPredictor` with the desired configuration.
427pub struct TextDetPredictorBuilder {
428    /// Common configuration parameters
429    common: CommonBuilderConfig,
430
431    /// Limit for the side length of the image
432    limit_side_len: Option<u32>,
433    /// Type of limit to apply (Max or Min)
434    limit_type: Option<LimitType>,
435    /// Threshold for binarization
436    thresh: Option<f32>,
437    /// Threshold for filtering text boxes
438    box_thresh: Option<f32>,
439    /// Ratio for unclipping text boxes
440    unclip_ratio: Option<f32>,
441    /// Input shape for the model (channels, height, width)
442    input_shape: Option<(u32, u32, u32)>,
443    /// Maximum side limit for the image
444    max_side_limit: Option<u32>,
445}
446
447impl_common_builder_methods!(TextDetPredictorBuilder, common);
448
449impl TextDetPredictorBuilder {
450    /// Creates a new `TextDetPredictorBuilder`
451    ///
452    /// This function initializes a new builder with default values.
453    pub fn new() -> Self {
454        Self {
455            common: CommonBuilderConfig::new(),
456            limit_side_len: None,
457            limit_type: None,
458            thresh: None,
459            box_thresh: None,
460            unclip_ratio: None,
461            input_shape: None,
462            max_side_limit: None,
463        }
464    }
465
466    /// Sets the limit for the side length of the image
467    ///
468    /// This function sets the limit for the side length of the image used in text detection.
469    pub fn limit_side_len(mut self, limit_side_len: u32) -> Self {
470        self.limit_side_len = Some(limit_side_len);
471        self
472    }
473
474    /// Sets the type of limit to apply
475    ///
476    /// This function sets the type of limit (Max or Min) to apply to the image side length
477    /// in text detection.
478    pub fn limit_type(mut self, limit_type: LimitType) -> Self {
479        self.limit_type = Some(limit_type);
480        self
481    }
482
483    /// Sets the threshold for binarization
484    ///
485    /// This function sets the threshold value used for binarization in text detection.
486    pub fn thresh(mut self, thresh: f32) -> Self {
487        self.thresh = Some(thresh);
488        self
489    }
490
491    /// Sets the threshold for filtering text boxes
492    ///
493    /// This function sets the threshold value used for filtering text boxes in text detection.
494    pub fn box_thresh(mut self, box_thresh: f32) -> Self {
495        self.box_thresh = Some(box_thresh);
496        self
497    }
498
499    /// Sets the ratio for unclipping text boxes
500    ///
501    /// This function sets the ratio used for unclipping text boxes in text detection.
502    pub fn unclip_ratio(mut self, unclip_ratio: f32) -> Self {
503        self.unclip_ratio = Some(unclip_ratio);
504        self
505    }
506
507    /// Sets the input shape for the model
508    ///
509    /// This function sets the input shape (channels, height, width) for the model.
510    pub fn input_shape(mut self, input_shape: (u32, u32, u32)) -> Self {
511        self.input_shape = Some(input_shape);
512        self
513    }
514
515    /// Sets the maximum side limit for the image
516    ///
517    /// This function sets the maximum side limit for the image used in text detection.
518    pub fn max_side_limit(mut self, max_side_limit: u32) -> Self {
519        self.max_side_limit = Some(max_side_limit);
520        self
521    }
522
523    /// Builds the `TextDetPredictor`
524    ///
525    /// This function builds the `TextDetPredictor` with the provided configuration.
526    pub fn build(self, model_path: &Path) -> Result<TextDetPredictor, OCRError> {
527        self.build_internal(model_path)
528    }
529
530    /// Builds the `TextDetPredictor` internally
531    ///
532    /// This function builds the `TextDetPredictor` with the provided configuration.
533    /// It also validates the configuration and handles the model path.
534    fn build_internal(mut self, model_path: &Path) -> Result<TextDetPredictor, OCRError> {
535        if self.common.model_path.is_none() {
536            self.common = self.common.model_path(model_path.to_path_buf());
537        }
538
539        let config = TextDetPredictorConfig {
540            common: self.common,
541            limit_side_len: self.limit_side_len,
542            limit_type: self.limit_type,
543            thresh: self.thresh,
544            box_thresh: self.box_thresh,
545            unclip_ratio: self.unclip_ratio,
546            input_shape: self.input_shape,
547            max_side_limit: self.max_side_limit,
548        };
549        let config = config.validate_and_wrap_ocr_error()?;
550
551        // Determine default values based on model name
552        let (default_limit_side_len, default_limit_type) =
553            if let Some(model_name) = &config.common.model_name {
554                match model_name.as_str() {
555                    "PP-OCRv5_server_det"
556                    | "PP-OCRv5_mobile_det"
557                    | "PP-OCRv4_server_det"
558                    | "PP-OCRv4_mobile_det"
559                    | "PP-OCRv3_server_det"
560                    | "PP-OCRv3_mobile_det" => (960, LimitType::Max),
561                    _ => (736, LimitType::Min),
562                }
563            } else {
564                (736, LimitType::Min)
565            };
566
567        let limit_side_len = config.limit_side_len.unwrap_or(default_limit_side_len);
568        let limit_type = config.limit_type.clone().unwrap_or(default_limit_type);
569        let max_side_limit = config.max_side_limit.unwrap_or(DEFAULT_MAX_SIDE_LIMIT);
570
571        // Create default configuration for components
572        let default_config = TextDetConfig {
573            limit_side_len: Some(limit_side_len),
574            limit_type: Some(limit_type.clone()),
575            thresh: config.thresh,
576            box_thresh: config.box_thresh,
577            unclip_ratio: config.unclip_ratio,
578            max_side_limit: Some(max_side_limit),
579        };
580
581        // Build modular components
582        let image_reader = TDImageReader::new();
583        let resize = DetResizeForTest::new(
584            config.input_shape,
585            None,
586            None,
587            Some(limit_side_len),
588            Some(limit_type.clone()),
589            None,
590            Some(max_side_limit),
591        );
592        let normalize = NormalizeImage::new(None, None, None, None)?;
593        let preprocessor = TDPreprocessor {
594            resize,
595            normalize,
596            default_config: default_config.clone(),
597        };
598        let infer = OrtInfer::from_common(&config.common, model_path, None)?;
599        let inference_engine = TDOrtInfer(infer);
600        let post_op = DBPostProcess::new(None, None, None, None, None, None, None);
601        let postprocessor = TDPostprocessor {
602            op: post_op,
603            default_config,
604        };
605
606        Ok(ModularPredictor::new(
607            image_reader,
608            preprocessor,
609            inference_engine,
610            postprocessor,
611        ))
612    }
613}
614
615impl Default for TextDetPredictorBuilder {
616    fn default() -> Self {
617        Self::new()
618    }
619}
620
621#[cfg(test)]
622mod tests_local {
623    use super::*;
624
625    #[test]
626    fn test_text_det_config_defaults_and_validate() {
627        let config = TextDetPredictorConfig::new();
628        // Defaults via get_defaults
629        assert_eq!(config.max_side_limit, Some(DEFAULT_MAX_SIDE_LIMIT));
630        assert!(config.validate().is_ok());
631    }
632}