Skip to main content

oar_ocr_core/processors/
db_postprocess.rs

1//! Post-processing for DB (Differentiable Binarization) text detection models.
2//!
3//! The [`DBPostProcess`] struct converts raw detection heatmaps into geometric
4//! bounding boxes by thresholding, contour extraction, scoring, and optional
5//! polygonal post-processing. Supporting functionality (bitmap extraction,
6//! scoring, mask morphology) is split across helper modules within this
7//! directory.
8
9#[path = "db_bitmap.rs"]
10mod db_bitmap;
11#[path = "db_mask.rs"]
12mod db_mask;
13#[path = "db_score.rs"]
14mod db_score;
15
16use crate::processors::geometry::BoundingBox;
17use crate::processors::types::{BoxType, ImageScaleInfo, ScoreMode};
18use ndarray::Axis;
19
20/// Runtime configuration for DB post-processing.
21///
22/// This struct contains parameters that may vary per inference call,
23/// such as detection thresholds and expansion ratios.
24#[derive(Debug, Clone)]
25pub struct DBPostProcessConfig {
26    /// Threshold for binarizing the prediction map.
27    pub thresh: f32,
28    /// Threshold for filtering bounding boxes based on their score.
29    pub box_thresh: f32,
30    /// Ratio for unclipping (expanding) bounding boxes.
31    pub unclip_ratio: f32,
32}
33
34impl DBPostProcessConfig {
35    /// Creates a new runtime config with specified values.
36    pub fn new(thresh: f32, box_thresh: f32, unclip_ratio: f32) -> Self {
37        Self {
38            thresh,
39            box_thresh,
40            unclip_ratio,
41        }
42    }
43}
44
45/// Post-processor for DB (Differentiable Binarization) text detection models.
46#[derive(Debug)]
47pub struct DBPostProcess {
48    /// Default threshold for binarizing the prediction map (default: 0.3).
49    pub thresh: f32,
50    /// Default threshold for filtering bounding boxes based on their score (default: 0.6).
51    pub box_thresh: f32,
52    /// Maximum number of candidate bounding boxes to consider (default: 1000).
53    pub max_candidates: usize,
54    /// Default ratio for unclipping (expanding) bounding boxes (default: 1.5).
55    pub unclip_ratio: f32,
56    /// Minimum side length for detected bounding boxes.
57    pub min_size: f32,
58    /// Method for calculating the score of a bounding box.
59    pub score_mode: ScoreMode,
60    /// Type of bounding box to generate (quadrilateral or polygon).
61    pub box_type: BoxType,
62    /// Whether to apply dilation to the segmentation mask before contour detection.
63    pub use_dilation: bool,
64}
65
66impl DBPostProcess {
67    /// Creates a new `DBPostProcess` instance with optional overrides.
68    pub fn new(
69        thresh: Option<f32>,
70        box_thresh: Option<f32>,
71        max_candidates: Option<usize>,
72        unclip_ratio: Option<f32>,
73        use_dilation: Option<bool>,
74        score_mode: Option<ScoreMode>,
75        box_type: Option<BoxType>,
76    ) -> Self {
77        Self {
78            thresh: thresh.unwrap_or(0.3),
79            box_thresh: box_thresh.unwrap_or(0.6),
80            max_candidates: max_candidates.unwrap_or(1000),
81            unclip_ratio: unclip_ratio.unwrap_or(1.5),
82            min_size: 3.0,
83            score_mode: score_mode.unwrap_or(ScoreMode::Fast),
84            box_type: box_type.unwrap_or(BoxType::Quad),
85            use_dilation: use_dilation.unwrap_or(false),
86        }
87    }
88
89    /// Applies post-processing to a batch of prediction maps.
90    ///
91    /// # Arguments
92    /// * `preds` - Model predictions (batch of heatmaps)
93    /// * `img_shapes` - Original image dimensions for each image in batch
94    /// * `config` - Runtime configuration for thresholds and ratios.
95    ///   If `None`, uses the default values stored in this processor.
96    ///
97    /// # Returns
98    /// Tuple of (bounding_boxes, scores) for each image in batch
99    pub fn apply(
100        &self,
101        preds: &ndarray::Array4<f32>,
102        img_shapes: Vec<ImageScaleInfo>,
103        config: Option<&DBPostProcessConfig>,
104    ) -> (Vec<Vec<BoundingBox>>, Vec<Vec<f32>>) {
105        // Use provided config or fall back to stored defaults
106        let thresh = config.map(|c| c.thresh).unwrap_or(self.thresh);
107        let box_thresh = config.map(|c| c.box_thresh).unwrap_or(self.box_thresh);
108        let unclip_ratio = config.map(|c| c.unclip_ratio).unwrap_or(self.unclip_ratio);
109
110        let mut all_boxes = Vec::new();
111        let mut all_scores = Vec::new();
112
113        for (batch_idx, shape_batch) in img_shapes.iter().enumerate() {
114            let pred_slice = preds.index_axis(Axis(0), batch_idx);
115            let pred_channel = pred_slice.index_axis(Axis(0), 0);
116
117            let (boxes, scores) =
118                self.process(&pred_channel, shape_batch, thresh, box_thresh, unclip_ratio);
119            all_boxes.push(boxes);
120            all_scores.push(scores);
121        }
122
123        (all_boxes, all_scores)
124    }
125
126    fn process(
127        &self,
128        pred: &ndarray::ArrayView2<f32>,
129        img_shape: &ImageScaleInfo,
130        thresh: f32,
131        box_thresh: f32,
132        unclip_ratio: f32,
133    ) -> (Vec<BoundingBox>, Vec<f32>) {
134        let src_h = img_shape.src_h as u32;
135        let src_w = img_shape.src_w as u32;
136
137        let height = pred.shape()[0] as u32;
138        let width = pred.shape()[1] as u32;
139
140        tracing::debug!(
141            "DBPostProcess: pred {}x{}, src {}x{} (dest dimensions)",
142            height,
143            width,
144            src_h,
145            src_w
146        );
147
148        // Create binary mask directly as GrayImage to avoid intermediate Vec<Vec<bool>>
149        let mut mask_img = image::GrayImage::new(width, height);
150        for y in 0..height as usize {
151            for x in 0..width as usize {
152                let pixel_value = if pred[[y, x]] > thresh { 255 } else { 0 };
153                mask_img.put_pixel(x as u32, y as u32, image::Luma([pixel_value]));
154            }
155        }
156
157        // Apply dilation if needed
158        let mask_img = if self.use_dilation {
159            self.dilate_mask_img(&mask_img)
160        } else {
161            mask_img
162        };
163
164        match self.box_type {
165            BoxType::Poly => {
166                self.polygons_from_bitmap(pred, &mask_img, src_w, src_h, box_thresh, unclip_ratio)
167            }
168            BoxType::Quad => {
169                self.boxes_from_bitmap(pred, &mask_img, src_w, src_h, box_thresh, unclip_ratio)
170            }
171        }
172    }
173}