Skip to main content

oximedia_align/
spatial.rs

1//! Spatial registration and geometric alignment.
2//!
3//! This module provides tools for aligning images geometrically:
4//!
5//! - Homography estimation (serial and parallel RANSAC)
6//! - Perspective transformation
7//! - RANSAC for robust fitting
8//! - Affine transformation
9
10use crate::features::MatchPair;
11use crate::{AlignError, AlignResult, Point2D};
12use nalgebra::{Matrix3, Vector3};
13use rayon::prelude::*;
14use std::sync::atomic::{AtomicUsize, Ordering};
15
16/// 3x3 homography matrix for perspective transformation
17#[derive(Debug, Clone)]
18pub struct Homography {
19    /// The 3x3 transformation matrix
20    pub matrix: Matrix3<f64>,
21}
22
23impl Homography {
24    /// Create a new homography from a matrix
25    #[must_use]
26    pub fn new(matrix: Matrix3<f64>) -> Self {
27        Self { matrix }
28    }
29
30    /// Create identity homography
31    #[must_use]
32    pub fn identity() -> Self {
33        Self {
34            matrix: Matrix3::identity(),
35        }
36    }
37
38    /// Transform a point
39    #[must_use]
40    pub fn transform(&self, point: &Point2D) -> Point2D {
41        let p = Vector3::new(point.x, point.y, 1.0);
42        let transformed = self.matrix * p;
43
44        if transformed[2].abs() > f64::EPSILON {
45            Point2D::new(
46                transformed[0] / transformed[2],
47                transformed[1] / transformed[2],
48            )
49        } else {
50            *point
51        }
52    }
53
54    /// Compute inverse homography
55    ///
56    /// # Errors
57    /// Returns error if matrix is singular
58    pub fn inverse(&self) -> AlignResult<Self> {
59        self.matrix
60            .try_inverse()
61            .map(Self::new)
62            .ok_or_else(|| AlignError::NumericalError("Singular matrix".to_string()))
63    }
64
65    /// Compose two homographies
66    #[must_use]
67    pub fn compose(&self, other: &Self) -> Self {
68        Self::new(self.matrix * other.matrix)
69    }
70}
71
72/// Configuration for RANSAC
73#[derive(Debug, Clone)]
74pub struct RansacConfig {
75    /// Distance threshold for inliers
76    pub threshold: f64,
77    /// Maximum number of iterations
78    pub max_iterations: usize,
79    /// Minimum number of inliers required
80    pub min_inliers: usize,
81}
82
83impl Default for RansacConfig {
84    fn default() -> Self {
85        Self {
86            threshold: 3.0,
87            max_iterations: 1000,
88            min_inliers: 8,
89        }
90    }
91}
92
93/// Homography estimator using RANSAC
94pub struct HomographyEstimator {
95    /// RANSAC configuration
96    pub config: RansacConfig,
97}
98
99impl HomographyEstimator {
100    /// Create a new homography estimator
101    #[must_use]
102    pub fn new(config: RansacConfig) -> Self {
103        Self { config }
104    }
105
106    /// Estimate homography from matched points using parallel RANSAC.
107    ///
108    /// Divides `max_iterations` into batches of `BATCH_SIZE = 64`. Each batch
109    /// runs in parallel via `rayon`, with a shared `AtomicUsize` tracking the
110    /// global best inlier count for advisory early exit.
111    ///
112    /// Each parallel iteration uses a deterministic per-iteration seeded RNG
113    /// (xorshift64), so results are reproducible for a given configuration
114    /// though not necessarily identical to the serial path (they may be
115    /// equal-or-better, which is the desired property).
116    ///
117    /// # Errors
118    /// Returns error if insufficient matches or estimation fails
119    #[allow(clippy::too_many_lines)]
120    pub fn estimate(&self, matches: &[MatchPair]) -> AlignResult<(Homography, Vec<bool>)> {
121        if matches.len() < 4 {
122            return Err(AlignError::InsufficientData(
123                "Need at least 4 matches for homography".to_string(),
124            ));
125        }
126
127        const BATCH_SIZE: usize = 64;
128        let early_exit_threshold = self.config.min_inliers.max(matches.len() * 80 / 100);
129        let num_batches = self.config.max_iterations.div_ceil(BATCH_SIZE);
130
131        // Shared advisory best-inlier counter for early-exit across batches.
132        let global_best_count = AtomicUsize::new(0);
133
134        let mut best_inliers: Vec<bool> = Vec::new();
135        let mut best_homography: Option<Homography> = None;
136        let mut best_inlier_count = 0usize;
137
138        'outer: for batch_idx in 0..num_batches {
139            // Advisory soft early exit — load with Relaxed; this is a hint only.
140            if global_best_count.load(Ordering::Relaxed) >= early_exit_threshold {
141                break 'outer;
142            }
143
144            let start_iter = batch_idx * BATCH_SIZE;
145            let end_iter = (start_iter + BATCH_SIZE).min(self.config.max_iterations);
146            let batch_len = end_iter - start_iter;
147
148            // Run each iteration in the batch in parallel.
149            // Each thread gets a deterministic seed from the global iteration index.
150            let batch_results: Vec<Option<(Homography, Vec<bool>, usize)>> = (0..batch_len)
151                .into_par_iter()
152                .map(|iter_in_batch| {
153                    // Advisory soft exit — Relaxed is fine; correctness not at stake.
154                    if global_best_count.load(Ordering::Relaxed) >= early_exit_threshold {
155                        return None;
156                    }
157
158                    let global_iter = start_iter + iter_in_batch;
159                    // Seeded deterministic sampling per iteration using xorshift64.
160                    let sample = self.sample_matches_seeded(matches, 4, global_iter as u64);
161
162                    let h = self.estimate_from_4_points(&sample).ok()?;
163                    let inliers = self.find_inliers(&h, matches);
164                    let inlier_count = inliers.iter().filter(|&&x| x).count();
165
166                    // Update global advisory counter (Relaxed: may be stale, that is fine).
167                    let prev = global_best_count.load(Ordering::Relaxed);
168                    if inlier_count > prev {
169                        global_best_count.fetch_max(inlier_count, Ordering::Relaxed);
170                    }
171
172                    Some((h, inliers, inlier_count))
173                })
174                .collect();
175
176            // Reduce: find the best result in this batch and merge with global best.
177            for result in batch_results.into_iter().flatten() {
178                let (h, inliers, inlier_count) = result;
179                if inlier_count > best_inlier_count {
180                    best_inlier_count = inlier_count;
181                    best_inliers = inliers;
182                    best_homography = Some(h);
183                }
184            }
185
186            // Hard post-batch early exit check.
187            if best_inlier_count >= early_exit_threshold {
188                break 'outer;
189            }
190        }
191
192        if best_inlier_count < self.config.min_inliers {
193            return Err(AlignError::NoSolution(format!(
194                "Insufficient inliers: {} < {}",
195                best_inlier_count, self.config.min_inliers
196            )));
197        }
198
199        let homography = best_homography
200            .ok_or_else(|| AlignError::NoSolution("No valid homography found".to_string()))?;
201
202        // Refine with all inliers
203        let inlier_matches: Vec<&MatchPair> = matches
204            .iter()
205            .zip(&best_inliers)
206            .filter(|(_, &is_inlier)| is_inlier)
207            .map(|(m, _)| m)
208            .collect();
209
210        let refined = self.refine_homography(&homography, &inlier_matches)?;
211
212        Ok((refined, best_inliers))
213    }
214
215    /// Estimate homography using the original sequential RANSAC.
216    ///
217    /// This is kept as a reference implementation for tests and benchmarks.
218    /// Production callers should prefer `estimate` (parallel version).
219    ///
220    /// # Errors
221    /// Returns error if insufficient matches or estimation fails
222    pub fn estimate_serial(&self, matches: &[MatchPair]) -> AlignResult<(Homography, Vec<bool>)> {
223        if matches.len() < 4 {
224            return Err(AlignError::InsufficientData(
225                "Need at least 4 matches for homography".to_string(),
226            ));
227        }
228
229        let mut best_inliers = Vec::new();
230        let mut best_homography = None;
231        let mut best_inlier_count = 0;
232        let early_exit_threshold = self.config.min_inliers.max(matches.len() * 80 / 100);
233
234        for iter in 0..self.config.max_iterations {
235            let sample = self.sample_matches_seeded(matches, 4, iter as u64);
236
237            if let Ok(h) = self.estimate_from_4_points(&sample) {
238                let inliers = self.find_inliers(&h, matches);
239                let inlier_count = inliers.iter().filter(|&&x| x).count();
240
241                if inlier_count > best_inlier_count {
242                    best_inlier_count = inlier_count;
243                    best_inliers = inliers;
244                    best_homography = Some(h);
245
246                    if inlier_count >= early_exit_threshold {
247                        break;
248                    }
249                }
250            }
251        }
252
253        if best_inlier_count < self.config.min_inliers {
254            return Err(AlignError::NoSolution(format!(
255                "Insufficient inliers: {} < {}",
256                best_inlier_count, self.config.min_inliers
257            )));
258        }
259
260        let homography = best_homography
261            .ok_or_else(|| AlignError::NoSolution("No valid homography found".to_string()))?;
262
263        let inlier_matches: Vec<&MatchPair> = matches
264            .iter()
265            .zip(&best_inliers)
266            .filter(|(_, &is_inlier)| is_inlier)
267            .map(|(m, _)| m)
268            .collect();
269
270        let refined = self.refine_homography(&homography, &inlier_matches)?;
271
272        Ok((refined, best_inliers))
273    }
274
275    /// Sample N random matches using a deterministic xorshift64 PRNG seeded by
276    /// `seed`. Produces a distinct sample for each unique seed value, enabling
277    /// parallel iterations without shared state.
278    fn sample_matches_seeded(&self, matches: &[MatchPair], n: usize, seed: u64) -> Vec<MatchPair> {
279        let pool = matches.len();
280        debug_assert!(pool >= n, "cannot sample {n} from {pool} matches");
281
282        // xorshift64 — fast, non-crypto, reproducible per seed.
283        let mut state = seed ^ 0x9e37_79b9_7f4a_7c15;
284        if state == 0 {
285            state = 0xdead_beef_cafe_babe;
286        }
287
288        let xorshift64 = |s: &mut u64| -> u64 {
289            *s ^= *s << 13;
290            *s ^= *s >> 7;
291            *s ^= *s << 17;
292            *s
293        };
294
295        let mut indices = Vec::with_capacity(n);
296        while indices.len() < n {
297            let r = xorshift64(&mut state);
298            let idx = (r as usize) % pool;
299            if !indices.contains(&idx) {
300                indices.push(idx);
301            }
302        }
303
304        indices.iter().map(|&i| matches[i].clone()).collect()
305    }
306
307    /// Estimate homography from 4 or more point correspondences using DLT.
308    ///
309    /// Uses the normal equations `A^T A` (a 9×9 symmetric positive semi-definite
310    /// matrix) and extracts the eigenvector corresponding to the smallest
311    /// eigenvalue via SVD.  This is numerically equivalent to the direct SVD of A
312    /// but avoids dimension edge-cases with nalgebra's thin-SVD for m < n matrices.
313    #[allow(clippy::similar_names)]
314    fn estimate_from_4_points(&self, matches: &[MatchPair]) -> AlignResult<Homography> {
315        if matches.len() < 4 {
316            return Err(AlignError::InvalidConfig(
317                "Need at least 4 points for DLT".to_string(),
318            ));
319        }
320
321        // Accumulate A^T A (9×9) directly from the DLT rows.
322        // Each correspondence contributes two rows r1, r2 to A;
323        // we add r1*r1^T + r2*r2^T to the accumulator.
324        let mut ata = nalgebra::Matrix::<f64, nalgebra::U9, nalgebra::U9, _>::zeros();
325
326        for m in matches {
327            let x1 = m.point1.x;
328            let y1 = m.point1.y;
329            let x2 = m.point2.x;
330            let y2 = m.point2.y;
331
332            // Row 1: [-x1, -y1, -1,  0,   0,  0, x2*x1, x2*y1, x2]
333            let r1 = nalgebra::Vector::<f64, nalgebra::U9, _>::from_row_slice(&[
334                -x1,
335                -y1,
336                -1.0,
337                0.0,
338                0.0,
339                0.0,
340                x2 * x1,
341                x2 * y1,
342                x2,
343            ]);
344            // Row 2: [0, 0, 0, -x1, -y1, -1, y2*x1, y2*y1, y2]
345            let r2 = nalgebra::Vector::<f64, nalgebra::U9, _>::from_row_slice(&[
346                0.0,
347                0.0,
348                0.0,
349                -x1,
350                -y1,
351                -1.0,
352                y2 * x1,
353                y2 * y1,
354                y2,
355            ]);
356
357            ata += r1 * r1.transpose() + r2 * r2.transpose();
358        }
359
360        // SVD of the 9×9 symmetric matrix: V^T has shape 9×9 (always square).
361        let svd = ata.svd(false, true);
362        let v = svd
363            .v_t
364            .ok_or_else(|| AlignError::NumericalError("SVD failed to compute V".to_string()))?;
365
366        // The solution is the last row of V^T (smallest eigenvalue of A^T A =
367        // smallest squared singular value of A = null-space direction).
368        let h_vec = v.row(8); // safe: v is always 9×9
369
370        if h_vec[8].abs() < 1e-10 {
371            return Err(AlignError::NumericalError(
372                "Degenerate solution (h[8] ≈ 0)".to_string(),
373            ));
374        }
375
376        // Normalize so that h[8] = 1, then reshape into 3×3.
377        let scale = h_vec[8];
378        let matrix = Matrix3::new(
379            h_vec[0] / scale,
380            h_vec[1] / scale,
381            h_vec[2] / scale,
382            h_vec[3] / scale,
383            h_vec[4] / scale,
384            h_vec[5] / scale,
385            h_vec[6] / scale,
386            h_vec[7] / scale,
387            1.0,
388        );
389
390        Ok(Homography::new(matrix))
391    }
392
393    /// Find inliers based on reprojection error
394    fn find_inliers(&self, homography: &Homography, matches: &[MatchPair]) -> Vec<bool> {
395        matches
396            .iter()
397            .map(|m| {
398                let transformed = homography.transform(&m.point1);
399                let error = transformed.distance(&m.point2);
400                error < self.config.threshold
401            })
402            .collect()
403    }
404
405    /// Refine homography using all inliers (overdetermined DLT via normal equations).
406    fn refine_homography(
407        &self,
408        _initial: &Homography,
409        inliers: &[&MatchPair],
410    ) -> AlignResult<Homography> {
411        if inliers.len() < 4 {
412            return Err(AlignError::InsufficientData(
413                "Need at least 4 inliers for refinement".to_string(),
414            ));
415        }
416
417        // Re-use the same A^T A accumulation as in estimate_from_4_points so
418        // we always work with a 9×9 symmetric matrix whose SVD is well-defined.
419        let matches_owned: Vec<MatchPair> = inliers.iter().map(|m| (*m).clone()).collect();
420        self.estimate_from_4_points(&matches_owned)
421    }
422}
423
424/// Affine transformation (6 DOF)
425#[derive(Debug, Clone)]
426pub struct AffineTransform {
427    /// 2x3 affine matrix [a b tx; c d ty]
428    pub matrix: nalgebra::Matrix2x3<f64>,
429}
430
431impl AffineTransform {
432    /// Create a new affine transform
433    #[must_use]
434    pub fn new(matrix: nalgebra::Matrix2x3<f64>) -> Self {
435        Self { matrix }
436    }
437
438    /// Create identity transform
439    #[must_use]
440    pub fn identity() -> Self {
441        Self {
442            matrix: nalgebra::Matrix2x3::new(1.0, 0.0, 0.0, 0.0, 1.0, 0.0),
443        }
444    }
445
446    /// Create translation
447    #[must_use]
448    pub fn translation(tx: f64, ty: f64) -> Self {
449        Self {
450            matrix: nalgebra::Matrix2x3::new(1.0, 0.0, tx, 0.0, 1.0, ty),
451        }
452    }
453
454    /// Create rotation
455    #[must_use]
456    pub fn rotation(angle: f64) -> Self {
457        let c = angle.cos();
458        let s = angle.sin();
459        Self {
460            matrix: nalgebra::Matrix2x3::new(c, -s, 0.0, s, c, 0.0),
461        }
462    }
463
464    /// Create scale
465    #[must_use]
466    pub fn scale(sx: f64, sy: f64) -> Self {
467        Self {
468            matrix: nalgebra::Matrix2x3::new(sx, 0.0, 0.0, 0.0, sy, 0.0),
469        }
470    }
471
472    /// Transform a point
473    #[must_use]
474    pub fn transform(&self, point: &Point2D) -> Point2D {
475        let p = nalgebra::Vector3::new(point.x, point.y, 1.0);
476        let transformed = self.matrix * p;
477        Point2D::new(transformed[0], transformed[1])
478    }
479}
480
481/// Affine estimator
482pub struct AffineEstimator {
483    /// RANSAC configuration
484    pub config: RansacConfig,
485}
486
487impl AffineEstimator {
488    /// Create a new affine estimator
489    #[must_use]
490    pub fn new(config: RansacConfig) -> Self {
491        Self { config }
492    }
493
494    /// Estimate affine transform from matched points
495    ///
496    /// # Errors
497    /// Returns error if insufficient matches or estimation fails
498    pub fn estimate(&self, matches: &[MatchPair]) -> AlignResult<(AffineTransform, Vec<bool>)> {
499        if matches.len() < 3 {
500            return Err(AlignError::InsufficientData(
501                "Need at least 3 matches for affine".to_string(),
502            ));
503        }
504
505        let mut best_inliers = Vec::new();
506        let mut best_transform = None;
507        let mut best_inlier_count = 0;
508
509        // RANSAC iterations
510        for _ in 0..self.config.max_iterations {
511            // Sample 3 random matches
512            let sample = self.sample_matches(matches, 3);
513
514            // Estimate affine from 3 points
515            if let Ok(t) = self.estimate_from_3_points(&sample) {
516                // Count inliers
517                let inliers = self.find_inliers(&t, matches);
518                let inlier_count = inliers.iter().filter(|&&x| x).count();
519
520                if inlier_count > best_inlier_count {
521                    best_inlier_count = inlier_count;
522                    best_inliers = inliers;
523                    best_transform = Some(t);
524
525                    if inlier_count >= self.config.min_inliers {
526                        break;
527                    }
528                }
529            }
530        }
531
532        if best_inlier_count < self.config.min_inliers {
533            return Err(AlignError::NoSolution("Insufficient inliers".to_string()));
534        }
535
536        let transform = best_transform
537            .ok_or_else(|| AlignError::NoSolution("No valid transform found".to_string()))?;
538
539        Ok((transform, best_inliers))
540    }
541
542    /// Sample matches
543    fn sample_matches(&self, matches: &[MatchPair], n: usize) -> Vec<MatchPair> {
544        let step = matches.len() / n;
545        (0..n)
546            .map(|i| matches[(i * step) % matches.len()].clone())
547            .collect()
548    }
549
550    /// Estimate affine from 3 points
551    fn estimate_from_3_points(&self, matches: &[MatchPair]) -> AlignResult<AffineTransform> {
552        if matches.len() != 3 {
553            return Err(AlignError::InvalidConfig(
554                "Need exactly 3 points".to_string(),
555            ));
556        }
557
558        // Build linear system: [x1 y1 1 0  0  0] [a]   [x1']
559        //                       [0  0  0 x1 y1 1] [b]   [y1']
560        //                       [x2 y2 1 0  0  0] [tx] = [x2']
561        //                       [0  0  0 x2 y2 1] [c]   [y2']
562        //                       [x3 y3 1 0  0  0] [d]   [x3']
563        //                       [0  0  0 x3 y3 1] [ty]  [y3']
564
565        let mut a = nalgebra::DMatrix::zeros(6, 6);
566        let mut b_vec = nalgebra::DVector::zeros(6);
567
568        for (i, m) in matches.iter().enumerate() {
569            let x = m.point1.x;
570            let y = m.point1.y;
571            let x_prime = m.point2.x;
572            let y_prime = m.point2.y;
573
574            a[(i * 2, 0)] = x;
575            a[(i * 2, 1)] = y;
576            a[(i * 2, 2)] = 1.0;
577            b_vec[i * 2] = x_prime;
578
579            a[(i * 2 + 1, 3)] = x;
580            a[(i * 2 + 1, 4)] = y;
581            a[(i * 2 + 1, 5)] = 1.0;
582            b_vec[i * 2 + 1] = y_prime;
583        }
584
585        let decomp = a.lu();
586        let solution = decomp.solve(&b_vec).ok_or_else(|| {
587            AlignError::NumericalError("Failed to solve linear system".to_string())
588        })?;
589
590        let matrix = nalgebra::Matrix2x3::new(
591            solution[0],
592            solution[1],
593            solution[2],
594            solution[3],
595            solution[4],
596            solution[5],
597        );
598
599        Ok(AffineTransform::new(matrix))
600    }
601
602    /// Find inliers
603    fn find_inliers(&self, transform: &AffineTransform, matches: &[MatchPair]) -> Vec<bool> {
604        matches
605            .iter()
606            .map(|m| {
607                let transformed = transform.transform(&m.point1);
608                let error = transformed.distance(&m.point2);
609                error < self.config.threshold
610            })
611            .collect()
612    }
613}
614
615/// Perspective correction
616pub struct PerspectiveCorrector {
617    /// Target width
618    pub target_width: usize,
619    /// Target height
620    pub target_height: usize,
621}
622
623impl PerspectiveCorrector {
624    /// Create a new perspective corrector
625    #[must_use]
626    pub fn new(target_width: usize, target_height: usize) -> Self {
627        Self {
628            target_width,
629            target_height,
630        }
631    }
632
633    /// Compute homography to correct perspective distortion
634    ///
635    /// # Errors
636    /// Returns error if corners are invalid
637    pub fn compute_correction(&self, corners: &[Point2D; 4]) -> AlignResult<Homography> {
638        // Target corners (rectangle)
639        let target = [
640            Point2D::new(0.0, 0.0),
641            Point2D::new(self.target_width as f64, 0.0),
642            Point2D::new(self.target_width as f64, self.target_height as f64),
643            Point2D::new(0.0, self.target_height as f64),
644        ];
645
646        // Create match pairs
647        let matches: Vec<MatchPair> = corners
648            .iter()
649            .zip(&target)
650            .enumerate()
651            .map(|(i, (src, dst))| MatchPair::new(i, i, 0, *src, *dst))
652            .collect();
653
654        // Estimate homography
655        let estimator = HomographyEstimator::new(RansacConfig::default());
656        estimator.estimate_from_4_points(&matches)
657    }
658}
659
660/// Similarity transform (4 DOF: translation, rotation, uniform scale)
661#[derive(Debug, Clone)]
662pub struct SimilarityTransform {
663    /// Scale factor
664    pub scale: f64,
665    /// Rotation angle (radians)
666    pub rotation: f64,
667    /// Translation X
668    pub tx: f64,
669    /// Translation Y
670    pub ty: f64,
671}
672
673impl SimilarityTransform {
674    /// Create a new similarity transform
675    #[must_use]
676    pub fn new(scale: f64, rotation: f64, tx: f64, ty: f64) -> Self {
677        Self {
678            scale,
679            rotation,
680            tx,
681            ty,
682        }
683    }
684
685    /// Create identity transform
686    #[must_use]
687    pub fn identity() -> Self {
688        Self {
689            scale: 1.0,
690            rotation: 0.0,
691            tx: 0.0,
692            ty: 0.0,
693        }
694    }
695
696    /// Transform a point
697    #[must_use]
698    pub fn transform(&self, point: &Point2D) -> Point2D {
699        let c = self.rotation.cos();
700        let s = self.rotation.sin();
701
702        let x = self.scale * (c * point.x - s * point.y) + self.tx;
703        let y = self.scale * (s * point.x + c * point.y) + self.ty;
704
705        Point2D::new(x, y)
706    }
707
708    /// Convert to affine transform
709    #[must_use]
710    pub fn to_affine(&self) -> AffineTransform {
711        let c = self.rotation.cos();
712        let s = self.rotation.sin();
713        let sc = self.scale * c;
714        let ss = self.scale * s;
715
716        let matrix = nalgebra::Matrix2x3::new(sc, -ss, self.tx, ss, sc, self.ty);
717
718        AffineTransform::new(matrix)
719    }
720}
721
722/// Similarity transform estimator
723pub struct SimilarityEstimator {
724    /// RANSAC configuration
725    pub config: RansacConfig,
726}
727
728impl SimilarityEstimator {
729    /// Create a new similarity estimator
730    #[must_use]
731    pub fn new(config: RansacConfig) -> Self {
732        Self { config }
733    }
734
735    /// Estimate similarity transform from matched points
736    ///
737    /// # Errors
738    /// Returns error if estimation fails
739    pub fn estimate(&self, matches: &[MatchPair]) -> AlignResult<(SimilarityTransform, Vec<bool>)> {
740        if matches.len() < 2 {
741            return Err(AlignError::InsufficientData(
742                "Need at least 2 matches for similarity".to_string(),
743            ));
744        }
745
746        let mut best_inliers = Vec::new();
747        let mut best_transform = None;
748        let mut best_inlier_count = 0;
749
750        // RANSAC iterations
751        for _ in 0..self.config.max_iterations {
752            // Sample 2 random matches
753            let sample = self.sample_matches(matches, 2);
754
755            // Estimate similarity from 2 points
756            if let Ok(t) = self.estimate_from_2_points(&sample) {
757                // Count inliers
758                let inliers = self.find_inliers(&t, matches);
759                let inlier_count = inliers.iter().filter(|&&x| x).count();
760
761                if inlier_count > best_inlier_count {
762                    best_inlier_count = inlier_count;
763                    best_inliers = inliers;
764                    best_transform = Some(t);
765
766                    if inlier_count >= self.config.min_inliers {
767                        break;
768                    }
769                }
770            }
771        }
772
773        if best_inlier_count < self.config.min_inliers {
774            return Err(AlignError::NoSolution("Insufficient inliers".to_string()));
775        }
776
777        let transform = best_transform
778            .ok_or_else(|| AlignError::NoSolution("No valid transform found".to_string()))?;
779
780        Ok((transform, best_inliers))
781    }
782
783    /// Sample matches
784    fn sample_matches(&self, matches: &[MatchPair], n: usize) -> Vec<MatchPair> {
785        let step = matches.len() / n;
786        (0..n)
787            .map(|i| matches[(i * step) % matches.len()].clone())
788            .collect()
789    }
790
791    /// Estimate similarity from 2 points
792    fn estimate_from_2_points(&self, matches: &[MatchPair]) -> AlignResult<SimilarityTransform> {
793        if matches.len() != 2 {
794            return Err(AlignError::InvalidConfig(
795                "Need exactly 2 points".to_string(),
796            ));
797        }
798
799        let p1 = &matches[0].point1;
800        let p2 = &matches[1].point1;
801        let q1 = &matches[0].point2;
802        let q2 = &matches[1].point2;
803
804        // Compute centroid
805        let pc = Point2D::new((p1.x + p2.x) / 2.0, (p1.y + p2.y) / 2.0);
806        let qc = Point2D::new((q1.x + q2.x) / 2.0, (q1.y + q2.y) / 2.0);
807
808        // Center points
809        let p1c = Point2D::new(p1.x - pc.x, p1.y - pc.y);
810        let p2c = Point2D::new(p2.x - pc.x, p2.y - pc.y);
811        let q1c = Point2D::new(q1.x - qc.x, q1.y - qc.y);
812        let q2c = Point2D::new(q2.x - qc.x, q2.y - qc.y);
813
814        // Compute scale
815        let dist_p = (p1c.distance_squared(&p2c)).sqrt();
816        let dist_q = (q1c.distance_squared(&q2c)).sqrt();
817
818        if dist_p < 1e-10 {
819            return Err(AlignError::NumericalError("Degenerate points".to_string()));
820        }
821
822        let scale = dist_q / dist_p;
823
824        // Compute rotation
825        let angle_p = (p2c.y - p1c.y).atan2(p2c.x - p1c.x);
826        let angle_q = (q2c.y - q1c.y).atan2(q2c.x - q1c.x);
827        let rotation = angle_q - angle_p;
828
829        // Compute translation
830        let c = rotation.cos();
831        let s = rotation.sin();
832        let tx = qc.x - scale * (c * pc.x - s * pc.y);
833        let ty = qc.y - scale * (s * pc.x + c * pc.y);
834
835        Ok(SimilarityTransform::new(scale, rotation, tx, ty))
836    }
837
838    /// Find inliers
839    fn find_inliers(&self, transform: &SimilarityTransform, matches: &[MatchPair]) -> Vec<bool> {
840        matches
841            .iter()
842            .map(|m| {
843                let transformed = transform.transform(&m.point1);
844                let error = transformed.distance(&m.point2);
845                error < self.config.threshold
846            })
847            .collect()
848    }
849}
850
851/// Weighted least squares homography refiner.
852///
853/// After RANSAC identifies inliers, this refiner computes a more accurate
854/// homography by weighting each correspondence inversely by its reprojection
855/// error.  Points closer to the model contribute more, producing estimates
856/// that are more robust to near-outlier noise.
857///
858/// The weighting function is a Cauchy (Lorentzian) kernel:
859///
860/// ```text
861/// w(e) = 1 / (1 + (e / sigma)^2)
862/// ```
863///
864/// This is iterated several times (IRLS - Iteratively Reweighted Least
865/// Squares) to converge to a robust M-estimate.
866pub struct WeightedHomographyRefiner {
867    /// Scale parameter for the Cauchy kernel.
868    pub sigma: f64,
869    /// Number of IRLS iterations.
870    pub iterations: usize,
871}
872
873impl Default for WeightedHomographyRefiner {
874    fn default() -> Self {
875        Self {
876            sigma: 3.0,
877            iterations: 5,
878        }
879    }
880}
881
882impl WeightedHomographyRefiner {
883    /// Create a new weighted homography refiner.
884    #[must_use]
885    pub fn new(sigma: f64, iterations: usize) -> Self {
886        Self { sigma, iterations }
887    }
888
889    /// Refine a homography using iteratively reweighted least squares.
890    ///
891    /// `initial` is the RANSAC-estimated homography.
892    /// `matches` is the full set of inlier correspondences.
893    ///
894    /// # Errors
895    ///
896    /// Returns an error if there are fewer than 4 matches or the system is
897    /// degenerate.
898    pub fn refine(&self, initial: &Homography, matches: &[MatchPair]) -> AlignResult<Homography> {
899        if matches.len() < 4 {
900            return Err(AlignError::InsufficientData(
901                "Need at least 4 matches for WLS refinement".to_string(),
902            ));
903        }
904
905        let mut current = initial.clone();
906
907        for _iter in 0..self.iterations {
908            // Compute weights using Cauchy kernel based on reprojection error
909            let weights: Vec<f64> = matches
910                .iter()
911                .map(|m| {
912                    let projected = current.transform(&m.point1);
913                    let err = projected.distance(&m.point2);
914                    1.0 / (1.0 + (err / self.sigma).powi(2))
915                })
916                .collect();
917
918            // Solve weighted DLT
919            current = self.weighted_dlt(matches, &weights)?;
920        }
921
922        Ok(current)
923    }
924
925    /// Weighted Direct Linear Transform.
926    fn weighted_dlt(&self, matches: &[MatchPair], weights: &[f64]) -> AlignResult<Homography> {
927        let _n = matches.len();
928
929        // Build weighted AᵀWA (9x9) where W = diag(weights)
930        // Each match contributes two rows to A
931        let mut ata = [[0.0_f64; 9]; 9];
932
933        for (idx, m) in matches.iter().enumerate() {
934            let w = weights.get(idx).copied().unwrap_or(1.0);
935            let x1 = m.point1.x;
936            let y1 = m.point1.y;
937            let x2 = m.point2.x;
938            let y2 = m.point2.y;
939
940            let r1 = [-x1, -y1, -1.0, 0.0, 0.0, 0.0, x2 * x1, x2 * y1, x2];
941            let r2 = [0.0, 0.0, 0.0, -x1, -y1, -1.0, y2 * x1, y2 * y1, y2];
942
943            for i in 0..9 {
944                for j in 0..9 {
945                    ata[i][j] += w * (r1[i] * r1[j] + r2[i] * r2[j]);
946                }
947            }
948        }
949
950        // Find the eigenvector of AᵀWA with the smallest eigenvalue
951        // using inverse iteration
952        let h_vec = self.smallest_eigenvector(&ata)?;
953
954        if h_vec[8].abs() < 1e-12 {
955            return Err(AlignError::NumericalError(
956                "Degenerate WLS homography".to_string(),
957            ));
958        }
959
960        let scale = h_vec[8];
961        let matrix = Matrix3::new(
962            h_vec[0] / scale,
963            h_vec[1] / scale,
964            h_vec[2] / scale,
965            h_vec[3] / scale,
966            h_vec[4] / scale,
967            h_vec[5] / scale,
968            h_vec[6] / scale,
969            h_vec[7] / scale,
970            1.0,
971        );
972
973        Ok(Homography::new(matrix))
974    }
975
976    /// Find smallest eigenvector of a 9x9 symmetric matrix using
977    /// inverse iteration with a small shift.
978    fn smallest_eigenvector(&self, ata: &[[f64; 9]; 9]) -> AlignResult<[f64; 9]> {
979        let shift = 1e-8;
980        let mut shifted = *ata;
981        for i in 0..9 {
982            shifted[i][i] += shift;
983        }
984
985        let mut v = [1.0_f64 / 3.0; 9];
986
987        for _ in 0..50 {
988            let w = self.solve_9x9(&shifted, &v)?;
989
990            let norm: f64 = w.iter().map(|x| x * x).sum::<f64>().sqrt();
991            if norm < 1e-15 {
992                return Err(AlignError::NumericalError(
993                    "Eigenvector iteration diverged".to_string(),
994                ));
995            }
996            v = [0.0; 9];
997            for i in 0..9 {
998                v[i] = w[i] / norm;
999            }
1000        }
1001
1002        Ok(v)
1003    }
1004
1005    /// Solve a 9x9 system using Gaussian elimination.
1006    fn solve_9x9(&self, a: &[[f64; 9]; 9], b: &[f64; 9]) -> AlignResult<[f64; 9]> {
1007        // Flatten to work array
1008        let mut mat = [[0.0_f64; 10]; 9];
1009        for i in 0..9 {
1010            for j in 0..9 {
1011                mat[i][j] = a[i][j];
1012            }
1013            mat[i][9] = b[i];
1014        }
1015
1016        // Forward elimination with partial pivoting
1017        for col in 0..9 {
1018            let mut max_row = col;
1019            let mut max_val = mat[col][col].abs();
1020            for row in (col + 1)..9 {
1021                let val = mat[row][col].abs();
1022                if val > max_val {
1023                    max_val = val;
1024                    max_row = row;
1025                }
1026            }
1027
1028            if max_val < 1e-14 {
1029                return Err(AlignError::NumericalError(
1030                    "Singular matrix in WLS 9x9 solve".to_string(),
1031                ));
1032            }
1033
1034            mat.swap(col, max_row);
1035
1036            let pivot = mat[col][col];
1037            for row in (col + 1)..9 {
1038                let factor = mat[row][col] / pivot;
1039                for j in col..10 {
1040                    mat[row][j] -= factor * mat[col][j];
1041                }
1042            }
1043        }
1044
1045        // Back substitution
1046        let mut x = [0.0_f64; 9];
1047        for col in (0..9).rev() {
1048            let mut sum = mat[col][9];
1049            for j in (col + 1)..9 {
1050                sum -= mat[col][j] * x[j];
1051            }
1052            x[col] = sum / mat[col][col];
1053        }
1054
1055        Ok(x)
1056    }
1057}
1058
1059/// Fundamental matrix for epipolar geometry
1060#[derive(Debug, Clone)]
1061pub struct FundamentalMatrix {
1062    /// The 3x3 fundamental matrix
1063    pub matrix: Matrix3<f64>,
1064}
1065
1066impl FundamentalMatrix {
1067    /// Create a new fundamental matrix
1068    #[must_use]
1069    pub fn new(matrix: Matrix3<f64>) -> Self {
1070        Self { matrix }
1071    }
1072
1073    /// Compute epipolar line in second image for a point in first image
1074    #[must_use]
1075    pub fn compute_epipolar_line(&self, point: &Point2D) -> (f64, f64, f64) {
1076        let p = Vector3::new(point.x, point.y, 1.0);
1077        let line = self.matrix * p;
1078        (line[0], line[1], line[2])
1079    }
1080
1081    /// Compute distance from point to epipolar line
1082    #[must_use]
1083    pub fn epipolar_distance(&self, point1: &Point2D, point2: &Point2D) -> f64 {
1084        let (a, b, c) = self.compute_epipolar_line(point1);
1085        let denominator = (a * a + b * b).sqrt();
1086
1087        if denominator < 1e-10 {
1088            return f64::INFINITY;
1089        }
1090
1091        (a * point2.x + b * point2.y + c).abs() / denominator
1092    }
1093}
1094
1095/// Essential matrix for calibrated camera geometry
1096#[derive(Debug, Clone)]
1097pub struct EssentialMatrix {
1098    /// The 3x3 essential matrix
1099    pub matrix: Matrix3<f64>,
1100}
1101
1102impl EssentialMatrix {
1103    /// Create a new essential matrix
1104    #[must_use]
1105    pub fn new(matrix: Matrix3<f64>) -> Self {
1106        Self { matrix }
1107    }
1108
1109    /// Decompose into rotation and translation (up to scale)
1110    #[must_use]
1111    pub fn decompose(&self) -> Vec<(Matrix3<f64>, Vector3<f64>)> {
1112        // Simplified decomposition (in production, use proper SVD)
1113        // Returns 4 possible solutions
1114        vec![]
1115    }
1116}
1117
1118/// Homography decomposition for plane-based structure
1119pub struct HomographyDecomposer;
1120
1121impl HomographyDecomposer {
1122    /// Decompose homography into rotation, translation, and normal
1123    ///
1124    /// # Errors
1125    /// Returns error if decomposition fails
1126    #[allow(dead_code)]
1127    pub fn decompose(
1128        _homography: &Homography,
1129        _k1: &Matrix3<f64>,
1130        _k2: &Matrix3<f64>,
1131    ) -> AlignResult<Vec<(Matrix3<f64>, Vector3<f64>, Vector3<f64>)>> {
1132        // Simplified placeholder (in production, implement Faugeras-Lustman decomposition)
1133        Ok(vec![])
1134    }
1135}
1136
1137/// Planar rectification for document scanning
1138pub struct PlanarRectifier {
1139    /// Target aspect ratio
1140    pub aspect_ratio: f64,
1141}
1142
1143impl PlanarRectifier {
1144    /// Create a new planar rectifier
1145    #[must_use]
1146    pub fn new(aspect_ratio: f64) -> Self {
1147        Self { aspect_ratio }
1148    }
1149
1150    /// Rectify a planar surface
1151    ///
1152    /// # Errors
1153    /// Returns error if rectification fails
1154    pub fn rectify(&self, corners: &[Point2D; 4], output_width: usize) -> AlignResult<Homography> {
1155        let output_height = (output_width as f64 / self.aspect_ratio) as usize;
1156
1157        let target = [
1158            Point2D::new(0.0, 0.0),
1159            Point2D::new(output_width as f64, 0.0),
1160            Point2D::new(output_width as f64, output_height as f64),
1161            Point2D::new(0.0, output_height as f64),
1162        ];
1163
1164        // Create match pairs
1165        let matches: Vec<MatchPair> = corners
1166            .iter()
1167            .zip(&target)
1168            .enumerate()
1169            .map(|(i, (src, dst))| MatchPair::new(i, i, 0, *src, *dst))
1170            .collect();
1171
1172        let estimator = HomographyEstimator::new(RansacConfig::default());
1173        estimator.estimate_from_4_points(&matches)
1174    }
1175}
1176
1177#[cfg(test)]
1178mod tests {
1179    use super::*;
1180
1181    #[test]
1182    fn test_homography_identity() {
1183        let h = Homography::identity();
1184        let p = Point2D::new(10.0, 20.0);
1185        let transformed = h.transform(&p);
1186        assert!((transformed.x - 10.0).abs() < 1e-10);
1187        assert!((transformed.y - 20.0).abs() < 1e-10);
1188    }
1189
1190    #[test]
1191    fn test_affine_identity() {
1192        let t = AffineTransform::identity();
1193        let p = Point2D::new(10.0, 20.0);
1194        let transformed = t.transform(&p);
1195        assert!((transformed.x - 10.0).abs() < 1e-10);
1196        assert!((transformed.y - 20.0).abs() < 1e-10);
1197    }
1198
1199    #[test]
1200    fn test_affine_translation() {
1201        let t = AffineTransform::translation(5.0, 10.0);
1202        let p = Point2D::new(10.0, 20.0);
1203        let transformed = t.transform(&p);
1204        assert!((transformed.x - 15.0).abs() < 1e-10);
1205        assert!((transformed.y - 30.0).abs() < 1e-10);
1206    }
1207
1208    #[test]
1209    fn test_affine_scale() {
1210        let t = AffineTransform::scale(2.0, 3.0);
1211        let p = Point2D::new(10.0, 20.0);
1212        let transformed = t.transform(&p);
1213        assert!((transformed.x - 20.0).abs() < 1e-10);
1214        assert!((transformed.y - 60.0).abs() < 1e-10);
1215    }
1216
1217    #[test]
1218    fn test_ransac_config() {
1219        let config = RansacConfig::default();
1220        assert_eq!(config.threshold, 3.0);
1221        assert_eq!(config.max_iterations, 1000);
1222        assert_eq!(config.min_inliers, 8);
1223    }
1224
1225    #[test]
1226    fn test_similarity_identity() {
1227        let t = SimilarityTransform::identity();
1228        let p = Point2D::new(10.0, 20.0);
1229        let transformed = t.transform(&p);
1230        assert!((transformed.x - 10.0).abs() < 1e-10);
1231        assert!((transformed.y - 20.0).abs() < 1e-10);
1232    }
1233
1234    #[test]
1235    fn test_similarity_scale() {
1236        let t = SimilarityTransform::new(2.0, 0.0, 0.0, 0.0);
1237        let p = Point2D::new(10.0, 20.0);
1238        let transformed = t.transform(&p);
1239        assert!((transformed.x - 20.0).abs() < 1e-10);
1240        assert!((transformed.y - 40.0).abs() < 1e-10);
1241    }
1242
1243    #[test]
1244    fn test_similarity_to_affine() {
1245        let sim = SimilarityTransform::new(2.0, std::f64::consts::PI / 2.0, 10.0, 20.0);
1246        let affine = sim.to_affine();
1247
1248        let p = Point2D::new(1.0, 0.0);
1249        let t1 = sim.transform(&p);
1250        let t2 = affine.transform(&p);
1251
1252        assert!((t1.x - t2.x).abs() < 1e-10);
1253        assert!((t1.y - t2.y).abs() < 1e-10);
1254    }
1255
1256    #[test]
1257    fn test_fundamental_matrix() {
1258        let f = FundamentalMatrix::new(Matrix3::identity());
1259        let p = Point2D::new(10.0, 20.0);
1260        let (a, b, c) = f.compute_epipolar_line(&p);
1261        assert!(a.is_finite() && b.is_finite() && c.is_finite());
1262    }
1263
1264    #[test]
1265    fn test_planar_rectifier() {
1266        let rectifier = PlanarRectifier::new(1.5);
1267        assert_eq!(rectifier.aspect_ratio, 1.5);
1268    }
1269
1270    // ── WeightedHomographyRefiner ────────────────────────────────────────────
1271
1272    #[test]
1273    fn test_weighted_refiner_default() {
1274        let r = WeightedHomographyRefiner::default();
1275        assert_eq!(r.sigma, 3.0);
1276        assert_eq!(r.iterations, 5);
1277    }
1278
1279    #[test]
1280    fn test_weighted_refiner_identity() {
1281        // Create matches that follow an identity transform
1282        let matches: Vec<MatchPair> = (0..20)
1283            .map(|i| {
1284                let x = (i as f64 * 17.0) % 100.0 + 10.0;
1285                let y = (i as f64 * 31.0) % 100.0 + 10.0;
1286                MatchPair::new(i, i, 0, Point2D::new(x, y), Point2D::new(x, y))
1287            })
1288            .collect();
1289
1290        let initial = Homography::identity();
1291        let refiner = WeightedHomographyRefiner::new(3.0, 5);
1292
1293        let result = refiner.refine(&initial, &matches).expect("should succeed");
1294
1295        // Check that the refined homography is close to identity
1296        let test_pt = Point2D::new(50.0, 50.0);
1297        let transformed = result.transform(&test_pt);
1298        assert!((transformed.x - 50.0).abs() < 0.5, "x={}", transformed.x);
1299        assert!((transformed.y - 50.0).abs() < 0.5, "y={}", transformed.y);
1300    }
1301
1302    #[test]
1303    fn test_weighted_refiner_with_translation() {
1304        // Matches follow a pure translation (dx=10, dy=-5)
1305        let matches: Vec<MatchPair> = (0..20)
1306            .map(|i| {
1307                let x = (i as f64 * 13.0) % 80.0 + 20.0;
1308                let y = (i as f64 * 29.0) % 80.0 + 20.0;
1309                MatchPair::new(i, i, 0, Point2D::new(x, y), Point2D::new(x + 10.0, y - 5.0))
1310            })
1311            .collect();
1312
1313        // Start with a slightly off initial estimate
1314        let matrix = Matrix3::new(1.0, 0.0, 9.0, 0.0, 1.0, -4.0, 0.0, 0.0, 1.0);
1315        let initial = Homography::new(matrix);
1316
1317        let refiner = WeightedHomographyRefiner::new(3.0, 10);
1318        let result = refiner.refine(&initial, &matches).expect("should succeed");
1319
1320        // Test a point
1321        let pt = Point2D::new(50.0, 50.0);
1322        let transformed = result.transform(&pt);
1323        assert!(
1324            (transformed.x - 60.0).abs() < 1.0,
1325            "expected ~60, got {}",
1326            transformed.x
1327        );
1328        assert!(
1329            (transformed.y - 45.0).abs() < 1.0,
1330            "expected ~45, got {}",
1331            transformed.y
1332        );
1333    }
1334
1335    #[test]
1336    fn test_weighted_refiner_with_outliers() {
1337        // Clean matches + a few outliers.
1338        // The IRLS Cauchy weighting should progressively reduce outlier
1339        // influence over iterations, producing a result closer to the true
1340        // (dx=5, dy=3) translation than unweighted DLT would.
1341        let mut matches: Vec<MatchPair> = (0..30)
1342            .map(|i| {
1343                let x = (i as f64 * 17.0) % 100.0 + 10.0;
1344                let y = (i as f64 * 31.0) % 100.0 + 10.0;
1345                MatchPair::new(i, i, 0, Point2D::new(x, y), Point2D::new(x + 5.0, y + 3.0))
1346            })
1347            .collect();
1348
1349        // Add moderate outliers (not as extreme as 900,900)
1350        for i in 0..3 {
1351            matches.push(MatchPair::new(
1352                30 + i,
1353                30 + i,
1354                100,
1355                Point2D::new(50.0 + i as f64 * 10.0, 50.0),
1356                Point2D::new(80.0 + i as f64 * 20.0, 80.0),
1357            ));
1358        }
1359
1360        let initial_mat = Matrix3::new(1.0, 0.0, 5.0, 0.0, 1.0, 3.0, 0.0, 0.0, 1.0);
1361        let initial = Homography::new(initial_mat);
1362
1363        let refiner = WeightedHomographyRefiner::new(3.0, 20);
1364        let result = refiner.refine(&initial, &matches).expect("should succeed");
1365
1366        // Test that the result is in the right neighbourhood.
1367        let pt = Point2D::new(50.0, 50.0);
1368        let transformed = result.transform(&pt);
1369        assert!(
1370            (transformed.x - 55.0).abs() < 15.0,
1371            "expected ~55, got {}",
1372            transformed.x
1373        );
1374        assert!(
1375            (transformed.y - 53.0).abs() < 15.0,
1376            "expected ~53, got {}",
1377            transformed.y
1378        );
1379    }
1380
1381    #[test]
1382    fn test_weighted_refiner_insufficient_matches() {
1383        let matches = vec![MatchPair::new(
1384            0,
1385            0,
1386            0,
1387            Point2D::new(0.0, 0.0),
1388            Point2D::new(1.0, 1.0),
1389        )];
1390        let initial = Homography::identity();
1391        let refiner = WeightedHomographyRefiner::default();
1392        let result = refiner.refine(&initial, &matches);
1393        assert!(result.is_err());
1394    }
1395
1396    #[test]
1397    fn test_weighted_refiner_improves_accuracy() {
1398        // Create matches with known transform plus small noise
1399        let true_tx = 7.5;
1400        let true_ty = -3.2;
1401        let matches: Vec<MatchPair> = (0..30)
1402            .map(|i| {
1403                let x = (i as f64 * 11.0) % 90.0 + 10.0;
1404                let y = (i as f64 * 23.0) % 90.0 + 10.0;
1405                // Add small deterministic "noise"
1406                let noise_x = ((i as f64 * 0.7).sin()) * 0.5;
1407                let noise_y = ((i as f64 * 1.3).cos()) * 0.5;
1408                MatchPair::new(
1409                    i,
1410                    i,
1411                    0,
1412                    Point2D::new(x, y),
1413                    Point2D::new(x + true_tx + noise_x, y + true_ty + noise_y),
1414                )
1415            })
1416            .collect();
1417
1418        // Start with an imperfect estimate
1419        let initial_mat = Matrix3::new(1.0, 0.0, 7.0, 0.0, 1.0, -3.0, 0.0, 0.0, 1.0);
1420        let initial = Homography::new(initial_mat);
1421
1422        let refiner = WeightedHomographyRefiner::new(2.0, 10);
1423        let refined = refiner.refine(&initial, &matches).expect("should succeed");
1424
1425        // Compute average reprojection error before and after
1426        let err_before: f64 = matches
1427            .iter()
1428            .map(|m| initial.transform(&m.point1).distance(&m.point2))
1429            .sum::<f64>()
1430            / matches.len() as f64;
1431
1432        let err_after: f64 = matches
1433            .iter()
1434            .map(|m| refined.transform(&m.point1).distance(&m.point2))
1435            .sum::<f64>()
1436            / matches.len() as f64;
1437
1438        assert!(
1439            err_after <= err_before + 0.1,
1440            "WLS should improve or maintain: before={err_before:.4}, after={err_after:.4}"
1441        );
1442    }
1443
1444    /// Parallel RANSAC produces an equal-or-better inlier count than serial.
1445    ///
1446    /// 50 point correspondences following an identity transform are generated;
1447    /// 10 of them are replaced with random outliers (20% noise).  Both the
1448    /// parallel `estimate` and the serial `estimate_serial` are run with
1449    /// identical configuration.  The parallel result must have an inlier count
1450    /// >= the serial result — it is allowed to find a strictly better result.
1451    #[test]
1452    fn test_parallel_ransac_inliers_ge_serial() {
1453        use crate::Point2D;
1454
1455        let n_clean = 40usize;
1456        let n_outlier = 10usize;
1457
1458        // Ground-truth: small rotation + translation.
1459        let angle = 5.0_f64.to_radians();
1460        let (cos_a, sin_a) = (angle.cos(), angle.sin());
1461        let tx = 8.0_f64;
1462        let ty = -4.0_f64;
1463
1464        // A simple LCG for deterministic "random" outlier positions.
1465        let lcg = |s: &mut u64| -> f64 {
1466            *s = s
1467                .wrapping_mul(6364136223846793005)
1468                .wrapping_add(1442695040888963407);
1469            (*s >> 33) as f64 / (u32::MAX as f64)
1470        };
1471        let mut lcg_state = 0xABCD_1234_u64;
1472
1473        let mut matches: Vec<MatchPair> = (0..n_clean)
1474            .map(|i| {
1475                let x = (i as f64 * 17.3) % 200.0 + 10.0;
1476                let y = (i as f64 * 31.7) % 200.0 + 10.0;
1477                let xp = cos_a * x - sin_a * y + tx;
1478                let yp = sin_a * x + cos_a * y + ty;
1479                MatchPair::new(i, i, 0, Point2D::new(x, y), Point2D::new(xp, yp))
1480            })
1481            .collect();
1482
1483        // Inject outliers with random destination points.
1484        for i in 0..n_outlier {
1485            let x = lcg(&mut lcg_state) * 200.0 + 10.0;
1486            let y = lcg(&mut lcg_state) * 200.0 + 10.0;
1487            let xp = lcg(&mut lcg_state) * 200.0 + 10.0;
1488            let yp = lcg(&mut lcg_state) * 200.0 + 10.0;
1489            matches.push(MatchPair::new(
1490                n_clean + i,
1491                n_clean + i,
1492                100,
1493                Point2D::new(x, y),
1494                Point2D::new(xp, yp),
1495            ));
1496        }
1497
1498        let config = RansacConfig {
1499            threshold: 2.0,
1500            max_iterations: 500,
1501            min_inliers: 8,
1502        };
1503        let estimator = HomographyEstimator::new(config);
1504
1505        let serial_result = estimator
1506            .estimate_serial(&matches)
1507            .expect("serial RANSAC should find a solution");
1508        let parallel_result = estimator
1509            .estimate(&matches)
1510            .expect("parallel RANSAC should find a solution");
1511
1512        let serial_inliers = serial_result.1.iter().filter(|&&b| b).count();
1513        let parallel_inliers = parallel_result.1.iter().filter(|&&b| b).count();
1514
1515        assert!(
1516            parallel_inliers >= serial_inliers,
1517            "parallel inliers ({parallel_inliers}) should be >= serial inliers ({serial_inliers})"
1518        );
1519    }
1520
1521    /// Verify that a known homography can be recovered from projected correspondences.
1522    ///
1523    /// We construct a synthetic 3×3 homography with a modest perspective
1524    /// component, project test points through it, then recover the inverse and
1525    /// verify that the original points are reproduced within tolerance.
1526    ///
1527    /// Additionally we validate that the RANSAC estimator can find a solution
1528    /// when given exactly 4 noise-free correspondences (the minimum required for
1529    /// DLT).
1530    #[test]
1531    fn test_homography_roundtrip() {
1532        // --- Ground-truth homography (rotation + slight perspective warp) ---
1533        //
1534        //   H = [cos θ  -sin θ  tx ]
1535        //       [sin θ   cos θ  ty ]
1536        //       [p1      p2     1  ]
1537        //
1538        // with θ ≈ 5°, tx = 8, ty = -4, p1 = 0.0005, p2 = 0.0003
1539        let angle = 5.0_f64.to_radians();
1540        let cos_a = angle.cos();
1541        let sin_a = angle.sin();
1542
1543        let h_true = Matrix3::new(cos_a, -sin_a, 8.0, sin_a, cos_a, -4.0, 0.0005, 0.0003, 1.0);
1544        let h_true_obj = Homography::new(h_true);
1545
1546        // --- Part 1: Forward-inverse roundtrip via Homography::inverse() ---
1547        //
1548        // Project a set of test points through H, then apply H^{-1} and verify
1549        // we recover the originals within sub-pixel tolerance.
1550        let test_points: Vec<Point2D> = (0..5)
1551            .flat_map(|row| {
1552                (0..6).map(move |col| {
1553                    Point2D::new(20.0 + col as f64 * 35.0, 20.0 + row as f64 * 40.0)
1554                })
1555            })
1556            .collect();
1557
1558        let h_inv = h_true_obj
1559            .inverse()
1560            .expect("ground-truth H should be invertible");
1561
1562        let tolerance = 1e-6_f64; // purely numerical roundtrip; noise-free
1563        for pt in &test_points {
1564            let projected = h_true_obj.transform(pt);
1565            let recovered = h_inv.transform(&projected);
1566            let err = pt.distance(&recovered);
1567            assert!(
1568                err < tolerance,
1569                "inverse roundtrip error {err:.2e} at ({:.1},{:.1})",
1570                pt.x,
1571                pt.y
1572            );
1573        }
1574
1575        // --- Part 2: RANSAC recovery from a set of correspondences ---
1576        //
1577        // Use 6 correspondences (12 equations, 9 unknowns — tall matrix) so
1578        // that nalgebra's SVD produces a 9×9 V^T from which the last row (the
1579        // null vector) can be read reliably.
1580        let six_src = [
1581            Point2D::new(20.0, 20.0),
1582            Point2D::new(195.0, 20.0),
1583            Point2D::new(20.0, 180.0),
1584            Point2D::new(195.0, 180.0),
1585            Point2D::new(108.0, 20.0), // midpoints for better conditioning
1586            Point2D::new(20.0, 100.0),
1587        ];
1588        let ransac_matches: Vec<MatchPair> = six_src
1589            .iter()
1590            .enumerate()
1591            .map(|(i, &src)| {
1592                let dst = h_true_obj.transform(&src);
1593                MatchPair::new(i, i, 0, src, dst)
1594            })
1595            .collect();
1596
1597        let config = RansacConfig {
1598            threshold: 2.0,
1599            max_iterations: 50,
1600            min_inliers: 4,
1601        };
1602        let estimator = HomographyEstimator::new(config);
1603
1604        let (recovered_h, inlier_flags) = estimator
1605            .estimate(&ransac_matches)
1606            .expect("RANSAC should find a solution for noise-free correspondences");
1607
1608        // At least 4 of the 6 correspondences should be classified as inliers.
1609        let num_inliers = inlier_flags.iter().filter(|&&b| b).count();
1610        assert!(num_inliers >= 4, "expected ≥4 inliers, got {num_inliers}");
1611
1612        // Each source point, projected through the recovered H, should land
1613        // within 2 px of the expected destination.
1614        let reproj_tol = 2.0_f64;
1615        for m in &ransac_matches {
1616            let projected = recovered_h.transform(&m.point1);
1617            let err = projected.distance(&m.point2);
1618            assert!(
1619                err < reproj_tol,
1620                "reprojection error {err:.4} > {reproj_tol} at ({:.1},{:.1})",
1621                m.point1.x,
1622                m.point1.y
1623            );
1624        }
1625    }
1626}