1use image::{DynamicImage, GenericImageView};
6use ndarray::ArrayD;
7use std::path::Path;
8
9use crate::error::{OcrError, OcrResult};
10use crate::mnn::{InferenceConfig, InferenceEngine};
11use crate::postprocess::{extract_boxes_with_unclip, TextBox};
12use crate::preprocess::{preprocess_for_det, NormalizeParams};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum DetPrecisionMode {
17 #[default]
19 Fast,
20}
21
22#[derive(Debug, Clone)]
24pub struct DetOptions {
25 pub max_side_len: u32,
27 pub box_threshold: f32,
29 pub unclip_ratio: f32,
31 pub score_threshold: f32,
33 pub min_area: u32,
35 pub box_border: u32,
37 pub merge_boxes: bool,
39 pub merge_threshold: i32,
41 pub precision_mode: DetPrecisionMode,
43 pub multi_scales: Vec<f32>,
45 pub block_size: u32,
47 pub block_overlap: u32,
49 pub nms_threshold: f32,
51}
52
53impl Default for DetOptions {
54 fn default() -> Self {
55 Self {
56 max_side_len: 960,
57 box_threshold: 0.5,
58 unclip_ratio: 1.5,
59 score_threshold: 0.3,
60 min_area: 16,
61 box_border: 5,
62 merge_boxes: false,
63 merge_threshold: 10,
64 precision_mode: DetPrecisionMode::Fast,
65 multi_scales: vec![0.5, 1.0, 1.5],
66 block_size: 640,
67 block_overlap: 100,
68 nms_threshold: 0.3,
69 }
70 }
71}
72
73impl DetOptions {
74 pub fn new() -> Self {
76 Self::default()
77 }
78
79 pub fn with_max_side_len(mut self, len: u32) -> Self {
81 self.max_side_len = len;
82 self
83 }
84
85 pub fn with_box_threshold(mut self, threshold: f32) -> Self {
87 self.box_threshold = threshold;
88 self
89 }
90
91 pub fn with_score_threshold(mut self, threshold: f32) -> Self {
93 self.score_threshold = threshold;
94 self
95 }
96
97 pub fn with_min_area(mut self, area: u32) -> Self {
99 self.min_area = area;
100 self
101 }
102
103 pub fn with_box_border(mut self, border: u32) -> Self {
105 self.box_border = border;
106 self
107 }
108
109 pub fn with_merge_boxes(mut self, merge: bool) -> Self {
111 self.merge_boxes = merge;
112 self
113 }
114
115 pub fn with_merge_threshold(mut self, threshold: i32) -> Self {
117 self.merge_threshold = threshold;
118 self
119 }
120
121 pub fn with_precision_mode(mut self, mode: DetPrecisionMode) -> Self {
123 self.precision_mode = mode;
124 self
125 }
126
127 pub fn with_multi_scales(mut self, scales: Vec<f32>) -> Self {
129 self.multi_scales = scales;
130 self
131 }
132
133 pub fn with_block_size(mut self, size: u32) -> Self {
135 self.block_size = size;
136 self
137 }
138
139 pub fn fast() -> Self {
141 Self {
142 max_side_len: 960,
143 precision_mode: DetPrecisionMode::Fast,
144 ..Default::default()
145 }
146 }
147}
148
149pub struct DetModel {
151 engine: InferenceEngine,
152 options: DetOptions,
153 normalize_params: NormalizeParams,
154}
155
156impl DetModel {
157 pub fn from_file(
163 model_path: impl AsRef<Path>,
164 config: Option<InferenceConfig>,
165 ) -> OcrResult<Self> {
166 let engine = InferenceEngine::from_file(model_path, config)?;
167 Ok(Self {
168 engine,
169 options: DetOptions::default(),
170 normalize_params: NormalizeParams::paddle_det(),
171 })
172 }
173
174 pub fn from_bytes(model_bytes: &[u8], config: Option<InferenceConfig>) -> OcrResult<Self> {
176 let engine = InferenceEngine::from_buffer(model_bytes, config)?;
177 Ok(Self {
178 engine,
179 options: DetOptions::default(),
180 normalize_params: NormalizeParams::paddle_det(),
181 })
182 }
183
184 pub fn with_options(mut self, options: DetOptions) -> Self {
186 self.options = options;
187 self
188 }
189
190 pub fn options(&self) -> &DetOptions {
192 &self.options
193 }
194
195 pub fn options_mut(&mut self) -> &mut DetOptions {
197 &mut self.options
198 }
199
200 pub fn detect(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
208 self.detect_fast(image)
209 }
210
211 pub fn detect_and_crop(&self, image: &DynamicImage) -> OcrResult<Vec<(DynamicImage, TextBox)>> {
219 let boxes = self.detect(image)?;
220 let (width, height) = image.dimensions();
221
222 let mut results = Vec::with_capacity(boxes.len());
223
224 for text_box in boxes {
225 let expanded = text_box.expand(self.options.box_border, width, height);
227
228 let cropped = image.crop_imm(
230 expanded.rect.left() as u32,
231 expanded.rect.top() as u32,
232 expanded.rect.width(),
233 expanded.rect.height(),
234 );
235
236 results.push((cropped, expanded));
237 }
238
239 Ok(results)
240 }
241
242 fn detect_fast(&self, image: &DynamicImage) -> OcrResult<Vec<TextBox>> {
244 let (original_width, original_height) = image.dimensions();
245
246 let scaled = self.scale_image(image);
248 let (scaled_width, scaled_height) = scaled.dimensions();
249
250 let input = preprocess_for_det(&scaled, &self.normalize_params);
252
253 let output = self.engine.run_dynamic(input.view().into_dyn())?;
255
256 let output_shape = output.shape();
258 let out_w = output_shape[3] as u32;
259 let out_h = output_shape[2] as u32;
260
261 let boxes = self.postprocess_output(
262 &output,
263 out_w,
264 out_h,
265 scaled_width,
266 scaled_height,
267 original_width,
268 original_height,
269 )?;
270
271 Ok(boxes)
272 }
273
274 fn scale_image(&self, image: &DynamicImage) -> DynamicImage {
277 let (w, h) = image.dimensions();
278 let max_dim = w.max(h);
279
280 if max_dim <= self.options.max_side_len {
281 return image.clone();
282 }
283
284 let scale = self.options.max_side_len as f64 / max_dim as f64;
285 let new_w = (w as f64 * scale).round() as u32;
286 let new_h = (h as f64 * scale).round() as u32;
287
288 image.resize_exact(new_w, new_h, image::imageops::FilterType::Lanczos3)
289 }
290
291 fn postprocess_output(
293 &self,
294 output: &ArrayD<f32>,
295 out_w: u32,
296 out_h: u32,
297 scaled_width: u32,
298 scaled_height: u32,
299 original_width: u32,
300 original_height: u32,
301 ) -> OcrResult<Vec<TextBox>> {
302 let output_shape = output.shape();
304 if output_shape.len() < 3 {
305 return Err(OcrError::PostprocessError(
306 "Detection model output shape invalid".to_string(),
307 ));
308 }
309
310 let mask_data: Vec<f32> = output.iter().cloned().collect();
312
313 let binary_mask: Vec<u8> = mask_data
315 .iter()
316 .map(|&v| {
317 if v > self.options.score_threshold {
318 255u8
319 } else {
320 0u8
321 }
322 })
323 .collect();
324
325 let boxes = extract_boxes_with_unclip(
328 &binary_mask,
329 out_w,
330 out_h,
331 scaled_width,
332 scaled_height,
333 original_width,
334 original_height,
335 self.options.min_area,
336 self.options.unclip_ratio,
337 );
338
339 Ok(boxes)
340 }
341}
342
343impl DetModel {
345 pub fn run_raw(&self, input: ndarray::ArrayViewD<f32>) -> OcrResult<ArrayD<f32>> {
355 Ok(self.engine.run_dynamic(input)?)
356 }
357
358 pub fn input_shape(&self) -> &[usize] {
360 self.engine.input_shape()
361 }
362
363 pub fn output_shape(&self) -> &[usize] {
365 self.engine.output_shape()
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_det_options_default() {
375 let opts = DetOptions::default();
376 assert_eq!(opts.max_side_len, 960);
377 assert_eq!(opts.box_threshold, 0.5);
378 assert_eq!(opts.unclip_ratio, 1.5);
379 assert_eq!(opts.score_threshold, 0.3);
380 assert_eq!(opts.min_area, 16);
381 assert_eq!(opts.box_border, 5);
382 assert!(!opts.merge_boxes);
383 assert_eq!(opts.merge_threshold, 10);
384 assert_eq!(opts.precision_mode, DetPrecisionMode::Fast);
385 assert_eq!(opts.nms_threshold, 0.3);
386 }
387
388 #[test]
389 fn test_det_options_fast() {
390 let opts = DetOptions::fast();
391 assert_eq!(opts.max_side_len, 960);
392 assert_eq!(opts.precision_mode, DetPrecisionMode::Fast);
393 }
394
395 #[test]
396 fn test_det_options_builder() {
397 let opts = DetOptions::new()
398 .with_max_side_len(1280)
399 .with_box_threshold(0.6)
400 .with_score_threshold(0.4)
401 .with_min_area(32)
402 .with_box_border(10)
403 .with_merge_boxes(true)
404 .with_merge_threshold(20)
405 .with_precision_mode(DetPrecisionMode::Fast)
406 .with_multi_scales(vec![0.5, 1.0, 1.5])
407 .with_block_size(800);
408
409 assert_eq!(opts.max_side_len, 1280);
410 assert_eq!(opts.box_threshold, 0.6);
411 assert_eq!(opts.score_threshold, 0.4);
412 assert_eq!(opts.min_area, 32);
413 assert_eq!(opts.box_border, 10);
414 assert!(opts.merge_boxes);
415 assert_eq!(opts.merge_threshold, 20);
416 assert_eq!(opts.precision_mode, DetPrecisionMode::Fast);
417 assert_eq!(opts.multi_scales, vec![0.5, 1.0, 1.5]);
418 assert_eq!(opts.block_size, 800);
419 }
420
421 #[test]
422 fn test_det_precision_mode_default() {
423 let mode = DetPrecisionMode::default();
424 assert_eq!(mode, DetPrecisionMode::Fast);
425 }
426
427 #[test]
428 fn test_det_precision_mode_equality() {
429 assert_eq!(DetPrecisionMode::Fast, DetPrecisionMode::Fast);
430 }
431
432 #[test]
433 fn test_det_options_chaining() {
434 let opts = DetOptions::new()
436 .with_max_side_len(1000)
437 .with_box_threshold(0.7);
438
439 assert_eq!(opts.max_side_len, 1000);
440 assert_eq!(opts.box_threshold, 0.7);
441 assert_eq!(opts.score_threshold, 0.3);
443 }
444
445 #[test]
446 fn test_det_options_presets_are_valid() {
447 let fast = DetOptions::fast();
449 assert!(fast.box_threshold >= 0.0 && fast.box_threshold <= 1.0);
450 assert!(fast.score_threshold >= 0.0 && fast.score_threshold <= 1.0);
451 assert!(fast.nms_threshold >= 0.0 && fast.nms_threshold <= 1.0);
452 }
453}