oar_ocr/processors/postprocess/
db.rs

1//! Post-processing for DB (Differentiable Binarization) text detection models.
2//!
3//! The [`DBPostProcess`] struct converts raw detection heatmaps into geometric
4//! bounding boxes by thresholding, contour extraction, scoring, and optional
5//! polygonal post-processing. Supporting functionality (bitmap extraction,
6//! scoring, mask morphology) is split across helper modules within this
7//! directory.
8
9#[path = "db_bitmap.rs"]
10mod db_bitmap;
11#[path = "db_mask.rs"]
12mod db_mask;
13#[path = "db_score.rs"]
14mod db_score;
15
16use crate::core::Tensor4D;
17use crate::processors::geometry::BoundingBox;
18use crate::processors::types::{BoxType, ScoreMode};
19use ndarray::Axis;
20
21/// Post-processor for DB (Differentiable Binarization) text detection models.
22#[derive(Debug)]
23pub struct DBPostProcess {
24    /// Threshold for binarizing the prediction map (default: 0.3).
25    pub thresh: f32,
26    /// Threshold for filtering bounding boxes based on their score (default: 0.7).
27    pub box_thresh: f32,
28    /// Maximum number of candidate bounding boxes to consider (default: 1000).
29    pub max_candidates: usize,
30    /// Ratio for unclipping (expanding) bounding boxes (default: 2.0).
31    pub unclip_ratio: f32,
32    /// Minimum side length for detected bounding boxes.
33    pub min_size: f32,
34    /// Method for calculating the score of a bounding box.
35    pub score_mode: ScoreMode,
36    /// Type of bounding box to generate (quadrilateral or polygon).
37    pub box_type: BoxType,
38    /// Whether to apply dilation to the segmentation mask before contour detection.
39    pub use_dilation: bool,
40}
41
42impl DBPostProcess {
43    /// Creates a new `DBPostProcess` instance with optional overrides.
44    pub fn new(
45        thresh: Option<f32>,
46        box_thresh: Option<f32>,
47        max_candidates: Option<usize>,
48        unclip_ratio: Option<f32>,
49        use_dilation: Option<bool>,
50        score_mode: Option<ScoreMode>,
51        box_type: Option<BoxType>,
52    ) -> Self {
53        Self {
54            thresh: thresh.unwrap_or(0.3),
55            box_thresh: box_thresh.unwrap_or(0.7),
56            max_candidates: max_candidates.unwrap_or(1000),
57            unclip_ratio: unclip_ratio.unwrap_or(2.0),
58            min_size: 3.0,
59            score_mode: score_mode.unwrap_or(ScoreMode::Fast),
60            box_type: box_type.unwrap_or(BoxType::Quad),
61            use_dilation: use_dilation.unwrap_or(false),
62        }
63    }
64
65    /// Applies post-processing to a batch of prediction maps.
66    pub fn apply(
67        &self,
68        preds: &Tensor4D,
69        img_shapes: Vec<[f32; 4]>,
70        thresh: Option<f32>,
71        box_thresh: Option<f32>,
72        unclip_ratio: Option<f32>,
73    ) -> (Vec<Vec<BoundingBox>>, Vec<Vec<f32>>) {
74        let mut all_boxes = Vec::new();
75        let mut all_scores = Vec::new();
76
77        for (batch_idx, shape_batch) in img_shapes.iter().enumerate() {
78            let pred_slice = preds.index_axis(Axis(0), batch_idx);
79            let pred_channel = pred_slice.index_axis(Axis(0), 0);
80
81            let (boxes, scores) = self.process(
82                &pred_channel.to_owned(),
83                *shape_batch,
84                thresh.unwrap_or(self.thresh),
85                box_thresh.unwrap_or(self.box_thresh),
86                unclip_ratio.unwrap_or(self.unclip_ratio),
87            );
88            all_boxes.push(boxes);
89            all_scores.push(scores);
90        }
91
92        (all_boxes, all_scores)
93    }
94
95    fn process(
96        &self,
97        pred: &ndarray::Array2<f32>,
98        img_shape: [f32; 4],
99        thresh: f32,
100        box_thresh: f32,
101        unclip_ratio: f32,
102    ) -> (Vec<BoundingBox>, Vec<f32>) {
103        let src_h = img_shape[0] as u32;
104        let src_w = img_shape[1] as u32;
105
106        let height = pred.shape()[0] as u32;
107        let width = pred.shape()[1] as u32;
108
109        let mut segmentation = vec![vec![false; width as usize]; height as usize];
110        for y in 0..height as usize {
111            for x in 0..width as usize {
112                segmentation[y][x] = pred[[y, x]] > thresh;
113            }
114        }
115
116        let mask = if self.use_dilation {
117            self.dilate_mask(&segmentation)
118        } else {
119            segmentation
120        };
121
122        match self.box_type {
123            BoxType::Poly => {
124                self.polygons_from_bitmap(pred, &mask, src_w, src_h, box_thresh, unclip_ratio)
125            }
126            BoxType::Quad => {
127                self.boxes_from_bitmap(pred, &mask, src_w, src_h, box_thresh, unclip_ratio)
128            }
129        }
130    }
131}