oar_ocr/processors/postprocess/doctr.rs
1//! Document transformation post-processing functionality.
2
3use std::str::FromStr;
4
5/// Post-processor for document transformation results.
6///
7/// The `DocTrPostProcess` struct handles the post-processing of document
8/// transformation model outputs, converting normalized coordinates back
9/// to pixel coordinates and applying various transformations.
10#[derive(Debug)]
11pub struct DocTrPostProcess {
12 /// Scale factor to convert normalized values back to pixel values.
13 pub scale: f32,
14}
15
16impl DocTrPostProcess {
17 /// Creates a new DocTrPostProcess instance.
18 ///
19 /// # Arguments
20 ///
21 /// * `scale` - Scale factor for converting normalized coordinates to pixels.
22 ///
23 /// # Examples
24 ///
25 /// ```rust,no_run
26 /// use oar_ocr::processors::DocTrPostProcess;
27 ///
28 /// let postprocessor = DocTrPostProcess::new(1.0);
29 /// ```
30 pub fn new(scale: f32) -> Self {
31 Self { scale }
32 }
33
34 /// Gets the current scale factor.
35 ///
36 /// # Returns
37 ///
38 /// The scale factor used for coordinate conversion.
39 pub fn scale(&self) -> f32 {
40 self.scale
41 }
42
43 /// Sets a new scale factor.
44 ///
45 /// # Arguments
46 ///
47 /// * `scale` - New scale factor.
48 pub fn set_scale(&mut self, scale: f32) {
49 self.scale = scale;
50 }
51
52 /// Converts normalized coordinates to pixel coordinates.
53 ///
54 /// # Arguments
55 ///
56 /// * `normalized_coords` - Vector of normalized coordinates (0.0 to 1.0).
57 ///
58 /// # Returns
59 ///
60 /// * `Vec<f32>` - Vector of pixel coordinates.
61 ///
62 /// # Examples
63 ///
64 /// ```rust,no_run
65 /// use oar_ocr::processors::DocTrPostProcess;
66 ///
67 /// let postprocessor = DocTrPostProcess::new(100.0);
68 /// let normalized = vec![0.1, 0.2, 0.8, 0.9];
69 /// let pixels = postprocessor.denormalize_coordinates(&normalized);
70 /// assert_eq!(pixels, vec![10.0, 20.0, 80.0, 90.0]);
71 /// ```
72 pub fn denormalize_coordinates(&self, normalized_coords: &[f32]) -> Vec<f32> {
73 normalized_coords
74 .iter()
75 .map(|&coord| coord * self.scale)
76 .collect()
77 }
78
79 /// Converts pixel coordinates to normalized coordinates.
80 ///
81 /// # Arguments
82 ///
83 /// * `pixel_coords` - Vector of pixel coordinates.
84 ///
85 /// # Returns
86 ///
87 /// * `Vec<f32>` - Vector of normalized coordinates (0.0 to 1.0).
88 ///
89 /// # Examples
90 ///
91 /// ```rust,no_run
92 /// use oar_ocr::processors::DocTrPostProcess;
93 ///
94 /// let postprocessor = DocTrPostProcess::new(100.0);
95 /// let pixels = vec![10.0, 20.0, 80.0, 90.0];
96 /// let normalized = postprocessor.normalize_coordinates(&pixels);
97 /// assert_eq!(normalized, vec![0.1, 0.2, 0.8, 0.9]);
98 /// ```
99 pub fn normalize_coordinates(&self, pixel_coords: &[f32]) -> Vec<f32> {
100 if self.scale == 0.0 {
101 return vec![0.0; pixel_coords.len()];
102 }
103 pixel_coords
104 .iter()
105 .map(|&coord| coord / self.scale)
106 .collect()
107 }
108
109 /// Processes a bounding box from normalized to pixel coordinates.
110 ///
111 /// # Arguments
112 ///
113 /// * `bbox` - Bounding box as [x1, y1, x2, y2] in normalized coordinates.
114 ///
115 /// # Returns
116 ///
117 /// * `[f32; 4]` - Bounding box in pixel coordinates.
118 ///
119 /// # Examples
120 ///
121 /// ```rust,no_run
122 /// use oar_ocr::processors::DocTrPostProcess;
123 ///
124 /// let postprocessor = DocTrPostProcess::new(100.0);
125 /// let normalized_bbox = [0.1, 0.2, 0.8, 0.9];
126 /// let pixel_bbox = postprocessor.process_bbox(&normalized_bbox);
127 /// assert_eq!(pixel_bbox, [10.0, 20.0, 80.0, 90.0]);
128 /// ```
129 pub fn process_bbox(&self, bbox: &[f32; 4]) -> [f32; 4] {
130 [
131 bbox[0] * self.scale,
132 bbox[1] * self.scale,
133 bbox[2] * self.scale,
134 bbox[3] * self.scale,
135 ]
136 }
137
138 /// Processes multiple bounding boxes.
139 ///
140 /// # Arguments
141 ///
142 /// * `bboxes` - Vector of bounding boxes in normalized coordinates.
143 ///
144 /// # Returns
145 ///
146 /// * `Vec<[f32; 4]>` - Vector of bounding boxes in pixel coordinates.
147 pub fn process_bboxes(&self, bboxes: &[[f32; 4]]) -> Vec<[f32; 4]> {
148 bboxes.iter().map(|bbox| self.process_bbox(bbox)).collect()
149 }
150
151 /// Processes a polygon from normalized to pixel coordinates.
152 ///
153 /// # Arguments
154 ///
155 /// * `polygon` - Vector of points as [x, y] pairs in normalized coordinates.
156 ///
157 /// # Returns
158 ///
159 /// * `Vec<[f32; 2]>` - Vector of points in pixel coordinates.
160 ///
161 /// # Examples
162 ///
163 /// ```rust,no_run
164 /// use oar_ocr::processors::DocTrPostProcess;
165 ///
166 /// let postprocessor = DocTrPostProcess::new(100.0);
167 /// let normalized_polygon = vec![[0.1, 0.2], [0.8, 0.2], [0.8, 0.9], [0.1, 0.9]];
168 /// let pixel_polygon = postprocessor.process_polygon(&normalized_polygon);
169 /// assert_eq!(pixel_polygon[0], [10.0, 20.0]);
170 /// ```
171 pub fn process_polygon(&self, polygon: &[[f32; 2]]) -> Vec<[f32; 2]> {
172 polygon
173 .iter()
174 .map(|&[x, y]| [x * self.scale, y * self.scale])
175 .collect()
176 }
177
178 /// Clamps coordinates to valid ranges.
179 ///
180 /// # Arguments
181 ///
182 /// * `coords` - Vector of coordinates to clamp.
183 /// * `min_val` - Minimum allowed value.
184 /// * `max_val` - Maximum allowed value.
185 ///
186 /// # Returns
187 ///
188 /// * `Vec<f32>` - Vector of clamped coordinates.
189 pub fn clamp_coordinates(&self, coords: &[f32], min_val: f32, max_val: f32) -> Vec<f32> {
190 coords
191 .iter()
192 .map(|&coord| coord.clamp(min_val, max_val))
193 .collect()
194 }
195
196 /// Validates that coordinates are within expected ranges.
197 ///
198 /// # Arguments
199 ///
200 /// * `coords` - Vector of coordinates to validate.
201 /// * `min_val` - Minimum expected value.
202 /// * `max_val` - Maximum expected value.
203 ///
204 /// # Returns
205 ///
206 /// * `true` - If all coordinates are within range.
207 /// * `false` - If any coordinate is out of range.
208 pub fn validate_coordinates(&self, coords: &[f32], min_val: f32, max_val: f32) -> bool {
209 coords
210 .iter()
211 .all(|&coord| coord >= min_val && coord <= max_val)
212 }
213
214 /// Rounds coordinates to integer values.
215 ///
216 /// # Arguments
217 ///
218 /// * `coords` - Vector of coordinates to round.
219 ///
220 /// # Returns
221 ///
222 /// * `Vec<i32>` - Vector of rounded integer coordinates.
223 pub fn round_coordinates(&self, coords: &[f32]) -> Vec<i32> {
224 coords.iter().map(|&coord| coord.round() as i32).collect()
225 }
226
227 /// Processes transformation matrix values.
228 ///
229 /// # Arguments
230 ///
231 /// * `matrix` - 3x3 transformation matrix as a flat vector.
232 ///
233 /// # Returns
234 ///
235 /// * `Vec<f32>` - Processed transformation matrix.
236 pub fn process_transformation_matrix(&self, matrix: &[f32; 9]) -> [f32; 9] {
237 // Apply scale to translation components (indices 2 and 5)
238 let mut processed = *matrix;
239 processed[2] *= self.scale; // tx
240 processed[5] *= self.scale; // ty
241 processed
242 }
243
244 /// Applies inverse transformation to coordinates.
245 ///
246 /// # Arguments
247 ///
248 /// * `coords` - Vector of coordinates to transform.
249 /// * `matrix` - 3x3 transformation matrix.
250 ///
251 /// # Returns
252 ///
253 /// * `Result<Vec<[f32; 2]>, String>` - Transformed coordinates or error.
254 pub fn apply_inverse_transform(
255 &self,
256 coords: &[[f32; 2]],
257 matrix: &[f32; 9],
258 ) -> Result<Vec<[f32; 2]>, String> {
259 // Calculate determinant for matrix inversion
260 let det = matrix[0] * (matrix[4] * matrix[8] - matrix[5] * matrix[7])
261 - matrix[1] * (matrix[3] * matrix[8] - matrix[5] * matrix[6])
262 + matrix[2] * (matrix[3] * matrix[7] - matrix[4] * matrix[6]);
263
264 if det.abs() < f32::EPSILON {
265 return Err("Matrix is not invertible (determinant is zero)".to_string());
266 }
267
268 // For simplicity, this is a basic implementation
269 // In practice, you might want to use a proper matrix library
270 let mut transformed = Vec::new();
271 for &[x, y] in coords {
272 // Apply inverse transformation (simplified)
273 let new_x = (x - matrix[2]) / matrix[0];
274 let new_y = (y - matrix[5]) / matrix[4];
275 transformed.push([new_x, new_y]);
276 }
277
278 Ok(transformed)
279 }
280
281 /// Applies batch processing to tensor output to produce rectified images.
282 ///
283 /// # Arguments
284 ///
285 /// * `output` - 4D tensor output from the model [batch, channels, height, width].
286 ///
287 /// # Returns
288 ///
289 /// * `Result<Vec<image::RgbImage>, String>` - Vector of rectified images or error.
290 pub fn apply_batch(
291 &self,
292 output: &crate::core::Tensor4D,
293 ) -> Result<Vec<image::RgbImage>, String> {
294 use image::{Rgb, RgbImage};
295
296 let shape = output.shape();
297 if shape.len() != 4 {
298 return Err("Expected 4D tensor [batch, channels, height, width]".to_string());
299 }
300
301 let batch_size = shape[0];
302 let channels = shape[1];
303 let height = shape[2];
304 let width = shape[3];
305
306 if channels != 3 {
307 return Err("Expected 3 channels (RGB)".to_string());
308 }
309
310 let mut images = Vec::with_capacity(batch_size);
311
312 for b in 0..batch_size {
313 let mut img = RgbImage::new(width as u32, height as u32);
314
315 for y in 0..height {
316 for x in 0..width {
317 // Extract RGB values and denormalize
318 let r = (output[[b, 0, y, x]] * 255.0).clamp(0.0, 255.0) as u8;
319 let g = (output[[b, 1, y, x]] * 255.0).clamp(0.0, 255.0) as u8;
320 let b_val = (output[[b, 2, y, x]] * 255.0).clamp(0.0, 255.0) as u8;
321
322 img.put_pixel(x as u32, y as u32, Rgb([r, g, b_val]));
323 }
324 }
325
326 images.push(img);
327 }
328
329 Ok(images)
330 }
331}
332
333impl Default for DocTrPostProcess {
334 /// Creates a default DocTrPostProcess with scale factor 1.0.
335 fn default() -> Self {
336 Self::new(1.0)
337 }
338}
339
340impl FromStr for DocTrPostProcess {
341 type Err = std::num::ParseFloatError;
342
343 /// Creates a DocTrPostProcess from a string representation of the scale factor.
344 ///
345 /// # Arguments
346 ///
347 /// * `s` - String representation of the scale factor.
348 ///
349 /// # Returns
350 ///
351 /// * `Ok(DocTrPostProcess)` - If the string can be parsed as a float.
352 /// * `Err(ParseFloatError)` - If the string cannot be parsed.
353 fn from_str(s: &str) -> Result<Self, Self::Err> {
354 let scale = s.parse::<f32>()?;
355 Ok(Self::new(scale))
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_denormalize_coordinates() {
365 let postprocessor = DocTrPostProcess::new(100.0);
366 let normalized = vec![0.1, 0.2, 0.8, 0.9];
367 let pixels = postprocessor.denormalize_coordinates(&normalized);
368 assert_eq!(pixels, vec![10.0, 20.0, 80.0, 90.0]);
369 }
370
371 #[test]
372 fn test_normalize_coordinates() {
373 let postprocessor = DocTrPostProcess::new(100.0);
374 let pixels = vec![10.0, 20.0, 80.0, 90.0];
375 let normalized = postprocessor.normalize_coordinates(&pixels);
376 assert_eq!(normalized, vec![0.1, 0.2, 0.8, 0.9]);
377 }
378
379 #[test]
380 fn test_process_bbox() {
381 let postprocessor = DocTrPostProcess::new(100.0);
382 let normalized_bbox = [0.1, 0.2, 0.8, 0.9];
383 let pixel_bbox = postprocessor.process_bbox(&normalized_bbox);
384 assert_eq!(pixel_bbox, [10.0, 20.0, 80.0, 90.0]);
385 }
386
387 #[test]
388 fn test_process_polygon() {
389 let postprocessor = DocTrPostProcess::new(100.0);
390 let normalized_polygon = vec![[0.1, 0.2], [0.8, 0.2], [0.8, 0.9], [0.1, 0.9]];
391 let pixel_polygon = postprocessor.process_polygon(&normalized_polygon);
392 assert_eq!(pixel_polygon[0], [10.0, 20.0]);
393 assert_eq!(pixel_polygon[1], [80.0, 20.0]);
394 }
395
396 #[test]
397 fn test_clamp_coordinates() {
398 let postprocessor = DocTrPostProcess::new(1.0);
399 let coords = vec![-10.0, 50.0, 150.0];
400 let clamped = postprocessor.clamp_coordinates(&coords, 0.0, 100.0);
401 assert_eq!(clamped, vec![0.0, 50.0, 100.0]);
402 }
403
404 #[test]
405 fn test_validate_coordinates() {
406 let postprocessor = DocTrPostProcess::new(1.0);
407 let valid_coords = vec![10.0, 50.0, 90.0];
408 let invalid_coords = vec![10.0, 150.0, 90.0];
409
410 assert!(postprocessor.validate_coordinates(&valid_coords, 0.0, 100.0));
411 assert!(!postprocessor.validate_coordinates(&invalid_coords, 0.0, 100.0));
412 }
413
414 #[test]
415 fn test_round_coordinates() {
416 let postprocessor = DocTrPostProcess::new(1.0);
417 let coords = vec![10.3, 20.7, 30.5];
418 let rounded = postprocessor.round_coordinates(&coords);
419 assert_eq!(rounded, vec![10, 21, 31]);
420 }
421
422 #[test]
423 fn test_from_str() {
424 let postprocessor: DocTrPostProcess = "2.5".parse().unwrap();
425 assert_eq!(postprocessor.scale(), 2.5);
426
427 assert!("invalid".parse::<DocTrPostProcess>().is_err());
428 }
429
430 #[test]
431 fn test_zero_scale_normalize() {
432 let postprocessor = DocTrPostProcess::new(0.0);
433 let pixels = vec![10.0, 20.0];
434 let normalized = postprocessor.normalize_coordinates(&pixels);
435 assert_eq!(normalized, vec![0.0, 0.0]);
436 }
437}