Skip to main content

oximedia_align/
spatial.rs

1//! Spatial registration and geometric alignment.
2//!
3//! This module provides tools for aligning images geometrically:
4//!
5//! - Homography estimation
6//! - Perspective transformation
7//! - RANSAC for robust fitting
8//! - Affine transformation
9
10use crate::features::MatchPair;
11use crate::{AlignError, AlignResult, Point2D};
12use nalgebra::{Matrix3, Vector3};
13
14/// 3x3 homography matrix for perspective transformation
15#[derive(Debug, Clone)]
16pub struct Homography {
17    /// The 3x3 transformation matrix
18    pub matrix: Matrix3<f64>,
19}
20
21impl Homography {
22    /// Create a new homography from a matrix
23    #[must_use]
24    pub fn new(matrix: Matrix3<f64>) -> Self {
25        Self { matrix }
26    }
27
28    /// Create identity homography
29    #[must_use]
30    pub fn identity() -> Self {
31        Self {
32            matrix: Matrix3::identity(),
33        }
34    }
35
36    /// Transform a point
37    #[must_use]
38    pub fn transform(&self, point: &Point2D) -> Point2D {
39        let p = Vector3::new(point.x, point.y, 1.0);
40        let transformed = self.matrix * p;
41
42        if transformed[2].abs() > f64::EPSILON {
43            Point2D::new(
44                transformed[0] / transformed[2],
45                transformed[1] / transformed[2],
46            )
47        } else {
48            *point
49        }
50    }
51
52    /// Compute inverse homography
53    ///
54    /// # Errors
55    /// Returns error if matrix is singular
56    pub fn inverse(&self) -> AlignResult<Self> {
57        self.matrix
58            .try_inverse()
59            .map(Self::new)
60            .ok_or_else(|| AlignError::NumericalError("Singular matrix".to_string()))
61    }
62
63    /// Compose two homographies
64    #[must_use]
65    pub fn compose(&self, other: &Self) -> Self {
66        Self::new(self.matrix * other.matrix)
67    }
68}
69
70/// Configuration for RANSAC
71#[derive(Debug, Clone)]
72pub struct RansacConfig {
73    /// Distance threshold for inliers
74    pub threshold: f64,
75    /// Maximum number of iterations
76    pub max_iterations: usize,
77    /// Minimum number of inliers required
78    pub min_inliers: usize,
79}
80
81impl Default for RansacConfig {
82    fn default() -> Self {
83        Self {
84            threshold: 3.0,
85            max_iterations: 1000,
86            min_inliers: 8,
87        }
88    }
89}
90
91/// Homography estimator using RANSAC
92pub struct HomographyEstimator {
93    /// RANSAC configuration
94    pub config: RansacConfig,
95}
96
97impl HomographyEstimator {
98    /// Create a new homography estimator
99    #[must_use]
100    pub fn new(config: RansacConfig) -> Self {
101        Self { config }
102    }
103
104    /// Estimate homography from matched points using RANSAC
105    ///
106    /// # Errors
107    /// Returns error if insufficient matches or estimation fails
108    #[allow(clippy::too_many_lines)]
109    pub fn estimate(&self, matches: &[MatchPair]) -> AlignResult<(Homography, Vec<bool>)> {
110        if matches.len() < 4 {
111            return Err(AlignError::InsufficientData(
112                "Need at least 4 matches for homography".to_string(),
113            ));
114        }
115
116        let mut best_inliers = Vec::new();
117        let mut best_homography = None;
118        let mut best_inlier_count = 0;
119
120        // RANSAC iterations
121        for _ in 0..self.config.max_iterations {
122            // Sample 4 random matches
123            let sample = self.sample_matches(matches, 4);
124
125            // Estimate homography from 4 points
126            if let Ok(h) = self.estimate_from_4_points(&sample) {
127                // Count inliers
128                let inliers = self.find_inliers(&h, matches);
129                let inlier_count = inliers.iter().filter(|&&x| x).count();
130
131                if inlier_count > best_inlier_count {
132                    best_inlier_count = inlier_count;
133                    best_inliers = inliers;
134                    best_homography = Some(h);
135
136                    // Early termination if we have enough inliers
137                    if inlier_count >= self.config.min_inliers.max(matches.len() * 80 / 100) {
138                        break;
139                    }
140                }
141            }
142        }
143
144        if best_inlier_count < self.config.min_inliers {
145            return Err(AlignError::NoSolution(format!(
146                "Insufficient inliers: {} < {}",
147                best_inlier_count, self.config.min_inliers
148            )));
149        }
150
151        let homography = best_homography
152            .ok_or_else(|| AlignError::NoSolution("No valid homography found".to_string()))?;
153
154        // Refine with all inliers
155        let inlier_matches: Vec<&MatchPair> = matches
156            .iter()
157            .zip(&best_inliers)
158            .filter(|(_, &is_inlier)| is_inlier)
159            .map(|(m, _)| m)
160            .collect();
161
162        let refined = self.refine_homography(&homography, &inlier_matches)?;
163
164        Ok((refined, best_inliers))
165    }
166
167    /// Sample N random matches
168    fn sample_matches(&self, matches: &[MatchPair], n: usize) -> Vec<MatchPair> {
169        // Simple deterministic sampling (in production, use proper PRNG)
170        let step = matches.len() / n;
171        (0..n)
172            .map(|i| matches[(i * step) % matches.len()].clone())
173            .collect()
174    }
175
176    /// Estimate homography from 4 or more point correspondences using DLT.
177    ///
178    /// Uses the normal equations `A^T A` (a 9×9 symmetric positive semi-definite
179    /// matrix) and extracts the eigenvector corresponding to the smallest
180    /// eigenvalue via SVD.  This is numerically equivalent to the direct SVD of A
181    /// but avoids dimension edge-cases with nalgebra's thin-SVD for m < n matrices.
182    #[allow(clippy::similar_names)]
183    fn estimate_from_4_points(&self, matches: &[MatchPair]) -> AlignResult<Homography> {
184        if matches.len() < 4 {
185            return Err(AlignError::InvalidConfig(
186                "Need at least 4 points for DLT".to_string(),
187            ));
188        }
189
190        // Accumulate A^T A (9×9) directly from the DLT rows.
191        // Each correspondence contributes two rows r1, r2 to A;
192        // we add r1*r1^T + r2*r2^T to the accumulator.
193        let mut ata = nalgebra::Matrix::<f64, nalgebra::U9, nalgebra::U9, _>::zeros();
194
195        for m in matches {
196            let x1 = m.point1.x;
197            let y1 = m.point1.y;
198            let x2 = m.point2.x;
199            let y2 = m.point2.y;
200
201            // Row 1: [-x1, -y1, -1,  0,   0,  0, x2*x1, x2*y1, x2]
202            let r1 = nalgebra::Vector::<f64, nalgebra::U9, _>::from_row_slice(&[
203                -x1,
204                -y1,
205                -1.0,
206                0.0,
207                0.0,
208                0.0,
209                x2 * x1,
210                x2 * y1,
211                x2,
212            ]);
213            // Row 2: [0, 0, 0, -x1, -y1, -1, y2*x1, y2*y1, y2]
214            let r2 = nalgebra::Vector::<f64, nalgebra::U9, _>::from_row_slice(&[
215                0.0,
216                0.0,
217                0.0,
218                -x1,
219                -y1,
220                -1.0,
221                y2 * x1,
222                y2 * y1,
223                y2,
224            ]);
225
226            ata += r1 * r1.transpose() + r2 * r2.transpose();
227        }
228
229        // SVD of the 9×9 symmetric matrix: V^T has shape 9×9 (always square).
230        let svd = ata.svd(false, true);
231        let v = svd
232            .v_t
233            .ok_or_else(|| AlignError::NumericalError("SVD failed to compute V".to_string()))?;
234
235        // The solution is the last row of V^T (smallest eigenvalue of A^T A =
236        // smallest squared singular value of A = null-space direction).
237        let h_vec = v.row(8); // safe: v is always 9×9
238
239        if h_vec[8].abs() < 1e-10 {
240            return Err(AlignError::NumericalError(
241                "Degenerate solution (h[8] ≈ 0)".to_string(),
242            ));
243        }
244
245        // Normalize so that h[8] = 1, then reshape into 3×3.
246        let scale = h_vec[8];
247        let matrix = Matrix3::new(
248            h_vec[0] / scale,
249            h_vec[1] / scale,
250            h_vec[2] / scale,
251            h_vec[3] / scale,
252            h_vec[4] / scale,
253            h_vec[5] / scale,
254            h_vec[6] / scale,
255            h_vec[7] / scale,
256            1.0,
257        );
258
259        Ok(Homography::new(matrix))
260    }
261
262    /// Find inliers based on reprojection error
263    fn find_inliers(&self, homography: &Homography, matches: &[MatchPair]) -> Vec<bool> {
264        matches
265            .iter()
266            .map(|m| {
267                let transformed = homography.transform(&m.point1);
268                let error = transformed.distance(&m.point2);
269                error < self.config.threshold
270            })
271            .collect()
272    }
273
274    /// Refine homography using all inliers (overdetermined DLT via normal equations).
275    fn refine_homography(
276        &self,
277        _initial: &Homography,
278        inliers: &[&MatchPair],
279    ) -> AlignResult<Homography> {
280        if inliers.len() < 4 {
281            return Err(AlignError::InsufficientData(
282                "Need at least 4 inliers for refinement".to_string(),
283            ));
284        }
285
286        // Re-use the same A^T A accumulation as in estimate_from_4_points so
287        // we always work with a 9×9 symmetric matrix whose SVD is well-defined.
288        let matches_owned: Vec<MatchPair> = inliers.iter().map(|m| (*m).clone()).collect();
289        self.estimate_from_4_points(&matches_owned)
290    }
291}
292
293/// Affine transformation (6 DOF)
294#[derive(Debug, Clone)]
295pub struct AffineTransform {
296    /// 2x3 affine matrix [a b tx; c d ty]
297    pub matrix: nalgebra::Matrix2x3<f64>,
298}
299
300impl AffineTransform {
301    /// Create a new affine transform
302    #[must_use]
303    pub fn new(matrix: nalgebra::Matrix2x3<f64>) -> Self {
304        Self { matrix }
305    }
306
307    /// Create identity transform
308    #[must_use]
309    pub fn identity() -> Self {
310        Self {
311            matrix: nalgebra::Matrix2x3::new(1.0, 0.0, 0.0, 0.0, 1.0, 0.0),
312        }
313    }
314
315    /// Create translation
316    #[must_use]
317    pub fn translation(tx: f64, ty: f64) -> Self {
318        Self {
319            matrix: nalgebra::Matrix2x3::new(1.0, 0.0, tx, 0.0, 1.0, ty),
320        }
321    }
322
323    /// Create rotation
324    #[must_use]
325    pub fn rotation(angle: f64) -> Self {
326        let c = angle.cos();
327        let s = angle.sin();
328        Self {
329            matrix: nalgebra::Matrix2x3::new(c, -s, 0.0, s, c, 0.0),
330        }
331    }
332
333    /// Create scale
334    #[must_use]
335    pub fn scale(sx: f64, sy: f64) -> Self {
336        Self {
337            matrix: nalgebra::Matrix2x3::new(sx, 0.0, 0.0, 0.0, sy, 0.0),
338        }
339    }
340
341    /// Transform a point
342    #[must_use]
343    pub fn transform(&self, point: &Point2D) -> Point2D {
344        let p = nalgebra::Vector3::new(point.x, point.y, 1.0);
345        let transformed = self.matrix * p;
346        Point2D::new(transformed[0], transformed[1])
347    }
348}
349
350/// Affine estimator
351pub struct AffineEstimator {
352    /// RANSAC configuration
353    pub config: RansacConfig,
354}
355
356impl AffineEstimator {
357    /// Create a new affine estimator
358    #[must_use]
359    pub fn new(config: RansacConfig) -> Self {
360        Self { config }
361    }
362
363    /// Estimate affine transform from matched points
364    ///
365    /// # Errors
366    /// Returns error if insufficient matches or estimation fails
367    pub fn estimate(&self, matches: &[MatchPair]) -> AlignResult<(AffineTransform, Vec<bool>)> {
368        if matches.len() < 3 {
369            return Err(AlignError::InsufficientData(
370                "Need at least 3 matches for affine".to_string(),
371            ));
372        }
373
374        let mut best_inliers = Vec::new();
375        let mut best_transform = None;
376        let mut best_inlier_count = 0;
377
378        // RANSAC iterations
379        for _ in 0..self.config.max_iterations {
380            // Sample 3 random matches
381            let sample = self.sample_matches(matches, 3);
382
383            // Estimate affine from 3 points
384            if let Ok(t) = self.estimate_from_3_points(&sample) {
385                // Count inliers
386                let inliers = self.find_inliers(&t, matches);
387                let inlier_count = inliers.iter().filter(|&&x| x).count();
388
389                if inlier_count > best_inlier_count {
390                    best_inlier_count = inlier_count;
391                    best_inliers = inliers;
392                    best_transform = Some(t);
393
394                    if inlier_count >= self.config.min_inliers {
395                        break;
396                    }
397                }
398            }
399        }
400
401        if best_inlier_count < self.config.min_inliers {
402            return Err(AlignError::NoSolution("Insufficient inliers".to_string()));
403        }
404
405        let transform = best_transform
406            .ok_or_else(|| AlignError::NoSolution("No valid transform found".to_string()))?;
407
408        Ok((transform, best_inliers))
409    }
410
411    /// Sample matches
412    fn sample_matches(&self, matches: &[MatchPair], n: usize) -> Vec<MatchPair> {
413        let step = matches.len() / n;
414        (0..n)
415            .map(|i| matches[(i * step) % matches.len()].clone())
416            .collect()
417    }
418
419    /// Estimate affine from 3 points
420    fn estimate_from_3_points(&self, matches: &[MatchPair]) -> AlignResult<AffineTransform> {
421        if matches.len() != 3 {
422            return Err(AlignError::InvalidConfig(
423                "Need exactly 3 points".to_string(),
424            ));
425        }
426
427        // Build linear system: [x1 y1 1 0  0  0] [a]   [x1']
428        //                       [0  0  0 x1 y1 1] [b]   [y1']
429        //                       [x2 y2 1 0  0  0] [tx] = [x2']
430        //                       [0  0  0 x2 y2 1] [c]   [y2']
431        //                       [x3 y3 1 0  0  0] [d]   [x3']
432        //                       [0  0  0 x3 y3 1] [ty]  [y3']
433
434        let mut a = nalgebra::DMatrix::zeros(6, 6);
435        let mut b_vec = nalgebra::DVector::zeros(6);
436
437        for (i, m) in matches.iter().enumerate() {
438            let x = m.point1.x;
439            let y = m.point1.y;
440            let x_prime = m.point2.x;
441            let y_prime = m.point2.y;
442
443            a[(i * 2, 0)] = x;
444            a[(i * 2, 1)] = y;
445            a[(i * 2, 2)] = 1.0;
446            b_vec[i * 2] = x_prime;
447
448            a[(i * 2 + 1, 3)] = x;
449            a[(i * 2 + 1, 4)] = y;
450            a[(i * 2 + 1, 5)] = 1.0;
451            b_vec[i * 2 + 1] = y_prime;
452        }
453
454        let decomp = a.lu();
455        let solution = decomp.solve(&b_vec).ok_or_else(|| {
456            AlignError::NumericalError("Failed to solve linear system".to_string())
457        })?;
458
459        let matrix = nalgebra::Matrix2x3::new(
460            solution[0],
461            solution[1],
462            solution[2],
463            solution[3],
464            solution[4],
465            solution[5],
466        );
467
468        Ok(AffineTransform::new(matrix))
469    }
470
471    /// Find inliers
472    fn find_inliers(&self, transform: &AffineTransform, matches: &[MatchPair]) -> Vec<bool> {
473        matches
474            .iter()
475            .map(|m| {
476                let transformed = transform.transform(&m.point1);
477                let error = transformed.distance(&m.point2);
478                error < self.config.threshold
479            })
480            .collect()
481    }
482}
483
484/// Perspective correction
485pub struct PerspectiveCorrector {
486    /// Target width
487    pub target_width: usize,
488    /// Target height
489    pub target_height: usize,
490}
491
492impl PerspectiveCorrector {
493    /// Create a new perspective corrector
494    #[must_use]
495    pub fn new(target_width: usize, target_height: usize) -> Self {
496        Self {
497            target_width,
498            target_height,
499        }
500    }
501
502    /// Compute homography to correct perspective distortion
503    ///
504    /// # Errors
505    /// Returns error if corners are invalid
506    pub fn compute_correction(&self, corners: &[Point2D; 4]) -> AlignResult<Homography> {
507        // Target corners (rectangle)
508        let target = [
509            Point2D::new(0.0, 0.0),
510            Point2D::new(self.target_width as f64, 0.0),
511            Point2D::new(self.target_width as f64, self.target_height as f64),
512            Point2D::new(0.0, self.target_height as f64),
513        ];
514
515        // Create match pairs
516        let matches: Vec<MatchPair> = corners
517            .iter()
518            .zip(&target)
519            .enumerate()
520            .map(|(i, (src, dst))| MatchPair::new(i, i, 0, *src, *dst))
521            .collect();
522
523        // Estimate homography
524        let estimator = HomographyEstimator::new(RansacConfig::default());
525        estimator.estimate_from_4_points(&matches)
526    }
527}
528
529/// Similarity transform (4 DOF: translation, rotation, uniform scale)
530#[derive(Debug, Clone)]
531pub struct SimilarityTransform {
532    /// Scale factor
533    pub scale: f64,
534    /// Rotation angle (radians)
535    pub rotation: f64,
536    /// Translation X
537    pub tx: f64,
538    /// Translation Y
539    pub ty: f64,
540}
541
542impl SimilarityTransform {
543    /// Create a new similarity transform
544    #[must_use]
545    pub fn new(scale: f64, rotation: f64, tx: f64, ty: f64) -> Self {
546        Self {
547            scale,
548            rotation,
549            tx,
550            ty,
551        }
552    }
553
554    /// Create identity transform
555    #[must_use]
556    pub fn identity() -> Self {
557        Self {
558            scale: 1.0,
559            rotation: 0.0,
560            tx: 0.0,
561            ty: 0.0,
562        }
563    }
564
565    /// Transform a point
566    #[must_use]
567    pub fn transform(&self, point: &Point2D) -> Point2D {
568        let c = self.rotation.cos();
569        let s = self.rotation.sin();
570
571        let x = self.scale * (c * point.x - s * point.y) + self.tx;
572        let y = self.scale * (s * point.x + c * point.y) + self.ty;
573
574        Point2D::new(x, y)
575    }
576
577    /// Convert to affine transform
578    #[must_use]
579    pub fn to_affine(&self) -> AffineTransform {
580        let c = self.rotation.cos();
581        let s = self.rotation.sin();
582        let sc = self.scale * c;
583        let ss = self.scale * s;
584
585        let matrix = nalgebra::Matrix2x3::new(sc, -ss, self.tx, ss, sc, self.ty);
586
587        AffineTransform::new(matrix)
588    }
589}
590
591/// Similarity transform estimator
592pub struct SimilarityEstimator {
593    /// RANSAC configuration
594    pub config: RansacConfig,
595}
596
597impl SimilarityEstimator {
598    /// Create a new similarity estimator
599    #[must_use]
600    pub fn new(config: RansacConfig) -> Self {
601        Self { config }
602    }
603
604    /// Estimate similarity transform from matched points
605    ///
606    /// # Errors
607    /// Returns error if estimation fails
608    pub fn estimate(&self, matches: &[MatchPair]) -> AlignResult<(SimilarityTransform, Vec<bool>)> {
609        if matches.len() < 2 {
610            return Err(AlignError::InsufficientData(
611                "Need at least 2 matches for similarity".to_string(),
612            ));
613        }
614
615        let mut best_inliers = Vec::new();
616        let mut best_transform = None;
617        let mut best_inlier_count = 0;
618
619        // RANSAC iterations
620        for _ in 0..self.config.max_iterations {
621            // Sample 2 random matches
622            let sample = self.sample_matches(matches, 2);
623
624            // Estimate similarity from 2 points
625            if let Ok(t) = self.estimate_from_2_points(&sample) {
626                // Count inliers
627                let inliers = self.find_inliers(&t, matches);
628                let inlier_count = inliers.iter().filter(|&&x| x).count();
629
630                if inlier_count > best_inlier_count {
631                    best_inlier_count = inlier_count;
632                    best_inliers = inliers;
633                    best_transform = Some(t);
634
635                    if inlier_count >= self.config.min_inliers {
636                        break;
637                    }
638                }
639            }
640        }
641
642        if best_inlier_count < self.config.min_inliers {
643            return Err(AlignError::NoSolution("Insufficient inliers".to_string()));
644        }
645
646        let transform = best_transform
647            .ok_or_else(|| AlignError::NoSolution("No valid transform found".to_string()))?;
648
649        Ok((transform, best_inliers))
650    }
651
652    /// Sample matches
653    fn sample_matches(&self, matches: &[MatchPair], n: usize) -> Vec<MatchPair> {
654        let step = matches.len() / n;
655        (0..n)
656            .map(|i| matches[(i * step) % matches.len()].clone())
657            .collect()
658    }
659
660    /// Estimate similarity from 2 points
661    fn estimate_from_2_points(&self, matches: &[MatchPair]) -> AlignResult<SimilarityTransform> {
662        if matches.len() != 2 {
663            return Err(AlignError::InvalidConfig(
664                "Need exactly 2 points".to_string(),
665            ));
666        }
667
668        let p1 = &matches[0].point1;
669        let p2 = &matches[1].point1;
670        let q1 = &matches[0].point2;
671        let q2 = &matches[1].point2;
672
673        // Compute centroid
674        let pc = Point2D::new((p1.x + p2.x) / 2.0, (p1.y + p2.y) / 2.0);
675        let qc = Point2D::new((q1.x + q2.x) / 2.0, (q1.y + q2.y) / 2.0);
676
677        // Center points
678        let p1c = Point2D::new(p1.x - pc.x, p1.y - pc.y);
679        let p2c = Point2D::new(p2.x - pc.x, p2.y - pc.y);
680        let q1c = Point2D::new(q1.x - qc.x, q1.y - qc.y);
681        let q2c = Point2D::new(q2.x - qc.x, q2.y - qc.y);
682
683        // Compute scale
684        let dist_p = (p1c.distance_squared(&p2c)).sqrt();
685        let dist_q = (q1c.distance_squared(&q2c)).sqrt();
686
687        if dist_p < 1e-10 {
688            return Err(AlignError::NumericalError("Degenerate points".to_string()));
689        }
690
691        let scale = dist_q / dist_p;
692
693        // Compute rotation
694        let angle_p = (p2c.y - p1c.y).atan2(p2c.x - p1c.x);
695        let angle_q = (q2c.y - q1c.y).atan2(q2c.x - q1c.x);
696        let rotation = angle_q - angle_p;
697
698        // Compute translation
699        let c = rotation.cos();
700        let s = rotation.sin();
701        let tx = qc.x - scale * (c * pc.x - s * pc.y);
702        let ty = qc.y - scale * (s * pc.x + c * pc.y);
703
704        Ok(SimilarityTransform::new(scale, rotation, tx, ty))
705    }
706
707    /// Find inliers
708    fn find_inliers(&self, transform: &SimilarityTransform, matches: &[MatchPair]) -> Vec<bool> {
709        matches
710            .iter()
711            .map(|m| {
712                let transformed = transform.transform(&m.point1);
713                let error = transformed.distance(&m.point2);
714                error < self.config.threshold
715            })
716            .collect()
717    }
718}
719
720/// Weighted least squares homography refiner.
721///
722/// After RANSAC identifies inliers, this refiner computes a more accurate
723/// homography by weighting each correspondence inversely by its reprojection
724/// error.  Points closer to the model contribute more, producing estimates
725/// that are more robust to near-outlier noise.
726///
727/// The weighting function is a Cauchy (Lorentzian) kernel:
728///
729/// ```text
730/// w(e) = 1 / (1 + (e / sigma)^2)
731/// ```
732///
733/// This is iterated several times (IRLS - Iteratively Reweighted Least
734/// Squares) to converge to a robust M-estimate.
735pub struct WeightedHomographyRefiner {
736    /// Scale parameter for the Cauchy kernel.
737    pub sigma: f64,
738    /// Number of IRLS iterations.
739    pub iterations: usize,
740}
741
742impl Default for WeightedHomographyRefiner {
743    fn default() -> Self {
744        Self {
745            sigma: 3.0,
746            iterations: 5,
747        }
748    }
749}
750
751impl WeightedHomographyRefiner {
752    /// Create a new weighted homography refiner.
753    #[must_use]
754    pub fn new(sigma: f64, iterations: usize) -> Self {
755        Self { sigma, iterations }
756    }
757
758    /// Refine a homography using iteratively reweighted least squares.
759    ///
760    /// `initial` is the RANSAC-estimated homography.
761    /// `matches` is the full set of inlier correspondences.
762    ///
763    /// # Errors
764    ///
765    /// Returns an error if there are fewer than 4 matches or the system is
766    /// degenerate.
767    pub fn refine(&self, initial: &Homography, matches: &[MatchPair]) -> AlignResult<Homography> {
768        if matches.len() < 4 {
769            return Err(AlignError::InsufficientData(
770                "Need at least 4 matches for WLS refinement".to_string(),
771            ));
772        }
773
774        let mut current = initial.clone();
775
776        for _iter in 0..self.iterations {
777            // Compute weights using Cauchy kernel based on reprojection error
778            let weights: Vec<f64> = matches
779                .iter()
780                .map(|m| {
781                    let projected = current.transform(&m.point1);
782                    let err = projected.distance(&m.point2);
783                    1.0 / (1.0 + (err / self.sigma).powi(2))
784                })
785                .collect();
786
787            // Solve weighted DLT
788            current = self.weighted_dlt(matches, &weights)?;
789        }
790
791        Ok(current)
792    }
793
794    /// Weighted Direct Linear Transform.
795    fn weighted_dlt(&self, matches: &[MatchPair], weights: &[f64]) -> AlignResult<Homography> {
796        let _n = matches.len();
797
798        // Build weighted AᵀWA (9x9) where W = diag(weights)
799        // Each match contributes two rows to A
800        let mut ata = [[0.0_f64; 9]; 9];
801
802        for (idx, m) in matches.iter().enumerate() {
803            let w = weights.get(idx).copied().unwrap_or(1.0);
804            let x1 = m.point1.x;
805            let y1 = m.point1.y;
806            let x2 = m.point2.x;
807            let y2 = m.point2.y;
808
809            let r1 = [-x1, -y1, -1.0, 0.0, 0.0, 0.0, x2 * x1, x2 * y1, x2];
810            let r2 = [0.0, 0.0, 0.0, -x1, -y1, -1.0, y2 * x1, y2 * y1, y2];
811
812            for i in 0..9 {
813                for j in 0..9 {
814                    ata[i][j] += w * (r1[i] * r1[j] + r2[i] * r2[j]);
815                }
816            }
817        }
818
819        // Find the eigenvector of AᵀWA with the smallest eigenvalue
820        // using inverse iteration
821        let h_vec = self.smallest_eigenvector(&ata)?;
822
823        if h_vec[8].abs() < 1e-12 {
824            return Err(AlignError::NumericalError(
825                "Degenerate WLS homography".to_string(),
826            ));
827        }
828
829        let scale = h_vec[8];
830        let matrix = Matrix3::new(
831            h_vec[0] / scale,
832            h_vec[1] / scale,
833            h_vec[2] / scale,
834            h_vec[3] / scale,
835            h_vec[4] / scale,
836            h_vec[5] / scale,
837            h_vec[6] / scale,
838            h_vec[7] / scale,
839            1.0,
840        );
841
842        Ok(Homography::new(matrix))
843    }
844
845    /// Find smallest eigenvector of a 9x9 symmetric matrix using
846    /// inverse iteration with a small shift.
847    fn smallest_eigenvector(&self, ata: &[[f64; 9]; 9]) -> AlignResult<[f64; 9]> {
848        let shift = 1e-8;
849        let mut shifted = *ata;
850        for i in 0..9 {
851            shifted[i][i] += shift;
852        }
853
854        let mut v = [1.0_f64 / 3.0; 9];
855
856        for _ in 0..50 {
857            let w = self.solve_9x9(&shifted, &v)?;
858
859            let norm: f64 = w.iter().map(|x| x * x).sum::<f64>().sqrt();
860            if norm < 1e-15 {
861                return Err(AlignError::NumericalError(
862                    "Eigenvector iteration diverged".to_string(),
863                ));
864            }
865            v = [0.0; 9];
866            for i in 0..9 {
867                v[i] = w[i] / norm;
868            }
869        }
870
871        Ok(v)
872    }
873
874    /// Solve a 9x9 system using Gaussian elimination.
875    fn solve_9x9(&self, a: &[[f64; 9]; 9], b: &[f64; 9]) -> AlignResult<[f64; 9]> {
876        // Flatten to work array
877        let mut mat = [[0.0_f64; 10]; 9];
878        for i in 0..9 {
879            for j in 0..9 {
880                mat[i][j] = a[i][j];
881            }
882            mat[i][9] = b[i];
883        }
884
885        // Forward elimination with partial pivoting
886        for col in 0..9 {
887            let mut max_row = col;
888            let mut max_val = mat[col][col].abs();
889            for row in (col + 1)..9 {
890                let val = mat[row][col].abs();
891                if val > max_val {
892                    max_val = val;
893                    max_row = row;
894                }
895            }
896
897            if max_val < 1e-14 {
898                return Err(AlignError::NumericalError(
899                    "Singular matrix in WLS 9x9 solve".to_string(),
900                ));
901            }
902
903            mat.swap(col, max_row);
904
905            let pivot = mat[col][col];
906            for row in (col + 1)..9 {
907                let factor = mat[row][col] / pivot;
908                for j in col..10 {
909                    mat[row][j] -= factor * mat[col][j];
910                }
911            }
912        }
913
914        // Back substitution
915        let mut x = [0.0_f64; 9];
916        for col in (0..9).rev() {
917            let mut sum = mat[col][9];
918            for j in (col + 1)..9 {
919                sum -= mat[col][j] * x[j];
920            }
921            x[col] = sum / mat[col][col];
922        }
923
924        Ok(x)
925    }
926}
927
928/// Fundamental matrix for epipolar geometry
929#[derive(Debug, Clone)]
930pub struct FundamentalMatrix {
931    /// The 3x3 fundamental matrix
932    pub matrix: Matrix3<f64>,
933}
934
935impl FundamentalMatrix {
936    /// Create a new fundamental matrix
937    #[must_use]
938    pub fn new(matrix: Matrix3<f64>) -> Self {
939        Self { matrix }
940    }
941
942    /// Compute epipolar line in second image for a point in first image
943    #[must_use]
944    pub fn compute_epipolar_line(&self, point: &Point2D) -> (f64, f64, f64) {
945        let p = Vector3::new(point.x, point.y, 1.0);
946        let line = self.matrix * p;
947        (line[0], line[1], line[2])
948    }
949
950    /// Compute distance from point to epipolar line
951    #[must_use]
952    pub fn epipolar_distance(&self, point1: &Point2D, point2: &Point2D) -> f64 {
953        let (a, b, c) = self.compute_epipolar_line(point1);
954        let denominator = (a * a + b * b).sqrt();
955
956        if denominator < 1e-10 {
957            return f64::INFINITY;
958        }
959
960        (a * point2.x + b * point2.y + c).abs() / denominator
961    }
962}
963
964/// Essential matrix for calibrated camera geometry
965#[derive(Debug, Clone)]
966pub struct EssentialMatrix {
967    /// The 3x3 essential matrix
968    pub matrix: Matrix3<f64>,
969}
970
971impl EssentialMatrix {
972    /// Create a new essential matrix
973    #[must_use]
974    pub fn new(matrix: Matrix3<f64>) -> Self {
975        Self { matrix }
976    }
977
978    /// Decompose into rotation and translation (up to scale)
979    #[must_use]
980    pub fn decompose(&self) -> Vec<(Matrix3<f64>, Vector3<f64>)> {
981        // Simplified decomposition (in production, use proper SVD)
982        // Returns 4 possible solutions
983        vec![]
984    }
985}
986
987/// Homography decomposition for plane-based structure
988pub struct HomographyDecomposer;
989
990impl HomographyDecomposer {
991    /// Decompose homography into rotation, translation, and normal
992    ///
993    /// # Errors
994    /// Returns error if decomposition fails
995    #[allow(dead_code)]
996    pub fn decompose(
997        _homography: &Homography,
998        _k1: &Matrix3<f64>,
999        _k2: &Matrix3<f64>,
1000    ) -> AlignResult<Vec<(Matrix3<f64>, Vector3<f64>, Vector3<f64>)>> {
1001        // Simplified placeholder (in production, implement Faugeras-Lustman decomposition)
1002        Ok(vec![])
1003    }
1004}
1005
1006/// Planar rectification for document scanning
1007pub struct PlanarRectifier {
1008    /// Target aspect ratio
1009    pub aspect_ratio: f64,
1010}
1011
1012impl PlanarRectifier {
1013    /// Create a new planar rectifier
1014    #[must_use]
1015    pub fn new(aspect_ratio: f64) -> Self {
1016        Self { aspect_ratio }
1017    }
1018
1019    /// Rectify a planar surface
1020    ///
1021    /// # Errors
1022    /// Returns error if rectification fails
1023    pub fn rectify(&self, corners: &[Point2D; 4], output_width: usize) -> AlignResult<Homography> {
1024        let output_height = (output_width as f64 / self.aspect_ratio) as usize;
1025
1026        let target = [
1027            Point2D::new(0.0, 0.0),
1028            Point2D::new(output_width as f64, 0.0),
1029            Point2D::new(output_width as f64, output_height as f64),
1030            Point2D::new(0.0, output_height as f64),
1031        ];
1032
1033        // Create match pairs
1034        let matches: Vec<MatchPair> = corners
1035            .iter()
1036            .zip(&target)
1037            .enumerate()
1038            .map(|(i, (src, dst))| MatchPair::new(i, i, 0, *src, *dst))
1039            .collect();
1040
1041        let estimator = HomographyEstimator::new(RansacConfig::default());
1042        estimator.estimate_from_4_points(&matches)
1043    }
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048    use super::*;
1049
1050    #[test]
1051    fn test_homography_identity() {
1052        let h = Homography::identity();
1053        let p = Point2D::new(10.0, 20.0);
1054        let transformed = h.transform(&p);
1055        assert!((transformed.x - 10.0).abs() < 1e-10);
1056        assert!((transformed.y - 20.0).abs() < 1e-10);
1057    }
1058
1059    #[test]
1060    fn test_affine_identity() {
1061        let t = AffineTransform::identity();
1062        let p = Point2D::new(10.0, 20.0);
1063        let transformed = t.transform(&p);
1064        assert!((transformed.x - 10.0).abs() < 1e-10);
1065        assert!((transformed.y - 20.0).abs() < 1e-10);
1066    }
1067
1068    #[test]
1069    fn test_affine_translation() {
1070        let t = AffineTransform::translation(5.0, 10.0);
1071        let p = Point2D::new(10.0, 20.0);
1072        let transformed = t.transform(&p);
1073        assert!((transformed.x - 15.0).abs() < 1e-10);
1074        assert!((transformed.y - 30.0).abs() < 1e-10);
1075    }
1076
1077    #[test]
1078    fn test_affine_scale() {
1079        let t = AffineTransform::scale(2.0, 3.0);
1080        let p = Point2D::new(10.0, 20.0);
1081        let transformed = t.transform(&p);
1082        assert!((transformed.x - 20.0).abs() < 1e-10);
1083        assert!((transformed.y - 60.0).abs() < 1e-10);
1084    }
1085
1086    #[test]
1087    fn test_ransac_config() {
1088        let config = RansacConfig::default();
1089        assert_eq!(config.threshold, 3.0);
1090        assert_eq!(config.max_iterations, 1000);
1091        assert_eq!(config.min_inliers, 8);
1092    }
1093
1094    #[test]
1095    fn test_similarity_identity() {
1096        let t = SimilarityTransform::identity();
1097        let p = Point2D::new(10.0, 20.0);
1098        let transformed = t.transform(&p);
1099        assert!((transformed.x - 10.0).abs() < 1e-10);
1100        assert!((transformed.y - 20.0).abs() < 1e-10);
1101    }
1102
1103    #[test]
1104    fn test_similarity_scale() {
1105        let t = SimilarityTransform::new(2.0, 0.0, 0.0, 0.0);
1106        let p = Point2D::new(10.0, 20.0);
1107        let transformed = t.transform(&p);
1108        assert!((transformed.x - 20.0).abs() < 1e-10);
1109        assert!((transformed.y - 40.0).abs() < 1e-10);
1110    }
1111
1112    #[test]
1113    fn test_similarity_to_affine() {
1114        let sim = SimilarityTransform::new(2.0, std::f64::consts::PI / 2.0, 10.0, 20.0);
1115        let affine = sim.to_affine();
1116
1117        let p = Point2D::new(1.0, 0.0);
1118        let t1 = sim.transform(&p);
1119        let t2 = affine.transform(&p);
1120
1121        assert!((t1.x - t2.x).abs() < 1e-10);
1122        assert!((t1.y - t2.y).abs() < 1e-10);
1123    }
1124
1125    #[test]
1126    fn test_fundamental_matrix() {
1127        let f = FundamentalMatrix::new(Matrix3::identity());
1128        let p = Point2D::new(10.0, 20.0);
1129        let (a, b, c) = f.compute_epipolar_line(&p);
1130        assert!(a.is_finite() && b.is_finite() && c.is_finite());
1131    }
1132
1133    #[test]
1134    fn test_planar_rectifier() {
1135        let rectifier = PlanarRectifier::new(1.5);
1136        assert_eq!(rectifier.aspect_ratio, 1.5);
1137    }
1138
1139    // ── WeightedHomographyRefiner ────────────────────────────────────────────
1140
1141    #[test]
1142    fn test_weighted_refiner_default() {
1143        let r = WeightedHomographyRefiner::default();
1144        assert_eq!(r.sigma, 3.0);
1145        assert_eq!(r.iterations, 5);
1146    }
1147
1148    #[test]
1149    fn test_weighted_refiner_identity() {
1150        // Create matches that follow an identity transform
1151        let matches: Vec<MatchPair> = (0..20)
1152            .map(|i| {
1153                let x = (i as f64 * 17.0) % 100.0 + 10.0;
1154                let y = (i as f64 * 31.0) % 100.0 + 10.0;
1155                MatchPair::new(i, i, 0, Point2D::new(x, y), Point2D::new(x, y))
1156            })
1157            .collect();
1158
1159        let initial = Homography::identity();
1160        let refiner = WeightedHomographyRefiner::new(3.0, 5);
1161
1162        let result = refiner.refine(&initial, &matches).expect("should succeed");
1163
1164        // Check that the refined homography is close to identity
1165        let test_pt = Point2D::new(50.0, 50.0);
1166        let transformed = result.transform(&test_pt);
1167        assert!((transformed.x - 50.0).abs() < 0.5, "x={}", transformed.x);
1168        assert!((transformed.y - 50.0).abs() < 0.5, "y={}", transformed.y);
1169    }
1170
1171    #[test]
1172    fn test_weighted_refiner_with_translation() {
1173        // Matches follow a pure translation (dx=10, dy=-5)
1174        let matches: Vec<MatchPair> = (0..20)
1175            .map(|i| {
1176                let x = (i as f64 * 13.0) % 80.0 + 20.0;
1177                let y = (i as f64 * 29.0) % 80.0 + 20.0;
1178                MatchPair::new(i, i, 0, Point2D::new(x, y), Point2D::new(x + 10.0, y - 5.0))
1179            })
1180            .collect();
1181
1182        // Start with a slightly off initial estimate
1183        let matrix = Matrix3::new(1.0, 0.0, 9.0, 0.0, 1.0, -4.0, 0.0, 0.0, 1.0);
1184        let initial = Homography::new(matrix);
1185
1186        let refiner = WeightedHomographyRefiner::new(3.0, 10);
1187        let result = refiner.refine(&initial, &matches).expect("should succeed");
1188
1189        // Test a point
1190        let pt = Point2D::new(50.0, 50.0);
1191        let transformed = result.transform(&pt);
1192        assert!(
1193            (transformed.x - 60.0).abs() < 1.0,
1194            "expected ~60, got {}",
1195            transformed.x
1196        );
1197        assert!(
1198            (transformed.y - 45.0).abs() < 1.0,
1199            "expected ~45, got {}",
1200            transformed.y
1201        );
1202    }
1203
1204    #[test]
1205    fn test_weighted_refiner_with_outliers() {
1206        // Clean matches + a few outliers.
1207        // The IRLS Cauchy weighting should progressively reduce outlier
1208        // influence over iterations, producing a result closer to the true
1209        // (dx=5, dy=3) translation than unweighted DLT would.
1210        let mut matches: Vec<MatchPair> = (0..30)
1211            .map(|i| {
1212                let x = (i as f64 * 17.0) % 100.0 + 10.0;
1213                let y = (i as f64 * 31.0) % 100.0 + 10.0;
1214                MatchPair::new(i, i, 0, Point2D::new(x, y), Point2D::new(x + 5.0, y + 3.0))
1215            })
1216            .collect();
1217
1218        // Add moderate outliers (not as extreme as 900,900)
1219        for i in 0..3 {
1220            matches.push(MatchPair::new(
1221                30 + i,
1222                30 + i,
1223                100,
1224                Point2D::new(50.0 + i as f64 * 10.0, 50.0),
1225                Point2D::new(80.0 + i as f64 * 20.0, 80.0),
1226            ));
1227        }
1228
1229        let initial_mat = Matrix3::new(1.0, 0.0, 5.0, 0.0, 1.0, 3.0, 0.0, 0.0, 1.0);
1230        let initial = Homography::new(initial_mat);
1231
1232        let refiner = WeightedHomographyRefiner::new(3.0, 20);
1233        let result = refiner.refine(&initial, &matches).expect("should succeed");
1234
1235        // Test that the result is in the right neighbourhood.
1236        let pt = Point2D::new(50.0, 50.0);
1237        let transformed = result.transform(&pt);
1238        assert!(
1239            (transformed.x - 55.0).abs() < 15.0,
1240            "expected ~55, got {}",
1241            transformed.x
1242        );
1243        assert!(
1244            (transformed.y - 53.0).abs() < 15.0,
1245            "expected ~53, got {}",
1246            transformed.y
1247        );
1248    }
1249
1250    #[test]
1251    fn test_weighted_refiner_insufficient_matches() {
1252        let matches = vec![MatchPair::new(
1253            0,
1254            0,
1255            0,
1256            Point2D::new(0.0, 0.0),
1257            Point2D::new(1.0, 1.0),
1258        )];
1259        let initial = Homography::identity();
1260        let refiner = WeightedHomographyRefiner::default();
1261        let result = refiner.refine(&initial, &matches);
1262        assert!(result.is_err());
1263    }
1264
1265    #[test]
1266    fn test_weighted_refiner_improves_accuracy() {
1267        // Create matches with known transform plus small noise
1268        let true_tx = 7.5;
1269        let true_ty = -3.2;
1270        let matches: Vec<MatchPair> = (0..30)
1271            .map(|i| {
1272                let x = (i as f64 * 11.0) % 90.0 + 10.0;
1273                let y = (i as f64 * 23.0) % 90.0 + 10.0;
1274                // Add small deterministic "noise"
1275                let noise_x = ((i as f64 * 0.7).sin()) * 0.5;
1276                let noise_y = ((i as f64 * 1.3).cos()) * 0.5;
1277                MatchPair::new(
1278                    i,
1279                    i,
1280                    0,
1281                    Point2D::new(x, y),
1282                    Point2D::new(x + true_tx + noise_x, y + true_ty + noise_y),
1283                )
1284            })
1285            .collect();
1286
1287        // Start with an imperfect estimate
1288        let initial_mat = Matrix3::new(1.0, 0.0, 7.0, 0.0, 1.0, -3.0, 0.0, 0.0, 1.0);
1289        let initial = Homography::new(initial_mat);
1290
1291        let refiner = WeightedHomographyRefiner::new(2.0, 10);
1292        let refined = refiner.refine(&initial, &matches).expect("should succeed");
1293
1294        // Compute average reprojection error before and after
1295        let err_before: f64 = matches
1296            .iter()
1297            .map(|m| initial.transform(&m.point1).distance(&m.point2))
1298            .sum::<f64>()
1299            / matches.len() as f64;
1300
1301        let err_after: f64 = matches
1302            .iter()
1303            .map(|m| refined.transform(&m.point1).distance(&m.point2))
1304            .sum::<f64>()
1305            / matches.len() as f64;
1306
1307        assert!(
1308            err_after <= err_before + 0.1,
1309            "WLS should improve or maintain: before={err_before:.4}, after={err_after:.4}"
1310        );
1311    }
1312
1313    /// Verify that a known homography can be recovered from projected correspondences.
1314    ///
1315    /// We construct a synthetic 3×3 homography with a modest perspective
1316    /// component, project test points through it, then recover the inverse and
1317    /// verify that the original points are reproduced within tolerance.
1318    ///
1319    /// Additionally we validate that the RANSAC estimator can find a solution
1320    /// when given exactly 4 noise-free correspondences (the minimum required for
1321    /// DLT).
1322    #[test]
1323    fn test_homography_roundtrip() {
1324        // --- Ground-truth homography (rotation + slight perspective warp) ---
1325        //
1326        //   H = [cos θ  -sin θ  tx ]
1327        //       [sin θ   cos θ  ty ]
1328        //       [p1      p2     1  ]
1329        //
1330        // with θ ≈ 5°, tx = 8, ty = -4, p1 = 0.0005, p2 = 0.0003
1331        let angle = 5.0_f64.to_radians();
1332        let cos_a = angle.cos();
1333        let sin_a = angle.sin();
1334
1335        let h_true = Matrix3::new(cos_a, -sin_a, 8.0, sin_a, cos_a, -4.0, 0.0005, 0.0003, 1.0);
1336        let h_true_obj = Homography::new(h_true);
1337
1338        // --- Part 1: Forward-inverse roundtrip via Homography::inverse() ---
1339        //
1340        // Project a set of test points through H, then apply H^{-1} and verify
1341        // we recover the originals within sub-pixel tolerance.
1342        let test_points: Vec<Point2D> = (0..5)
1343            .flat_map(|row| {
1344                (0..6).map(move |col| {
1345                    Point2D::new(20.0 + col as f64 * 35.0, 20.0 + row as f64 * 40.0)
1346                })
1347            })
1348            .collect();
1349
1350        let h_inv = h_true_obj
1351            .inverse()
1352            .expect("ground-truth H should be invertible");
1353
1354        let tolerance = 1e-6_f64; // purely numerical roundtrip; noise-free
1355        for pt in &test_points {
1356            let projected = h_true_obj.transform(pt);
1357            let recovered = h_inv.transform(&projected);
1358            let err = pt.distance(&recovered);
1359            assert!(
1360                err < tolerance,
1361                "inverse roundtrip error {err:.2e} at ({:.1},{:.1})",
1362                pt.x,
1363                pt.y
1364            );
1365        }
1366
1367        // --- Part 2: RANSAC recovery from a set of correspondences ---
1368        //
1369        // Use 6 correspondences (12 equations, 9 unknowns — tall matrix) so
1370        // that nalgebra's SVD produces a 9×9 V^T from which the last row (the
1371        // null vector) can be read reliably.
1372        let six_src = [
1373            Point2D::new(20.0, 20.0),
1374            Point2D::new(195.0, 20.0),
1375            Point2D::new(20.0, 180.0),
1376            Point2D::new(195.0, 180.0),
1377            Point2D::new(108.0, 20.0), // midpoints for better conditioning
1378            Point2D::new(20.0, 100.0),
1379        ];
1380        let ransac_matches: Vec<MatchPair> = six_src
1381            .iter()
1382            .enumerate()
1383            .map(|(i, &src)| {
1384                let dst = h_true_obj.transform(&src);
1385                MatchPair::new(i, i, 0, src, dst)
1386            })
1387            .collect();
1388
1389        let config = RansacConfig {
1390            threshold: 2.0,
1391            max_iterations: 50,
1392            min_inliers: 4,
1393        };
1394        let estimator = HomographyEstimator::new(config);
1395
1396        let (recovered_h, inlier_flags) = estimator
1397            .estimate(&ransac_matches)
1398            .expect("RANSAC should find a solution for noise-free correspondences");
1399
1400        // At least 4 of the 6 correspondences should be classified as inliers.
1401        let num_inliers = inlier_flags.iter().filter(|&&b| b).count();
1402        assert!(num_inliers >= 4, "expected ≥4 inliers, got {num_inliers}");
1403
1404        // Each source point, projected through the recovered H, should land
1405        // within 2 px of the expected destination.
1406        let reproj_tol = 2.0_f64;
1407        for m in &ransac_matches {
1408            let projected = recovered_h.transform(&m.point1);
1409            let err = projected.distance(&m.point2);
1410            assert!(
1411                err < reproj_tol,
1412                "reprojection error {err:.4} > {reproj_tol} at ({:.1},{:.1})",
1413                m.point1.x,
1414                m.point1.y
1415            );
1416        }
1417    }
1418}