Skip to main content

oar_ocr_core/processors/
db_postprocess.rs

1//! Post-processing for DB (Differentiable Binarization) text detection models.
2//!
3//! The [`DBPostProcess`] struct converts raw detection heatmaps into geometric
4//! bounding boxes by thresholding, contour extraction, scoring, and optional
5//! polygonal post-processing. Supporting functionality (bitmap extraction,
6//! scoring, mask morphology) is split across helper modules within this
7//! directory.
8
9#[path = "db_bitmap.rs"]
10mod db_bitmap;
11#[path = "db_mask.rs"]
12mod db_mask;
13#[path = "db_score.rs"]
14mod db_score;
15
16use crate::core::Tensor4D;
17use crate::processors::geometry::BoundingBox;
18use crate::processors::types::{BoxType, ImageScaleInfo, ScoreMode};
19use ndarray::Axis;
20
21/// Runtime configuration for DB post-processing.
22///
23/// This struct contains parameters that may vary per inference call,
24/// such as detection thresholds and expansion ratios.
25#[derive(Debug, Clone)]
26pub struct DBPostProcessConfig {
27    /// Threshold for binarizing the prediction map.
28    pub thresh: f32,
29    /// Threshold for filtering bounding boxes based on their score.
30    pub box_thresh: f32,
31    /// Ratio for unclipping (expanding) bounding boxes.
32    pub unclip_ratio: f32,
33}
34
35impl DBPostProcessConfig {
36    /// Creates a new runtime config with specified values.
37    pub fn new(thresh: f32, box_thresh: f32, unclip_ratio: f32) -> Self {
38        Self {
39            thresh,
40            box_thresh,
41            unclip_ratio,
42        }
43    }
44}
45
46/// Post-processor for DB (Differentiable Binarization) text detection models.
47#[derive(Debug)]
48pub struct DBPostProcess {
49    /// Default threshold for binarizing the prediction map (default: 0.3).
50    pub thresh: f32,
51    /// Default threshold for filtering bounding boxes based on their score (default: 0.6).
52    pub box_thresh: f32,
53    /// Maximum number of candidate bounding boxes to consider (default: 1000).
54    pub max_candidates: usize,
55    /// Default ratio for unclipping (expanding) bounding boxes (default: 1.5).
56    pub unclip_ratio: f32,
57    /// Minimum side length for detected bounding boxes.
58    pub min_size: f32,
59    /// Method for calculating the score of a bounding box.
60    pub score_mode: ScoreMode,
61    /// Type of bounding box to generate (quadrilateral or polygon).
62    pub box_type: BoxType,
63    /// Whether to apply dilation to the segmentation mask before contour detection.
64    pub use_dilation: bool,
65}
66
67impl DBPostProcess {
68    /// Creates a new `DBPostProcess` instance with optional overrides.
69    pub fn new(
70        thresh: Option<f32>,
71        box_thresh: Option<f32>,
72        max_candidates: Option<usize>,
73        unclip_ratio: Option<f32>,
74        use_dilation: Option<bool>,
75        score_mode: Option<ScoreMode>,
76        box_type: Option<BoxType>,
77    ) -> Self {
78        Self {
79            thresh: thresh.unwrap_or(0.3),
80            box_thresh: box_thresh.unwrap_or(0.6),
81            max_candidates: max_candidates.unwrap_or(1000),
82            unclip_ratio: unclip_ratio.unwrap_or(1.5),
83            min_size: 3.0,
84            score_mode: score_mode.unwrap_or(ScoreMode::Fast),
85            box_type: box_type.unwrap_or(BoxType::Quad),
86            use_dilation: use_dilation.unwrap_or(false),
87        }
88    }
89
90    /// Applies post-processing to a batch of prediction maps.
91    ///
92    /// # Arguments
93    /// * `preds` - Model predictions (batch of heatmaps)
94    /// * `img_shapes` - Original image dimensions for each image in batch
95    /// * `config` - Runtime configuration for thresholds and ratios.
96    ///   If `None`, uses the default values stored in this processor.
97    ///
98    /// # Returns
99    /// Tuple of (bounding_boxes, scores) for each image in batch
100    pub fn apply(
101        &self,
102        preds: &Tensor4D,
103        img_shapes: Vec<ImageScaleInfo>,
104        config: Option<&DBPostProcessConfig>,
105    ) -> (Vec<Vec<BoundingBox>>, Vec<Vec<f32>>) {
106        // Use provided config or fall back to stored defaults
107        let thresh = config.map(|c| c.thresh).unwrap_or(self.thresh);
108        let box_thresh = config.map(|c| c.box_thresh).unwrap_or(self.box_thresh);
109        let unclip_ratio = config.map(|c| c.unclip_ratio).unwrap_or(self.unclip_ratio);
110
111        let mut all_boxes = Vec::new();
112        let mut all_scores = Vec::new();
113
114        for (batch_idx, shape_batch) in img_shapes.iter().enumerate() {
115            let pred_slice = preds.index_axis(Axis(0), batch_idx);
116            let pred_channel = pred_slice.index_axis(Axis(0), 0);
117
118            let (boxes, scores) =
119                self.process(&pred_channel, shape_batch, thresh, box_thresh, unclip_ratio);
120            all_boxes.push(boxes);
121            all_scores.push(scores);
122        }
123
124        (all_boxes, all_scores)
125    }
126
127    fn process(
128        &self,
129        pred: &ndarray::ArrayView2<f32>,
130        img_shape: &ImageScaleInfo,
131        thresh: f32,
132        box_thresh: f32,
133        unclip_ratio: f32,
134    ) -> (Vec<BoundingBox>, Vec<f32>) {
135        let src_h = img_shape.src_h as u32;
136        let src_w = img_shape.src_w as u32;
137
138        let height = pred.shape()[0] as u32;
139        let width = pred.shape()[1] as u32;
140
141        tracing::debug!(
142            "DBPostProcess: pred {}x{}, src {}x{} (dest dimensions)",
143            height,
144            width,
145            src_h,
146            src_w
147        );
148
149        // Create binary mask directly as GrayImage to avoid intermediate Vec<Vec<bool>>
150        let mut mask_img = image::GrayImage::new(width, height);
151        for y in 0..height as usize {
152            for x in 0..width as usize {
153                let pixel_value = if pred[[y, x]] > thresh { 255 } else { 0 };
154                mask_img.put_pixel(x as u32, y as u32, image::Luma([pixel_value]));
155            }
156        }
157
158        // Apply dilation if needed
159        let mask_img = if self.use_dilation {
160            self.dilate_mask_img(&mask_img)
161        } else {
162            mask_img
163        };
164
165        match self.box_type {
166            BoxType::Poly => {
167                self.polygons_from_bitmap(pred, &mask_img, src_w, src_h, box_thresh, unclip_ratio)
168            }
169            BoxType::Quad => {
170                self.boxes_from_bitmap(pred, &mask_img, src_w, src_h, box_thresh, unclip_ratio)
171            }
172        }
173    }
174}