oar_ocr_core/models/rectification/
uvdoc.rs1use crate::core::OCRError;
7use crate::core::inference::{OrtInfer, TensorInput};
8use crate::processors::{NormalizeImage, TensorLayout, UVDocPostProcess};
9use image::{DynamicImage, RgbImage, imageops::FilterType};
10
11type PreprocessResult = Result<(ndarray::Array4<f32>, Vec<(u32, u32)>), OCRError>;
12
13#[derive(Debug, Clone)]
15pub struct UVDocPreprocessConfig {
16 pub rec_image_shape: [usize; 3],
18}
19
20impl Default for UVDocPreprocessConfig {
21 fn default() -> Self {
22 Self {
23 rec_image_shape: [3, 512, 512],
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct UVDocModelOutput {
31 pub images: Vec<RgbImage>,
33}
34
35#[derive(Debug)]
39pub struct UVDocModel {
40 inference: OrtInfer,
42 normalizer: NormalizeImage,
44 postprocessor: UVDocPostProcess,
46 rec_image_shape: [usize; 3],
48}
49
50impl UVDocModel {
51 pub fn new(
53 inference: OrtInfer,
54 normalizer: NormalizeImage,
55 postprocessor: UVDocPostProcess,
56 rec_image_shape: [usize; 3],
57 ) -> Self {
58 Self {
59 inference,
60 normalizer,
61 postprocessor,
62 rec_image_shape,
63 }
64 }
65
66 pub fn preprocess(&self, images: Vec<RgbImage>) -> PreprocessResult {
76 let mut original_sizes = Vec::with_capacity(images.len());
77 let mut processed_images = Vec::with_capacity(images.len());
78
79 let target_height = self.rec_image_shape[1] as u32;
80 let target_width = self.rec_image_shape[2] as u32;
81 let should_resize = target_height > 0 && target_width > 0;
82
83 for img in images {
84 let original_size = (img.width(), img.height());
85 original_sizes.push(original_size);
86
87 if should_resize && (img.width() != target_width || img.height() != target_height) {
88 let resized = DynamicImage::ImageRgb8(img).resize_exact(
90 target_width,
91 target_height,
92 FilterType::Triangle,
93 );
94 processed_images.push(resized);
95 } else {
96 processed_images.push(DynamicImage::ImageRgb8(img));
97 }
98 }
99
100 let batch_tensor = self.normalizer.normalize_batch_to(processed_images)?;
102
103 Ok((batch_tensor, original_sizes))
104 }
105
106 pub fn infer(
116 &self,
117 batch_tensor: &ndarray::Array4<f32>,
118 ) -> Result<ndarray::Array4<f32>, OCRError> {
119 let input_name = self.inference.input_name();
120 let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
121
122 let outputs = self
123 .inference
124 .infer(&inputs)
125 .map_err(|e| OCRError::Inference {
126 model_name: "UVDoc".to_string(),
127 context: format!(
128 "failed to run inference on batch with shape {:?}",
129 batch_tensor.shape()
130 ),
131 source: Box::new(e),
132 })?;
133
134 let output = outputs
135 .into_iter()
136 .next()
137 .ok_or_else(|| OCRError::InvalidInput {
138 message: "UVDoc: no output returned from inference".to_string(),
139 })?;
140
141 output
142 .1
143 .try_into_array4_f32()
144 .map_err(|e| OCRError::Inference {
145 model_name: "UVDoc".to_string(),
146 context: "failed to convert output to 4D array".to_string(),
147 source: Box::new(e),
148 })
149 }
150
151 pub fn postprocess(
162 &self,
163 predictions: &ndarray::Array4<f32>,
164 original_sizes: &[(u32, u32)],
165 ) -> Result<Vec<RgbImage>, OCRError> {
166 let mut images =
168 self.postprocessor
169 .apply_batch(predictions)
170 .map_err(|e| OCRError::ConfigError {
171 message: format!("Failed to postprocess rectification output: {}", e),
172 })?;
173
174 if images.len() != original_sizes.len() {
175 return Err(OCRError::InvalidInput {
176 message: format!(
177 "Mismatched rectification batch sizes: predictions={}, originals={}",
178 images.len(),
179 original_sizes.len()
180 ),
181 });
182 }
183
184 for (img, &(orig_w, orig_h)) in images.iter_mut().zip(original_sizes) {
186 if orig_w == 0 || orig_h == 0 {
187 continue;
188 }
189
190 if img.width() != orig_w || img.height() != orig_h {
191 let resized = DynamicImage::ImageRgb8(std::mem::take(img)).resize_exact(
193 orig_w,
194 orig_h,
195 FilterType::Triangle,
196 );
197 *img = resized.into_rgb8();
198 }
199 }
200
201 Ok(images)
202 }
203
204 pub fn forward(&self, images: Vec<RgbImage>) -> Result<UVDocModelOutput, OCRError> {
214 let (batch_tensor, original_sizes) = self.preprocess(images)?;
215 let predictions = self.infer(&batch_tensor)?;
216 let rectified_images = self.postprocess(&predictions, &original_sizes)?;
217
218 Ok(UVDocModelOutput {
219 images: rectified_images,
220 })
221 }
222}
223
224#[derive(Debug, Default)]
226pub struct UVDocModelBuilder {
227 preprocess_config: UVDocPreprocessConfig,
229 ort_config: Option<crate::core::config::OrtSessionConfig>,
231}
232
233impl UVDocModelBuilder {
234 pub fn new() -> Self {
236 Self {
237 preprocess_config: UVDocPreprocessConfig::default(),
238 ort_config: None,
239 }
240 }
241
242 pub fn preprocess_config(mut self, config: UVDocPreprocessConfig) -> Self {
244 self.preprocess_config = config;
245 self
246 }
247
248 pub fn rec_image_shape(mut self, shape: [usize; 3]) -> Self {
250 self.preprocess_config.rec_image_shape = shape;
251 self
252 }
253
254 pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
256 self.ort_config = Some(config);
257 self
258 }
259
260 pub fn build(self, model_path: &std::path::Path) -> Result<UVDocModel, OCRError> {
270 let inference = if self.ort_config.is_some() {
272 use crate::core::config::ModelInferenceConfig;
273 let common_config = ModelInferenceConfig {
274 ort_session: self.ort_config,
275 ..Default::default()
276 };
277 OrtInfer::from_config(&common_config, model_path, Some("image"))?
278 } else {
279 OrtInfer::new(model_path, Some("image"))?
280 };
281
282 let normalizer = NormalizeImage::with_color_order(
286 Some(1.0 / 255.0),
287 Some(vec![0.0, 0.0, 0.0]),
288 Some(vec![1.0, 1.0, 1.0]),
289 Some(TensorLayout::CHW),
290 Some(crate::processors::types::ColorOrder::BGR),
291 )?;
292
293 let postprocessor = UVDocPostProcess::new(255.0);
295
296 Ok(UVDocModel::new(
297 inference,
298 normalizer,
299 postprocessor,
300 self.preprocess_config.rec_image_shape,
301 ))
302 }
303}