Skip to main content

oximedia_gpu/
perspective_transform.rs

1//! GPU-accelerated perspective transform and lens distortion correction.
2//!
3//! This module provides two closely related geometric operations:
4//!
5//! 1. **Perspective (homography) transform**: maps a quadrilateral region of
6//!    the source image to a rectangle in the output (or vice versa).  The
7//!    transform is specified as a 3×3 homography matrix.
8//!
9//! 2. **Lens distortion correction**: removes barrel or pincushion distortion
10//!    using the Brown-Conrady radial/tangential distortion model.
11//!
12//! Both operations use backward-mapping with bilinear interpolation: for each
13//! output pixel the inverse transform is applied to find the corresponding
14//! source location, which is then sampled bilinearly from the input.
15//!
16//! All heavy work is parallelised over rows using rayon.
17//!
18//! # Example
19//!
20//! ```no_run
21//! use oximedia_gpu::perspective_transform::{
22//!     HomographyMatrix, PerspectiveTransform, LensDistortionParams, LensDistortionCorrector,
23//! };
24//!
25//! let src = vec![0u8; 640 * 480 * 4];
26//! let mut dst = vec![0u8; 640 * 480 * 4];
27//!
28//! // Identity transform
29//! let h = HomographyMatrix::identity();
30//! PerspectiveTransform::new(h)
31//!     .warp_rgba(&src, 640, 480, &mut dst, 640, 480)
32//!     .unwrap();
33//! ```
34
35use crate::{GpuError, Result};
36use rayon::prelude::*;
37
38// ─────────────────────────────────────────────────────────────────────────────
39// Homography matrix
40// ─────────────────────────────────────────────────────────────────────────────
41
42/// A 3×3 homography (perspective transform) matrix stored in row-major order.
43///
44/// The matrix maps homogeneous source coordinates `[x, y, 1]` to destination
45/// coordinates:
46/// ```text
47/// [x', y', w'] = H * [x, y, 1]
48/// dst = (x'/w', y'/w')
49/// ```
50#[derive(Debug, Clone, Copy, PartialEq)]
51pub struct HomographyMatrix {
52    /// Row-major 3×3 matrix coefficients.
53    pub m: [[f64; 3]; 3],
54}
55
56impl HomographyMatrix {
57    /// Create a homography from a row-major 3×3 array.
58    #[must_use]
59    pub fn new(m: [[f64; 3]; 3]) -> Self {
60        Self { m }
61    }
62
63    /// Identity homography (no transform).
64    #[must_use]
65    pub fn identity() -> Self {
66        Self {
67            m: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
68        }
69    }
70
71    /// Create a 2D translation matrix.
72    #[must_use]
73    pub fn translation(tx: f64, ty: f64) -> Self {
74        Self {
75            m: [[1.0, 0.0, tx], [0.0, 1.0, ty], [0.0, 0.0, 1.0]],
76        }
77    }
78
79    /// Create a uniform scale matrix.
80    #[must_use]
81    pub fn scale(sx: f64, sy: f64) -> Self {
82        Self {
83            m: [[sx, 0.0, 0.0], [0.0, sy, 0.0], [0.0, 0.0, 1.0]],
84        }
85    }
86
87    /// Create a counter-clockwise rotation matrix around the origin.
88    #[must_use]
89    pub fn rotation(angle_rad: f64) -> Self {
90        let c = angle_rad.cos();
91        let s = angle_rad.sin();
92        Self {
93            m: [[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]],
94        }
95    }
96
97    /// Apply the homography to a point `(x, y)`.
98    ///
99    /// Returns `None` if the projective `w` coordinate is near zero (degenerate).
100    #[must_use]
101    pub fn apply(&self, x: f64, y: f64) -> Option<(f64, f64)> {
102        let m = &self.m;
103        let xp = m[0][0] * x + m[0][1] * y + m[0][2];
104        let yp = m[1][0] * x + m[1][1] * y + m[1][2];
105        let wp = m[2][0] * x + m[2][1] * y + m[2][2];
106        if wp.abs() < 1e-10 {
107            return None;
108        }
109        Some((xp / wp, yp / wp))
110    }
111
112    /// Compute the inverse of this homography using Cramer's rule.
113    ///
114    /// Returns `None` if the matrix is singular.
115    #[must_use]
116    pub fn inverse(&self) -> Option<Self> {
117        let m = &self.m;
118
119        // Cofactors
120        let c00 = m[1][1] * m[2][2] - m[1][2] * m[2][1];
121        let c01 = -(m[1][0] * m[2][2] - m[1][2] * m[2][0]);
122        let c02 = m[1][0] * m[2][1] - m[1][1] * m[2][0];
123        let c10 = -(m[0][1] * m[2][2] - m[0][2] * m[2][1]);
124        let c11 = m[0][0] * m[2][2] - m[0][2] * m[2][0];
125        let c12 = -(m[0][0] * m[2][1] - m[0][1] * m[2][0]);
126        let c20 = m[0][1] * m[1][2] - m[0][2] * m[1][1];
127        let c21 = -(m[0][0] * m[1][2] - m[0][2] * m[1][0]);
128        let c22 = m[0][0] * m[1][1] - m[0][1] * m[1][0];
129
130        let det = m[0][0] * c00 + m[0][1] * c01 + m[0][2] * c02;
131        if det.abs() < 1e-12 {
132            return None;
133        }
134
135        let inv_det = 1.0 / det;
136
137        // Adjugate (transpose of cofactor matrix)
138        Some(Self {
139            m: [
140                [c00 * inv_det, c10 * inv_det, c20 * inv_det],
141                [c01 * inv_det, c11 * inv_det, c21 * inv_det],
142                [c02 * inv_det, c12 * inv_det, c22 * inv_det],
143            ],
144        })
145    }
146
147    /// Multiply two homographies.
148    #[must_use]
149    pub fn compose(&self, other: &Self) -> Self {
150        let a = &self.m;
151        let b = &other.m;
152        let mut result = [[0.0f64; 3]; 3];
153        for i in 0..3 {
154            for j in 0..3 {
155                for k in 0..3 {
156                    result[i][j] += a[i][k] * b[k][j];
157                }
158            }
159        }
160        Self { m: result }
161    }
162}
163
164// ─────────────────────────────────────────────────────────────────────────────
165// Perspective transform
166// ─────────────────────────────────────────────────────────────────────────────
167
168/// Perspective (homography) transform applied to RGBA images.
169#[derive(Debug, Clone)]
170pub struct PerspectiveTransform {
171    /// The forward homography (src → dst mapping).
172    pub homography: HomographyMatrix,
173}
174
175impl PerspectiveTransform {
176    /// Create a new perspective transform with the given homography.
177    #[must_use]
178    pub fn new(homography: HomographyMatrix) -> Self {
179        Self { homography }
180    }
181
182    /// Apply a perspective warp to an RGBA source image.
183    ///
184    /// Uses backward-mapping: for each destination pixel, the inverse
185    /// homography maps it to a source coordinate, which is sampled bilinearly.
186    ///
187    /// Out-of-bounds source coordinates are filled with black (`[0, 0, 0, 0]`).
188    ///
189    /// # Errors
190    ///
191    /// Returns an error if buffer sizes are inconsistent or if the homography
192    /// is not invertible.
193    pub fn warp_rgba(
194        &self,
195        src: &[u8],
196        src_w: u32,
197        src_h: u32,
198        dst: &mut [u8],
199        dst_w: u32,
200        dst_h: u32,
201    ) -> Result<()> {
202        let src_expected = (src_w as usize) * (src_h as usize) * 4;
203        let dst_expected = (dst_w as usize) * (dst_h as usize) * 4;
204
205        if src.len() != src_expected {
206            return Err(GpuError::InvalidBufferSize {
207                expected: src_expected,
208                actual: src.len(),
209            });
210        }
211        if dst.len() != dst_expected {
212            return Err(GpuError::InvalidBufferSize {
213                expected: dst_expected,
214                actual: dst.len(),
215            });
216        }
217
218        let inv = self
219            .homography
220            .inverse()
221            .ok_or_else(|| GpuError::Internal("Homography is not invertible".to_string()))?;
222
223        let sw = src_w as usize;
224        let sh = src_h as usize;
225        let dw = dst_w as usize;
226
227        // Process rows in parallel
228        dst.par_chunks_exact_mut(dw * 4)
229            .enumerate()
230            .for_each(|(dy, row)| {
231                for dx in 0..dw {
232                    let (sx, sy) = match inv.apply(dx as f64, dy as f64) {
233                        Some(p) => p,
234                        None => {
235                            let off = dx * 4;
236                            row[off..off + 4].copy_from_slice(&[0u8; 4]);
237                            continue;
238                        }
239                    };
240
241                    let pixel = bilinear_sample_rgba(src, sx, sy, sw, sh);
242                    let off = dx * 4;
243                    row[off..off + 4].copy_from_slice(&pixel);
244                }
245            });
246
247        Ok(())
248    }
249}
250
251// ─────────────────────────────────────────────────────────────────────────────
252// Lens distortion correction
253// ─────────────────────────────────────────────────────────────────────────────
254
255/// Brown-Conrady radial and tangential lens distortion parameters.
256///
257/// These describe the distortion of a physical lens. Negative radial
258/// coefficients model barrel distortion; positive values model pincushion.
259#[derive(Debug, Clone, Copy)]
260pub struct LensDistortionParams {
261    /// Radial distortion coefficient k1.
262    pub k1: f64,
263    /// Radial distortion coefficient k2.
264    pub k2: f64,
265    /// Radial distortion coefficient k3.
266    pub k3: f64,
267    /// Tangential distortion coefficient p1.
268    pub p1: f64,
269    /// Tangential distortion coefficient p2.
270    pub p2: f64,
271    /// Principal point x (normalised, typically 0.5).
272    pub cx: f64,
273    /// Principal point y (normalised, typically 0.5).
274    pub cy: f64,
275    /// Focal length x (normalised, typically ~1.0 for a 90° FOV).
276    pub fx: f64,
277    /// Focal length y (normalised).
278    pub fy: f64,
279}
280
281impl Default for LensDistortionParams {
282    fn default() -> Self {
283        Self {
284            k1: 0.0,
285            k2: 0.0,
286            k3: 0.0,
287            p1: 0.0,
288            p2: 0.0,
289            cx: 0.5,
290            cy: 0.5,
291            fx: 1.0,
292            fy: 1.0,
293        }
294    }
295}
296
297impl LensDistortionParams {
298    /// Create parameters modelling mild barrel distortion (typical wide-angle lens).
299    #[must_use]
300    pub fn barrel() -> Self {
301        Self {
302            k1: -0.3,
303            k2: 0.1,
304            k3: 0.0,
305            p1: 0.0,
306            p2: 0.0,
307            ..Default::default()
308        }
309    }
310
311    /// Create parameters modelling mild pincushion distortion (typical telephoto).
312    #[must_use]
313    pub fn pincushion() -> Self {
314        Self {
315            k1: 0.3,
316            k2: -0.05,
317            k3: 0.0,
318            p1: 0.0,
319            p2: 0.0,
320            ..Default::default()
321        }
322    }
323
324    /// Map a distorted normalised pixel coordinate to an undistorted one.
325    ///
326    /// The coordinate is normalised so that (0, 0) = top-left and
327    /// (1, 1) = bottom-right.
328    #[must_use]
329    pub fn undistort_point(&self, x_nd: f64, y_nd: f64) -> (f64, f64) {
330        // Convert to camera coordinates centred on principal point
331        let x = (x_nd - self.cx) / self.fx;
332        let y = (y_nd - self.cy) / self.fy;
333
334        let r2 = x * x + y * y;
335        let r4 = r2 * r2;
336        let r6 = r2 * r4;
337
338        let radial = 1.0 + self.k1 * r2 + self.k2 * r4 + self.k3 * r6;
339        let x_tan = 2.0 * self.p1 * x * y + self.p2 * (r2 + 2.0 * x * x);
340        let y_tan = self.p1 * (r2 + 2.0 * y * y) + 2.0 * self.p2 * x * y;
341
342        let xu = x * radial + x_tan;
343        let yu = y * radial + y_tan;
344
345        // Back to normalised image coordinates
346        (xu * self.fx + self.cx, yu * self.fy + self.cy)
347    }
348}
349
350/// Lens distortion corrector.
351///
352/// Removes lens distortion from RGBA frames using the Brown-Conrady model.
353#[derive(Debug, Clone)]
354pub struct LensDistortionCorrector {
355    params: LensDistortionParams,
356}
357
358impl LensDistortionCorrector {
359    /// Create a new corrector with the given lens parameters.
360    #[must_use]
361    pub fn new(params: LensDistortionParams) -> Self {
362        Self { params }
363    }
364
365    /// Remove lens distortion from an RGBA frame.
366    ///
367    /// For each output pixel the undistortion mapping is applied to find the
368    /// source location, which is sampled bilinearly.
369    ///
370    /// # Errors
371    ///
372    /// Returns an error if `src` and `dst` buffer sizes don't match the
373    /// declared dimensions.
374    pub fn undistort_rgba(&self, src: &[u8], src_w: u32, src_h: u32, dst: &mut [u8]) -> Result<()> {
375        let expected = (src_w as usize) * (src_h as usize) * 4;
376        if src.len() != expected {
377            return Err(GpuError::InvalidBufferSize {
378                expected,
379                actual: src.len(),
380            });
381        }
382        if dst.len() != expected {
383            return Err(GpuError::InvalidBufferSize {
384                expected,
385                actual: dst.len(),
386            });
387        }
388
389        let sw = src_w as usize;
390        let sh = src_h as usize;
391
392        dst.par_chunks_exact_mut(sw * 4)
393            .enumerate()
394            .for_each(|(dy, row)| {
395                for dx in 0..sw {
396                    // Normalised coordinates [0, 1)
397                    let x_nd = (dx as f64 + 0.5) / sw as f64;
398                    let y_nd = (dy as f64 + 0.5) / sh as f64;
399
400                    let (sx, sy) = self.params.undistort_point(x_nd, y_nd);
401
402                    // Convert back to pixel coordinates
403                    let sx_px = sx * sw as f64 - 0.5;
404                    let sy_px = sy * sh as f64 - 0.5;
405
406                    let pixel = bilinear_sample_rgba(src, sx_px, sy_px, sw, sh);
407                    let off = dx * 4;
408                    row[off..off + 4].copy_from_slice(&pixel);
409                }
410            });
411
412        Ok(())
413    }
414}
415
416// ─────────────────────────────────────────────────────────────────────────────
417// Shared bilinear sampling helper
418// ─────────────────────────────────────────────────────────────────────────────
419
420/// Bilinear sample from an RGBA image at fractional pixel coordinate `(x, y)`.
421///
422/// Out-of-bounds coordinates produce `[0, 0, 0, 0]`.
423fn bilinear_sample_rgba(src: &[u8], x: f64, y: f64, w: usize, h: usize) -> [u8; 4] {
424    let x0 = x.floor() as isize;
425    let y0 = y.floor() as isize;
426    let tx = (x - x0 as f64) as f32;
427    let ty = (y - y0 as f64) as f32;
428
429    let get_pixel = |xi: isize, yi: isize| -> [f32; 4] {
430        if xi < 0 || yi < 0 || xi >= w as isize || yi >= h as isize {
431            return [0.0; 4];
432        }
433        let off = (yi as usize * w + xi as usize) * 4;
434        [
435            src[off] as f32,
436            src[off + 1] as f32,
437            src[off + 2] as f32,
438            src[off + 3] as f32,
439        ]
440    };
441
442    let c00 = get_pixel(x0, y0);
443    let c10 = get_pixel(x0 + 1, y0);
444    let c01 = get_pixel(x0, y0 + 1);
445    let c11 = get_pixel(x0 + 1, y0 + 1);
446
447    let mut result = [0u8; 4];
448    for i in 0..4 {
449        let v = c00[i] * (1.0 - tx) * (1.0 - ty)
450            + c10[i] * tx * (1.0 - ty)
451            + c01[i] * (1.0 - tx) * ty
452            + c11[i] * tx * ty;
453        result[i] = v.clamp(0.0, 255.0) as u8;
454    }
455    result
456}
457
458// ─────────────────────────────────────────────────────────────────────────────
459// Tests
460// ─────────────────────────────────────────────────────────────────────────────
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    fn rgba_frame(w: usize, h: usize) -> Vec<u8> {
467        (0..w * h * 4).map(|i| (i % 256) as u8).collect()
468    }
469
470    // ── HomographyMatrix ──────────────────────────────────────────────────────
471
472    #[test]
473    fn test_identity_apply() {
474        let h = HomographyMatrix::identity();
475        let (x, y) = h.apply(3.0, 7.0).expect("should not be degenerate");
476        assert!((x - 3.0).abs() < 1e-10);
477        assert!((y - 7.0).abs() < 1e-10);
478    }
479
480    #[test]
481    fn test_inverse_of_identity() {
482        let h = HomographyMatrix::identity();
483        let inv = h.inverse().expect("identity is invertible");
484        let (x, y) = inv.apply(5.0, 9.0).expect("ok");
485        assert!((x - 5.0).abs() < 1e-9);
486        assert!((y - 9.0).abs() < 1e-9);
487    }
488
489    #[test]
490    fn test_translation_roundtrip() {
491        let h = HomographyMatrix::translation(10.0, -5.0);
492        let inv = h.inverse().expect("translation is invertible");
493        let (x, y) = h.apply(3.0, 3.0).expect("ok");
494        let (x2, y2) = inv.apply(x, y).expect("ok");
495        assert!((x2 - 3.0).abs() < 1e-9);
496        assert!((y2 - 3.0).abs() < 1e-9);
497    }
498
499    #[test]
500    fn test_scale_apply() {
501        let h = HomographyMatrix::scale(2.0, 3.0);
502        let (x, y) = h.apply(4.0, 5.0).expect("ok");
503        assert!((x - 8.0).abs() < 1e-9);
504        assert!((y - 15.0).abs() < 1e-9);
505    }
506
507    #[test]
508    fn test_rotation_preserves_magnitude() {
509        use std::f64::consts::PI;
510        let h = HomographyMatrix::rotation(PI / 4.0);
511        let (x, y) = h.apply(1.0, 0.0).expect("ok");
512        let mag = (x * x + y * y).sqrt();
513        assert!((mag - 1.0).abs() < 1e-10);
514    }
515
516    #[test]
517    fn test_compose() {
518        let t1 = HomographyMatrix::translation(1.0, 0.0);
519        let t2 = HomographyMatrix::translation(2.0, 0.0);
520        let composed = t1.compose(&t2);
521        let (x, _) = composed.apply(0.0, 0.0).expect("ok");
522        assert!((x - 3.0).abs() < 1e-9);
523    }
524
525    #[test]
526    fn test_singular_inverse_returns_none() {
527        let h = HomographyMatrix::new([[0.0; 3]; 3]);
528        assert!(h.inverse().is_none());
529    }
530
531    // ── PerspectiveTransform ──────────────────────────────────────────────────
532
533    #[test]
534    fn test_identity_warp_preserves_size() {
535        let w = 8u32;
536        let h = 8u32;
537        let src = rgba_frame(w as usize, h as usize);
538        let mut dst = vec![0u8; src.len()];
539        let pt = PerspectiveTransform::new(HomographyMatrix::identity());
540        pt.warp_rgba(&src, w, h, &mut dst, w, h)
541            .expect("warp should succeed");
542        assert_eq!(dst.len(), src.len());
543    }
544
545    #[test]
546    fn test_warp_wrong_size_rejected() {
547        let src = vec![0u8; 8 * 8 * 4];
548        let mut dst = vec![0u8; 4 * 4 * 4]; // wrong size for declared dimensions
549        let pt = PerspectiveTransform::new(HomographyMatrix::identity());
550        // dst declared as 8x8 but actually 4x4 → error
551        let res = pt.warp_rgba(&src, 8, 8, &mut dst, 8, 8);
552        assert!(res.is_err());
553    }
554
555    #[test]
556    fn test_singular_homography_rejected() {
557        let src = vec![0u8; 4 * 4 * 4];
558        let mut dst = vec![0u8; 4 * 4 * 4];
559        let pt = PerspectiveTransform::new(HomographyMatrix::new([[0.0; 3]; 3]));
560        let res = pt.warp_rgba(&src, 4, 4, &mut dst, 4, 4);
561        assert!(res.is_err());
562    }
563
564    // ── LensDistortionCorrector ───────────────────────────────────────────────
565
566    #[test]
567    fn test_no_distortion_identity() {
568        let params = LensDistortionParams::default();
569        let (xu, yu) = params.undistort_point(0.5, 0.5);
570        assert!((xu - 0.5).abs() < 1e-10);
571        assert!((yu - 0.5).abs() < 1e-10);
572    }
573
574    #[test]
575    fn test_undistort_rgba_correct_size() {
576        let w = 8u32;
577        let h = 8u32;
578        let src = rgba_frame(w as usize, h as usize);
579        let mut dst = vec![0u8; src.len()];
580        let corrector = LensDistortionCorrector::new(LensDistortionParams::default());
581        corrector
582            .undistort_rgba(&src, w, h, &mut dst)
583            .expect("should succeed");
584        assert_eq!(dst.len(), src.len());
585    }
586
587    #[test]
588    fn test_undistort_rgba_wrong_size_rejected() {
589        let src = vec![0u8; 8 * 8 * 4];
590        let mut dst = vec![0u8; 4]; // too small
591        let corrector = LensDistortionCorrector::new(LensDistortionParams::default());
592        let res = corrector.undistort_rgba(&src, 8, 8, &mut dst);
593        assert!(res.is_err());
594    }
595
596    #[test]
597    fn test_barrel_pincushion_differ() {
598        let barrel = LensDistortionParams::barrel();
599        let pin = LensDistortionParams::pincushion();
600        assert!(barrel.k1 < 0.0);
601        assert!(pin.k1 > 0.0);
602    }
603}