1use crate::processors::{BoundingBox, DBPostProcess, DetResizeForTest, LimitType, NormalizeImage};
15use image::{DynamicImage, RgbImage};
16use std::fmt;
17use std::path::Path;
18use std::sync::Arc;
19
20use crate::impl_config_new_and_with_common;
21
22use crate::impl_common_builder_methods;
23
24use crate::core::ImageReader as CoreImageReader;
25use crate::core::{
26 BatchData, CommonBuilderConfig, OCRError, Tensor4D,
27 config::{ConfigValidator, ConfigValidatorExt},
28 constants::{DEFAULT_BATCH_SIZE, DEFAULT_MAX_SIDE_LIMIT},
29};
30use crate::core::{DefaultImageReader, OrtInfer};
31use crate::core::{
32 GranularImageReader as GIReader, InferenceEngine as GInferenceEngine, ModularPredictor,
33 Postprocessor as GPostprocessor, Preprocessor as GPreprocessor,
34};
35
36const DEFAULT_THRESH: f32 = 0.3;
37
38const DEFAULT_BOX_THRESH: f32 = 0.6;
39
40const DEFAULT_UNCLIP_RATIO: f32 = 1.5;
41
42#[derive(Debug, Clone, Default)]
46pub struct TextDetConfig {
47 pub limit_side_len: Option<u32>,
49 pub limit_type: Option<LimitType>,
51 pub thresh: Option<f32>,
53 pub box_thresh: Option<f32>,
55 pub unclip_ratio: Option<f32>,
57 pub max_side_limit: Option<u32>,
59}
60
61#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
65pub struct TextDetPredictorConfig {
66 pub common: CommonBuilderConfig,
68 pub limit_side_len: Option<u32>,
70 pub limit_type: Option<LimitType>,
72 pub thresh: Option<f32>,
74 pub box_thresh: Option<f32>,
76 pub unclip_ratio: Option<f32>,
78 pub input_shape: Option<(u32, u32, u32)>,
80 pub max_side_limit: Option<u32>,
82}
83
84impl_config_new_and_with_common!(
85 TextDetPredictorConfig,
86 common_defaults: (None, Some(DEFAULT_BATCH_SIZE)),
87 fields: {
88 limit_side_len: None,
89 limit_type: None,
90 thresh: None,
91 box_thresh: None,
92 unclip_ratio: None,
93 input_shape: None,
94 max_side_limit: Some(DEFAULT_MAX_SIDE_LIMIT)
95 }
96);
97
98impl TextDetPredictorConfig {
99 pub fn validate(&self) -> Result<(), crate::core::ConfigError> {
104 ConfigValidator::validate(self)
105 }
106}
107
108impl ConfigValidator for TextDetPredictorConfig {
109 fn validate(&self) -> Result<(), crate::core::ConfigError> {
110 self.common.validate()?;
111
112 if let Some(thresh) = self.thresh {
113 self.validate_f32_range(thresh, 0.0, 1.0, "threshold")?;
114 }
115
116 if let Some(box_thresh) = self.box_thresh {
117 self.validate_f32_range(box_thresh, 0.0, 1.0, "box threshold")?;
118 }
119
120 if let Some(unclip_ratio) = self.unclip_ratio {
121 self.validate_positive_f32(unclip_ratio, "unclip ratio")?;
122 }
123
124 if let Some(max_side_limit) = self.max_side_limit {
125 self.validate_positive_usize(max_side_limit as usize, "max side limit")?;
126 }
127
128 if let Some(limit_side_len) = self.limit_side_len {
129 self.validate_positive_usize(limit_side_len as usize, "limit side length")?;
130 }
131
132 if let Some((c, h, w)) = self.input_shape
133 && (c == 0 || h == 0 || w == 0)
134 {
135 return Err(crate::core::ConfigError::InvalidConfig {
136 message: format!(
137 "Input shape dimensions must be greater than 0, got ({c}, {h}, {w})"
138 ),
139 });
140 }
141
142 Ok(())
143 }
144
145 fn get_defaults() -> Self {
146 Self {
147 common: CommonBuilderConfig::get_defaults(),
148 limit_side_len: Some(960),
149 limit_type: Some(LimitType::Max),
150 thresh: Some(DEFAULT_THRESH),
151 box_thresh: Some(DEFAULT_BOX_THRESH),
152 unclip_ratio: Some(DEFAULT_UNCLIP_RATIO),
153 input_shape: Some((3, 640, 640)),
154 max_side_limit: Some(DEFAULT_MAX_SIDE_LIMIT),
155 }
156 }
157}
158
159#[derive(Debug, Clone)]
163pub struct TextDetResult {
164 pub input_path: Vec<Arc<str>>,
166 pub index: Vec<usize>,
168 pub input_img: Vec<Arc<RgbImage>>,
170 pub dt_polys: Vec<Vec<BoundingBox>>,
172 pub dt_scores: Vec<Vec<f32>>,
174}
175
176impl TextDetResult {
177 pub fn new() -> Self {
182 Self {
183 input_path: Vec::new(),
184 index: Vec::new(),
185 input_img: Vec::new(),
186 dt_polys: Vec::new(),
187 dt_scores: Vec::new(),
188 }
189 }
190}
191
192impl fmt::Display for TextDetResult {
193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 for (i, ((path, polys), scores)) in self
195 .input_path
196 .iter()
197 .zip(self.dt_polys.iter())
198 .zip(self.dt_scores.iter())
199 .enumerate()
200 {
201 writeln!(f, "Image {} of {}: {}", i + 1, self.input_path.len(), path)?;
202 writeln!(f, " Total regions: {}", polys.len())?;
203
204 if !polys.is_empty() {
205 writeln!(f, " Detection polygons:")?;
206 for (j, (bbox, &score)) in polys.iter().zip(scores.iter()).enumerate() {
207 if bbox.points.is_empty() {
208 writeln!(f, " Region {j}: [] (empty, score: {score:.3})")?;
209 continue;
210 }
211
212 write!(f, " Region {j}: [")?;
213 for (k, point) in bbox.points.iter().enumerate() {
214 if k == 0 {
215 write!(f, "[{:.0}, {:.0}]", point.x, point.y)?;
216 } else {
217 write!(f, ", [{:.0}, {:.0}]", point.x, point.y)?;
218 }
219 }
220 writeln!(f, "] (score: {score:.3})")?;
221 }
222 }
223
224 if i < self.input_path.len() - 1 {
225 writeln!(f)?;
226 }
227 }
228
229 Ok(())
230 }
231}
232
233impl Default for TextDetResult {
234 fn default() -> Self {
235 Self::new()
236 }
237}
238
239pub type TextDetPredictor =
244 ModularPredictor<TDImageReader, TDPreprocessor, TDOrtInfer, TDPostprocessor>;
245
246#[derive(Debug)]
247pub struct TDImageReader {
248 inner: DefaultImageReader,
249}
250impl TDImageReader {
251 pub fn new() -> Self {
252 Self {
253 inner: DefaultImageReader::new(),
254 }
255 }
256}
257impl Default for TDImageReader {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262impl GIReader for TDImageReader {
263 fn read_images<'a>(
264 &self,
265 paths: impl Iterator<Item = &'a str>,
266 ) -> Result<Vec<RgbImage>, OCRError> {
267 self.inner.apply(paths)
268 }
269}
270
271#[derive(Debug)]
272pub struct TDPreprocessor {
273 pub resize: DetResizeForTest,
274 pub normalize: NormalizeImage,
275 pub default_config: TextDetConfig,
277}
278#[derive(Debug)]
279pub struct TextDetPreprocessOutput {
280 pub tensor: Tensor4D,
281 pub shapes: Vec<[f32; 4]>,
282}
283impl GPreprocessor for TDPreprocessor {
284 type Config = TextDetConfig;
285 type Output = TextDetPreprocessOutput;
286 fn preprocess(
287 &self,
288 images: Vec<RgbImage>,
289 config: Option<&Self::Config>,
290 ) -> crate::core::OcrResult<Self::Output> {
291 let merged = match config {
293 Some(runtime_config) => TextDetConfig {
294 limit_side_len: runtime_config
295 .limit_side_len
296 .or(self.default_config.limit_side_len),
297 limit_type: runtime_config
298 .limit_type
299 .clone()
300 .or(self.default_config.limit_type.clone()),
301 thresh: runtime_config.thresh.or(self.default_config.thresh),
302 box_thresh: runtime_config.box_thresh.or(self.default_config.box_thresh),
303 unclip_ratio: runtime_config
304 .unclip_ratio
305 .or(self.default_config.unclip_ratio),
306 max_side_limit: runtime_config
307 .max_side_limit
308 .or(self.default_config.max_side_limit),
309 },
310 None => self.default_config.clone(),
311 };
312
313 let limit_side_len = merged
314 .limit_side_len
315 .unwrap_or(self.resize.limit_side_len.unwrap_or(960));
316 let limit_type = merged
317 .limit_type
318 .unwrap_or(self.resize.limit_type.clone().unwrap_or(LimitType::Min));
319 let max_side_limit = merged.max_side_limit.unwrap_or(self.resize.max_side_limit);
320 let batch_imgs: Vec<DynamicImage> =
321 images.into_iter().map(DynamicImage::ImageRgb8).collect();
322 let (resized_imgs, shapes) = self.resize.apply(
323 batch_imgs,
324 Some(limit_side_len),
325 Some(limit_type.clone()),
326 Some(max_side_limit),
327 );
328 let tensor = self
329 .normalize
330 .normalize_batch_to(resized_imgs)
331 .map_err(|e| {
332 OCRError::model_inference_error(
333 "TextDetection",
334 "preprocessing_normalization",
335 0,
336 &[shapes.len()],
337 "Normalization failed in TDPreprocessor",
338 e,
339 )
340 })?;
341 Ok(TextDetPreprocessOutput { tensor, shapes })
342 }
343}
344
345#[derive(Debug)]
346pub struct TDOrtInfer(pub OrtInfer);
347impl GInferenceEngine for TDOrtInfer {
348 type Input = TextDetPreprocessOutput;
349 type Output = Tensor4D;
350 fn infer(&self, input: &Self::Input) -> Result<Self::Output, OCRError> {
351 self.0.infer_4d(&input.tensor)
353 }
354 fn engine_info(&self) -> String {
355 "ONNXRuntime-4D".to_string()
356 }
357}
358
359#[derive(Debug)]
360pub struct TDPostprocessor {
361 pub op: DBPostProcess,
362 pub default_config: TextDetConfig,
364}
365impl GPostprocessor for TDPostprocessor {
366 type Config = TextDetConfig;
367 type InferenceOutput = Tensor4D;
368 type PreprocessOutput = TextDetPreprocessOutput;
369 type Result = TextDetResult;
370 fn postprocess(
371 &self,
372 output: Self::InferenceOutput,
373 pre: Option<&Self::PreprocessOutput>,
374 batch_data: &BatchData,
375 raw_images: Vec<RgbImage>,
376 config: Option<&Self::Config>,
377 ) -> crate::core::OcrResult<Self::Result> {
378 let merged = match config {
380 Some(runtime_config) => TextDetConfig {
381 limit_side_len: runtime_config
382 .limit_side_len
383 .or(self.default_config.limit_side_len),
384 limit_type: runtime_config
385 .limit_type
386 .clone()
387 .or(self.default_config.limit_type.clone()),
388 thresh: runtime_config.thresh.or(self.default_config.thresh),
389 box_thresh: runtime_config.box_thresh.or(self.default_config.box_thresh),
390 unclip_ratio: runtime_config
391 .unclip_ratio
392 .or(self.default_config.unclip_ratio),
393 max_side_limit: runtime_config
394 .max_side_limit
395 .or(self.default_config.max_side_limit),
396 },
397 None => self.default_config.clone(),
398 };
399
400 let thresh = merged.thresh.unwrap_or(DEFAULT_THRESH);
401 let box_thresh = merged.box_thresh.unwrap_or(DEFAULT_BOX_THRESH);
402 let unclip_ratio = merged.unclip_ratio.unwrap_or(DEFAULT_UNCLIP_RATIO);
403 let shapes = pre.map(|p| p.shapes.clone()).unwrap_or_default();
404 let (polys, scores) = self.op.apply(
405 &output,
406 shapes,
407 Some(thresh),
408 Some(box_thresh),
409 Some(unclip_ratio),
410 );
411 Ok(TextDetResult {
412 input_path: batch_data.input_paths.clone(),
413 index: batch_data.indexes.clone(),
414 input_img: raw_images.into_iter().map(Arc::new).collect(),
415 dt_polys: polys,
416 dt_scores: scores,
417 })
418 }
419 fn empty_result(&self) -> crate::core::OcrResult<Self::Result> {
420 Ok(TextDetResult::new())
421 }
422}
423
424pub struct TextDetPredictorBuilder {
428 common: CommonBuilderConfig,
430
431 limit_side_len: Option<u32>,
433 limit_type: Option<LimitType>,
435 thresh: Option<f32>,
437 box_thresh: Option<f32>,
439 unclip_ratio: Option<f32>,
441 input_shape: Option<(u32, u32, u32)>,
443 max_side_limit: Option<u32>,
445}
446
447impl_common_builder_methods!(TextDetPredictorBuilder, common);
448
449impl TextDetPredictorBuilder {
450 pub fn new() -> Self {
454 Self {
455 common: CommonBuilderConfig::new(),
456 limit_side_len: None,
457 limit_type: None,
458 thresh: None,
459 box_thresh: None,
460 unclip_ratio: None,
461 input_shape: None,
462 max_side_limit: None,
463 }
464 }
465
466 pub fn limit_side_len(mut self, limit_side_len: u32) -> Self {
470 self.limit_side_len = Some(limit_side_len);
471 self
472 }
473
474 pub fn limit_type(mut self, limit_type: LimitType) -> Self {
479 self.limit_type = Some(limit_type);
480 self
481 }
482
483 pub fn thresh(mut self, thresh: f32) -> Self {
487 self.thresh = Some(thresh);
488 self
489 }
490
491 pub fn box_thresh(mut self, box_thresh: f32) -> Self {
495 self.box_thresh = Some(box_thresh);
496 self
497 }
498
499 pub fn unclip_ratio(mut self, unclip_ratio: f32) -> Self {
503 self.unclip_ratio = Some(unclip_ratio);
504 self
505 }
506
507 pub fn input_shape(mut self, input_shape: (u32, u32, u32)) -> Self {
511 self.input_shape = Some(input_shape);
512 self
513 }
514
515 pub fn max_side_limit(mut self, max_side_limit: u32) -> Self {
519 self.max_side_limit = Some(max_side_limit);
520 self
521 }
522
523 pub fn build(self, model_path: &Path) -> Result<TextDetPredictor, OCRError> {
527 self.build_internal(model_path)
528 }
529
530 fn build_internal(mut self, model_path: &Path) -> Result<TextDetPredictor, OCRError> {
535 if self.common.model_path.is_none() {
536 self.common = self.common.model_path(model_path.to_path_buf());
537 }
538
539 let config = TextDetPredictorConfig {
540 common: self.common,
541 limit_side_len: self.limit_side_len,
542 limit_type: self.limit_type,
543 thresh: self.thresh,
544 box_thresh: self.box_thresh,
545 unclip_ratio: self.unclip_ratio,
546 input_shape: self.input_shape,
547 max_side_limit: self.max_side_limit,
548 };
549 let config = config.validate_and_wrap_ocr_error()?;
550
551 let (default_limit_side_len, default_limit_type) =
553 if let Some(model_name) = &config.common.model_name {
554 match model_name.as_str() {
555 "PP-OCRv5_server_det"
556 | "PP-OCRv5_mobile_det"
557 | "PP-OCRv4_server_det"
558 | "PP-OCRv4_mobile_det"
559 | "PP-OCRv3_server_det"
560 | "PP-OCRv3_mobile_det" => (960, LimitType::Max),
561 _ => (736, LimitType::Min),
562 }
563 } else {
564 (736, LimitType::Min)
565 };
566
567 let limit_side_len = config.limit_side_len.unwrap_or(default_limit_side_len);
568 let limit_type = config.limit_type.clone().unwrap_or(default_limit_type);
569 let max_side_limit = config.max_side_limit.unwrap_or(DEFAULT_MAX_SIDE_LIMIT);
570
571 let default_config = TextDetConfig {
573 limit_side_len: Some(limit_side_len),
574 limit_type: Some(limit_type.clone()),
575 thresh: config.thresh,
576 box_thresh: config.box_thresh,
577 unclip_ratio: config.unclip_ratio,
578 max_side_limit: Some(max_side_limit),
579 };
580
581 let image_reader = TDImageReader::new();
583 let resize = DetResizeForTest::new(
584 config.input_shape,
585 None,
586 None,
587 Some(limit_side_len),
588 Some(limit_type.clone()),
589 None,
590 Some(max_side_limit),
591 );
592 let normalize = NormalizeImage::new(None, None, None, None)?;
593 let preprocessor = TDPreprocessor {
594 resize,
595 normalize,
596 default_config: default_config.clone(),
597 };
598 let infer = OrtInfer::from_common(&config.common, model_path, None)?;
599 let inference_engine = TDOrtInfer(infer);
600 let post_op = DBPostProcess::new(None, None, None, None, None, None, None);
601 let postprocessor = TDPostprocessor {
602 op: post_op,
603 default_config,
604 };
605
606 Ok(ModularPredictor::new(
607 image_reader,
608 preprocessor,
609 inference_engine,
610 postprocessor,
611 ))
612 }
613}
614
615impl Default for TextDetPredictorBuilder {
616 fn default() -> Self {
617 Self::new()
618 }
619}
620
621#[cfg(test)]
622mod tests_local {
623 use super::*;
624
625 #[test]
626 fn test_text_det_config_defaults_and_validate() {
627 let config = TextDetPredictorConfig::new();
628 assert_eq!(config.max_side_limit, Some(DEFAULT_MAX_SIDE_LIMIT));
630 assert!(config.validate().is_ok());
631 }
632}