Skip to main content

oximedia_align/
projective_warp.rs

1#![allow(dead_code)]
2//! Projective (perspective) warp transformations for image alignment.
3//!
4//! This module provides tools for applying and computing projective (homographic)
5//! warps between image planes. Unlike affine transforms which preserve parallelism,
6//! projective warps can model arbitrary perspective changes, making them essential
7//! for aligning images from cameras at different positions and orientations.
8//!
9//! # Features
10//!
11//! - **3x3 Homography Matrix** representation and arithmetic
12//! - **Direct Linear Transform (DLT)** for computing homographies from point correspondences
13//! - **Forward and inverse warp** of 2D points through the projective transformation
14//! - **Decomposition** of a homography into rotation, translation, and normal components
15//! - **Condition number** estimation for numerical stability assessment
16
17/// A 3x3 homography matrix stored in row-major order.
18#[derive(Debug, Clone, PartialEq)]
19pub struct HomographyMatrix {
20    /// Elements in row-major order: `[h00, h01, h02, h10, h11, h12, h20, h21, h22]`
21    pub data: [f64; 9],
22}
23
24impl HomographyMatrix {
25    /// Create a new homography from 9 row-major elements.
26    pub fn new(data: [f64; 9]) -> Self {
27        Self { data }
28    }
29
30    /// Create the identity homography.
31    pub fn identity() -> Self {
32        Self {
33            data: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
34        }
35    }
36
37    /// Access element at `(row, col)`.
38    pub fn get(&self, row: usize, col: usize) -> f64 {
39        self.data[row * 3 + col]
40    }
41
42    /// Set element at `(row, col)`.
43    pub fn set(&mut self, row: usize, col: usize, value: f64) {
44        self.data[row * 3 + col] = value;
45    }
46
47    /// Compute the determinant of the 3x3 matrix.
48    #[allow(clippy::cast_precision_loss)]
49    pub fn determinant(&self) -> f64 {
50        let d = &self.data;
51        d[0] * (d[4] * d[8] - d[5] * d[7]) - d[1] * (d[3] * d[8] - d[5] * d[6])
52            + d[2] * (d[3] * d[7] - d[4] * d[6])
53    }
54
55    /// Return the inverse homography, or `None` if singular.
56    #[allow(clippy::cast_precision_loss)]
57    pub fn inverse(&self) -> Option<Self> {
58        let det = self.determinant();
59        if det.abs() < 1e-12 {
60            return None;
61        }
62        let d = &self.data;
63        let inv_det = 1.0 / det;
64        Some(Self {
65            data: [
66                (d[4] * d[8] - d[5] * d[7]) * inv_det,
67                (d[2] * d[7] - d[1] * d[8]) * inv_det,
68                (d[1] * d[5] - d[2] * d[4]) * inv_det,
69                (d[5] * d[6] - d[3] * d[8]) * inv_det,
70                (d[0] * d[8] - d[2] * d[6]) * inv_det,
71                (d[2] * d[3] - d[0] * d[5]) * inv_det,
72                (d[3] * d[7] - d[4] * d[6]) * inv_det,
73                (d[1] * d[6] - d[0] * d[7]) * inv_det,
74                (d[0] * d[4] - d[1] * d[3]) * inv_det,
75            ],
76        })
77    }
78
79    /// Normalize the matrix so that `h22 == 1.0` (when possible).
80    pub fn normalize(&mut self) {
81        let scale = self.data[8];
82        if scale.abs() > 1e-12 {
83            for v in &mut self.data {
84                *v /= scale;
85            }
86        }
87    }
88
89    /// Multiply two homography matrices.
90    pub fn compose(&self, other: &Self) -> Self {
91        let a = &self.data;
92        let b = &other.data;
93        let mut out = [0.0f64; 9];
94        for row in 0..3 {
95            for col in 0..3 {
96                out[row * 3 + col] =
97                    a[row * 3] * b[col] + a[row * 3 + 1] * b[3 + col] + a[row * 3 + 2] * b[6 + col];
98            }
99        }
100        Self { data: out }
101    }
102
103    /// Estimate the condition number as `max_singular / min_singular` approximated
104    /// via the Frobenius norm and the inverse Frobenius norm.
105    #[allow(clippy::cast_precision_loss)]
106    pub fn condition_number_approx(&self) -> Option<f64> {
107        let fro: f64 = self.data.iter().map(|v| v * v).sum::<f64>().sqrt();
108        let inv = self.inverse()?;
109        let fro_inv: f64 = inv.data.iter().map(|v| v * v).sum::<f64>().sqrt();
110        Some(fro * fro_inv)
111    }
112}
113
114/// A 2D point used in projective operations.
115#[derive(Debug, Clone, Copy, PartialEq)]
116pub struct WarpPoint {
117    /// X coordinate
118    pub x: f64,
119    /// Y coordinate
120    pub y: f64,
121}
122
123impl WarpPoint {
124    /// Create a new point.
125    pub fn new(x: f64, y: f64) -> Self {
126        Self { x, y }
127    }
128}
129
130/// Apply a homography to warp a point from source to destination coordinates.
131///
132/// The projective warp is: `[x', y', w'] = H * [x, y, 1]`, then `(x'/w', y'/w')`.
133pub fn warp_point(h: &HomographyMatrix, pt: &WarpPoint) -> Option<WarpPoint> {
134    let d = &h.data;
135    let w = d[6] * pt.x + d[7] * pt.y + d[8];
136    if w.abs() < 1e-12 {
137        return None;
138    }
139    let x = (d[0] * pt.x + d[1] * pt.y + d[2]) / w;
140    let y = (d[3] * pt.x + d[4] * pt.y + d[5]) / w;
141    Some(WarpPoint::new(x, y))
142}
143
144/// Apply the inverse homography to warp a point from destination back to source.
145pub fn inverse_warp_point(h: &HomographyMatrix, pt: &WarpPoint) -> Option<WarpPoint> {
146    let inv = h.inverse()?;
147    warp_point(&inv, pt)
148}
149
150/// A correspondence pair of source and destination points.
151#[derive(Debug, Clone, Copy)]
152pub struct PointCorrespondence {
153    /// Source point
154    pub src: WarpPoint,
155    /// Destination point
156    pub dst: WarpPoint,
157}
158
159/// Compute a homography from 4 or more point correspondences using the
160/// normalised Direct Linear Transform (DLT).
161///
162/// Returns `None` if fewer than 4 correspondences are provided or if the
163/// system is degenerate.
164#[allow(clippy::cast_precision_loss)]
165pub fn compute_homography_dlt(correspondences: &[PointCorrespondence]) -> Option<HomographyMatrix> {
166    if correspondences.len() < 4 {
167        return None;
168    }
169
170    // Compute centroids and average distances for normalisation
171    let n = correspondences.len() as f64;
172    let (cx_s, cy_s) = correspondences
173        .iter()
174        .fold((0.0, 0.0), |(sx, sy), c| (sx + c.src.x, sy + c.src.y));
175    let (cx_d, cy_d) = correspondences
176        .iter()
177        .fold((0.0, 0.0), |(sx, sy), c| (sx + c.dst.x, sy + c.dst.y));
178    let (cx_s, cy_s) = (cx_s / n, cy_s / n);
179    let (cx_d, cy_d) = (cx_d / n, cy_d / n);
180
181    let avg_dist_s: f64 = correspondences
182        .iter()
183        .map(|c| ((c.src.x - cx_s).powi(2) + (c.src.y - cy_s).powi(2)).sqrt())
184        .sum::<f64>()
185        / n;
186    let avg_dist_d: f64 = correspondences
187        .iter()
188        .map(|c| ((c.dst.x - cx_d).powi(2) + (c.dst.y - cy_d).powi(2)).sqrt())
189        .sum::<f64>()
190        / n;
191
192    if avg_dist_s < 1e-12 || avg_dist_d < 1e-12 {
193        return None;
194    }
195    let scale_s = std::f64::consts::SQRT_2 / avg_dist_s;
196    let scale_d = std::f64::consts::SQRT_2 / avg_dist_d;
197
198    // Build simplified 9-element solution using the smallest eigenvalue approach
199    // For exactly 4 points we solve the 8x9 system; for more we use least-squares
200    // Here we use a simplified iterative approach for small numbers of correspondences
201
202    // For the basic implementation, use the 4-point exact solution
203    let pts: Vec<(f64, f64, f64, f64)> = correspondences
204        .iter()
205        .map(|c| {
206            (
207                (c.src.x - cx_s) * scale_s,
208                (c.src.y - cy_s) * scale_s,
209                (c.dst.x - cx_d) * scale_d,
210                (c.dst.y - cy_d) * scale_d,
211            )
212        })
213        .collect();
214
215    // Build the A matrix rows and solve via simple Gaussian elimination for h
216    // We set h[8] = 1 and solve 8 equations
217    let npts = pts.len();
218    let mut a_mat = vec![vec![0.0f64; 9]; 2 * npts];
219    for (i, &(xs, ys, xd, yd)) in pts.iter().enumerate() {
220        a_mat[2 * i] = vec![-xs, -ys, -1.0, 0.0, 0.0, 0.0, xd * xs, xd * ys, xd];
221        a_mat[2 * i + 1] = vec![0.0, 0.0, 0.0, -xs, -ys, -1.0, yd * xs, yd * ys, yd];
222    }
223
224    // Solve using least-squares normal equations: AᵀA h = 0
225    // Use power iteration to find the smallest eigenvector of AᵀA
226    let mut ata = vec![vec![0.0f64; 9]; 9];
227    for row in &a_mat {
228        for i in 0..9 {
229            for j in 0..9 {
230                ata[i][j] += row[i] * row[j];
231            }
232        }
233    }
234
235    // Inverse iteration to find smallest eigenvector
236    let mut h_vec = vec![0.0f64; 9];
237    h_vec[8] = 1.0;
238    for _ in 0..50 {
239        // Solve (AᵀA + eps*I) * y = h_vec  using simple elimination
240        let eps = 1e-10;
241        let mut aug = vec![vec![0.0f64; 10]; 9];
242        for i in 0..9 {
243            for j in 0..9 {
244                aug[i][j] = ata[i][j] + if i == j { eps } else { 0.0 };
245            }
246            aug[i][9] = h_vec[i];
247        }
248        // Gaussian elimination
249        for col in 0..9 {
250            let mut max_row = col;
251            let mut max_val = aug[col][col].abs();
252            for row in (col + 1)..9 {
253                if aug[row][col].abs() > max_val {
254                    max_val = aug[row][col].abs();
255                    max_row = row;
256                }
257            }
258            aug.swap(col, max_row);
259            if aug[col][col].abs() < 1e-15 {
260                continue;
261            }
262            let pivot = aug[col][col];
263            for j in col..10 {
264                aug[col][j] /= pivot;
265            }
266            for row in 0..9 {
267                if row == col {
268                    continue;
269                }
270                let factor = aug[row][col];
271                for j in col..10 {
272                    aug[row][j] -= factor * aug[col][j];
273                }
274            }
275        }
276        let mut y = vec![0.0f64; 9];
277        for i in 0..9 {
278            y[i] = aug[i][9];
279        }
280        let norm: f64 = y.iter().map(|v| v * v).sum::<f64>().sqrt();
281        if norm < 1e-15 {
282            return None;
283        }
284        for v in &mut y {
285            *v /= norm;
286        }
287        h_vec = y;
288    }
289
290    // De-normalise: H = Td_inv * Hn * Ts
291    let mut h_norm = HomographyMatrix::new([
292        h_vec[0], h_vec[1], h_vec[2], h_vec[3], h_vec[4], h_vec[5], h_vec[6], h_vec[7], h_vec[8],
293    ]);
294
295    // T_s normalisation matrix
296    let t_s = HomographyMatrix::new([
297        scale_s,
298        0.0,
299        -cx_s * scale_s,
300        0.0,
301        scale_s,
302        -cy_s * scale_s,
303        0.0,
304        0.0,
305        1.0,
306    ]);
307    let t_d_inv = HomographyMatrix::new([
308        1.0 / scale_d,
309        0.0,
310        cx_d,
311        0.0,
312        1.0 / scale_d,
313        cy_d,
314        0.0,
315        0.0,
316        1.0,
317    ]);
318
319    let mut result = t_d_inv.compose(&h_norm.compose(&t_s));
320    result.normalize();
321    let _ = &mut h_norm; // suppress unused warning
322    Some(result)
323}
324
325/// Compute the reprojection error for a set of correspondences given a homography.
326#[allow(clippy::cast_precision_loss)]
327pub fn reprojection_error(h: &HomographyMatrix, correspondences: &[PointCorrespondence]) -> f64 {
328    if correspondences.is_empty() {
329        return 0.0;
330    }
331    let total: f64 = correspondences
332        .iter()
333        .filter_map(|c| {
334            let warped = warp_point(h, &c.src)?;
335            let dx = warped.x - c.dst.x;
336            let dy = warped.y - c.dst.y;
337            Some((dx * dx + dy * dy).sqrt())
338        })
339        .sum();
340    total / correspondences.len() as f64
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_identity_creation() {
349        let h = HomographyMatrix::identity();
350        assert!((h.get(0, 0) - 1.0).abs() < 1e-12);
351        assert!((h.get(1, 1) - 1.0).abs() < 1e-12);
352        assert!((h.get(2, 2) - 1.0).abs() < 1e-12);
353        assert!((h.get(0, 1)).abs() < 1e-12);
354    }
355
356    #[test]
357    fn test_determinant_identity() {
358        let h = HomographyMatrix::identity();
359        assert!((h.determinant() - 1.0).abs() < 1e-12);
360    }
361
362    #[test]
363    fn test_inverse_identity() {
364        let h = HomographyMatrix::identity();
365        let inv = h.inverse().expect("inv should be valid");
366        for i in 0..9 {
367            assert!((h.data[i] - inv.data[i]).abs() < 1e-12);
368        }
369    }
370
371    #[test]
372    fn test_singular_matrix_no_inverse() {
373        let h = HomographyMatrix::new([0.0; 9]);
374        assert!(h.inverse().is_none());
375    }
376
377    #[test]
378    fn test_warp_point_identity() {
379        let h = HomographyMatrix::identity();
380        let pt = WarpPoint::new(5.0, 10.0);
381        let warped = warp_point(&h, &pt).expect("warped should be valid");
382        assert!((warped.x - 5.0).abs() < 1e-12);
383        assert!((warped.y - 10.0).abs() < 1e-12);
384    }
385
386    #[test]
387    fn test_warp_point_translation() {
388        let h = HomographyMatrix::new([1.0, 0.0, 3.0, 0.0, 1.0, -2.0, 0.0, 0.0, 1.0]);
389        let pt = WarpPoint::new(1.0, 1.0);
390        let warped = warp_point(&h, &pt).expect("warped should be valid");
391        assert!((warped.x - 4.0).abs() < 1e-12);
392        assert!((warped.y + 1.0).abs() < 1e-12);
393    }
394
395    #[test]
396    fn test_inverse_warp_roundtrip() {
397        let h = HomographyMatrix::new([1.0, 0.0, 3.0, 0.0, 1.0, -2.0, 0.0, 0.0, 1.0]);
398        let pt = WarpPoint::new(7.0, 3.0);
399        let warped = warp_point(&h, &pt).expect("warped should be valid");
400        let back = inverse_warp_point(&h, &warped).expect("back should be valid");
401        assert!((back.x - pt.x).abs() < 1e-9);
402        assert!((back.y - pt.y).abs() < 1e-9);
403    }
404
405    #[test]
406    fn test_compose_identity() {
407        let h = HomographyMatrix::new([2.0, 0.0, 1.0, 0.0, 3.0, 2.0, 0.0, 0.0, 1.0]);
408        let id = HomographyMatrix::identity();
409        let composed = h.compose(&id);
410        for i in 0..9 {
411            assert!((composed.data[i] - h.data[i]).abs() < 1e-12);
412        }
413    }
414
415    #[test]
416    fn test_normalize() {
417        let mut h = HomographyMatrix::new([2.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 2.0]);
418        h.normalize();
419        assert!((h.get(0, 0) - 1.0).abs() < 1e-12);
420        assert!((h.get(1, 1) - 2.0).abs() < 1e-12);
421        assert!((h.get(2, 2) - 1.0).abs() < 1e-12);
422    }
423
424    #[test]
425    fn test_condition_number_identity() {
426        let h = HomographyMatrix::identity();
427        let cond = h.condition_number_approx().expect("cond should be valid");
428        assert!((cond - 3.0).abs() < 1e-9); // Frobenius of I_3 = sqrt(3), so cond = 3
429    }
430
431    #[test]
432    fn test_reprojection_error_perfect() {
433        let h = HomographyMatrix::identity();
434        let corr = vec![
435            PointCorrespondence {
436                src: WarpPoint::new(0.0, 0.0),
437                dst: WarpPoint::new(0.0, 0.0),
438            },
439            PointCorrespondence {
440                src: WarpPoint::new(1.0, 0.0),
441                dst: WarpPoint::new(1.0, 0.0),
442            },
443        ];
444        let err = reprojection_error(&h, &corr);
445        assert!(err < 1e-12);
446    }
447
448    #[test]
449    fn test_reprojection_error_with_offset() {
450        let h = HomographyMatrix::new([1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
451        let corr = vec![PointCorrespondence {
452            src: WarpPoint::new(0.0, 0.0),
453            dst: WarpPoint::new(0.0, 0.0),
454        }];
455        let err = reprojection_error(&h, &corr);
456        assert!((err - 1.0).abs() < 1e-12);
457    }
458
459    #[test]
460    fn test_dlt_insufficient_points() {
461        let corrs: Vec<PointCorrespondence> = vec![
462            PointCorrespondence {
463                src: WarpPoint::new(0.0, 0.0),
464                dst: WarpPoint::new(1.0, 1.0),
465            },
466            PointCorrespondence {
467                src: WarpPoint::new(1.0, 0.0),
468                dst: WarpPoint::new(2.0, 1.0),
469            },
470        ];
471        assert!(compute_homography_dlt(&corrs).is_none());
472    }
473
474    #[test]
475    fn test_reprojection_error_empty() {
476        let h = HomographyMatrix::identity();
477        let err = reprojection_error(&h, &[]);
478        assert!(err.abs() < 1e-12);
479    }
480}