oar_ocr_core/models/rectification/
uvdoc.rs1use crate::core::inference::OrtInfer;
7use crate::core::{OCRError, Tensor4D};
8use crate::processors::{NormalizeImage, TensorLayout, UVDocPostProcess};
9use image::{DynamicImage, RgbImage, imageops::FilterType};
10
11#[derive(Debug, Clone)]
13pub struct UVDocPreprocessConfig {
14 pub rec_image_shape: [usize; 3],
16}
17
18impl Default for UVDocPreprocessConfig {
19 fn default() -> Self {
20 Self {
21 rec_image_shape: [3, 512, 512],
22 }
23 }
24}
25
26#[derive(Debug, Clone)]
28pub struct UVDocModelOutput {
29 pub images: Vec<RgbImage>,
31}
32
33#[derive(Debug)]
37pub struct UVDocModel {
38 inference: OrtInfer,
40 normalizer: NormalizeImage,
42 postprocessor: UVDocPostProcess,
44 rec_image_shape: [usize; 3],
46}
47
48impl UVDocModel {
49 pub fn new(
51 inference: OrtInfer,
52 normalizer: NormalizeImage,
53 postprocessor: UVDocPostProcess,
54 rec_image_shape: [usize; 3],
55 ) -> Self {
56 Self {
57 inference,
58 normalizer,
59 postprocessor,
60 rec_image_shape,
61 }
62 }
63
64 pub fn preprocess(
74 &self,
75 images: Vec<RgbImage>,
76 ) -> Result<(Tensor4D, Vec<(u32, u32)>), OCRError> {
77 let mut original_sizes = Vec::with_capacity(images.len());
78 let mut processed_images = Vec::with_capacity(images.len());
79
80 let target_height = self.rec_image_shape[1] as u32;
81 let target_width = self.rec_image_shape[2] as u32;
82 let should_resize = target_height > 0 && target_width > 0;
83
84 for img in images {
85 let original_size = (img.width(), img.height());
86 original_sizes.push(original_size);
87
88 if should_resize && (img.width() != target_width || img.height() != target_height) {
89 let resized = DynamicImage::ImageRgb8(img).resize_exact(
91 target_width,
92 target_height,
93 FilterType::Triangle,
94 );
95 processed_images.push(resized);
96 } else {
97 processed_images.push(DynamicImage::ImageRgb8(img));
98 }
99 }
100
101 let batch_tensor = self.normalizer.normalize_batch_to(processed_images)?;
103
104 Ok((batch_tensor, original_sizes))
105 }
106
107 pub fn infer(&self, batch_tensor: &Tensor4D) -> Result<Tensor4D, OCRError> {
117 self.inference
118 .infer_4d(batch_tensor)
119 .map_err(|e| OCRError::Inference {
120 model_name: "UVDoc".to_string(),
121 context: format!(
122 "failed to run inference on batch with shape {:?}",
123 batch_tensor.shape()
124 ),
125 source: Box::new(e),
126 })
127 }
128
129 pub fn postprocess(
140 &self,
141 predictions: &Tensor4D,
142 original_sizes: &[(u32, u32)],
143 ) -> Result<Vec<RgbImage>, OCRError> {
144 let mut images =
146 self.postprocessor
147 .apply_batch(predictions)
148 .map_err(|e| OCRError::ConfigError {
149 message: format!("Failed to postprocess rectification output: {}", e),
150 })?;
151
152 if images.len() != original_sizes.len() {
153 return Err(OCRError::InvalidInput {
154 message: format!(
155 "Mismatched rectification batch sizes: predictions={}, originals={}",
156 images.len(),
157 original_sizes.len()
158 ),
159 });
160 }
161
162 for (img, &(orig_w, orig_h)) in images.iter_mut().zip(original_sizes) {
164 if orig_w == 0 || orig_h == 0 {
165 continue;
166 }
167
168 if img.width() != orig_w || img.height() != orig_h {
169 let resized = DynamicImage::ImageRgb8(std::mem::take(img)).resize_exact(
171 orig_w,
172 orig_h,
173 FilterType::Triangle,
174 );
175 *img = resized.into_rgb8();
176 }
177 }
178
179 Ok(images)
180 }
181
182 pub fn forward(&self, images: Vec<RgbImage>) -> Result<UVDocModelOutput, OCRError> {
192 let (batch_tensor, original_sizes) = self.preprocess(images)?;
193 let predictions = self.infer(&batch_tensor)?;
194 let rectified_images = self.postprocess(&predictions, &original_sizes)?;
195
196 Ok(UVDocModelOutput {
197 images: rectified_images,
198 })
199 }
200}
201
202#[derive(Debug, Default)]
204pub struct UVDocModelBuilder {
205 preprocess_config: UVDocPreprocessConfig,
207 ort_config: Option<crate::core::config::OrtSessionConfig>,
209}
210
211impl UVDocModelBuilder {
212 pub fn new() -> Self {
214 Self {
215 preprocess_config: UVDocPreprocessConfig::default(),
216 ort_config: None,
217 }
218 }
219
220 pub fn preprocess_config(mut self, config: UVDocPreprocessConfig) -> Self {
222 self.preprocess_config = config;
223 self
224 }
225
226 pub fn rec_image_shape(mut self, shape: [usize; 3]) -> Self {
228 self.preprocess_config.rec_image_shape = shape;
229 self
230 }
231
232 pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
234 self.ort_config = Some(config);
235 self
236 }
237
238 pub fn build(self, model_path: &std::path::Path) -> Result<UVDocModel, OCRError> {
248 let inference = if self.ort_config.is_some() {
250 use crate::core::config::ModelInferenceConfig;
251 let common_config = ModelInferenceConfig {
252 ort_session: self.ort_config,
253 ..Default::default()
254 };
255 OrtInfer::from_config(&common_config, model_path, Some("image"))?
256 } else {
257 OrtInfer::new(model_path, Some("image"))?
258 };
259
260 let normalizer = NormalizeImage::with_color_order(
264 Some(1.0 / 255.0),
265 Some(vec![0.0, 0.0, 0.0]),
266 Some(vec![1.0, 1.0, 1.0]),
267 Some(TensorLayout::CHW),
268 Some(crate::processors::types::ColorOrder::BGR),
269 )?;
270
271 let postprocessor = UVDocPostProcess::new(255.0);
273
274 Ok(UVDocModel::new(
275 inference,
276 normalizer,
277 postprocessor,
278 self.preprocess_config.rec_image_shape,
279 ))
280 }
281}