Skip to main content

oar_ocr_core/processors/
uvdoc_postprocess.rs

1//! Document transformation post-processing functionality.
2
3use crate::core::OcrResult;
4use crate::core::errors::OCRError;
5use std::str::FromStr;
6
7/// Post-processor for document transformation results.
8///
9/// The `UVDocPostProcess` struct handles the post-processing of document
10/// transformation model outputs, converting normalized coordinates back
11/// to pixel coordinates and applying various transformations.
12#[derive(Debug)]
13pub struct UVDocPostProcess {
14    /// Scale factor to convert normalized values back to pixel values.
15    pub scale: f32,
16}
17
18impl UVDocPostProcess {
19    /// Creates a new UVDocPostProcess instance.
20    ///
21    /// # Arguments
22    ///
23    /// * `scale` - Scale factor for converting normalized coordinates to pixels.
24    ///
25    /// # Examples
26    ///
27    /// ```rust,no_run
28    /// use oar_ocr_core::processors::UVDocPostProcess;
29    ///
30    /// let postprocessor = UVDocPostProcess::new(1.0);
31    /// ```
32    pub fn new(scale: f32) -> Self {
33        Self { scale }
34    }
35
36    /// Gets the current scale factor.
37    ///
38    /// # Returns
39    ///
40    /// The scale factor used for coordinate conversion.
41    pub fn scale(&self) -> f32 {
42        self.scale
43    }
44
45    /// Sets a new scale factor.
46    ///
47    /// # Arguments
48    ///
49    /// * `scale` - New scale factor.
50    pub fn set_scale(&mut self, scale: f32) {
51        self.scale = scale;
52    }
53
54    /// Converts normalized coordinates to pixel coordinates.
55    ///
56    /// # Arguments
57    ///
58    /// * `normalized_coords` - Vector of normalized coordinates (0.0 to 1.0).
59    ///
60    /// # Returns
61    ///
62    /// * `Vec<f32>` - Vector of pixel coordinates.
63    ///
64    /// # Examples
65    ///
66    /// ```rust,no_run
67    /// use oar_ocr_core::processors::UVDocPostProcess;
68    ///
69    /// let postprocessor = UVDocPostProcess::new(100.0);
70    /// let normalized = vec![0.1, 0.2, 0.8, 0.9];
71    /// let pixels = postprocessor.denormalize_coordinates(&normalized);
72    /// assert_eq!(pixels, vec![10.0, 20.0, 80.0, 90.0]);
73    /// ```
74    pub fn denormalize_coordinates(&self, normalized_coords: &[f32]) -> Vec<f32> {
75        normalized_coords
76            .iter()
77            .map(|&coord| coord * self.scale)
78            .collect()
79    }
80
81    /// Converts pixel coordinates to normalized coordinates.
82    ///
83    /// # Arguments
84    ///
85    /// * `pixel_coords` - Vector of pixel coordinates.
86    ///
87    /// # Returns
88    ///
89    /// * `Vec<f32>` - Vector of normalized coordinates (0.0 to 1.0).
90    ///
91    /// # Examples
92    ///
93    /// ```rust,no_run
94    /// use oar_ocr_core::processors::UVDocPostProcess;
95    ///
96    /// let postprocessor = UVDocPostProcess::new(100.0);
97    /// let pixels = vec![10.0, 20.0, 80.0, 90.0];
98    /// let normalized = postprocessor.normalize_coordinates(&pixels);
99    /// assert_eq!(normalized, vec![0.1, 0.2, 0.8, 0.9]);
100    /// ```
101    pub fn normalize_coordinates(&self, pixel_coords: &[f32]) -> Vec<f32> {
102        if self.scale == 0.0 {
103            return vec![0.0; pixel_coords.len()];
104        }
105        pixel_coords
106            .iter()
107            .map(|&coord| coord / self.scale)
108            .collect()
109    }
110
111    /// Processes a bounding box from normalized to pixel coordinates.
112    ///
113    /// # Arguments
114    ///
115    /// * `bbox` - Bounding box as [x1, y1, x2, y2] in normalized coordinates.
116    ///
117    /// # Returns
118    ///
119    /// * `[f32; 4]` - Bounding box in pixel coordinates.
120    ///
121    /// # Examples
122    ///
123    /// ```rust,no_run
124    /// use oar_ocr_core::processors::UVDocPostProcess;
125    ///
126    /// let postprocessor = UVDocPostProcess::new(100.0);
127    /// let normalized_bbox = [0.1, 0.2, 0.8, 0.9];
128    /// let pixel_bbox = postprocessor.process_bbox(&normalized_bbox);
129    /// assert_eq!(pixel_bbox, [10.0, 20.0, 80.0, 90.0]);
130    /// ```
131    pub fn process_bbox(&self, bbox: &[f32; 4]) -> [f32; 4] {
132        [
133            bbox[0] * self.scale,
134            bbox[1] * self.scale,
135            bbox[2] * self.scale,
136            bbox[3] * self.scale,
137        ]
138    }
139
140    /// Processes multiple bounding boxes.
141    ///
142    /// # Arguments
143    ///
144    /// * `bboxes` - Vector of bounding boxes in normalized coordinates.
145    ///
146    /// # Returns
147    ///
148    /// * `Vec<[f32; 4]>` - Vector of bounding boxes in pixel coordinates.
149    pub fn process_bboxes(&self, bboxes: &[[f32; 4]]) -> Vec<[f32; 4]> {
150        bboxes.iter().map(|bbox| self.process_bbox(bbox)).collect()
151    }
152
153    /// Processes a polygon from normalized to pixel coordinates.
154    ///
155    /// # Arguments
156    ///
157    /// * `polygon` - Vector of points as [x, y] pairs in normalized coordinates.
158    ///
159    /// # Returns
160    ///
161    /// * `Vec<[f32; 2]>` - Vector of points in pixel coordinates.
162    ///
163    /// # Examples
164    ///
165    /// ```rust,no_run
166    /// use oar_ocr_core::processors::UVDocPostProcess;
167    ///
168    /// let postprocessor = UVDocPostProcess::new(100.0);
169    /// let normalized_polygon = vec![[0.1, 0.2], [0.8, 0.2], [0.8, 0.9], [0.1, 0.9]];
170    /// let pixel_polygon = postprocessor.process_polygon(&normalized_polygon);
171    /// assert_eq!(pixel_polygon[0], [10.0, 20.0]);
172    /// ```
173    pub fn process_polygon(&self, polygon: &[[f32; 2]]) -> Vec<[f32; 2]> {
174        polygon
175            .iter()
176            .map(|&[x, y]| [x * self.scale, y * self.scale])
177            .collect()
178    }
179
180    /// Clamps coordinates to valid ranges.
181    ///
182    /// # Arguments
183    ///
184    /// * `coords` - Vector of coordinates to clamp.
185    /// * `min_val` - Minimum allowed value.
186    /// * `max_val` - Maximum allowed value.
187    ///
188    /// # Returns
189    ///
190    /// * `Vec<f32>` - Vector of clamped coordinates.
191    pub fn clamp_coordinates(&self, coords: &[f32], min_val: f32, max_val: f32) -> Vec<f32> {
192        coords
193            .iter()
194            .map(|&coord| coord.clamp(min_val, max_val))
195            .collect()
196    }
197
198    /// Validates that coordinates are within expected ranges.
199    ///
200    /// # Arguments
201    ///
202    /// * `coords` - Vector of coordinates to validate.
203    /// * `min_val` - Minimum expected value.
204    /// * `max_val` - Maximum expected value.
205    ///
206    /// # Returns
207    ///
208    /// * `true` - If all coordinates are within range.
209    /// * `false` - If any coordinate is out of range.
210    pub fn validate_coordinates(&self, coords: &[f32], min_val: f32, max_val: f32) -> bool {
211        coords
212            .iter()
213            .all(|&coord| coord >= min_val && coord <= max_val)
214    }
215
216    /// Rounds coordinates to integer values.
217    ///
218    /// # Arguments
219    ///
220    /// * `coords` - Vector of coordinates to round.
221    ///
222    /// # Returns
223    ///
224    /// * `Vec<i32>` - Vector of rounded integer coordinates.
225    pub fn round_coordinates(&self, coords: &[f32]) -> Vec<i32> {
226        coords.iter().map(|&coord| coord.round() as i32).collect()
227    }
228
229    /// Processes transformation matrix values.
230    ///
231    /// # Arguments
232    ///
233    /// * `matrix` - 3x3 transformation matrix as a flat vector.
234    ///
235    /// # Returns
236    ///
237    /// * `Vec<f32>` - Processed transformation matrix.
238    pub fn process_transformation_matrix(&self, matrix: &[f32; 9]) -> [f32; 9] {
239        // Apply scale to translation components (indices 2 and 5)
240        let mut processed = *matrix;
241        processed[2] *= self.scale; // tx
242        processed[5] *= self.scale; // ty
243        processed
244    }
245
246    /// Applies inverse transformation to coordinates.
247    ///
248    /// # Arguments
249    ///
250    /// * `coords` - Vector of coordinates to transform.
251    /// * `matrix` - 3x3 transformation matrix.
252    ///
253    /// # Returns
254    ///
255    /// * `OcrResult<Vec<[f32; 2]>>` - Transformed coordinates or error.
256    pub fn apply_inverse_transform(
257        &self,
258        coords: &[[f32; 2]],
259        matrix: &[f32; 9],
260    ) -> OcrResult<Vec<[f32; 2]>> {
261        // Calculate determinant for matrix inversion
262        let det = matrix[0] * (matrix[4] * matrix[8] - matrix[5] * matrix[7])
263            - matrix[1] * (matrix[3] * matrix[8] - matrix[5] * matrix[6])
264            + matrix[2] * (matrix[3] * matrix[7] - matrix[4] * matrix[6]);
265
266        if det.abs() < f32::EPSILON {
267            return Err(OCRError::InvalidInput {
268                message: "Matrix is not invertible (determinant is zero)".to_string(),
269            });
270        }
271
272        // For simplicity, this is a basic implementation
273        // In practice, you might want to use a proper matrix library
274        let mut transformed = Vec::new();
275        for &[x, y] in coords {
276            // Apply inverse transformation (simplified)
277            let new_x = (x - matrix[2]) / matrix[0];
278            let new_y = (y - matrix[5]) / matrix[4];
279            transformed.push([new_x, new_y]);
280        }
281
282        Ok(transformed)
283    }
284
285    /// Applies batch processing to tensor output to produce rectified images.
286    ///
287    /// # Arguments
288    ///
289    /// * `output` - 4D tensor output from the model [batch, channels, height, width].
290    ///
291    /// # Returns
292    ///
293    /// * `OcrResult<Vec<image::RgbImage>>` - Vector of rectified images or error.
294    pub fn apply_batch(&self, output: &crate::core::Tensor4D) -> OcrResult<Vec<image::RgbImage>> {
295        use image::{Rgb, RgbImage};
296
297        let shape = output.shape();
298        if shape.len() != 4 {
299            return Err(OCRError::InvalidInput {
300                message: "Expected 4D tensor [batch, channels, height, width]".to_string(),
301            });
302        }
303
304        let batch_size = shape[0];
305        let channels = shape[1];
306        let height = shape[2];
307        let width = shape[3];
308
309        if channels != 3 {
310            return Err(OCRError::InvalidInput {
311                message: "Expected 3 channels (RGB)".to_string(),
312            });
313        }
314
315        let mut images = Vec::with_capacity(batch_size);
316
317        let scale = self.scale;
318
319        for b in 0..batch_size {
320            let mut img = RgbImage::new(width as u32, height as u32);
321
322            for y in 0..height {
323                for x in 0..width {
324                    // Model outputs are in BGR order; convert back to RGB.
325                    let b_val = (output[[b, 0, y, x]] * scale).clamp(0.0, 255.0) as u8;
326                    let g_val = (output[[b, 1, y, x]] * scale).clamp(0.0, 255.0) as u8;
327                    let r_val = (output[[b, 2, y, x]] * scale).clamp(0.0, 255.0) as u8;
328
329                    img.put_pixel(x as u32, y as u32, Rgb([r_val, g_val, b_val]));
330                }
331            }
332
333            images.push(img);
334        }
335
336        Ok(images)
337    }
338}
339
340impl Default for UVDocPostProcess {
341    /// Creates a default UVDocPostProcess with scale factor 1.0.
342    fn default() -> Self {
343        Self::new(1.0)
344    }
345}
346
347impl FromStr for UVDocPostProcess {
348    type Err = std::num::ParseFloatError;
349
350    /// Creates a UVDocPostProcess from a string representation of the scale factor.
351    ///
352    /// # Arguments
353    ///
354    /// * `s` - String representation of the scale factor.
355    ///
356    /// # Returns
357    ///
358    /// * `Ok(UVDocPostProcess)` - If the string can be parsed as a float.
359    /// * `Err(ParseFloatError)` - If the string cannot be parsed.
360    fn from_str(s: &str) -> Result<Self, Self::Err> {
361        let scale = s.parse::<f32>()?;
362        Ok(Self::new(scale))
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_denormalize_coordinates() {
372        let postprocessor = UVDocPostProcess::new(100.0);
373        let normalized = vec![0.1, 0.2, 0.8, 0.9];
374        let pixels = postprocessor.denormalize_coordinates(&normalized);
375        assert_eq!(pixels, vec![10.0, 20.0, 80.0, 90.0]);
376    }
377
378    #[test]
379    fn test_normalize_coordinates() {
380        let postprocessor = UVDocPostProcess::new(100.0);
381        let pixels = vec![10.0, 20.0, 80.0, 90.0];
382        let normalized = postprocessor.normalize_coordinates(&pixels);
383        assert_eq!(normalized, vec![0.1, 0.2, 0.8, 0.9]);
384    }
385
386    #[test]
387    fn test_process_bbox() {
388        let postprocessor = UVDocPostProcess::new(100.0);
389        let normalized_bbox = [0.1, 0.2, 0.8, 0.9];
390        let pixel_bbox = postprocessor.process_bbox(&normalized_bbox);
391        assert_eq!(pixel_bbox, [10.0, 20.0, 80.0, 90.0]);
392    }
393
394    #[test]
395    fn test_process_polygon() {
396        let postprocessor = UVDocPostProcess::new(100.0);
397        let normalized_polygon = vec![[0.1, 0.2], [0.8, 0.2], [0.8, 0.9], [0.1, 0.9]];
398        let pixel_polygon = postprocessor.process_polygon(&normalized_polygon);
399        assert_eq!(pixel_polygon[0], [10.0, 20.0]);
400        assert_eq!(pixel_polygon[1], [80.0, 20.0]);
401    }
402
403    #[test]
404    fn test_clamp_coordinates() {
405        let postprocessor = UVDocPostProcess::new(1.0);
406        let coords = vec![-10.0, 50.0, 150.0];
407        let clamped = postprocessor.clamp_coordinates(&coords, 0.0, 100.0);
408        assert_eq!(clamped, vec![0.0, 50.0, 100.0]);
409    }
410
411    #[test]
412    fn test_validate_coordinates() {
413        let postprocessor = UVDocPostProcess::new(1.0);
414        let valid_coords = vec![10.0, 50.0, 90.0];
415        let invalid_coords = vec![10.0, 150.0, 90.0];
416
417        assert!(postprocessor.validate_coordinates(&valid_coords, 0.0, 100.0));
418        assert!(!postprocessor.validate_coordinates(&invalid_coords, 0.0, 100.0));
419    }
420
421    #[test]
422    fn test_round_coordinates() {
423        let postprocessor = UVDocPostProcess::new(1.0);
424        let coords = vec![10.3, 20.7, 30.5];
425        let rounded = postprocessor.round_coordinates(&coords);
426        assert_eq!(rounded, vec![10, 21, 31]);
427    }
428
429    #[test]
430    fn test_from_str() -> Result<(), std::num::ParseFloatError> {
431        let postprocessor: UVDocPostProcess = "2.5".parse()?;
432        assert_eq!(postprocessor.scale(), 2.5);
433
434        assert!("invalid".parse::<UVDocPostProcess>().is_err());
435        Ok(())
436    }
437
438    #[test]
439    fn test_zero_scale_normalize() {
440        let postprocessor = UVDocPostProcess::new(0.0);
441        let pixels = vec![10.0, 20.0];
442        let normalized = postprocessor.normalize_coordinates(&pixels);
443        assert_eq!(normalized, vec![0.0, 0.0]);
444    }
445}