Skip to main content

oar_ocr_core/utils/
transform.rs

1//! Image transformation utilities for OCR processing.
2//!
3//! This module provides functions for perspective transformation and image warping,
4//! which are essential for correcting skewed text regions in images.
5
6use crate::core::OCRError;
7use crate::processors::Point;
8use image::{Rgb, RgbImage, imageops};
9use nalgebra::{Matrix3, Vector3};
10use rayon::prelude::*;
11use tracing::debug;
12
13/// Calculates the Euclidean distance between two points.
14///
15/// # Arguments
16///
17/// * `p1` - First point
18/// * `p2` - Second point
19///
20/// # Returns
21///
22/// The distance between the two points.
23fn distance(p1: &Point, p2: &Point) -> f32 {
24    (p1.x - p2.x).hypot(p1.y - p2.y)
25}
26
27/// Extracts a rotated and cropped image from a source image based on bounding box points.
28///
29/// This function takes a source image and a set of four points that define a quadrilateral
30/// region in the image. It crops the image to the bounding box of these points, then applies
31/// a perspective transformation to produce a rectified image of the region. If the resulting
32/// image has an aspect ratio that suggests it's rotated, it will be automatically rotated.
33///
34/// # Arguments
35///
36/// * `src_image` - The source image to crop from
37/// * `box_points` - Array of exactly 4 points defining the quadrilateral region
38///
39/// # Returns
40///
41/// A Result containing the cropped and transformed image, or an OCRError if the operation fails.
42///
43/// # Errors
44///
45/// Returns an OCRError if:
46/// * The box_points array doesn't contain exactly 4 points
47/// * The calculated crop region is invalid
48/// * The calculated crop dimensions are zero
49/// * The perspective transformation fails
50pub fn get_rotate_crop_image(
51    src_image: &RgbImage,
52    box_points: &[Point],
53) -> Result<RgbImage, OCRError> {
54    // Validate input
55    if box_points.len() != 4 {
56        return Err(OCRError::InvalidInput {
57            message: "Box must contain exactly 4 points".to_string(),
58        });
59    }
60
61    // Find bounding box of the points
62    let mut min_x = f32::INFINITY;
63    let mut max_x = f32::NEG_INFINITY;
64    let mut min_y = f32::INFINITY;
65    let mut max_y = f32::NEG_INFINITY;
66
67    for p in box_points {
68        min_x = min_x.min(p.x);
69        max_x = max_x.max(p.x);
70        min_y = min_y.min(p.y);
71        max_y = max_y.max(p.y);
72    }
73
74    // Calculate crop boundaries, clamping to image dimensions
75    let left = min_x.max(0.0) as u32;
76    let top = min_y.max(0.0) as u32;
77    let right = max_x.min(src_image.width() as f32) as u32;
78    let bottom = max_y.min(src_image.height() as f32) as u32;
79
80    // Validate crop region
81    if right <= left || bottom <= top {
82        return Err(OCRError::InvalidInput {
83            message: "Invalid crop region".to_string(),
84        });
85    }
86
87    // Perform initial crop
88    let crop_width = right - left;
89    let crop_height = bottom - top;
90    let img_crop = imageops::crop_imm(src_image, left, top, crop_width, crop_height).to_image();
91
92    // Adjust points relative to the cropped image
93    let points: Vec<Point> = box_points
94        .iter()
95        .map(|p| Point::new(p.x - left as f32, p.y - top as f32))
96        .collect();
97
98    // Reorder points to (top-left, top-right, bottom-right, bottom-left)
99    // to keep width/height estimation stable when point order varies.
100    let mut sorted = points.clone();
101    sorted.sort_by(|a, b| a.x.partial_cmp(&b.x).unwrap_or(std::cmp::Ordering::Equal));
102    let (mut index_a, mut index_d) = (0usize, 1usize);
103    if sorted[1].y < sorted[0].y {
104        index_a = 1;
105        index_d = 0;
106    }
107    let (mut index_b, mut index_c) = (2usize, 3usize);
108    if sorted[3].y < sorted[2].y {
109        index_b = 3;
110        index_c = 2;
111    }
112    let ordered = [
113        sorted[index_a],
114        sorted[index_b],
115        sorted[index_c],
116        sorted[index_d],
117    ];
118
119    // Calculate target image dimensions based on the max opposite-edge lengths
120    let width1 = distance(&ordered[0], &ordered[1]);
121    let width2 = distance(&ordered[2], &ordered[3]);
122    let img_crop_width = width1.max(width2).round() as u32;
123
124    let height1 = distance(&ordered[0], &ordered[3]);
125    let height2 = distance(&ordered[1], &ordered[2]);
126    let img_crop_height = height1.max(height2).round() as u32;
127
128    // Validate target dimensions
129    if img_crop_width == 0 || img_crop_height == 0 {
130        return Err(OCRError::InvalidInput {
131            message: "Invalid crop dimensions".to_string(),
132        });
133    }
134
135    // Define standard points for the target rectangle
136    let pts_std = [
137        Point::new(0.0, 0.0),
138        Point::new(img_crop_width as f32, 0.0),
139        Point::new(img_crop_width as f32, img_crop_height as f32),
140        Point::new(0.0, img_crop_height as f32),
141    ];
142
143    // Calculate perspective transformation matrix
144    let transform_matrix = get_perspective_transform(&ordered, &pts_std)?;
145
146    // Apply perspective transformation
147    let dst_img = warp_perspective(
148        &img_crop,
149        &transform_matrix,
150        img_crop_width,
151        img_crop_height,
152    )?;
153
154    // Automatically rotate if the aspect ratio suggests the text is vertical
155    if dst_img.height() as f32 >= dst_img.width() as f32 * 1.5 {
156        debug!(
157            "Rotating image due to aspect ratio: {}x{}",
158            dst_img.width(),
159            dst_img.height()
160        );
161
162        Ok(imageops::rotate270(&dst_img))
163    } else {
164        Ok(dst_img)
165    }
166}
167
168/// Calculates the perspective transformation matrix that maps source points to destination points.
169///
170/// This function solves the linear system of equations to find the perspective transformation
171/// matrix that maps four source points to four destination points.
172///
173/// # Arguments
174///
175/// * `src_points` - Array of exactly 4 source points
176/// * `dst_points` - Array of exactly 4 destination points
177///
178/// # Returns
179///
180/// A Result containing the 3x3 transformation matrix, or an OCRError if the operation fails.
181///
182/// # Errors
183///
184/// Returns an OCRError if:
185/// * Either array doesn't contain exactly 4 points
186/// * The linear system cannot be solved
187fn get_perspective_transform(
188    src_points: &[Point],
189    dst_points: &[Point],
190) -> Result<Matrix3<f32>, OCRError> {
191    // Validate input
192    if src_points.len() != 4 || dst_points.len() != 4 {
193        return Err(OCRError::InvalidInput {
194            message: "Need exactly 4 points for perspective transformation".to_string(),
195        });
196    }
197
198    // Set up the linear system of equations
199    let mut a = nalgebra::DMatrix::<f32>::zeros(8, 8);
200    let mut b = nalgebra::DVector::<f32>::zeros(8);
201
202    // Fill the matrix A and vector b with the equations for perspective transformation
203    for i in 0..4 {
204        let src = &src_points[i];
205        let dst = &dst_points[i];
206
207        // First equation for x coordinate transformation
208        a.set_row(
209            i * 2,
210            &nalgebra::RowDVector::from_row_slice(&[
211                src.x,
212                src.y,
213                1.0,
214                0.0,
215                0.0,
216                0.0,
217                -src.x * dst.x,
218                -src.y * dst.x,
219            ]),
220        );
221        b[i * 2] = dst.x;
222
223        // Second equation for y coordinate transformation
224        a.set_row(
225            i * 2 + 1,
226            &nalgebra::RowDVector::from_row_slice(&[
227                0.0,
228                0.0,
229                0.0,
230                src.x,
231                src.y,
232                1.0,
233                -src.x * dst.y,
234                -src.y * dst.y,
235            ]),
236        );
237        b[i * 2 + 1] = dst.y;
238    }
239
240    // Solve the linear system to find the transformation parameters
241    let decomp = a.lu();
242    let solution = decomp.solve(&b).ok_or_else(|| OCRError::InvalidInput {
243        message: "Cannot solve perspective transformation".to_string(),
244    })?;
245
246    // Construct the 3x3 transformation matrix
247    Ok(Matrix3::new(
248        solution[0],
249        solution[1],
250        solution[2],
251        solution[3],
252        solution[4],
253        solution[5],
254        solution[6],
255        solution[7],
256        1.0,
257    ))
258}
259
260/// Applies a perspective transformation to an image.
261///
262/// This function transforms an image using a given perspective transformation matrix.
263/// It uses inverse mapping with bilinear interpolation to produce the output image.
264///
265/// # Arguments
266///
267/// * `src_image` - The source image to transform
268/// * `transform_matrix` - The 3x3 perspective transformation matrix
269/// * `dst_width` - Width of the output image
270/// * `dst_height` - Height of the output image
271///
272/// # Returns
273///
274/// A Result containing the transformed image, or an OCRError if the operation fails.
275///
276/// # Errors
277///
278/// Returns an OCRError if:
279/// * The transformation matrix cannot be inverted
280fn warp_perspective(
281    src_image: &RgbImage,
282    transform_matrix: &Matrix3<f32>,
283    dst_width: u32,
284    dst_height: u32,
285) -> Result<RgbImage, OCRError> {
286    // Calculate the inverse transformation matrix for inverse mapping
287    let inv_matrix = transform_matrix
288        .try_inverse()
289        .ok_or_else(|| OCRError::InvalidInput {
290            message: "Cannot invert transformation matrix".to_string(),
291        })?;
292
293    // Create the destination image
294    let mut dst_image = RgbImage::new(dst_width, dst_height);
295    let buffer: &mut [u8] = dst_image.as_mut();
296
297    // Process rows with a small-image sequential fast path to avoid rayon overhead
298    // Use bicubic interpolation with border replication (matches cv2.warpPerspective
299    // with flags=INTER_CUBIC and borderMode=BORDER_REPLICATE)
300    if dst_height <= 1 {
301        let row_buffer = &mut buffer[0..(dst_width * 3) as usize];
302        let dst_y = 0u32;
303        for dst_x in 0..dst_width {
304            let dst_point = Vector3::new(dst_x as f32, dst_y as f32, 1.0);
305            let src_point = inv_matrix * dst_point;
306            let final_pixel = if src_point.z.abs() > f32::EPSILON {
307                let src_x = src_point.x / src_point.z;
308                let src_y = src_point.y / src_point.z;
309                // bicubic_interpolate handles out-of-bounds via border replication
310                bicubic_interpolate(src_image, src_x, src_y)
311            } else {
312                // Degenerate case: replicate top-left corner pixel
313                *src_image.get_pixel(0, 0)
314            };
315            let index = (dst_x * 3) as usize;
316            row_buffer[index..index + 3].copy_from_slice(&final_pixel.0);
317        }
318    } else {
319        buffer
320            .par_chunks_mut((dst_width * 3) as usize)
321            .enumerate()
322            .for_each(|(dst_y, row_buffer)| {
323                for dst_x in 0..dst_width {
324                    let dst_point = Vector3::new(dst_x as f32, dst_y as f32, 1.0);
325                    let src_point = inv_matrix * dst_point;
326                    let final_pixel = if src_point.z.abs() > f32::EPSILON {
327                        let src_x = src_point.x / src_point.z;
328                        let src_y = src_point.y / src_point.z;
329                        // bicubic_interpolate handles out-of-bounds via border replication
330                        bicubic_interpolate(src_image, src_x, src_y)
331                    } else {
332                        // Degenerate case: replicate top-left corner pixel
333                        *src_image.get_pixel(0, 0)
334                    };
335                    let index = (dst_x * 3) as usize;
336                    row_buffer[index..index + 3].copy_from_slice(&final_pixel.0);
337                }
338            });
339    }
340
341    Ok(dst_image)
342}
343
344/// Gets a pixel value with border replication for out-of-bounds coordinates.
345///
346/// This function implements OpenCV's BORDER_REPLICATE behavior:
347/// when coordinates are outside the image, the nearest edge pixel is used.
348///
349/// # Arguments
350///
351/// * `image` - The source image
352/// * `x` - X coordinate (can be negative or >= width)
353/// * `y` - Y coordinate (can be negative or >= height)
354///
355/// # Returns
356///
357/// The pixel value at the clamped coordinates.
358#[inline]
359fn get_pixel_replicate(image: &RgbImage, x: i32, y: i32) -> Rgb<u8> {
360    let clamped_x = x.clamp(0, image.width() as i32 - 1) as u32;
361    let clamped_y = y.clamp(0, image.height() as i32 - 1) as u32;
362    *image.get_pixel(clamped_x, clamped_y)
363}
364
365/// Cubic interpolation kernel function.
366///
367/// This implements the standard cubic convolution kernel used in bicubic interpolation.
368/// The kernel is defined as:
369/// - For |t| <= 1: (a+2)|t|³ - (a+3)|t|² + 1
370/// - For 1 < |t| < 2: a|t|³ - 5a|t|² + 8a|t| - 4a
371/// - Otherwise: 0
372///
373/// Where a = -0.5 (Catmull-Rom spline, same as OpenCV's default)
374#[inline]
375fn cubic_kernel(t: f32) -> f32 {
376    const A: f32 = -0.5; // Catmull-Rom spline coefficient (OpenCV default)
377    let t_abs = t.abs();
378
379    if t_abs <= 1.0 {
380        (A + 2.0) * t_abs * t_abs * t_abs - (A + 3.0) * t_abs * t_abs + 1.0
381    } else if t_abs < 2.0 {
382        A * t_abs * t_abs * t_abs - 5.0 * A * t_abs * t_abs + 8.0 * A * t_abs - 4.0 * A
383    } else {
384        0.0
385    }
386}
387
388/// Performs bicubic interpolation to get a pixel value at non-integer coordinates.
389///
390/// This function calculates the pixel value at a fractional (x, y) coordinate
391/// by interpolating using a 4x4 neighborhood of pixels with cubic convolution.
392/// Uses border replication for edge handling (same as OpenCV's BORDER_REPLICATE).
393///
394/// # Arguments
395///
396/// * `image` - The source image
397/// * `x` - X coordinate (can be fractional)
398/// * `y` - Y coordinate (can be fractional)
399///
400/// # Returns
401///
402/// The interpolated pixel value.
403fn bicubic_interpolate(image: &RgbImage, x: f32, y: f32) -> Rgb<u8> {
404    let x_int = x.floor() as i32;
405    let y_int = y.floor() as i32;
406    let dx = x - x_int as f32;
407    let dy = y - y_int as f32;
408
409    // Calculate x-direction weights
410    let wx = [
411        cubic_kernel(dx + 1.0),
412        cubic_kernel(dx),
413        cubic_kernel(dx - 1.0),
414        cubic_kernel(dx - 2.0),
415    ];
416
417    // Calculate y-direction weights
418    let wy = [
419        cubic_kernel(dy + 1.0),
420        cubic_kernel(dy),
421        cubic_kernel(dy - 1.0),
422        cubic_kernel(dy - 2.0),
423    ];
424
425    let mut result = [0.0f32; 3];
426
427    // Sample 4x4 neighborhood
428    for (j, &weight_y) in wy.iter().enumerate() {
429        let sample_y = y_int - 1 + j as i32;
430
431        for (i, &weight_x) in wx.iter().enumerate() {
432            let sample_x = x_int - 1 + i as i32;
433            let weight = weight_x * weight_y;
434
435            // Use border replication for out-of-bounds pixels
436            let pixel = get_pixel_replicate(image, sample_x, sample_y);
437
438            for (c, result_c) in result.iter_mut().enumerate().take(3) {
439                *result_c += weight * pixel.0[c] as f32;
440            }
441        }
442    }
443
444    // Clamp and convert to u8
445    Rgb([
446        result[0].round().clamp(0.0, 255.0) as u8,
447        result[1].round().clamp(0.0, 255.0) as u8,
448        result[2].round().clamp(0.0, 255.0) as u8,
449    ])
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    /// Performs bilinear interpolation to get a pixel value at non-integer coordinates.
457    ///
458    /// This function calculates the pixel value at a fractional (x, y) coordinate
459    /// by interpolating between the four nearest pixels.
460    /// Uses border replication for edge handling (same as OpenCV's BORDER_REPLICATE).
461    fn bilinear_interpolate(image: &RgbImage, x: f32, y: f32) -> Rgb<u8> {
462        let x_int = x.floor() as i32;
463        let y_int = y.floor() as i32;
464
465        // Calculate the fractional parts
466        let dx = x - x_int as f32;
467        let dy = y - y_int as f32;
468
469        // Get the four neighboring pixels with border replication
470        let p11 = get_pixel_replicate(image, x_int, y_int);
471        let p12 = get_pixel_replicate(image, x_int, y_int + 1);
472        let p21 = get_pixel_replicate(image, x_int + 1, y_int);
473        let p22 = get_pixel_replicate(image, x_int + 1, y_int + 1);
474
475        // Interpolate each color channel
476        let mut result = [0u8; 3];
477        for (i, result_channel) in result.iter_mut().enumerate() {
478            let val = (1.0 - dx) * (1.0 - dy) * p11.0[i] as f32
479                + dx * (1.0 - dy) * p21.0[i] as f32
480                + (1.0 - dx) * dy * p12.0[i] as f32
481                + dx * dy * p22.0[i] as f32;
482            *result_channel = val.round().clamp(0.0, 255.0) as u8;
483        }
484
485        Rgb(result)
486    }
487
488    #[test]
489    fn test_distance() {
490        let p1 = Point::new(0.0, 0.0);
491        let p2 = Point::new(3.0, 4.0);
492        let dist = distance(&p1, &p2);
493        assert_eq!(dist, 5.0);
494    }
495
496    #[test]
497    fn test_get_perspective_transform() -> Result<(), OCRError> {
498        // Define a simple square in source and destination
499        let src_points = [
500            Point::new(0.0, 0.0),
501            Point::new(1.0, 0.0),
502            Point::new(1.0, 1.0),
503            Point::new(0.0, 1.0),
504        ];
505
506        let dst_points = [
507            Point::new(0.0, 0.0),
508            Point::new(2.0, 0.0),
509            Point::new(2.0, 2.0),
510            Point::new(0.0, 2.0),
511        ];
512
513        let transform = get_perspective_transform(&src_points, &dst_points)?;
514
515        // Check that the transformation matrix is valid (all elements are finite)
516        assert!(transform.iter().all(|&x| x.is_finite()));
517        Ok(())
518    }
519
520    #[test]
521    fn test_get_perspective_transform_invalid_input() {
522        // Test with wrong number of points
523        let src_points = [Point::new(0.0, 0.0), Point::new(1.0, 0.0)];
524
525        let dst_points = [
526            Point::new(0.0, 0.0),
527            Point::new(2.0, 0.0),
528            Point::new(2.0, 2.0),
529            Point::new(0.0, 2.0),
530        ];
531
532        let result = get_perspective_transform(&src_points, &dst_points);
533        assert!(result.is_err());
534    }
535
536    #[test]
537    fn test_get_rotate_crop_image_invalid_points() {
538        // Create a simple 4x4 image
539        let image = RgbImage::new(4, 4);
540
541        // Test with wrong number of points
542        let points = vec![Point::new(0.0, 0.0), Point::new(1.0, 0.0)];
543
544        let result = get_rotate_crop_image(&image, &points);
545        assert!(result.is_err());
546    }
547
548    #[test]
549    fn test_get_rotate_crop_image_success() -> Result<(), OCRError> {
550        // Create a simple 4x4 image with distinct colors
551        let mut image = RgbImage::new(4, 4);
552        for y in 0..4 {
553            for x in 0..4 {
554                // Create a gradient
555                let r = (x * 64) as u8;
556                let g = (y * 64) as u8;
557                let b = ((x + y) * 32) as u8;
558                image.put_pixel(x, y, Rgb([r, g, b]));
559            }
560        }
561
562        // Define a simple square region
563        let points = vec![
564            Point::new(1.0, 1.0),
565            Point::new(3.0, 1.0),
566            Point::new(3.0, 3.0),
567            Point::new(1.0, 3.0),
568        ];
569
570        let cropped_image = get_rotate_crop_image(&image, &points)?;
571        // Check that we got an image back
572        assert!(cropped_image.width() > 0);
573        assert!(cropped_image.height() > 0);
574        Ok(())
575    }
576
577    #[test]
578    fn test_warp_perspective_invalid_matrix() {
579        // Create a simple 2x2 image
580        let image = RgbImage::new(2, 2);
581
582        // Create a singular matrix (non-invertible)
583        let matrix = Matrix3::new(1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0);
584
585        let result = warp_perspective(&image, &matrix, 2, 2);
586        assert!(result.is_err());
587    }
588
589    #[test]
590    fn test_bilinear_interpolate() {
591        // Create a simple 2x2 image with distinct colors
592        let mut image = RgbImage::new(2, 2);
593        image.put_pixel(0, 0, Rgb([255, 0, 0])); // Red
594        image.put_pixel(1, 0, Rgb([0, 255, 0])); // Green
595        image.put_pixel(0, 1, Rgb([0, 0, 255])); // Blue
596        image.put_pixel(1, 1, Rgb([255, 255, 0])); // Yellow
597
598        // Test interpolation at the center
599        let pixel = bilinear_interpolate(&image, 0.5, 0.5);
600        // Expected: average of all four colors
601        // Red + Green + Blue + Yellow = (255, 0, 0) + (0, 255, 0) + (0, 0, 255) + (255, 255, 0)
602        // = (510, 510, 255) / 4 = (127.5, 127.5, 63.75) ≈ (128, 128, 64)
603        assert_eq!(pixel.0[0], 128);
604        assert_eq!(pixel.0[1], 128);
605        assert_eq!(pixel.0[2], 64);
606    }
607}