oar_ocr_core/processors/
db_postprocess.rs1#[path = "db_bitmap.rs"]
10mod db_bitmap;
11#[path = "db_mask.rs"]
12mod db_mask;
13#[path = "db_score.rs"]
14mod db_score;
15
16use crate::processors::geometry::BoundingBox;
17use crate::processors::types::{BoxType, ImageScaleInfo, ScoreMode};
18use ndarray::Axis;
19
20#[derive(Debug, Clone)]
25pub struct DBPostProcessConfig {
26 pub thresh: f32,
28 pub box_thresh: f32,
30 pub unclip_ratio: f32,
32}
33
34impl DBPostProcessConfig {
35 pub fn new(thresh: f32, box_thresh: f32, unclip_ratio: f32) -> Self {
37 Self {
38 thresh,
39 box_thresh,
40 unclip_ratio,
41 }
42 }
43}
44
45#[derive(Debug)]
47pub struct DBPostProcess {
48 pub thresh: f32,
50 pub box_thresh: f32,
52 pub max_candidates: usize,
54 pub unclip_ratio: f32,
56 pub min_size: f32,
58 pub score_mode: ScoreMode,
60 pub box_type: BoxType,
62 pub use_dilation: bool,
64}
65
66impl DBPostProcess {
67 pub fn new(
69 thresh: Option<f32>,
70 box_thresh: Option<f32>,
71 max_candidates: Option<usize>,
72 unclip_ratio: Option<f32>,
73 use_dilation: Option<bool>,
74 score_mode: Option<ScoreMode>,
75 box_type: Option<BoxType>,
76 ) -> Self {
77 Self {
78 thresh: thresh.unwrap_or(0.3),
79 box_thresh: box_thresh.unwrap_or(0.6),
80 max_candidates: max_candidates.unwrap_or(1000),
81 unclip_ratio: unclip_ratio.unwrap_or(1.5),
82 min_size: 3.0,
83 score_mode: score_mode.unwrap_or(ScoreMode::Fast),
84 box_type: box_type.unwrap_or(BoxType::Quad),
85 use_dilation: use_dilation.unwrap_or(false),
86 }
87 }
88
89 pub fn apply(
100 &self,
101 preds: &ndarray::Array4<f32>,
102 img_shapes: Vec<ImageScaleInfo>,
103 config: Option<&DBPostProcessConfig>,
104 ) -> (Vec<Vec<BoundingBox>>, Vec<Vec<f32>>) {
105 let thresh = config.map(|c| c.thresh).unwrap_or(self.thresh);
107 let box_thresh = config.map(|c| c.box_thresh).unwrap_or(self.box_thresh);
108 let unclip_ratio = config.map(|c| c.unclip_ratio).unwrap_or(self.unclip_ratio);
109
110 let mut all_boxes = Vec::new();
111 let mut all_scores = Vec::new();
112
113 for (batch_idx, shape_batch) in img_shapes.iter().enumerate() {
114 let pred_slice = preds.index_axis(Axis(0), batch_idx);
115 let pred_channel = pred_slice.index_axis(Axis(0), 0);
116
117 let (boxes, scores) =
118 self.process(&pred_channel, shape_batch, thresh, box_thresh, unclip_ratio);
119 all_boxes.push(boxes);
120 all_scores.push(scores);
121 }
122
123 (all_boxes, all_scores)
124 }
125
126 fn process(
127 &self,
128 pred: &ndarray::ArrayView2<f32>,
129 img_shape: &ImageScaleInfo,
130 thresh: f32,
131 box_thresh: f32,
132 unclip_ratio: f32,
133 ) -> (Vec<BoundingBox>, Vec<f32>) {
134 let src_h = img_shape.src_h as u32;
135 let src_w = img_shape.src_w as u32;
136
137 let height = pred.shape()[0] as u32;
138 let width = pred.shape()[1] as u32;
139
140 tracing::debug!(
141 "DBPostProcess: pred {}x{}, src {}x{} (dest dimensions)",
142 height,
143 width,
144 src_h,
145 src_w
146 );
147
148 let mut mask_img = image::GrayImage::new(width, height);
150 for y in 0..height as usize {
151 for x in 0..width as usize {
152 let pixel_value = if pred[[y, x]] > thresh { 255 } else { 0 };
153 mask_img.put_pixel(x as u32, y as u32, image::Luma([pixel_value]));
154 }
155 }
156
157 let mask_img = if self.use_dilation {
159 self.dilate_mask_img(&mask_img)
160 } else {
161 mask_img
162 };
163
164 match self.box_type {
165 BoxType::Poly => {
166 self.polygons_from_bitmap(pred, &mask_img, src_w, src_h, box_thresh, unclip_ratio)
167 }
168 BoxType::Quad => {
169 self.boxes_from_bitmap(pred, &mask_img, src_w, src_h, box_thresh, unclip_ratio)
170 }
171 }
172 }
173}