oar_ocr_core/processors/
db_postprocess.rs1#[path = "db_bitmap.rs"]
10mod db_bitmap;
11#[path = "db_mask.rs"]
12mod db_mask;
13#[path = "db_score.rs"]
14mod db_score;
15
16use crate::core::Tensor4D;
17use crate::processors::geometry::BoundingBox;
18use crate::processors::types::{BoxType, ImageScaleInfo, ScoreMode};
19use ndarray::Axis;
20
21#[derive(Debug, Clone)]
26pub struct DBPostProcessConfig {
27 pub thresh: f32,
29 pub box_thresh: f32,
31 pub unclip_ratio: f32,
33}
34
35impl DBPostProcessConfig {
36 pub fn new(thresh: f32, box_thresh: f32, unclip_ratio: f32) -> Self {
38 Self {
39 thresh,
40 box_thresh,
41 unclip_ratio,
42 }
43 }
44}
45
46#[derive(Debug)]
48pub struct DBPostProcess {
49 pub thresh: f32,
51 pub box_thresh: f32,
53 pub max_candidates: usize,
55 pub unclip_ratio: f32,
57 pub min_size: f32,
59 pub score_mode: ScoreMode,
61 pub box_type: BoxType,
63 pub use_dilation: bool,
65}
66
67impl DBPostProcess {
68 pub fn new(
70 thresh: Option<f32>,
71 box_thresh: Option<f32>,
72 max_candidates: Option<usize>,
73 unclip_ratio: Option<f32>,
74 use_dilation: Option<bool>,
75 score_mode: Option<ScoreMode>,
76 box_type: Option<BoxType>,
77 ) -> Self {
78 Self {
79 thresh: thresh.unwrap_or(0.3),
80 box_thresh: box_thresh.unwrap_or(0.6),
81 max_candidates: max_candidates.unwrap_or(1000),
82 unclip_ratio: unclip_ratio.unwrap_or(1.5),
83 min_size: 3.0,
84 score_mode: score_mode.unwrap_or(ScoreMode::Fast),
85 box_type: box_type.unwrap_or(BoxType::Quad),
86 use_dilation: use_dilation.unwrap_or(false),
87 }
88 }
89
90 pub fn apply(
101 &self,
102 preds: &Tensor4D,
103 img_shapes: Vec<ImageScaleInfo>,
104 config: Option<&DBPostProcessConfig>,
105 ) -> (Vec<Vec<BoundingBox>>, Vec<Vec<f32>>) {
106 let thresh = config.map(|c| c.thresh).unwrap_or(self.thresh);
108 let box_thresh = config.map(|c| c.box_thresh).unwrap_or(self.box_thresh);
109 let unclip_ratio = config.map(|c| c.unclip_ratio).unwrap_or(self.unclip_ratio);
110
111 let mut all_boxes = Vec::new();
112 let mut all_scores = Vec::new();
113
114 for (batch_idx, shape_batch) in img_shapes.iter().enumerate() {
115 let pred_slice = preds.index_axis(Axis(0), batch_idx);
116 let pred_channel = pred_slice.index_axis(Axis(0), 0);
117
118 let (boxes, scores) =
119 self.process(&pred_channel, shape_batch, thresh, box_thresh, unclip_ratio);
120 all_boxes.push(boxes);
121 all_scores.push(scores);
122 }
123
124 (all_boxes, all_scores)
125 }
126
127 fn process(
128 &self,
129 pred: &ndarray::ArrayView2<f32>,
130 img_shape: &ImageScaleInfo,
131 thresh: f32,
132 box_thresh: f32,
133 unclip_ratio: f32,
134 ) -> (Vec<BoundingBox>, Vec<f32>) {
135 let src_h = img_shape.src_h as u32;
136 let src_w = img_shape.src_w as u32;
137
138 let height = pred.shape()[0] as u32;
139 let width = pred.shape()[1] as u32;
140
141 tracing::debug!(
142 "DBPostProcess: pred {}x{}, src {}x{} (dest dimensions)",
143 height,
144 width,
145 src_h,
146 src_w
147 );
148
149 let mut mask_img = image::GrayImage::new(width, height);
151 for y in 0..height as usize {
152 for x in 0..width as usize {
153 let pixel_value = if pred[[y, x]] > thresh { 255 } else { 0 };
154 mask_img.put_pixel(x as u32, y as u32, image::Luma([pixel_value]));
155 }
156 }
157
158 let mask_img = if self.use_dilation {
160 self.dilate_mask_img(&mask_img)
161 } else {
162 mask_img
163 };
164
165 match self.box_type {
166 BoxType::Poly => {
167 self.polygons_from_bitmap(pred, &mask_img, src_w, src_h, box_thresh, unclip_ratio)
168 }
169 BoxType::Quad => {
170 self.boxes_from_bitmap(pred, &mask_img, src_w, src_h, box_thresh, unclip_ratio)
171 }
172 }
173 }
174}