Skip to main content

oar_ocr_core/models/rectification/
uvdoc.rs

1//! UVDoc Document Rectification Model
2//!
3//! This module provides a pure implementation of the UVDoc model for document rectification.
4//! The model takes distorted document images and outputs rectified (flattened) versions.
5
6use crate::core::OCRError;
7use crate::core::inference::{OrtInfer, TensorInput};
8use crate::processors::{NormalizeImage, TensorLayout, UVDocPostProcess};
9use image::{DynamicImage, RgbImage, imageops::FilterType};
10
11type PreprocessResult = Result<(ndarray::Array4<f32>, Vec<(u32, u32)>), OCRError>;
12
13/// Configuration for UVDoc model preprocessing.
14#[derive(Debug, Clone)]
15pub struct UVDocPreprocessConfig {
16    /// Input shape [channels, height, width]
17    pub rec_image_shape: [usize; 3],
18}
19
20impl Default for UVDocPreprocessConfig {
21    fn default() -> Self {
22        Self {
23            rec_image_shape: [3, 512, 512],
24        }
25    }
26}
27
28/// Output from UVDoc model.
29#[derive(Debug, Clone)]
30pub struct UVDocModelOutput {
31    /// Rectified images
32    pub images: Vec<RgbImage>,
33}
34
35/// Pure UVDoc model implementation.
36///
37/// This model performs document rectification (unwarping) on distorted document images.
38#[derive(Debug)]
39pub struct UVDocModel {
40    /// ONNX Runtime inference engine
41    inference: OrtInfer,
42    /// Image normalizer for preprocessing
43    normalizer: NormalizeImage,
44    /// UVDoc postprocessor for converting tensor to images
45    postprocessor: UVDocPostProcess,
46    /// Input shape [channels, height, width]
47    rec_image_shape: [usize; 3],
48}
49
50impl UVDocModel {
51    /// Creates a new UVDoc model.
52    pub fn new(
53        inference: OrtInfer,
54        normalizer: NormalizeImage,
55        postprocessor: UVDocPostProcess,
56        rec_image_shape: [usize; 3],
57    ) -> Self {
58        Self {
59            inference,
60            normalizer,
61            postprocessor,
62            rec_image_shape,
63        }
64    }
65
66    /// Preprocesses images for rectification.
67    ///
68    /// # Arguments
69    ///
70    /// * `images` - Input images to preprocess
71    ///
72    /// # Returns
73    ///
74    /// A tuple of (batch_tensor, original_sizes)
75    pub fn preprocess(&self, images: Vec<RgbImage>) -> PreprocessResult {
76        let mut original_sizes = Vec::with_capacity(images.len());
77        let mut processed_images = Vec::with_capacity(images.len());
78
79        let target_height = self.rec_image_shape[1] as u32;
80        let target_width = self.rec_image_shape[2] as u32;
81        let should_resize = target_height > 0 && target_width > 0;
82
83        for img in images {
84            let original_size = (img.width(), img.height());
85            original_sizes.push(original_size);
86
87            if should_resize && (img.width() != target_width || img.height() != target_height) {
88                // Use cv2.INTER_LINEAR for UVDoc resize.
89                let resized = DynamicImage::ImageRgb8(img).resize_exact(
90                    target_width,
91                    target_height,
92                    FilterType::Triangle,
93                );
94                processed_images.push(resized);
95            } else {
96                processed_images.push(DynamicImage::ImageRgb8(img));
97            }
98        }
99
100        // Normalize and convert to tensor
101        let batch_tensor = self.normalizer.normalize_batch_to(processed_images)?;
102
103        Ok((batch_tensor, original_sizes))
104    }
105
106    /// Runs inference on the preprocessed batch.
107    ///
108    /// # Arguments
109    ///
110    /// * `batch_tensor` - Preprocessed batch tensor
111    ///
112    /// # Returns
113    ///
114    /// Model predictions as a 4D tensor
115    pub fn infer(
116        &self,
117        batch_tensor: &ndarray::Array4<f32>,
118    ) -> Result<ndarray::Array4<f32>, OCRError> {
119        let input_name = self.inference.input_name();
120        let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
121
122        let outputs = self
123            .inference
124            .infer(&inputs)
125            .map_err(|e| OCRError::Inference {
126                model_name: "UVDoc".to_string(),
127                context: format!(
128                    "failed to run inference on batch with shape {:?}",
129                    batch_tensor.shape()
130                ),
131                source: Box::new(e),
132            })?;
133
134        let output = outputs
135            .into_iter()
136            .next()
137            .ok_or_else(|| OCRError::InvalidInput {
138                message: "UVDoc: no output returned from inference".to_string(),
139            })?;
140
141        output
142            .1
143            .try_into_array4_f32()
144            .map_err(|e| OCRError::Inference {
145                model_name: "UVDoc".to_string(),
146                context: "failed to convert output to 4D array".to_string(),
147                source: Box::new(e),
148            })
149    }
150
151    /// Postprocesses model predictions to rectified images.
152    ///
153    /// # Arguments
154    ///
155    /// * `predictions` - Model predictions
156    /// * `original_sizes` - Original image sizes (width, height)
157    ///
158    /// # Returns
159    ///
160    /// Rectified images resized to original dimensions
161    pub fn postprocess(
162        &self,
163        predictions: &ndarray::Array4<f32>,
164        original_sizes: &[(u32, u32)],
165    ) -> Result<Vec<RgbImage>, OCRError> {
166        // Use UVDocPostProcess to convert tensor to images
167        let mut images =
168            self.postprocessor
169                .apply_batch(predictions)
170                .map_err(|e| OCRError::ConfigError {
171                    message: format!("Failed to postprocess rectification output: {}", e),
172                })?;
173
174        if images.len() != original_sizes.len() {
175            return Err(OCRError::InvalidInput {
176                message: format!(
177                    "Mismatched rectification batch sizes: predictions={}, originals={}",
178                    images.len(),
179                    original_sizes.len()
180                ),
181            });
182        }
183
184        // Resize back to original dimensions
185        for (img, &(orig_w, orig_h)) in images.iter_mut().zip(original_sizes) {
186            if orig_w == 0 || orig_h == 0 {
187                continue;
188            }
189
190            if img.width() != orig_w || img.height() != orig_h {
191                // Use cv2.INTER_LINEAR for resizing outputs back to original size.
192                let resized = DynamicImage::ImageRgb8(std::mem::take(img)).resize_exact(
193                    orig_w,
194                    orig_h,
195                    FilterType::Triangle,
196                );
197                *img = resized.into_rgb8();
198            }
199        }
200
201        Ok(images)
202    }
203
204    /// Performs complete forward pass: preprocess -> infer -> postprocess.
205    ///
206    /// # Arguments
207    ///
208    /// * `images` - Input images to rectify
209    ///
210    /// # Returns
211    ///
212    /// UVDocModelOutput containing rectified images
213    pub fn forward(&self, images: Vec<RgbImage>) -> Result<UVDocModelOutput, OCRError> {
214        let (batch_tensor, original_sizes) = self.preprocess(images)?;
215        let predictions = self.infer(&batch_tensor)?;
216        let rectified_images = self.postprocess(&predictions, &original_sizes)?;
217
218        Ok(UVDocModelOutput {
219            images: rectified_images,
220        })
221    }
222}
223
224/// Builder for UVDoc model.
225#[derive(Debug, Default)]
226pub struct UVDocModelBuilder {
227    /// Preprocessing configuration
228    preprocess_config: UVDocPreprocessConfig,
229    /// ONNX Runtime session configuration
230    ort_config: Option<crate::core::config::OrtSessionConfig>,
231}
232
233impl UVDocModelBuilder {
234    /// Creates a new UVDoc model builder.
235    pub fn new() -> Self {
236        Self {
237            preprocess_config: UVDocPreprocessConfig::default(),
238            ort_config: None,
239        }
240    }
241
242    /// Sets the preprocessing configuration.
243    pub fn preprocess_config(mut self, config: UVDocPreprocessConfig) -> Self {
244        self.preprocess_config = config;
245        self
246    }
247
248    /// Sets the input image shape.
249    pub fn rec_image_shape(mut self, shape: [usize; 3]) -> Self {
250        self.preprocess_config.rec_image_shape = shape;
251        self
252    }
253
254    /// Sets the ONNX Runtime session configuration.
255    pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
256        self.ort_config = Some(config);
257        self
258    }
259
260    /// Builds the UVDoc model.
261    ///
262    /// # Arguments
263    ///
264    /// * `model_path` - Path to the ONNX model file
265    ///
266    /// # Returns
267    ///
268    /// A configured UVDoc model instance
269    pub fn build(self, model_path: &std::path::Path) -> Result<UVDocModel, OCRError> {
270        // Create ONNX inference engine
271        let inference = if self.ort_config.is_some() {
272            use crate::core::config::ModelInferenceConfig;
273            let common_config = ModelInferenceConfig {
274                ort_session: self.ort_config,
275                ..Default::default()
276            };
277            OrtInfer::from_config(&common_config, model_path, Some("image"))?
278        } else {
279            OrtInfer::new(model_path, Some("image"))?
280        };
281
282        // Create normalizer (scale to [0, 1] without mean shift).
283        // Images are read in BGR and UVDoc models are trained with BGR order,
284        // so keep color order consistent here.
285        let normalizer = NormalizeImage::with_color_order(
286            Some(1.0 / 255.0),
287            Some(vec![0.0, 0.0, 0.0]),
288            Some(vec![1.0, 1.0, 1.0]),
289            Some(TensorLayout::CHW),
290            Some(crate::processors::types::ColorOrder::BGR),
291        )?;
292
293        // Create postprocessor
294        let postprocessor = UVDocPostProcess::new(255.0);
295
296        Ok(UVDocModel::new(
297            inference,
298            normalizer,
299            postprocessor,
300            self.preprocess_config.rec_image_shape,
301        ))
302    }
303}