oar_ocr/processors/postprocess/
doctr.rs

1//! Document transformation post-processing functionality.
2
3use std::str::FromStr;
4
5/// Post-processor for document transformation results.
6///
7/// The `DocTrPostProcess` struct handles the post-processing of document
8/// transformation model outputs, converting normalized coordinates back
9/// to pixel coordinates and applying various transformations.
10#[derive(Debug)]
11pub struct DocTrPostProcess {
12    /// Scale factor to convert normalized values back to pixel values.
13    pub scale: f32,
14}
15
16impl DocTrPostProcess {
17    /// Creates a new DocTrPostProcess instance.
18    ///
19    /// # Arguments
20    ///
21    /// * `scale` - Scale factor for converting normalized coordinates to pixels.
22    ///
23    /// # Examples
24    ///
25    /// ```rust,no_run
26    /// use oar_ocr::processors::DocTrPostProcess;
27    ///
28    /// let postprocessor = DocTrPostProcess::new(1.0);
29    /// ```
30    pub fn new(scale: f32) -> Self {
31        Self { scale }
32    }
33
34    /// Gets the current scale factor.
35    ///
36    /// # Returns
37    ///
38    /// The scale factor used for coordinate conversion.
39    pub fn scale(&self) -> f32 {
40        self.scale
41    }
42
43    /// Sets a new scale factor.
44    ///
45    /// # Arguments
46    ///
47    /// * `scale` - New scale factor.
48    pub fn set_scale(&mut self, scale: f32) {
49        self.scale = scale;
50    }
51
52    /// Converts normalized coordinates to pixel coordinates.
53    ///
54    /// # Arguments
55    ///
56    /// * `normalized_coords` - Vector of normalized coordinates (0.0 to 1.0).
57    ///
58    /// # Returns
59    ///
60    /// * `Vec<f32>` - Vector of pixel coordinates.
61    ///
62    /// # Examples
63    ///
64    /// ```rust,no_run
65    /// use oar_ocr::processors::DocTrPostProcess;
66    ///
67    /// let postprocessor = DocTrPostProcess::new(100.0);
68    /// let normalized = vec![0.1, 0.2, 0.8, 0.9];
69    /// let pixels = postprocessor.denormalize_coordinates(&normalized);
70    /// assert_eq!(pixels, vec![10.0, 20.0, 80.0, 90.0]);
71    /// ```
72    pub fn denormalize_coordinates(&self, normalized_coords: &[f32]) -> Vec<f32> {
73        normalized_coords
74            .iter()
75            .map(|&coord| coord * self.scale)
76            .collect()
77    }
78
79    /// Converts pixel coordinates to normalized coordinates.
80    ///
81    /// # Arguments
82    ///
83    /// * `pixel_coords` - Vector of pixel coordinates.
84    ///
85    /// # Returns
86    ///
87    /// * `Vec<f32>` - Vector of normalized coordinates (0.0 to 1.0).
88    ///
89    /// # Examples
90    ///
91    /// ```rust,no_run
92    /// use oar_ocr::processors::DocTrPostProcess;
93    ///
94    /// let postprocessor = DocTrPostProcess::new(100.0);
95    /// let pixels = vec![10.0, 20.0, 80.0, 90.0];
96    /// let normalized = postprocessor.normalize_coordinates(&pixels);
97    /// assert_eq!(normalized, vec![0.1, 0.2, 0.8, 0.9]);
98    /// ```
99    pub fn normalize_coordinates(&self, pixel_coords: &[f32]) -> Vec<f32> {
100        if self.scale == 0.0 {
101            return vec![0.0; pixel_coords.len()];
102        }
103        pixel_coords
104            .iter()
105            .map(|&coord| coord / self.scale)
106            .collect()
107    }
108
109    /// Processes a bounding box from normalized to pixel coordinates.
110    ///
111    /// # Arguments
112    ///
113    /// * `bbox` - Bounding box as [x1, y1, x2, y2] in normalized coordinates.
114    ///
115    /// # Returns
116    ///
117    /// * `[f32; 4]` - Bounding box in pixel coordinates.
118    ///
119    /// # Examples
120    ///
121    /// ```rust,no_run
122    /// use oar_ocr::processors::DocTrPostProcess;
123    ///
124    /// let postprocessor = DocTrPostProcess::new(100.0);
125    /// let normalized_bbox = [0.1, 0.2, 0.8, 0.9];
126    /// let pixel_bbox = postprocessor.process_bbox(&normalized_bbox);
127    /// assert_eq!(pixel_bbox, [10.0, 20.0, 80.0, 90.0]);
128    /// ```
129    pub fn process_bbox(&self, bbox: &[f32; 4]) -> [f32; 4] {
130        [
131            bbox[0] * self.scale,
132            bbox[1] * self.scale,
133            bbox[2] * self.scale,
134            bbox[3] * self.scale,
135        ]
136    }
137
138    /// Processes multiple bounding boxes.
139    ///
140    /// # Arguments
141    ///
142    /// * `bboxes` - Vector of bounding boxes in normalized coordinates.
143    ///
144    /// # Returns
145    ///
146    /// * `Vec<[f32; 4]>` - Vector of bounding boxes in pixel coordinates.
147    pub fn process_bboxes(&self, bboxes: &[[f32; 4]]) -> Vec<[f32; 4]> {
148        bboxes.iter().map(|bbox| self.process_bbox(bbox)).collect()
149    }
150
151    /// Processes a polygon from normalized to pixel coordinates.
152    ///
153    /// # Arguments
154    ///
155    /// * `polygon` - Vector of points as [x, y] pairs in normalized coordinates.
156    ///
157    /// # Returns
158    ///
159    /// * `Vec<[f32; 2]>` - Vector of points in pixel coordinates.
160    ///
161    /// # Examples
162    ///
163    /// ```rust,no_run
164    /// use oar_ocr::processors::DocTrPostProcess;
165    ///
166    /// let postprocessor = DocTrPostProcess::new(100.0);
167    /// let normalized_polygon = vec![[0.1, 0.2], [0.8, 0.2], [0.8, 0.9], [0.1, 0.9]];
168    /// let pixel_polygon = postprocessor.process_polygon(&normalized_polygon);
169    /// assert_eq!(pixel_polygon[0], [10.0, 20.0]);
170    /// ```
171    pub fn process_polygon(&self, polygon: &[[f32; 2]]) -> Vec<[f32; 2]> {
172        polygon
173            .iter()
174            .map(|&[x, y]| [x * self.scale, y * self.scale])
175            .collect()
176    }
177
178    /// Clamps coordinates to valid ranges.
179    ///
180    /// # Arguments
181    ///
182    /// * `coords` - Vector of coordinates to clamp.
183    /// * `min_val` - Minimum allowed value.
184    /// * `max_val` - Maximum allowed value.
185    ///
186    /// # Returns
187    ///
188    /// * `Vec<f32>` - Vector of clamped coordinates.
189    pub fn clamp_coordinates(&self, coords: &[f32], min_val: f32, max_val: f32) -> Vec<f32> {
190        coords
191            .iter()
192            .map(|&coord| coord.clamp(min_val, max_val))
193            .collect()
194    }
195
196    /// Validates that coordinates are within expected ranges.
197    ///
198    /// # Arguments
199    ///
200    /// * `coords` - Vector of coordinates to validate.
201    /// * `min_val` - Minimum expected value.
202    /// * `max_val` - Maximum expected value.
203    ///
204    /// # Returns
205    ///
206    /// * `true` - If all coordinates are within range.
207    /// * `false` - If any coordinate is out of range.
208    pub fn validate_coordinates(&self, coords: &[f32], min_val: f32, max_val: f32) -> bool {
209        coords
210            .iter()
211            .all(|&coord| coord >= min_val && coord <= max_val)
212    }
213
214    /// Rounds coordinates to integer values.
215    ///
216    /// # Arguments
217    ///
218    /// * `coords` - Vector of coordinates to round.
219    ///
220    /// # Returns
221    ///
222    /// * `Vec<i32>` - Vector of rounded integer coordinates.
223    pub fn round_coordinates(&self, coords: &[f32]) -> Vec<i32> {
224        coords.iter().map(|&coord| coord.round() as i32).collect()
225    }
226
227    /// Processes transformation matrix values.
228    ///
229    /// # Arguments
230    ///
231    /// * `matrix` - 3x3 transformation matrix as a flat vector.
232    ///
233    /// # Returns
234    ///
235    /// * `Vec<f32>` - Processed transformation matrix.
236    pub fn process_transformation_matrix(&self, matrix: &[f32; 9]) -> [f32; 9] {
237        // Apply scale to translation components (indices 2 and 5)
238        let mut processed = *matrix;
239        processed[2] *= self.scale; // tx
240        processed[5] *= self.scale; // ty
241        processed
242    }
243
244    /// Applies inverse transformation to coordinates.
245    ///
246    /// # Arguments
247    ///
248    /// * `coords` - Vector of coordinates to transform.
249    /// * `matrix` - 3x3 transformation matrix.
250    ///
251    /// # Returns
252    ///
253    /// * `Result<Vec<[f32; 2]>, String>` - Transformed coordinates or error.
254    pub fn apply_inverse_transform(
255        &self,
256        coords: &[[f32; 2]],
257        matrix: &[f32; 9],
258    ) -> Result<Vec<[f32; 2]>, String> {
259        // Calculate determinant for matrix inversion
260        let det = matrix[0] * (matrix[4] * matrix[8] - matrix[5] * matrix[7])
261            - matrix[1] * (matrix[3] * matrix[8] - matrix[5] * matrix[6])
262            + matrix[2] * (matrix[3] * matrix[7] - matrix[4] * matrix[6]);
263
264        if det.abs() < f32::EPSILON {
265            return Err("Matrix is not invertible (determinant is zero)".to_string());
266        }
267
268        // For simplicity, this is a basic implementation
269        // In practice, you might want to use a proper matrix library
270        let mut transformed = Vec::new();
271        for &[x, y] in coords {
272            // Apply inverse transformation (simplified)
273            let new_x = (x - matrix[2]) / matrix[0];
274            let new_y = (y - matrix[5]) / matrix[4];
275            transformed.push([new_x, new_y]);
276        }
277
278        Ok(transformed)
279    }
280
281    /// Applies batch processing to tensor output to produce rectified images.
282    ///
283    /// # Arguments
284    ///
285    /// * `output` - 4D tensor output from the model [batch, channels, height, width].
286    ///
287    /// # Returns
288    ///
289    /// * `Result<Vec<image::RgbImage>, String>` - Vector of rectified images or error.
290    pub fn apply_batch(
291        &self,
292        output: &crate::core::Tensor4D,
293    ) -> Result<Vec<image::RgbImage>, String> {
294        use image::{Rgb, RgbImage};
295
296        let shape = output.shape();
297        if shape.len() != 4 {
298            return Err("Expected 4D tensor [batch, channels, height, width]".to_string());
299        }
300
301        let batch_size = shape[0];
302        let channels = shape[1];
303        let height = shape[2];
304        let width = shape[3];
305
306        if channels != 3 {
307            return Err("Expected 3 channels (RGB)".to_string());
308        }
309
310        let mut images = Vec::with_capacity(batch_size);
311
312        for b in 0..batch_size {
313            let mut img = RgbImage::new(width as u32, height as u32);
314
315            for y in 0..height {
316                for x in 0..width {
317                    // Extract RGB values and denormalize
318                    let r = (output[[b, 0, y, x]] * 255.0).clamp(0.0, 255.0) as u8;
319                    let g = (output[[b, 1, y, x]] * 255.0).clamp(0.0, 255.0) as u8;
320                    let b_val = (output[[b, 2, y, x]] * 255.0).clamp(0.0, 255.0) as u8;
321
322                    img.put_pixel(x as u32, y as u32, Rgb([r, g, b_val]));
323                }
324            }
325
326            images.push(img);
327        }
328
329        Ok(images)
330    }
331}
332
333impl Default for DocTrPostProcess {
334    /// Creates a default DocTrPostProcess with scale factor 1.0.
335    fn default() -> Self {
336        Self::new(1.0)
337    }
338}
339
340impl FromStr for DocTrPostProcess {
341    type Err = std::num::ParseFloatError;
342
343    /// Creates a DocTrPostProcess from a string representation of the scale factor.
344    ///
345    /// # Arguments
346    ///
347    /// * `s` - String representation of the scale factor.
348    ///
349    /// # Returns
350    ///
351    /// * `Ok(DocTrPostProcess)` - If the string can be parsed as a float.
352    /// * `Err(ParseFloatError)` - If the string cannot be parsed.
353    fn from_str(s: &str) -> Result<Self, Self::Err> {
354        let scale = s.parse::<f32>()?;
355        Ok(Self::new(scale))
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_denormalize_coordinates() {
365        let postprocessor = DocTrPostProcess::new(100.0);
366        let normalized = vec![0.1, 0.2, 0.8, 0.9];
367        let pixels = postprocessor.denormalize_coordinates(&normalized);
368        assert_eq!(pixels, vec![10.0, 20.0, 80.0, 90.0]);
369    }
370
371    #[test]
372    fn test_normalize_coordinates() {
373        let postprocessor = DocTrPostProcess::new(100.0);
374        let pixels = vec![10.0, 20.0, 80.0, 90.0];
375        let normalized = postprocessor.normalize_coordinates(&pixels);
376        assert_eq!(normalized, vec![0.1, 0.2, 0.8, 0.9]);
377    }
378
379    #[test]
380    fn test_process_bbox() {
381        let postprocessor = DocTrPostProcess::new(100.0);
382        let normalized_bbox = [0.1, 0.2, 0.8, 0.9];
383        let pixel_bbox = postprocessor.process_bbox(&normalized_bbox);
384        assert_eq!(pixel_bbox, [10.0, 20.0, 80.0, 90.0]);
385    }
386
387    #[test]
388    fn test_process_polygon() {
389        let postprocessor = DocTrPostProcess::new(100.0);
390        let normalized_polygon = vec![[0.1, 0.2], [0.8, 0.2], [0.8, 0.9], [0.1, 0.9]];
391        let pixel_polygon = postprocessor.process_polygon(&normalized_polygon);
392        assert_eq!(pixel_polygon[0], [10.0, 20.0]);
393        assert_eq!(pixel_polygon[1], [80.0, 20.0]);
394    }
395
396    #[test]
397    fn test_clamp_coordinates() {
398        let postprocessor = DocTrPostProcess::new(1.0);
399        let coords = vec![-10.0, 50.0, 150.0];
400        let clamped = postprocessor.clamp_coordinates(&coords, 0.0, 100.0);
401        assert_eq!(clamped, vec![0.0, 50.0, 100.0]);
402    }
403
404    #[test]
405    fn test_validate_coordinates() {
406        let postprocessor = DocTrPostProcess::new(1.0);
407        let valid_coords = vec![10.0, 50.0, 90.0];
408        let invalid_coords = vec![10.0, 150.0, 90.0];
409
410        assert!(postprocessor.validate_coordinates(&valid_coords, 0.0, 100.0));
411        assert!(!postprocessor.validate_coordinates(&invalid_coords, 0.0, 100.0));
412    }
413
414    #[test]
415    fn test_round_coordinates() {
416        let postprocessor = DocTrPostProcess::new(1.0);
417        let coords = vec![10.3, 20.7, 30.5];
418        let rounded = postprocessor.round_coordinates(&coords);
419        assert_eq!(rounded, vec![10, 21, 31]);
420    }
421
422    #[test]
423    fn test_from_str() {
424        let postprocessor: DocTrPostProcess = "2.5".parse().unwrap();
425        assert_eq!(postprocessor.scale(), 2.5);
426
427        assert!("invalid".parse::<DocTrPostProcess>().is_err());
428    }
429
430    #[test]
431    fn test_zero_scale_normalize() {
432        let postprocessor = DocTrPostProcess::new(0.0);
433        let pixels = vec![10.0, 20.0];
434        let normalized = postprocessor.normalize_coordinates(&pixels);
435        assert_eq!(normalized, vec![0.0, 0.0]);
436    }
437}