Skip to main content

oar_ocr_core/models/detection/
db.rs

1//! DB (Differentiable Binarization) Model
2//!
3//! This module provides a pure implementation of the DB text detection model.
4//! The model handles preprocessing, inference, and postprocessing independently of tasks.
5
6use crate::core::inference::{OrtInfer, TensorInput};
7use crate::core::{OCRError, validate_positive, validate_range};
8use crate::processors::{
9    BoundingBox, BoxType, DBPostProcess, DBPostProcessConfig, DetResizeForTest, ImageScaleInfo,
10    LimitType, NormalizeImage, ScoreMode, TensorLayout,
11};
12use image::{DynamicImage, RgbImage};
13use std::path::Path;
14use tracing::debug;
15
16/// Configuration for DB model preprocessing.
17#[derive(Debug, Clone, Default)]
18pub struct DBPreprocessConfig {
19    /// Limit for the side length of the image
20    pub limit_side_len: Option<u32>,
21    /// Type of limit to apply
22    pub limit_type: Option<LimitType>,
23    /// Maximum side limit for the image
24    pub max_side_limit: Option<u32>,
25    /// Resize long dimension (alternative to limit_side_len)
26    pub resize_long: Option<u32>,
27}
28
29/// Configuration for DB model postprocessing.
30#[derive(Debug, Clone)]
31pub struct DBPostprocessConfig {
32    /// Pixel-level threshold for text detection
33    pub score_threshold: f32,
34    /// Box-level threshold for filtering detections
35    pub box_threshold: f32,
36    /// Expansion ratio for detected regions using Vatti clipping
37    pub unclip_ratio: f32,
38    /// Maximum number of candidate detections
39    pub max_candidates: usize,
40    /// Whether to use dilation
41    pub use_dilation: bool,
42    /// Score calculation mode
43    pub score_mode: ScoreMode,
44    /// Type of bounding box (Quad or Poly)
45    pub box_type: BoxType,
46}
47
48impl Default for DBPostprocessConfig {
49    fn default() -> Self {
50        Self {
51            score_threshold: 0.3,
52            box_threshold: 0.7,
53            unclip_ratio: 1.5,
54            max_candidates: 1000,
55            use_dilation: false,
56            score_mode: ScoreMode::Fast,
57            box_type: BoxType::Quad,
58        }
59    }
60}
61
62impl DBPostprocessConfig {
63    /// Validates the configuration parameters.
64    pub fn validate(&self) -> Result<(), OCRError> {
65        // Validate score_threshold is in [0, 1]
66        validate_range(self.score_threshold, 0.0, 1.0, "score_threshold")?;
67
68        // Validate box_threshold is in [0, 1]
69        validate_range(self.box_threshold, 0.0, 1.0, "box_threshold")?;
70
71        // Validate unclip_ratio is positive
72        validate_positive(self.unclip_ratio, "unclip_ratio")?;
73
74        // Validate max_candidates is positive
75        validate_positive(self.max_candidates, "max_candidates")?;
76
77        Ok(())
78    }
79}
80
81/// DB model output containing bounding boxes and confidence scores.
82#[derive(Debug, Clone)]
83pub struct DBModelOutput {
84    /// Detected bounding boxes for each image in the batch
85    pub boxes: Vec<Vec<BoundingBox>>,
86    /// Confidence scores for each bounding box
87    pub scores: Vec<Vec<f32>>,
88}
89
90/// Pure DB model implementation.
91///
92/// This model implements the core DB architecture and can be configured
93/// for different detection tasks through preprocessing and postprocessing configs.
94#[derive(Debug)]
95pub struct DBModel {
96    /// ONNX Runtime inference engine
97    inference: OrtInfer,
98    /// Image resizer for preprocessing
99    resizer: DetResizeForTest,
100    /// Image normalizer for preprocessing
101    normalizer: NormalizeImage,
102    /// Postprocessor for converting predictions to bounding boxes
103    postprocessor: DBPostProcess,
104}
105
106impl DBModel {
107    /// Creates a new DB model.
108    pub fn new(
109        inference: OrtInfer,
110        resizer: DetResizeForTest,
111        normalizer: NormalizeImage,
112        postprocessor: DBPostProcess,
113    ) -> Self {
114        Self {
115            inference,
116            resizer,
117            normalizer,
118            postprocessor,
119        }
120    }
121
122    /// Preprocesses images for detection.
123    pub fn preprocess(
124        &self,
125        images: Vec<RgbImage>,
126    ) -> Result<(ndarray::Array4<f32>, Vec<ImageScaleInfo>), OCRError> {
127        // Convert to DynamicImage
128        let dynamic_images: Vec<DynamicImage> =
129            images.into_iter().map(DynamicImage::ImageRgb8).collect();
130
131        // Apply detection resizing
132        let (resized_images, img_shapes) = self.resizer.apply(
133            dynamic_images,
134            None, // Use default limit_side_len
135            None, // Use default limit_type
136            None, // Use default max_side_limit
137        );
138
139        debug!("After resize: {} images", resized_images.len());
140        for (i, (img, shape)) in resized_images.iter().zip(&img_shapes).enumerate() {
141            debug!(
142                "  Image {}: {}x{}, shape=[src_h={:.0}, src_w={:.0}, ratio_h={:.3}, ratio_w={:.3}]",
143                i,
144                img.width(),
145                img.height(),
146                shape.src_h,
147                shape.src_w,
148                shape.ratio_h,
149                shape.ratio_w
150            );
151        }
152
153        // Apply ImageNet normalization and convert to tensor.
154        //
155        // Note: External models often decode images as BGR and then normalize with
156        // mean/std as provided in their configs. In this repo, input images are
157        // loaded as RGB; we keep them in RGB here and rely on `NormalizeImage`
158        // with `ColorOrder::BGR` to map channels (RGB -> BGR) without a manual swap.
159        let batch_tensor = self.normalizer.normalize_batch_to(resized_images)?;
160        debug!("Batch tensor shape: {:?}", batch_tensor.shape());
161
162        Ok((batch_tensor, img_shapes))
163    }
164
165    /// Runs inference on the preprocessed batch.
166    pub fn infer(
167        &self,
168        batch_tensor: &ndarray::Array4<f32>,
169    ) -> Result<ndarray::Array4<f32>, OCRError> {
170        let input_name = self.inference.input_name();
171        let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
172
173        let outputs = self
174            .inference
175            .infer(&inputs)
176            .map_err(|e| OCRError::Inference {
177                model_name: "DB".to_string(),
178                context: format!(
179                    "failed to run inference on batch with shape {:?}",
180                    batch_tensor.shape()
181                ),
182                source: Box::new(e),
183            })?;
184
185        let output = outputs
186            .into_iter()
187            .next()
188            .ok_or_else(|| OCRError::InvalidInput {
189                message: "DB: no output returned from inference".to_string(),
190            })?;
191
192        output
193            .1
194            .try_into_array4_f32()
195            .map_err(|e| OCRError::Inference {
196                model_name: "DB".to_string(),
197                context: "failed to convert output to 4D array".to_string(),
198                source: Box::new(e),
199            })
200    }
201
202    /// Postprocesses model predictions to bounding boxes.
203    pub fn postprocess(
204        &self,
205        predictions: &ndarray::Array4<f32>,
206        img_shapes: Vec<ImageScaleInfo>,
207        score_threshold: f32,
208        box_threshold: f32,
209        unclip_ratio: f32,
210    ) -> DBModelOutput {
211        let config = DBPostProcessConfig::new(score_threshold, box_threshold, unclip_ratio);
212        let (boxes, scores) = self
213            .postprocessor
214            .apply(predictions, img_shapes, Some(&config));
215        DBModelOutput { boxes, scores }
216    }
217
218    /// Runs the complete forward pass: preprocess -> infer -> postprocess.
219    pub fn forward(
220        &self,
221        images: Vec<RgbImage>,
222        score_threshold: f32,
223        box_threshold: f32,
224        unclip_ratio: f32,
225    ) -> Result<DBModelOutput, OCRError> {
226        let (batch_tensor, img_shapes) = self.preprocess(images)?;
227        let predictions = self.infer(&batch_tensor)?;
228        Ok(self.postprocess(
229            &predictions,
230            img_shapes,
231            score_threshold,
232            box_threshold,
233            unclip_ratio,
234        ))
235    }
236}
237
238/// Builder for DB model.
239pub struct DBModelBuilder {
240    /// Preprocessing configuration
241    preprocess_config: DBPreprocessConfig,
242    /// Postprocessing configuration
243    postprocess_config: DBPostprocessConfig,
244    /// ONNX Runtime session configuration
245    ort_config: Option<crate::core::config::OrtSessionConfig>,
246}
247
248impl DBModelBuilder {
249    /// Creates a new DB model builder with default settings.
250    pub fn new() -> Self {
251        Self {
252            preprocess_config: DBPreprocessConfig::default(),
253            postprocess_config: DBPostprocessConfig::default(),
254            ort_config: None,
255        }
256    }
257
258    /// Sets the preprocessing configuration.
259    pub fn preprocess_config(mut self, config: DBPreprocessConfig) -> Self {
260        self.preprocess_config = config;
261        self
262    }
263
264    /// Sets the postprocessing configuration.
265    pub fn postprocess_config(mut self, config: DBPostprocessConfig) -> Self {
266        self.postprocess_config = config;
267        self
268    }
269
270    /// Sets the ONNX Runtime session configuration.
271    pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
272        self.ort_config = Some(config);
273        self
274    }
275
276    /// Builds the DB model.
277    pub fn build(self, model_path: &Path) -> Result<DBModel, OCRError> {
278        // Create ONNX inference engine
279        let inference = if self.ort_config.is_some() {
280            use crate::core::config::ModelInferenceConfig;
281            let common_config = ModelInferenceConfig {
282                ort_session: self.ort_config,
283                ..Default::default()
284            };
285            OrtInfer::from_config(&common_config, model_path, Some("x"))?
286        } else {
287            OrtInfer::new(model_path, Some("x"))?
288        };
289
290        // Create resizer
291        let resizer = DetResizeForTest::new(
292            None,                                  // input_shape
293            None,                                  // image_shape
294            None,                                  // keep_ratio
295            self.preprocess_config.limit_side_len, // limit_side_len
296            self.preprocess_config.limit_type,     // limit_type
297            self.preprocess_config.resize_long,    // resize_long
298            self.preprocess_config.max_side_limit, // max_side_limit
299        );
300
301        // Create normalizer.
302        // External models read images in BGR. Their configs use ImageNet stats
303        // in that *same* channel order (B, G, R). Our images are loaded as RGB,
304        // so we keep them in RGB and use `ColorOrder::BGR` to map channels
305        // into BGR order during normalization.
306        let normalizer = NormalizeImage::with_color_order(
307            Some(1.0 / 255.0),               // scale
308            Some(vec![0.485, 0.456, 0.406]), // mean
309            Some(vec![0.229, 0.224, 0.225]), // std
310            Some(TensorLayout::CHW),         // order
311            Some(crate::processors::types::ColorOrder::BGR),
312        )?;
313
314        // Create postprocessor
315        let postprocessor = DBPostProcess::new(
316            Some(self.postprocess_config.score_threshold),
317            Some(self.postprocess_config.box_threshold),
318            Some(self.postprocess_config.max_candidates),
319            Some(self.postprocess_config.unclip_ratio),
320            Some(self.postprocess_config.use_dilation),
321            Some(self.postprocess_config.score_mode),
322            Some(self.postprocess_config.box_type),
323        );
324
325        Ok(DBModel::new(inference, resizer, normalizer, postprocessor))
326    }
327}
328
329impl Default for DBModelBuilder {
330    fn default() -> Self {
331        Self::new()
332    }
333}