oar_ocr_core/processors/uvdoc_postprocess.rs
1//! Document transformation post-processing functionality.
2
3use crate::core::OcrResult;
4use crate::core::errors::OCRError;
5use std::str::FromStr;
6
7/// Post-processor for document transformation results.
8///
9/// The `UVDocPostProcess` struct handles the post-processing of document
10/// transformation model outputs, converting normalized coordinates back
11/// to pixel coordinates and applying various transformations.
12#[derive(Debug)]
13pub struct UVDocPostProcess {
14 /// Scale factor to convert normalized values back to pixel values.
15 pub scale: f32,
16}
17
18impl UVDocPostProcess {
19 /// Creates a new UVDocPostProcess instance.
20 ///
21 /// # Arguments
22 ///
23 /// * `scale` - Scale factor for converting normalized coordinates to pixels.
24 ///
25 /// # Examples
26 ///
27 /// ```rust,no_run
28 /// use oar_ocr_core::processors::UVDocPostProcess;
29 ///
30 /// let postprocessor = UVDocPostProcess::new(1.0);
31 /// ```
32 pub fn new(scale: f32) -> Self {
33 Self { scale }
34 }
35
36 /// Gets the current scale factor.
37 ///
38 /// # Returns
39 ///
40 /// The scale factor used for coordinate conversion.
41 pub fn scale(&self) -> f32 {
42 self.scale
43 }
44
45 /// Sets a new scale factor.
46 ///
47 /// # Arguments
48 ///
49 /// * `scale` - New scale factor.
50 pub fn set_scale(&mut self, scale: f32) {
51 self.scale = scale;
52 }
53
54 /// Converts normalized coordinates to pixel coordinates.
55 ///
56 /// # Arguments
57 ///
58 /// * `normalized_coords` - Vector of normalized coordinates (0.0 to 1.0).
59 ///
60 /// # Returns
61 ///
62 /// * `Vec<f32>` - Vector of pixel coordinates.
63 ///
64 /// # Examples
65 ///
66 /// ```rust,no_run
67 /// use oar_ocr_core::processors::UVDocPostProcess;
68 ///
69 /// let postprocessor = UVDocPostProcess::new(100.0);
70 /// let normalized = vec![0.1, 0.2, 0.8, 0.9];
71 /// let pixels = postprocessor.denormalize_coordinates(&normalized);
72 /// assert_eq!(pixels, vec![10.0, 20.0, 80.0, 90.0]);
73 /// ```
74 pub fn denormalize_coordinates(&self, normalized_coords: &[f32]) -> Vec<f32> {
75 normalized_coords
76 .iter()
77 .map(|&coord| coord * self.scale)
78 .collect()
79 }
80
81 /// Converts pixel coordinates to normalized coordinates.
82 ///
83 /// # Arguments
84 ///
85 /// * `pixel_coords` - Vector of pixel coordinates.
86 ///
87 /// # Returns
88 ///
89 /// * `Vec<f32>` - Vector of normalized coordinates (0.0 to 1.0).
90 ///
91 /// # Examples
92 ///
93 /// ```rust,no_run
94 /// use oar_ocr_core::processors::UVDocPostProcess;
95 ///
96 /// let postprocessor = UVDocPostProcess::new(100.0);
97 /// let pixels = vec![10.0, 20.0, 80.0, 90.0];
98 /// let normalized = postprocessor.normalize_coordinates(&pixels);
99 /// assert_eq!(normalized, vec![0.1, 0.2, 0.8, 0.9]);
100 /// ```
101 pub fn normalize_coordinates(&self, pixel_coords: &[f32]) -> Vec<f32> {
102 if self.scale == 0.0 {
103 return vec![0.0; pixel_coords.len()];
104 }
105 pixel_coords
106 .iter()
107 .map(|&coord| coord / self.scale)
108 .collect()
109 }
110
111 /// Processes a bounding box from normalized to pixel coordinates.
112 ///
113 /// # Arguments
114 ///
115 /// * `bbox` - Bounding box as [x1, y1, x2, y2] in normalized coordinates.
116 ///
117 /// # Returns
118 ///
119 /// * `[f32; 4]` - Bounding box in pixel coordinates.
120 ///
121 /// # Examples
122 ///
123 /// ```rust,no_run
124 /// use oar_ocr_core::processors::UVDocPostProcess;
125 ///
126 /// let postprocessor = UVDocPostProcess::new(100.0);
127 /// let normalized_bbox = [0.1, 0.2, 0.8, 0.9];
128 /// let pixel_bbox = postprocessor.process_bbox(&normalized_bbox);
129 /// assert_eq!(pixel_bbox, [10.0, 20.0, 80.0, 90.0]);
130 /// ```
131 pub fn process_bbox(&self, bbox: &[f32; 4]) -> [f32; 4] {
132 [
133 bbox[0] * self.scale,
134 bbox[1] * self.scale,
135 bbox[2] * self.scale,
136 bbox[3] * self.scale,
137 ]
138 }
139
140 /// Processes multiple bounding boxes.
141 ///
142 /// # Arguments
143 ///
144 /// * `bboxes` - Vector of bounding boxes in normalized coordinates.
145 ///
146 /// # Returns
147 ///
148 /// * `Vec<[f32; 4]>` - Vector of bounding boxes in pixel coordinates.
149 pub fn process_bboxes(&self, bboxes: &[[f32; 4]]) -> Vec<[f32; 4]> {
150 bboxes.iter().map(|bbox| self.process_bbox(bbox)).collect()
151 }
152
153 /// Processes a polygon from normalized to pixel coordinates.
154 ///
155 /// # Arguments
156 ///
157 /// * `polygon` - Vector of points as [x, y] pairs in normalized coordinates.
158 ///
159 /// # Returns
160 ///
161 /// * `Vec<[f32; 2]>` - Vector of points in pixel coordinates.
162 ///
163 /// # Examples
164 ///
165 /// ```rust,no_run
166 /// use oar_ocr_core::processors::UVDocPostProcess;
167 ///
168 /// let postprocessor = UVDocPostProcess::new(100.0);
169 /// let normalized_polygon = vec![[0.1, 0.2], [0.8, 0.2], [0.8, 0.9], [0.1, 0.9]];
170 /// let pixel_polygon = postprocessor.process_polygon(&normalized_polygon);
171 /// assert_eq!(pixel_polygon[0], [10.0, 20.0]);
172 /// ```
173 pub fn process_polygon(&self, polygon: &[[f32; 2]]) -> Vec<[f32; 2]> {
174 polygon
175 .iter()
176 .map(|&[x, y]| [x * self.scale, y * self.scale])
177 .collect()
178 }
179
180 /// Clamps coordinates to valid ranges.
181 ///
182 /// # Arguments
183 ///
184 /// * `coords` - Vector of coordinates to clamp.
185 /// * `min_val` - Minimum allowed value.
186 /// * `max_val` - Maximum allowed value.
187 ///
188 /// # Returns
189 ///
190 /// * `Vec<f32>` - Vector of clamped coordinates.
191 pub fn clamp_coordinates(&self, coords: &[f32], min_val: f32, max_val: f32) -> Vec<f32> {
192 coords
193 .iter()
194 .map(|&coord| coord.clamp(min_val, max_val))
195 .collect()
196 }
197
198 /// Validates that coordinates are within expected ranges.
199 ///
200 /// # Arguments
201 ///
202 /// * `coords` - Vector of coordinates to validate.
203 /// * `min_val` - Minimum expected value.
204 /// * `max_val` - Maximum expected value.
205 ///
206 /// # Returns
207 ///
208 /// * `true` - If all coordinates are within range.
209 /// * `false` - If any coordinate is out of range.
210 pub fn validate_coordinates(&self, coords: &[f32], min_val: f32, max_val: f32) -> bool {
211 coords
212 .iter()
213 .all(|&coord| coord >= min_val && coord <= max_val)
214 }
215
216 /// Rounds coordinates to integer values.
217 ///
218 /// # Arguments
219 ///
220 /// * `coords` - Vector of coordinates to round.
221 ///
222 /// # Returns
223 ///
224 /// * `Vec<i32>` - Vector of rounded integer coordinates.
225 pub fn round_coordinates(&self, coords: &[f32]) -> Vec<i32> {
226 coords.iter().map(|&coord| coord.round() as i32).collect()
227 }
228
229 /// Processes transformation matrix values.
230 ///
231 /// # Arguments
232 ///
233 /// * `matrix` - 3x3 transformation matrix as a flat vector.
234 ///
235 /// # Returns
236 ///
237 /// * `Vec<f32>` - Processed transformation matrix.
238 pub fn process_transformation_matrix(&self, matrix: &[f32; 9]) -> [f32; 9] {
239 // Apply scale to translation components (indices 2 and 5)
240 let mut processed = *matrix;
241 processed[2] *= self.scale; // tx
242 processed[5] *= self.scale; // ty
243 processed
244 }
245
246 /// Applies inverse transformation to coordinates.
247 ///
248 /// # Arguments
249 ///
250 /// * `coords` - Vector of coordinates to transform.
251 /// * `matrix` - 3x3 transformation matrix.
252 ///
253 /// # Returns
254 ///
255 /// * `OcrResult<Vec<[f32; 2]>>` - Transformed coordinates or error.
256 pub fn apply_inverse_transform(
257 &self,
258 coords: &[[f32; 2]],
259 matrix: &[f32; 9],
260 ) -> OcrResult<Vec<[f32; 2]>> {
261 // Calculate determinant for matrix inversion
262 let det = matrix[0] * (matrix[4] * matrix[8] - matrix[5] * matrix[7])
263 - matrix[1] * (matrix[3] * matrix[8] - matrix[5] * matrix[6])
264 + matrix[2] * (matrix[3] * matrix[7] - matrix[4] * matrix[6]);
265
266 if det.abs() < f32::EPSILON {
267 return Err(OCRError::InvalidInput {
268 message: "Matrix is not invertible (determinant is zero)".to_string(),
269 });
270 }
271
272 // For simplicity, this is a basic implementation
273 // In practice, you might want to use a proper matrix library
274 let mut transformed = Vec::new();
275 for &[x, y] in coords {
276 // Apply inverse transformation (simplified)
277 let new_x = (x - matrix[2]) / matrix[0];
278 let new_y = (y - matrix[5]) / matrix[4];
279 transformed.push([new_x, new_y]);
280 }
281
282 Ok(transformed)
283 }
284
285 /// Applies batch processing to tensor output to produce rectified images.
286 ///
287 /// # Arguments
288 ///
289 /// * `output` - 4D tensor output from the model [batch, channels, height, width].
290 ///
291 /// # Returns
292 ///
293 /// * `OcrResult<Vec<image::RgbImage>>` - Vector of rectified images or error.
294 pub fn apply_batch(&self, output: &crate::core::Tensor4D) -> OcrResult<Vec<image::RgbImage>> {
295 use image::{Rgb, RgbImage};
296
297 let shape = output.shape();
298 if shape.len() != 4 {
299 return Err(OCRError::InvalidInput {
300 message: "Expected 4D tensor [batch, channels, height, width]".to_string(),
301 });
302 }
303
304 let batch_size = shape[0];
305 let channels = shape[1];
306 let height = shape[2];
307 let width = shape[3];
308
309 if channels != 3 {
310 return Err(OCRError::InvalidInput {
311 message: "Expected 3 channels (RGB)".to_string(),
312 });
313 }
314
315 let mut images = Vec::with_capacity(batch_size);
316
317 let scale = self.scale;
318
319 for b in 0..batch_size {
320 let mut img = RgbImage::new(width as u32, height as u32);
321
322 for y in 0..height {
323 for x in 0..width {
324 // Model outputs are in BGR order; convert back to RGB.
325 let b_val = (output[[b, 0, y, x]] * scale).clamp(0.0, 255.0) as u8;
326 let g_val = (output[[b, 1, y, x]] * scale).clamp(0.0, 255.0) as u8;
327 let r_val = (output[[b, 2, y, x]] * scale).clamp(0.0, 255.0) as u8;
328
329 img.put_pixel(x as u32, y as u32, Rgb([r_val, g_val, b_val]));
330 }
331 }
332
333 images.push(img);
334 }
335
336 Ok(images)
337 }
338}
339
340impl Default for UVDocPostProcess {
341 /// Creates a default UVDocPostProcess with scale factor 1.0.
342 fn default() -> Self {
343 Self::new(1.0)
344 }
345}
346
347impl FromStr for UVDocPostProcess {
348 type Err = std::num::ParseFloatError;
349
350 /// Creates a UVDocPostProcess from a string representation of the scale factor.
351 ///
352 /// # Arguments
353 ///
354 /// * `s` - String representation of the scale factor.
355 ///
356 /// # Returns
357 ///
358 /// * `Ok(UVDocPostProcess)` - If the string can be parsed as a float.
359 /// * `Err(ParseFloatError)` - If the string cannot be parsed.
360 fn from_str(s: &str) -> Result<Self, Self::Err> {
361 let scale = s.parse::<f32>()?;
362 Ok(Self::new(scale))
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_denormalize_coordinates() {
372 let postprocessor = UVDocPostProcess::new(100.0);
373 let normalized = vec![0.1, 0.2, 0.8, 0.9];
374 let pixels = postprocessor.denormalize_coordinates(&normalized);
375 assert_eq!(pixels, vec![10.0, 20.0, 80.0, 90.0]);
376 }
377
378 #[test]
379 fn test_normalize_coordinates() {
380 let postprocessor = UVDocPostProcess::new(100.0);
381 let pixels = vec![10.0, 20.0, 80.0, 90.0];
382 let normalized = postprocessor.normalize_coordinates(&pixels);
383 assert_eq!(normalized, vec![0.1, 0.2, 0.8, 0.9]);
384 }
385
386 #[test]
387 fn test_process_bbox() {
388 let postprocessor = UVDocPostProcess::new(100.0);
389 let normalized_bbox = [0.1, 0.2, 0.8, 0.9];
390 let pixel_bbox = postprocessor.process_bbox(&normalized_bbox);
391 assert_eq!(pixel_bbox, [10.0, 20.0, 80.0, 90.0]);
392 }
393
394 #[test]
395 fn test_process_polygon() {
396 let postprocessor = UVDocPostProcess::new(100.0);
397 let normalized_polygon = vec![[0.1, 0.2], [0.8, 0.2], [0.8, 0.9], [0.1, 0.9]];
398 let pixel_polygon = postprocessor.process_polygon(&normalized_polygon);
399 assert_eq!(pixel_polygon[0], [10.0, 20.0]);
400 assert_eq!(pixel_polygon[1], [80.0, 20.0]);
401 }
402
403 #[test]
404 fn test_clamp_coordinates() {
405 let postprocessor = UVDocPostProcess::new(1.0);
406 let coords = vec![-10.0, 50.0, 150.0];
407 let clamped = postprocessor.clamp_coordinates(&coords, 0.0, 100.0);
408 assert_eq!(clamped, vec![0.0, 50.0, 100.0]);
409 }
410
411 #[test]
412 fn test_validate_coordinates() {
413 let postprocessor = UVDocPostProcess::new(1.0);
414 let valid_coords = vec![10.0, 50.0, 90.0];
415 let invalid_coords = vec![10.0, 150.0, 90.0];
416
417 assert!(postprocessor.validate_coordinates(&valid_coords, 0.0, 100.0));
418 assert!(!postprocessor.validate_coordinates(&invalid_coords, 0.0, 100.0));
419 }
420
421 #[test]
422 fn test_round_coordinates() {
423 let postprocessor = UVDocPostProcess::new(1.0);
424 let coords = vec![10.3, 20.7, 30.5];
425 let rounded = postprocessor.round_coordinates(&coords);
426 assert_eq!(rounded, vec![10, 21, 31]);
427 }
428
429 #[test]
430 fn test_from_str() {
431 let postprocessor: UVDocPostProcess = "2.5".parse().unwrap();
432 assert_eq!(postprocessor.scale(), 2.5);
433
434 assert!("invalid".parse::<UVDocPostProcess>().is_err());
435 }
436
437 #[test]
438 fn test_zero_scale_normalize() {
439 let postprocessor = UVDocPostProcess::new(0.0);
440 let pixels = vec![10.0, 20.0];
441 let normalized = postprocessor.normalize_coordinates(&pixels);
442 assert_eq!(normalized, vec![0.0, 0.0]);
443 }
444}