oar_ocr_core/models/detection/
db.rs1use crate::core::inference::{OrtInfer, TensorInput};
7use crate::core::{OCRError, validate_positive, validate_range};
8use crate::processors::{
9 BoundingBox, BoxType, DBPostProcess, DBPostProcessConfig, DetResizeForTest, ImageScaleInfo,
10 LimitType, NormalizeImage, ScoreMode, TensorLayout,
11};
12use image::{DynamicImage, RgbImage};
13use std::path::Path;
14use tracing::debug;
15
16#[derive(Debug, Clone, Default)]
18pub struct DBPreprocessConfig {
19 pub limit_side_len: Option<u32>,
21 pub limit_type: Option<LimitType>,
23 pub max_side_limit: Option<u32>,
25 pub resize_long: Option<u32>,
27}
28
29#[derive(Debug, Clone)]
31pub struct DBPostprocessConfig {
32 pub score_threshold: f32,
34 pub box_threshold: f32,
36 pub unclip_ratio: f32,
38 pub max_candidates: usize,
40 pub use_dilation: bool,
42 pub score_mode: ScoreMode,
44 pub box_type: BoxType,
46}
47
48impl Default for DBPostprocessConfig {
49 fn default() -> Self {
50 Self {
51 score_threshold: 0.3,
52 box_threshold: 0.7,
53 unclip_ratio: 1.5,
54 max_candidates: 1000,
55 use_dilation: false,
56 score_mode: ScoreMode::Fast,
57 box_type: BoxType::Quad,
58 }
59 }
60}
61
62impl DBPostprocessConfig {
63 pub fn validate(&self) -> Result<(), OCRError> {
65 validate_range(self.score_threshold, 0.0, 1.0, "score_threshold")?;
67
68 validate_range(self.box_threshold, 0.0, 1.0, "box_threshold")?;
70
71 validate_positive(self.unclip_ratio, "unclip_ratio")?;
73
74 validate_positive(self.max_candidates, "max_candidates")?;
76
77 Ok(())
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct DBModelOutput {
84 pub boxes: Vec<Vec<BoundingBox>>,
86 pub scores: Vec<Vec<f32>>,
88}
89
90#[derive(Debug)]
95pub struct DBModel {
96 inference: OrtInfer,
98 resizer: DetResizeForTest,
100 normalizer: NormalizeImage,
102 postprocessor: DBPostProcess,
104}
105
106impl DBModel {
107 pub fn new(
109 inference: OrtInfer,
110 resizer: DetResizeForTest,
111 normalizer: NormalizeImage,
112 postprocessor: DBPostProcess,
113 ) -> Self {
114 Self {
115 inference,
116 resizer,
117 normalizer,
118 postprocessor,
119 }
120 }
121
122 pub fn preprocess(
124 &self,
125 images: Vec<RgbImage>,
126 ) -> Result<(ndarray::Array4<f32>, Vec<ImageScaleInfo>), OCRError> {
127 let dynamic_images: Vec<DynamicImage> =
129 images.into_iter().map(DynamicImage::ImageRgb8).collect();
130
131 let (resized_images, img_shapes) = self.resizer.apply(
133 dynamic_images,
134 None, None, None, );
138
139 debug!("After resize: {} images", resized_images.len());
140 for (i, (img, shape)) in resized_images.iter().zip(&img_shapes).enumerate() {
141 debug!(
142 " Image {}: {}x{}, shape=[src_h={:.0}, src_w={:.0}, ratio_h={:.3}, ratio_w={:.3}]",
143 i,
144 img.width(),
145 img.height(),
146 shape.src_h,
147 shape.src_w,
148 shape.ratio_h,
149 shape.ratio_w
150 );
151 }
152
153 let batch_tensor = self.normalizer.normalize_batch_to(resized_images)?;
160 debug!("Batch tensor shape: {:?}", batch_tensor.shape());
161
162 Ok((batch_tensor, img_shapes))
163 }
164
165 pub fn infer(
167 &self,
168 batch_tensor: &ndarray::Array4<f32>,
169 ) -> Result<ndarray::Array4<f32>, OCRError> {
170 let input_name = self.inference.input_name();
171 let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
172
173 let outputs = self
174 .inference
175 .infer(&inputs)
176 .map_err(|e| OCRError::Inference {
177 model_name: "DB".to_string(),
178 context: format!(
179 "failed to run inference on batch with shape {:?}",
180 batch_tensor.shape()
181 ),
182 source: Box::new(e),
183 })?;
184
185 let output = outputs
186 .into_iter()
187 .next()
188 .ok_or_else(|| OCRError::InvalidInput {
189 message: "DB: no output returned from inference".to_string(),
190 })?;
191
192 output
193 .1
194 .try_into_array4_f32()
195 .map_err(|e| OCRError::Inference {
196 model_name: "DB".to_string(),
197 context: "failed to convert output to 4D array".to_string(),
198 source: Box::new(e),
199 })
200 }
201
202 pub fn postprocess(
204 &self,
205 predictions: &ndarray::Array4<f32>,
206 img_shapes: Vec<ImageScaleInfo>,
207 score_threshold: f32,
208 box_threshold: f32,
209 unclip_ratio: f32,
210 ) -> DBModelOutput {
211 let config = DBPostProcessConfig::new(score_threshold, box_threshold, unclip_ratio);
212 let (boxes, scores) = self
213 .postprocessor
214 .apply(predictions, img_shapes, Some(&config));
215 DBModelOutput { boxes, scores }
216 }
217
218 pub fn forward(
220 &self,
221 images: Vec<RgbImage>,
222 score_threshold: f32,
223 box_threshold: f32,
224 unclip_ratio: f32,
225 ) -> Result<DBModelOutput, OCRError> {
226 let (batch_tensor, img_shapes) = self.preprocess(images)?;
227 let predictions = self.infer(&batch_tensor)?;
228 Ok(self.postprocess(
229 &predictions,
230 img_shapes,
231 score_threshold,
232 box_threshold,
233 unclip_ratio,
234 ))
235 }
236}
237
238pub struct DBModelBuilder {
240 preprocess_config: DBPreprocessConfig,
242 postprocess_config: DBPostprocessConfig,
244 ort_config: Option<crate::core::config::OrtSessionConfig>,
246}
247
248impl DBModelBuilder {
249 pub fn new() -> Self {
251 Self {
252 preprocess_config: DBPreprocessConfig::default(),
253 postprocess_config: DBPostprocessConfig::default(),
254 ort_config: None,
255 }
256 }
257
258 pub fn preprocess_config(mut self, config: DBPreprocessConfig) -> Self {
260 self.preprocess_config = config;
261 self
262 }
263
264 pub fn postprocess_config(mut self, config: DBPostprocessConfig) -> Self {
266 self.postprocess_config = config;
267 self
268 }
269
270 pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
272 self.ort_config = Some(config);
273 self
274 }
275
276 pub fn build(self, model_path: &Path) -> Result<DBModel, OCRError> {
278 let inference = if self.ort_config.is_some() {
280 use crate::core::config::ModelInferenceConfig;
281 let common_config = ModelInferenceConfig {
282 ort_session: self.ort_config,
283 ..Default::default()
284 };
285 OrtInfer::from_config(&common_config, model_path, Some("x"))?
286 } else {
287 OrtInfer::new(model_path, Some("x"))?
288 };
289
290 let resizer = DetResizeForTest::new(
292 None, None, None, self.preprocess_config.limit_side_len, self.preprocess_config.limit_type, self.preprocess_config.resize_long, self.preprocess_config.max_side_limit, );
300
301 let normalizer = NormalizeImage::with_color_order(
307 Some(1.0 / 255.0), Some(vec![0.485, 0.456, 0.406]), Some(vec![0.229, 0.224, 0.225]), Some(TensorLayout::CHW), Some(crate::processors::types::ColorOrder::BGR),
312 )?;
313
314 let postprocessor = DBPostProcess::new(
316 Some(self.postprocess_config.score_threshold),
317 Some(self.postprocess_config.box_threshold),
318 Some(self.postprocess_config.max_candidates),
319 Some(self.postprocess_config.unclip_ratio),
320 Some(self.postprocess_config.use_dilation),
321 Some(self.postprocess_config.score_mode),
322 Some(self.postprocess_config.box_type),
323 );
324
325 Ok(DBModel::new(inference, resizer, normalizer, postprocessor))
326 }
327}
328
329impl Default for DBModelBuilder {
330 fn default() -> Self {
331 Self::new()
332 }
333}