Skip to main content

oximedia_align/
prosac.rs

1//! PROSAC (PROgressive SAmple Consensus) for robust model estimation.
2//!
3//! PROSAC is an improvement over RANSAC that exploits a quality ordering of
4//! the input correspondences. Instead of sampling uniformly from all matches,
5//! PROSAC starts by drawing from the top-ranked matches and progressively
6//! expands the sampling pool. This typically converges much faster than RANSAC
7//! when the matches are sorted by descriptor distance or response strength.
8//!
9//! # Algorithm
10//!
11//! 1. Sort matches by quality (ascending descriptor distance).
12//! 2. Maintain a growing subset size `n` (starting from `min_sample_size`).
13//! 3. At each iteration, with high probability draw at least one sample from
14//!    the `n`-th match (the "growth" step), and the rest from the top `n-1`.
15//! 4. If the model from the current sample has more inliers than the best so
16//!    far, update the best model.
17//! 5. Adaptively update `n` based on the inlier ratio to expand the pool.
18//!
19//! # References
20//!
21//! - Chum, O. and Matas, J. "Matching with PROSAC - Progressive Sample Consensus"
22//!   CVPR 2005.
23
24use crate::features::MatchPair;
25use crate::{AlignError, AlignResult};
26
27/// Configuration for PROSAC.
28#[derive(Debug, Clone)]
29pub struct ProsacConfig {
30    /// Distance threshold for counting inliers (in pixels).
31    pub inlier_threshold: f64,
32    /// Maximum number of iterations.
33    pub max_iterations: usize,
34    /// Minimum number of inliers required for a valid model.
35    pub min_inliers: usize,
36    /// Confidence level (0.0 to 1.0). Higher values mean more iterations.
37    pub confidence: f64,
38    /// Initial subset size (must be >= min_sample_size).
39    /// If `None`, starts at `min_sample_size`.
40    pub initial_n: Option<usize>,
41}
42
43impl Default for ProsacConfig {
44    fn default() -> Self {
45        Self {
46            inlier_threshold: 3.0,
47            max_iterations: 2000,
48            min_inliers: 8,
49            confidence: 0.99,
50            initial_n: None,
51        }
52    }
53}
54
55/// Model type that PROSAC can estimate.
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum ProsacModelType {
58    /// Affine transform (minimum 3 points).
59    Affine,
60    /// Homography (minimum 4 points).
61    Homography,
62}
63
64impl ProsacModelType {
65    /// Minimum number of samples required to fit this model.
66    #[must_use]
67    pub fn min_samples(&self) -> usize {
68        match self {
69            Self::Affine => 3,
70            Self::Homography => 4,
71        }
72    }
73}
74
75/// Result of a PROSAC estimation.
76#[derive(Debug, Clone)]
77pub struct ProsacResult {
78    /// The estimated model parameters (as a flat vector).
79    /// For Affine: [a, b, tx, c, d, ty] (6 elements).
80    /// For Homography: [h00..h22] (9 elements, row-major 3x3).
81    pub params: Vec<f64>,
82    /// Inlier mask (true = inlier).
83    pub inlier_mask: Vec<bool>,
84    /// Number of inliers.
85    pub num_inliers: usize,
86    /// Number of iterations performed.
87    pub iterations: usize,
88}
89
90/// PROSAC estimator.
91pub struct ProsacEstimator {
92    /// Configuration.
93    pub config: ProsacConfig,
94    /// Model type.
95    pub model_type: ProsacModelType,
96}
97
98impl ProsacEstimator {
99    /// Create a new PROSAC estimator.
100    #[must_use]
101    pub fn new(config: ProsacConfig, model_type: ProsacModelType) -> Self {
102        Self { config, model_type }
103    }
104
105    /// Run PROSAC on the given matches.
106    ///
107    /// Matches should be pre-sorted by quality (ascending descriptor distance
108    /// = best first). If they are not sorted, PROSAC degrades gracefully to
109    /// RANSAC-like behaviour.
110    ///
111    /// # Errors
112    ///
113    /// Returns an error if there are insufficient matches.
114    pub fn estimate(&self, matches: &[MatchPair]) -> AlignResult<ProsacResult> {
115        let min_s = self.model_type.min_samples();
116
117        if matches.len() < min_s {
118            return Err(AlignError::InsufficientData(format!(
119                "Need at least {min_s} matches, got {}",
120                matches.len()
121            )));
122        }
123
124        let total = matches.len();
125        let mut n = self.config.initial_n.unwrap_or(min_s).max(min_s).min(total);
126        let mut best_inliers = 0usize;
127        let mut best_mask: Vec<bool> = vec![false; total];
128        let mut best_params: Vec<f64> = Vec::new();
129        let mut best_iter = 0;
130
131        // Deterministic PRNG seed
132        let mut rng_state = 0x1234_5678_u64;
133
134        // Growth function: how many iterations to spend at subset size n
135        // before expanding to n+1.  We use the simplified formula from the
136        // PROSAC paper:  T(n) = T(N) * C(n, m) / C(N, m)
137        // where m = min_samples, N = total.
138        // For practical purposes we approximate with a linear growth schedule.
139        let mut t_n = 1.0_f64; // iterations at current n
140        let mut t_n_prime = 0.0_f64; // fractional accumulator
141
142        for iter in 0..self.config.max_iterations {
143            // Progressive sampling: decide whether to include the n-th point
144            t_n_prime += 1.0;
145
146            if t_n_prime >= t_n && n < total {
147                n += 1;
148                // Update T(n) using the ratio formula
149                let ratio = if n > min_s {
150                    (n - min_s) as f64 / n as f64
151                } else {
152                    1.0
153                };
154                t_n *= 1.0 + ratio;
155                t_n_prime = 0.0;
156            }
157
158            // Sample min_s points from the top n matches
159            let sample = self.sample_from_top_n(matches, n, min_s, &mut rng_state);
160
161            // Fit model
162            let params = match self.model_type {
163                ProsacModelType::Affine => self.fit_affine(&sample),
164                ProsacModelType::Homography => self.fit_homography(&sample),
165            };
166
167            let params = match params {
168                Ok(p) => p,
169                Err(_) => continue,
170            };
171
172            // Count inliers
173            let (mask, count) = self.count_inliers(matches, &params);
174
175            if count > best_inliers {
176                best_inliers = count;
177                best_mask = mask;
178                best_params = params;
179                best_iter = iter;
180
181                // Adaptive termination: if we have enough inliers, check if
182                // we can stop early based on the confidence level.
183                if best_inliers >= self.config.min_inliers {
184                    let inlier_ratio = best_inliers as f64 / total as f64;
185                    let expected_iters =
186                        adaptive_max_iterations(inlier_ratio, min_s, self.config.confidence);
187                    if iter as f64 >= expected_iters {
188                        break;
189                    }
190                }
191            }
192        }
193
194        if best_inliers < self.config.min_inliers {
195            return Err(AlignError::NoSolution(format!(
196                "Insufficient inliers: {best_inliers} < {}",
197                self.config.min_inliers
198            )));
199        }
200
201        // Refine with all inliers
202        let inlier_matches: Vec<&MatchPair> = matches
203            .iter()
204            .zip(&best_mask)
205            .filter(|(_, &is_inlier)| is_inlier)
206            .map(|(m, _)| m)
207            .collect();
208
209        let refined_params = match self.model_type {
210            ProsacModelType::Affine => {
211                let pairs: Vec<MatchPair> = inlier_matches.iter().map(|m| (*m).clone()).collect();
212                self.fit_affine(&pairs).unwrap_or(best_params.clone())
213            }
214            ProsacModelType::Homography => {
215                let pairs: Vec<MatchPair> = inlier_matches.iter().map(|m| (*m).clone()).collect();
216                self.fit_homography(&pairs).unwrap_or(best_params.clone())
217            }
218        };
219
220        // Recount inliers with refined model
221        let (final_mask, final_count) = self.count_inliers(matches, &refined_params);
222
223        Ok(ProsacResult {
224            params: refined_params,
225            inlier_mask: final_mask,
226            num_inliers: final_count,
227            iterations: best_iter + 1,
228        })
229    }
230
231    // -- Sampling -------------------------------------------------------------
232
233    fn sample_from_top_n(
234        &self,
235        matches: &[MatchPair],
236        n: usize,
237        count: usize,
238        rng: &mut u64,
239    ) -> Vec<MatchPair> {
240        let pool_size = n.min(matches.len());
241        let mut indices = Vec::with_capacity(count);
242
243        while indices.len() < count {
244            let idx = lcg_next(rng) as usize % pool_size;
245            if !indices.contains(&idx) {
246                indices.push(idx);
247            }
248        }
249
250        indices.iter().map(|&i| matches[i].clone()).collect()
251    }
252
253    // -- Model fitting --------------------------------------------------------
254
255    fn fit_affine(&self, matches: &[MatchPair]) -> AlignResult<Vec<f64>> {
256        if matches.len() < 3 {
257            return Err(AlignError::InsufficientData(
258                "Need >= 3 points for affine".to_string(),
259            ));
260        }
261
262        // Least squares: [x y 1 0 0 0] [a]   [x']
263        //                 [0 0 0 x y 1] [b] = [y']
264        //                                      [tx]
265        //                                      [c]
266        //                                      [d]
267        //                                      [ty]
268        let n = matches.len();
269        let _rows = n * 2;
270
271        // Build ATA (6x6) and ATb (6x1) incrementally
272        let mut ata = [0.0_f64; 36];
273        let mut atb = [0.0_f64; 6];
274
275        for m in matches {
276            let x = m.point1.x;
277            let y = m.point1.y;
278            let xp = m.point2.x;
279            let yp = m.point2.y;
280
281            // Row 1: [x, y, 1, 0, 0, 0] -> xp
282            let r1 = [x, y, 1.0, 0.0, 0.0, 0.0];
283            // Row 2: [0, 0, 0, x, y, 1] -> yp
284            let r2 = [0.0, 0.0, 0.0, x, y, 1.0];
285
286            for i in 0..6 {
287                for j in 0..6 {
288                    ata[i * 6 + j] += r1[i] * r1[j] + r2[i] * r2[j];
289                }
290                atb[i] += r1[i] * xp + r2[i] * yp;
291            }
292        }
293
294        // Solve 6x6 system using Cramer's/Gaussian elimination
295        let solution = solve_6x6(&ata, &atb)?;
296
297        Ok(solution.to_vec())
298    }
299
300    fn fit_homography(&self, matches: &[MatchPair]) -> AlignResult<Vec<f64>> {
301        if matches.len() < 4 {
302            return Err(AlignError::InsufficientData(
303                "Need >= 4 points for homography".to_string(),
304            ));
305        }
306
307        // Use DLT: build 2n x 9 matrix and find null space via SVD-like approach.
308        // For simplicity, we use the iterative power method for the smallest
309        // singular value.
310
311        // Normalize points for numerical stability
312        let (norm1, t1) = normalize_points(matches, true);
313        let (norm2, t2) = normalize_points(matches, false);
314
315        let n = matches.len();
316        // Build ATA (9x9) = sum of a_i * a_i^T
317        let mut ata = [0.0_f64; 81];
318
319        for i in 0..n {
320            let x = norm1[i].0;
321            let y = norm1[i].1;
322            let xp = norm2[i].0;
323            let yp = norm2[i].1;
324
325            let r1 = [-x, -y, -1.0, 0.0, 0.0, 0.0, xp * x, xp * y, xp];
326            let r2 = [0.0, 0.0, 0.0, -x, -y, -1.0, yp * x, yp * y, yp];
327
328            for a in 0..9 {
329                for b in 0..9 {
330                    ata[a * 9 + b] += r1[a] * r1[b] + r2[a] * r2[b];
331                }
332            }
333        }
334
335        // Find eigenvector of ATA with smallest eigenvalue using inverse iteration
336        let h_norm = find_smallest_eigenvector_9x9(&ata)?;
337
338        // Denormalize: H = T2_inv * H_norm * T1
339        let h = denormalize_homography(&h_norm, &t1, &t2);
340
341        // Normalize so h[8] = 1
342        if h[8].abs() < 1e-12 {
343            return Err(AlignError::NumericalError(
344                "Degenerate homography".to_string(),
345            ));
346        }
347
348        let scale = h[8];
349        Ok(h.iter().map(|&v| v / scale).collect())
350    }
351
352    // -- Inlier counting ------------------------------------------------------
353
354    fn count_inliers(&self, matches: &[MatchPair], params: &[f64]) -> (Vec<bool>, usize) {
355        let threshold_sq = self.config.inlier_threshold * self.config.inlier_threshold;
356        let mut mask = vec![false; matches.len()];
357        let mut count = 0usize;
358
359        for (i, m) in matches.iter().enumerate() {
360            let projected = self.project_point(m.point1.x, m.point1.y, params);
361            let dx = projected.0 - m.point2.x;
362            let dy = projected.1 - m.point2.y;
363            let err_sq = dx * dx + dy * dy;
364
365            if err_sq < threshold_sq {
366                mask[i] = true;
367                count += 1;
368            }
369        }
370
371        (mask, count)
372    }
373
374    fn project_point(&self, x: f64, y: f64, params: &[f64]) -> (f64, f64) {
375        match self.model_type {
376            ProsacModelType::Affine => {
377                if params.len() < 6 {
378                    return (x, y);
379                }
380                let xp = params[0] * x + params[1] * y + params[2];
381                let yp = params[3] * x + params[4] * y + params[5];
382                (xp, yp)
383            }
384            ProsacModelType::Homography => {
385                if params.len() < 9 {
386                    return (x, y);
387                }
388                let w = params[6] * x + params[7] * y + params[8];
389                if w.abs() < 1e-12 {
390                    return (x, y);
391                }
392                let xp = (params[0] * x + params[1] * y + params[2]) / w;
393                let yp = (params[3] * x + params[4] * y + params[5]) / w;
394                (xp, yp)
395            }
396        }
397    }
398}
399
400// -- Adaptive iteration count -------------------------------------------------
401
402fn adaptive_max_iterations(inlier_ratio: f64, min_samples: usize, confidence: f64) -> f64 {
403    if inlier_ratio <= 0.0 || inlier_ratio >= 1.0 {
404        return 1.0;
405    }
406
407    let num = (1.0 - confidence).ln();
408    let denom = (1.0 - inlier_ratio.powi(min_samples as i32)).ln();
409
410    if denom.abs() < 1e-15 {
411        return 1.0;
412    }
413
414    num / denom
415}
416
417// -- LCG PRNG -----------------------------------------------------------------
418
419fn lcg_next(state: &mut u64) -> u64 {
420    *state = state
421        .wrapping_mul(6364136223846793005)
422        .wrapping_add(1442695040888963407);
423    *state >> 33
424}
425
426// -- Linear algebra helpers ---------------------------------------------------
427
428/// Solve a 6x6 linear system Ax = b using Gaussian elimination with partial pivoting.
429fn solve_6x6(ata: &[f64; 36], atb: &[f64; 6]) -> AlignResult<[f64; 6]> {
430    let mut a = *ata;
431    let mut b = *atb;
432
433    // Forward elimination
434    for col in 0..6 {
435        // Partial pivoting
436        let mut max_row = col;
437        let mut max_val = a[col * 6 + col].abs();
438        for row in (col + 1)..6 {
439            let val = a[row * 6 + col].abs();
440            if val > max_val {
441                max_val = val;
442                max_row = row;
443            }
444        }
445
446        if max_val < 1e-12 {
447            return Err(AlignError::NumericalError(
448                "Singular matrix in 6x6 solve".to_string(),
449            ));
450        }
451
452        // Swap rows
453        if max_row != col {
454            for j in 0..6 {
455                a.swap(col * 6 + j, max_row * 6 + j);
456            }
457            b.swap(col, max_row);
458        }
459
460        // Eliminate below
461        let pivot = a[col * 6 + col];
462        for row in (col + 1)..6 {
463            let factor = a[row * 6 + col] / pivot;
464            for j in col..6 {
465                a[row * 6 + j] -= factor * a[col * 6 + j];
466            }
467            b[row] -= factor * b[col];
468        }
469    }
470
471    // Back substitution
472    let mut x = [0.0_f64; 6];
473    for col in (0..6).rev() {
474        let mut sum = b[col];
475        for j in (col + 1)..6 {
476            sum -= a[col * 6 + j] * x[j];
477        }
478        x[col] = sum / a[col * 6 + col];
479    }
480
481    Ok(x)
482}
483
484/// Normalize 2D points: translate centroid to origin, scale so avg distance = sqrt(2).
485fn normalize_points(matches: &[MatchPair], use_first: bool) -> (Vec<(f64, f64)>, [f64; 9]) {
486    let pts: Vec<(f64, f64)> = if use_first {
487        matches.iter().map(|m| (m.point1.x, m.point1.y)).collect()
488    } else {
489        matches.iter().map(|m| (m.point2.x, m.point2.y)).collect()
490    };
491
492    let n = pts.len() as f64;
493    let cx: f64 = pts.iter().map(|p| p.0).sum::<f64>() / n;
494    let cy: f64 = pts.iter().map(|p| p.1).sum::<f64>() / n;
495
496    let avg_dist: f64 = pts
497        .iter()
498        .map(|p| ((p.0 - cx).powi(2) + (p.1 - cy).powi(2)).sqrt())
499        .sum::<f64>()
500        / n;
501
502    let s = if avg_dist > 1e-10 {
503        std::f64::consts::SQRT_2 / avg_dist
504    } else {
505        1.0
506    };
507
508    let normalized: Vec<(f64, f64)> = pts
509        .iter()
510        .map(|p| ((p.0 - cx) * s, (p.1 - cy) * s))
511        .collect();
512
513    // T = [s 0 -s*cx; 0 s -s*cy; 0 0 1]
514    let t = [s, 0.0, -s * cx, 0.0, s, -s * cy, 0.0, 0.0, 1.0];
515
516    (normalized, t)
517}
518
519/// Denormalize homography: H = T2_inv * Hn * T1
520fn denormalize_homography(h_norm: &[f64; 9], t1: &[f64; 9], t2: &[f64; 9]) -> [f64; 9] {
521    // T2_inv
522    let s2 = t2[0];
523    let tx2 = t2[2];
524    let ty2 = t2[5];
525
526    let t2_inv = if s2.abs() > 1e-15 {
527        let inv_s = 1.0 / s2;
528        [
529            inv_s,
530            0.0,
531            -tx2 * inv_s,
532            0.0,
533            inv_s,
534            -ty2 * inv_s,
535            0.0,
536            0.0,
537            1.0,
538        ]
539    } else {
540        [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
541    };
542
543    let tmp = mat3_mul(&t2_inv, h_norm);
544    mat3_mul(&tmp, t1)
545}
546
547fn mat3_mul(a: &[f64; 9], b: &[f64; 9]) -> [f64; 9] {
548    let mut c = [0.0_f64; 9];
549    for i in 0..3 {
550        for j in 0..3 {
551            for k in 0..3 {
552                c[i * 3 + j] += a[i * 3 + k] * b[k * 3 + j];
553            }
554        }
555    }
556    c
557}
558
559/// Find the eigenvector corresponding to the smallest eigenvalue of a 9x9
560/// symmetric positive semi-definite matrix, using inverse iteration with
561/// a shift near zero.
562fn find_smallest_eigenvector_9x9(ata: &[f64; 81]) -> AlignResult<[f64; 9]> {
563    // We use 50 iterations of inverse power iteration with a small shift.
564    let shift = 1e-8;
565    let mut a_shifted = *ata;
566    for i in 0..9 {
567        a_shifted[i * 9 + i] += shift;
568    }
569
570    // Start with a uniform vector
571    let mut v = [1.0_f64 / 3.0; 9];
572
573    for _iter in 0..50 {
574        // Solve (ATA + shift*I) * w = v
575        let w = solve_9x9_gauss(&a_shifted, &v)?;
576
577        // Normalize
578        let norm: f64 = w.iter().map(|&x| x * x).sum::<f64>().sqrt();
579        if norm < 1e-15 {
580            return Err(AlignError::NumericalError(
581                "Eigenvector iteration diverged".to_string(),
582            ));
583        }
584        for i in 0..9 {
585            v[i] = w[i] / norm;
586        }
587    }
588
589    Ok(v)
590}
591
592/// Solve a 9x9 linear system using Gaussian elimination.
593fn solve_9x9_gauss(a: &[f64; 81], b: &[f64; 9]) -> AlignResult<[f64; 9]> {
594    let mut mat = *a;
595    let mut rhs = *b;
596
597    for col in 0..9 {
598        // Partial pivoting
599        let mut max_row = col;
600        let mut max_val = mat[col * 9 + col].abs();
601        for row in (col + 1)..9 {
602            let val = mat[row * 9 + col].abs();
603            if val > max_val {
604                max_val = val;
605                max_row = row;
606            }
607        }
608
609        if max_val < 1e-14 {
610            return Err(AlignError::NumericalError(
611                "Singular matrix in 9x9 solve".to_string(),
612            ));
613        }
614
615        if max_row != col {
616            for j in 0..9 {
617                mat.swap(col * 9 + j, max_row * 9 + j);
618            }
619            rhs.swap(col, max_row);
620        }
621
622        let pivot = mat[col * 9 + col];
623        for row in (col + 1)..9 {
624            let factor = mat[row * 9 + col] / pivot;
625            for j in col..9 {
626                mat[row * 9 + j] -= factor * mat[col * 9 + j];
627            }
628            rhs[row] -= factor * rhs[col];
629        }
630    }
631
632    let mut x = [0.0_f64; 9];
633    for col in (0..9).rev() {
634        let mut sum = rhs[col];
635        for j in (col + 1)..9 {
636            sum -= mat[col * 9 + j] * x[j];
637        }
638        x[col] = sum / mat[col * 9 + col];
639    }
640
641    Ok(x)
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647    use crate::Point2D;
648
649    fn make_affine_matches(
650        a: f64,
651        b: f64,
652        tx: f64,
653        c: f64,
654        d: f64,
655        ty: f64,
656        n: usize,
657    ) -> Vec<MatchPair> {
658        (0..n)
659            .map(|i| {
660                let x = (i as f64 * 17.0) % 100.0;
661                let y = (i as f64 * 31.0) % 100.0;
662                let xp = a * x + b * y + tx;
663                let yp = c * x + d * y + ty;
664                MatchPair::new(i, i, i as u32, Point2D::new(x, y), Point2D::new(xp, yp))
665            })
666            .collect()
667    }
668
669    fn make_identity_matches(n: usize) -> Vec<MatchPair> {
670        make_affine_matches(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, n)
671    }
672
673    // -- ProsacConfig ---------------------------------------------------------
674
675    #[test]
676    fn test_prosac_config_default() {
677        let config = ProsacConfig::default();
678        assert_eq!(config.max_iterations, 2000);
679        assert_eq!(config.min_inliers, 8);
680    }
681
682    #[test]
683    fn test_model_type_min_samples() {
684        assert_eq!(ProsacModelType::Affine.min_samples(), 3);
685        assert_eq!(ProsacModelType::Homography.min_samples(), 4);
686    }
687
688    // -- Affine estimation ----------------------------------------------------
689
690    #[test]
691    fn test_prosac_affine_identity() {
692        let matches = make_identity_matches(20);
693
694        let estimator = ProsacEstimator::new(
695            ProsacConfig {
696                min_inliers: 5,
697                ..ProsacConfig::default()
698            },
699            ProsacModelType::Affine,
700        );
701
702        let result = estimator.estimate(&matches).expect("should succeed");
703        assert!(result.num_inliers >= 5);
704        assert_eq!(result.params.len(), 6);
705
706        // Check that parameters are close to identity [1,0,0, 0,1,0]
707        assert!(
708            (result.params[0] - 1.0).abs() < 0.1,
709            "a={}",
710            result.params[0]
711        );
712        assert!((result.params[1]).abs() < 0.1, "b={}", result.params[1]);
713        assert!((result.params[2]).abs() < 1.0, "tx={}", result.params[2]);
714        assert!((result.params[3]).abs() < 0.1, "c={}", result.params[3]);
715        assert!(
716            (result.params[4] - 1.0).abs() < 0.1,
717            "d={}",
718            result.params[4]
719        );
720        assert!((result.params[5]).abs() < 1.0, "ty={}", result.params[5]);
721    }
722
723    #[test]
724    fn test_prosac_affine_translation() {
725        let matches = make_affine_matches(1.0, 0.0, 10.0, 0.0, 1.0, -5.0, 20);
726
727        let estimator = ProsacEstimator::new(
728            ProsacConfig {
729                min_inliers: 5,
730                ..ProsacConfig::default()
731            },
732            ProsacModelType::Affine,
733        );
734
735        let result = estimator.estimate(&matches).expect("should succeed");
736        assert!(
737            (result.params[2] - 10.0).abs() < 1.0,
738            "tx={}",
739            result.params[2]
740        );
741        assert!(
742            (result.params[5] + 5.0).abs() < 1.0,
743            "ty={}",
744            result.params[5]
745        );
746    }
747
748    #[test]
749    fn test_prosac_affine_with_outliers() {
750        let mut matches = make_affine_matches(1.0, 0.0, 5.0, 0.0, 1.0, 3.0, 30);
751
752        // Add some outliers
753        for i in 0..5 {
754            matches.push(MatchPair::new(
755                30 + i,
756                30 + i,
757                100,
758                Point2D::new(i as f64 * 10.0, i as f64 * 10.0),
759                Point2D::new(999.0, 999.0),
760            ));
761        }
762
763        let estimator = ProsacEstimator::new(
764            ProsacConfig {
765                min_inliers: 5,
766                ..ProsacConfig::default()
767            },
768            ProsacModelType::Affine,
769        );
770
771        let result = estimator.estimate(&matches).expect("should succeed");
772        // Should reject outliers
773        assert!(result.num_inliers >= 20);
774    }
775
776    // -- Homography estimation ------------------------------------------------
777
778    #[test]
779    fn test_prosac_homography_identity() {
780        let matches = make_identity_matches(20);
781
782        let estimator = ProsacEstimator::new(
783            ProsacConfig {
784                min_inliers: 5,
785                ..ProsacConfig::default()
786            },
787            ProsacModelType::Homography,
788        );
789
790        let result = estimator.estimate(&matches).expect("should succeed");
791        assert!(result.num_inliers >= 5);
792        assert_eq!(result.params.len(), 9);
793
794        // H should be close to identity
795        assert!(
796            (result.params[0] - 1.0).abs() < 0.2,
797            "h00={}",
798            result.params[0]
799        );
800        assert!(
801            (result.params[4] - 1.0).abs() < 0.2,
802            "h11={}",
803            result.params[4]
804        );
805        assert!(
806            (result.params[8] - 1.0).abs() < 0.2,
807            "h22={}",
808            result.params[8]
809        );
810    }
811
812    #[test]
813    fn test_prosac_insufficient_matches() {
814        let matches = vec![MatchPair::new(
815            0,
816            0,
817            0,
818            Point2D::new(0.0, 0.0),
819            Point2D::new(1.0, 1.0),
820        )];
821
822        let estimator = ProsacEstimator::new(ProsacConfig::default(), ProsacModelType::Homography);
823        let result = estimator.estimate(&matches);
824        assert!(result.is_err());
825    }
826
827    // -- Adaptive iteration count ---------------------------------------------
828
829    #[test]
830    fn test_adaptive_max_iterations() {
831        let iters = adaptive_max_iterations(0.5, 4, 0.99);
832        assert!(iters > 0.0 && iters < 10_000.0, "iters={iters}");
833    }
834
835    #[test]
836    fn test_adaptive_max_iterations_high_inlier_ratio() {
837        let iters = adaptive_max_iterations(0.9, 4, 0.99);
838        // High inlier ratio should need few iterations
839        assert!(iters < 100.0, "iters={iters}");
840    }
841
842    #[test]
843    fn test_adaptive_max_iterations_edge_cases() {
844        assert_eq!(adaptive_max_iterations(0.0, 4, 0.99), 1.0);
845        assert_eq!(adaptive_max_iterations(1.0, 4, 0.99), 1.0);
846    }
847
848    // -- LCG PRNG -------------------------------------------------------------
849
850    #[test]
851    fn test_lcg_deterministic() {
852        let mut state1 = 42u64;
853        let mut state2 = 42u64;
854        assert_eq!(lcg_next(&mut state1), lcg_next(&mut state2));
855    }
856
857    #[test]
858    fn test_lcg_different_seeds() {
859        let mut s1 = 1u64;
860        let mut s2 = 2u64;
861        assert_ne!(lcg_next(&mut s1), lcg_next(&mut s2));
862    }
863
864    // -- Linear algebra -------------------------------------------------------
865
866    #[test]
867    fn test_solve_6x6_identity_system() {
868        // Identity matrix * x = b => x = b
869        let mut ata = [0.0_f64; 36];
870        for i in 0..6 {
871            ata[i * 6 + i] = 1.0;
872        }
873        let atb = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
874        let x = solve_6x6(&ata, &atb).expect("should succeed");
875        for i in 0..6 {
876            assert!((x[i] - atb[i]).abs() < 1e-10);
877        }
878    }
879
880    #[test]
881    fn test_solve_6x6_singular() {
882        let ata = [0.0_f64; 36]; // all zeros = singular
883        let atb = [1.0; 6];
884        assert!(solve_6x6(&ata, &atb).is_err());
885    }
886
887    #[test]
888    fn test_mat3_mul_identity() {
889        let id = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
890        let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
891        let c = mat3_mul(&id, &a);
892        for i in 0..9 {
893            assert!((c[i] - a[i]).abs() < 1e-10);
894        }
895    }
896
897    // -- Normalization --------------------------------------------------------
898
899    #[test]
900    fn test_normalize_points_centered() {
901        let matches = vec![
902            MatchPair::new(0, 0, 0, Point2D::new(-1.0, -1.0), Point2D::new(0.0, 0.0)),
903            MatchPair::new(1, 1, 0, Point2D::new(1.0, 1.0), Point2D::new(0.0, 0.0)),
904        ];
905        let (norm, _t) = normalize_points(&matches, true);
906        // Centroid should be at origin
907        let cx = norm.iter().map(|p| p.0).sum::<f64>() / norm.len() as f64;
908        let cy = norm.iter().map(|p| p.1).sum::<f64>() / norm.len() as f64;
909        assert!(cx.abs() < 1e-10);
910        assert!(cy.abs() < 1e-10);
911    }
912}