Skip to main content

oar_ocr_core/models/rectification/
uvdoc.rs

1//! UVDoc Document Rectification Model
2//!
3//! This module provides a pure implementation of the UVDoc model for document rectification.
4//! The model takes distorted document images and outputs rectified (flattened) versions.
5
6use crate::core::inference::OrtInfer;
7use crate::core::{OCRError, Tensor4D};
8use crate::processors::{NormalizeImage, TensorLayout, UVDocPostProcess};
9use image::{DynamicImage, RgbImage, imageops::FilterType};
10
11/// Configuration for UVDoc model preprocessing.
12#[derive(Debug, Clone)]
13pub struct UVDocPreprocessConfig {
14    /// Input shape [channels, height, width]
15    pub rec_image_shape: [usize; 3],
16}
17
18impl Default for UVDocPreprocessConfig {
19    fn default() -> Self {
20        Self {
21            rec_image_shape: [3, 512, 512],
22        }
23    }
24}
25
26/// Output from UVDoc model.
27#[derive(Debug, Clone)]
28pub struct UVDocModelOutput {
29    /// Rectified images
30    pub images: Vec<RgbImage>,
31}
32
33/// Pure UVDoc model implementation.
34///
35/// This model performs document rectification (unwarping) on distorted document images.
36#[derive(Debug)]
37pub struct UVDocModel {
38    /// ONNX Runtime inference engine
39    inference: OrtInfer,
40    /// Image normalizer for preprocessing
41    normalizer: NormalizeImage,
42    /// UVDoc postprocessor for converting tensor to images
43    postprocessor: UVDocPostProcess,
44    /// Input shape [channels, height, width]
45    rec_image_shape: [usize; 3],
46}
47
48impl UVDocModel {
49    /// Creates a new UVDoc model.
50    pub fn new(
51        inference: OrtInfer,
52        normalizer: NormalizeImage,
53        postprocessor: UVDocPostProcess,
54        rec_image_shape: [usize; 3],
55    ) -> Self {
56        Self {
57            inference,
58            normalizer,
59            postprocessor,
60            rec_image_shape,
61        }
62    }
63
64    /// Preprocesses images for rectification.
65    ///
66    /// # Arguments
67    ///
68    /// * `images` - Input images to preprocess
69    ///
70    /// # Returns
71    ///
72    /// A tuple of (batch_tensor, original_sizes)
73    pub fn preprocess(
74        &self,
75        images: Vec<RgbImage>,
76    ) -> Result<(Tensor4D, Vec<(u32, u32)>), OCRError> {
77        let mut original_sizes = Vec::with_capacity(images.len());
78        let mut processed_images = Vec::with_capacity(images.len());
79
80        let target_height = self.rec_image_shape[1] as u32;
81        let target_width = self.rec_image_shape[2] as u32;
82        let should_resize = target_height > 0 && target_width > 0;
83
84        for img in images {
85            let original_size = (img.width(), img.height());
86            original_sizes.push(original_size);
87
88            if should_resize && (img.width() != target_width || img.height() != target_height) {
89                // Use cv2.INTER_LINEAR for UVDoc resize.
90                let resized = DynamicImage::ImageRgb8(img).resize_exact(
91                    target_width,
92                    target_height,
93                    FilterType::Triangle,
94                );
95                processed_images.push(resized);
96            } else {
97                processed_images.push(DynamicImage::ImageRgb8(img));
98            }
99        }
100
101        // Normalize and convert to tensor
102        let batch_tensor = self.normalizer.normalize_batch_to(processed_images)?;
103
104        Ok((batch_tensor, original_sizes))
105    }
106
107    /// Runs inference on the preprocessed batch.
108    ///
109    /// # Arguments
110    ///
111    /// * `batch_tensor` - Preprocessed batch tensor
112    ///
113    /// # Returns
114    ///
115    /// Model predictions as a 4D tensor
116    pub fn infer(&self, batch_tensor: &Tensor4D) -> Result<Tensor4D, OCRError> {
117        self.inference
118            .infer_4d(batch_tensor)
119            .map_err(|e| OCRError::Inference {
120                model_name: "UVDoc".to_string(),
121                context: format!(
122                    "failed to run inference on batch with shape {:?}",
123                    batch_tensor.shape()
124                ),
125                source: Box::new(e),
126            })
127    }
128
129    /// Postprocesses model predictions to rectified images.
130    ///
131    /// # Arguments
132    ///
133    /// * `predictions` - Model predictions
134    /// * `original_sizes` - Original image sizes (width, height)
135    ///
136    /// # Returns
137    ///
138    /// Rectified images resized to original dimensions
139    pub fn postprocess(
140        &self,
141        predictions: &Tensor4D,
142        original_sizes: &[(u32, u32)],
143    ) -> Result<Vec<RgbImage>, OCRError> {
144        // Use UVDocPostProcess to convert tensor to images
145        let mut images =
146            self.postprocessor
147                .apply_batch(predictions)
148                .map_err(|e| OCRError::ConfigError {
149                    message: format!("Failed to postprocess rectification output: {}", e),
150                })?;
151
152        if images.len() != original_sizes.len() {
153            return Err(OCRError::InvalidInput {
154                message: format!(
155                    "Mismatched rectification batch sizes: predictions={}, originals={}",
156                    images.len(),
157                    original_sizes.len()
158                ),
159            });
160        }
161
162        // Resize back to original dimensions
163        for (img, &(orig_w, orig_h)) in images.iter_mut().zip(original_sizes) {
164            if orig_w == 0 || orig_h == 0 {
165                continue;
166            }
167
168            if img.width() != orig_w || img.height() != orig_h {
169                // Use cv2.INTER_LINEAR for resizing outputs back to original size.
170                let resized = DynamicImage::ImageRgb8(std::mem::take(img)).resize_exact(
171                    orig_w,
172                    orig_h,
173                    FilterType::Triangle,
174                );
175                *img = resized.into_rgb8();
176            }
177        }
178
179        Ok(images)
180    }
181
182    /// Performs complete forward pass: preprocess -> infer -> postprocess.
183    ///
184    /// # Arguments
185    ///
186    /// * `images` - Input images to rectify
187    ///
188    /// # Returns
189    ///
190    /// UVDocModelOutput containing rectified images
191    pub fn forward(&self, images: Vec<RgbImage>) -> Result<UVDocModelOutput, OCRError> {
192        let (batch_tensor, original_sizes) = self.preprocess(images)?;
193        let predictions = self.infer(&batch_tensor)?;
194        let rectified_images = self.postprocess(&predictions, &original_sizes)?;
195
196        Ok(UVDocModelOutput {
197            images: rectified_images,
198        })
199    }
200}
201
202/// Builder for UVDoc model.
203#[derive(Debug, Default)]
204pub struct UVDocModelBuilder {
205    /// Preprocessing configuration
206    preprocess_config: UVDocPreprocessConfig,
207    /// ONNX Runtime session configuration
208    ort_config: Option<crate::core::config::OrtSessionConfig>,
209}
210
211impl UVDocModelBuilder {
212    /// Creates a new UVDoc model builder.
213    pub fn new() -> Self {
214        Self {
215            preprocess_config: UVDocPreprocessConfig::default(),
216            ort_config: None,
217        }
218    }
219
220    /// Sets the preprocessing configuration.
221    pub fn preprocess_config(mut self, config: UVDocPreprocessConfig) -> Self {
222        self.preprocess_config = config;
223        self
224    }
225
226    /// Sets the input image shape.
227    pub fn rec_image_shape(mut self, shape: [usize; 3]) -> Self {
228        self.preprocess_config.rec_image_shape = shape;
229        self
230    }
231
232    /// Sets the ONNX Runtime session configuration.
233    pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
234        self.ort_config = Some(config);
235        self
236    }
237
238    /// Builds the UVDoc model.
239    ///
240    /// # Arguments
241    ///
242    /// * `model_path` - Path to the ONNX model file
243    ///
244    /// # Returns
245    ///
246    /// A configured UVDoc model instance
247    pub fn build(self, model_path: &std::path::Path) -> Result<UVDocModel, OCRError> {
248        // Create ONNX inference engine
249        let inference = if self.ort_config.is_some() {
250            use crate::core::config::ModelInferenceConfig;
251            let common_config = ModelInferenceConfig {
252                ort_session: self.ort_config,
253                ..Default::default()
254            };
255            OrtInfer::from_config(&common_config, model_path, Some("image"))?
256        } else {
257            OrtInfer::new(model_path, Some("image"))?
258        };
259
260        // Create normalizer (scale to [0, 1] without mean shift).
261        // Images are read in BGR and UVDoc models are trained with BGR order,
262        // so keep color order consistent here.
263        let normalizer = NormalizeImage::with_color_order(
264            Some(1.0 / 255.0),
265            Some(vec![0.0, 0.0, 0.0]),
266            Some(vec![1.0, 1.0, 1.0]),
267            Some(TensorLayout::CHW),
268            Some(crate::processors::types::ColorOrder::BGR),
269        )?;
270
271        // Create postprocessor
272        let postprocessor = UVDocPostProcess::new(255.0);
273
274        Ok(UVDocModel::new(
275            inference,
276            normalizer,
277            postprocessor,
278            self.preprocess_config.rec_image_shape,
279        ))
280    }
281}