1use crate::core::traits::ImageReader as CoreImageReader;
9use crate::core::{
10 BatchData, CommonBuilderConfig, DefaultImageReader, OCRError, OrtInfer, Tensor4D,
11 config::{ConfigValidator, ConfigValidatorExt},
12};
13use crate::core::{
14 GranularImageReader as GIReader, ModularPredictor, OrtInfer4D, Postprocessor as GPostprocessor,
15 Preprocessor as GPreprocessor,
16};
17use crate::processors::{DocTrPostProcess, NormalizeImage};
18
19use image::{DynamicImage, RgbImage};
20use std::path::Path;
21use std::sync::Arc;
22
23use crate::impl_config_new_and_with_common;
24
25#[derive(Debug, Clone)]
30pub struct DoctrRectifierResult {
31 pub input_path: Vec<Arc<str>>,
33 pub index: Vec<usize>,
35 pub input_img: Vec<Arc<RgbImage>>,
37 pub rectified_img: Vec<Arc<RgbImage>>,
39}
40
41#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
46pub struct DoctrRectifierPredictorConfig {
47 pub common: CommonBuilderConfig,
49 pub rec_image_shape: Option<[usize; 3]>,
51}
52
53impl_config_new_and_with_common!(
54 DoctrRectifierPredictorConfig,
55 common_defaults: (Some("doctr_rectifier".to_string()), Some(32)),
56 fields: {
57 rec_image_shape: Some([3, 512, 512])
58 }
59);
60
61impl DoctrRectifierPredictorConfig {
62 pub fn validate(&self) -> Result<(), crate::core::ConfigError> {
70 ConfigValidator::validate(self)
71 }
72}
73
74impl ConfigValidator for DoctrRectifierPredictorConfig {
75 fn validate(&self) -> Result<(), crate::core::ConfigError> {
76 self.common.validate()?;
77
78 if let Some(rec_shape) = self.rec_image_shape {
79 if rec_shape[0] == 0 || rec_shape[1] == 0 || rec_shape[2] == 0 {
80 return Err(crate::core::ConfigError::InvalidConfig {
81 message: format!(
82 "Recognition image shape dimensions must be greater than 0, got [{}, {}, {}]",
83 rec_shape[0], rec_shape[1], rec_shape[2]
84 ),
85 });
86 }
87
88 const MAX_SHAPE_SIZE: usize = 2048;
89 for (i, &dim) in rec_shape.iter().enumerate() {
90 if dim > MAX_SHAPE_SIZE {
91 return Err(crate::core::ConfigError::ResourceLimitExceeded {
92 message: format!(
93 "Recognition image shape dimension {i} ({dim}) exceeds maximum allowed size {MAX_SHAPE_SIZE}"
94 ),
95 });
96 }
97 }
98 }
99
100 Ok(())
101 }
102
103 fn get_defaults() -> Self {
104 Self {
105 common: CommonBuilderConfig::get_defaults(),
106 rec_image_shape: Some([3, 512, 512]),
107 }
108 }
109}
110
111impl DoctrRectifierResult {
112 pub fn new() -> Self {
121 Self {
122 input_path: Vec::new(),
123 index: Vec::new(),
124 input_img: Vec::new(),
125 rectified_img: Vec::new(),
126 }
127 }
128}
129
130impl Default for DoctrRectifierResult {
131 fn default() -> Self {
139 Self::new()
140 }
141}
142
143pub type DoctrRectifierPredictor =
148 ModularPredictor<DRImageReader, DRPreprocessor, OrtInfer4D, DRPostprocessor>;
149
150#[derive(Debug)]
151pub struct DRImageReader {
152 inner: DefaultImageReader,
153}
154impl DRImageReader {
155 pub fn new() -> Self {
156 Self {
157 inner: DefaultImageReader::new(),
158 }
159 }
160}
161impl Default for DRImageReader {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166impl GIReader for DRImageReader {
167 fn read_images<'a>(
168 &self,
169 paths: impl Iterator<Item = &'a str>,
170 ) -> Result<Vec<RgbImage>, OCRError> {
171 self.inner.apply(paths)
172 }
173}
174
175#[derive(Debug)]
176pub struct DRPreprocessor {
177 pub normalize: NormalizeImage,
178}
179impl GPreprocessor for DRPreprocessor {
180 type Config = DoctrRectifierConfig;
181 type Output = Tensor4D;
182 fn preprocess(
183 &self,
184 images: Vec<RgbImage>,
185 _config: Option<&Self::Config>,
186 ) -> Result<Self::Output, OCRError> {
187 let batch_imgs: Vec<DynamicImage> =
188 images.into_iter().map(DynamicImage::ImageRgb8).collect();
189 self.normalize.normalize_batch_to(batch_imgs)
190 }
191}
192
193#[derive(Debug)]
194pub struct DRPostprocessor {
195 pub op: DocTrPostProcess,
196}
197impl GPostprocessor for DRPostprocessor {
198 type Config = DoctrRectifierConfig;
199 type InferenceOutput = Tensor4D;
200 type PreprocessOutput = Tensor4D;
201 type Result = DoctrRectifierResult;
202 fn postprocess(
203 &self,
204 output: Self::InferenceOutput,
205 _pre: Option<&Self::PreprocessOutput>,
206 batch_data: &BatchData,
207 raw_images: Vec<RgbImage>,
208 _config: Option<&Self::Config>,
209 ) -> crate::core::OcrResult<Self::Result> {
210 let rectified_imgs = self
211 .op
212 .apply_batch(&output)
213 .map_err(|e| OCRError::ConfigError {
214 message: format!("DocTr post-processing failed: {}", e),
215 })?;
216 Ok(DoctrRectifierResult {
217 input_path: batch_data.input_paths.clone(),
218 index: batch_data.indexes.clone(),
219 input_img: raw_images.into_iter().map(Arc::new).collect(),
220 rectified_img: rectified_imgs.into_iter().map(Arc::new).collect(),
221 })
222 }
223 fn empty_result(&self) -> crate::core::OcrResult<Self::Result> {
224 Ok(DoctrRectifierResult::new())
225 }
226}
227
228#[derive(Debug, Clone)]
234pub struct DoctrRectifierConfig;
235
236pub struct DoctrRectifierPredictorBuilder {
241 common: CommonBuilderConfig,
243
244 rec_image_shape: Option<[usize; 3]>,
246}
247
248crate::impl_common_builder_methods!(DoctrRectifierPredictorBuilder, common);
249
250impl DoctrRectifierPredictorBuilder {
251 pub fn new() -> Self {
260 Self {
261 common: CommonBuilderConfig::new(),
262 rec_image_shape: None,
263 }
264 }
265
266 pub fn rec_image_shape(mut self, rec_image_shape: [usize; 3]) -> Self {
278 self.rec_image_shape = Some(rec_image_shape);
279 self
280 }
281
282 pub fn build(self, model_path: &Path) -> Result<DoctrRectifierPredictor, OCRError> {
295 self.build_internal(model_path)
296 }
297
298 fn build_internal(mut self, model_path: &Path) -> Result<DoctrRectifierPredictor, OCRError> {
312 if self.common.model_path.is_none() {
313 self.common = self.common.model_path(model_path.to_path_buf());
314 }
315
316 let config = DoctrRectifierPredictorConfig {
317 common: self.common,
318 rec_image_shape: self.rec_image_shape,
319 };
320
321 let config = config.validate_and_wrap_ocr_error()?;
322
323 let image_reader = DRImageReader::new();
325 let normalize = NormalizeImage::new(
326 Some(1.0 / 255.0),
327 Some(vec![0.0, 0.0, 0.0]),
328 Some(vec![1.0, 1.0, 1.0]),
329 None,
330 )?;
331 let preprocessor = DRPreprocessor { normalize };
332 let infer = OrtInfer::from_common_with_auto_input(&config.common, model_path)?;
333 let inference_engine = OrtInfer4D::new(infer);
334 let postprocessor = DRPostprocessor {
335 op: DocTrPostProcess::new(1.0),
336 };
337
338 Ok(ModularPredictor::new(
339 image_reader,
340 preprocessor,
341 inference_engine,
342 postprocessor,
343 ))
344 }
345}
346
347impl Default for DoctrRectifierPredictorBuilder {
348 fn default() -> Self {
356 Self::new()
357 }
358}