1use crate::features::MatchPair;
25use crate::{AlignError, AlignResult};
26
27#[derive(Debug, Clone)]
29pub struct ProsacConfig {
30 pub inlier_threshold: f64,
32 pub max_iterations: usize,
34 pub min_inliers: usize,
36 pub confidence: f64,
38 pub initial_n: Option<usize>,
41}
42
43impl Default for ProsacConfig {
44 fn default() -> Self {
45 Self {
46 inlier_threshold: 3.0,
47 max_iterations: 2000,
48 min_inliers: 8,
49 confidence: 0.99,
50 initial_n: None,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum ProsacModelType {
58 Affine,
60 Homography,
62}
63
64impl ProsacModelType {
65 #[must_use]
67 pub fn min_samples(&self) -> usize {
68 match self {
69 Self::Affine => 3,
70 Self::Homography => 4,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct ProsacResult {
78 pub params: Vec<f64>,
82 pub inlier_mask: Vec<bool>,
84 pub num_inliers: usize,
86 pub iterations: usize,
88}
89
90pub struct ProsacEstimator {
92 pub config: ProsacConfig,
94 pub model_type: ProsacModelType,
96}
97
98impl ProsacEstimator {
99 #[must_use]
101 pub fn new(config: ProsacConfig, model_type: ProsacModelType) -> Self {
102 Self { config, model_type }
103 }
104
105 pub fn estimate(&self, matches: &[MatchPair]) -> AlignResult<ProsacResult> {
115 let min_s = self.model_type.min_samples();
116
117 if matches.len() < min_s {
118 return Err(AlignError::InsufficientData(format!(
119 "Need at least {min_s} matches, got {}",
120 matches.len()
121 )));
122 }
123
124 let total = matches.len();
125 let mut n = self.config.initial_n.unwrap_or(min_s).max(min_s).min(total);
126 let mut best_inliers = 0usize;
127 let mut best_mask: Vec<bool> = vec![false; total];
128 let mut best_params: Vec<f64> = Vec::new();
129 let mut best_iter = 0;
130
131 let mut rng_state = 0x1234_5678_u64;
133
134 let mut t_n = 1.0_f64; let mut t_n_prime = 0.0_f64; for iter in 0..self.config.max_iterations {
143 t_n_prime += 1.0;
145
146 if t_n_prime >= t_n && n < total {
147 n += 1;
148 let ratio = if n > min_s {
150 (n - min_s) as f64 / n as f64
151 } else {
152 1.0
153 };
154 t_n *= 1.0 + ratio;
155 t_n_prime = 0.0;
156 }
157
158 let sample = self.sample_from_top_n(matches, n, min_s, &mut rng_state);
160
161 let params = match self.model_type {
163 ProsacModelType::Affine => self.fit_affine(&sample),
164 ProsacModelType::Homography => self.fit_homography(&sample),
165 };
166
167 let params = match params {
168 Ok(p) => p,
169 Err(_) => continue,
170 };
171
172 let (mask, count) = self.count_inliers(matches, ¶ms);
174
175 if count > best_inliers {
176 best_inliers = count;
177 best_mask = mask;
178 best_params = params;
179 best_iter = iter;
180
181 if best_inliers >= self.config.min_inliers {
184 let inlier_ratio = best_inliers as f64 / total as f64;
185 let expected_iters =
186 adaptive_max_iterations(inlier_ratio, min_s, self.config.confidence);
187 if iter as f64 >= expected_iters {
188 break;
189 }
190 }
191 }
192 }
193
194 if best_inliers < self.config.min_inliers {
195 return Err(AlignError::NoSolution(format!(
196 "Insufficient inliers: {best_inliers} < {}",
197 self.config.min_inliers
198 )));
199 }
200
201 let inlier_matches: Vec<&MatchPair> = matches
203 .iter()
204 .zip(&best_mask)
205 .filter(|(_, &is_inlier)| is_inlier)
206 .map(|(m, _)| m)
207 .collect();
208
209 let refined_params = match self.model_type {
210 ProsacModelType::Affine => {
211 let pairs: Vec<MatchPair> = inlier_matches.iter().map(|m| (*m).clone()).collect();
212 self.fit_affine(&pairs).unwrap_or(best_params.clone())
213 }
214 ProsacModelType::Homography => {
215 let pairs: Vec<MatchPair> = inlier_matches.iter().map(|m| (*m).clone()).collect();
216 self.fit_homography(&pairs).unwrap_or(best_params.clone())
217 }
218 };
219
220 let (final_mask, final_count) = self.count_inliers(matches, &refined_params);
222
223 Ok(ProsacResult {
224 params: refined_params,
225 inlier_mask: final_mask,
226 num_inliers: final_count,
227 iterations: best_iter + 1,
228 })
229 }
230
231 fn sample_from_top_n(
234 &self,
235 matches: &[MatchPair],
236 n: usize,
237 count: usize,
238 rng: &mut u64,
239 ) -> Vec<MatchPair> {
240 let pool_size = n.min(matches.len());
241 let mut indices = Vec::with_capacity(count);
242
243 while indices.len() < count {
244 let idx = lcg_next(rng) as usize % pool_size;
245 if !indices.contains(&idx) {
246 indices.push(idx);
247 }
248 }
249
250 indices.iter().map(|&i| matches[i].clone()).collect()
251 }
252
253 fn fit_affine(&self, matches: &[MatchPair]) -> AlignResult<Vec<f64>> {
256 if matches.len() < 3 {
257 return Err(AlignError::InsufficientData(
258 "Need >= 3 points for affine".to_string(),
259 ));
260 }
261
262 let n = matches.len();
269 let _rows = n * 2;
270
271 let mut ata = [0.0_f64; 36];
273 let mut atb = [0.0_f64; 6];
274
275 for m in matches {
276 let x = m.point1.x;
277 let y = m.point1.y;
278 let xp = m.point2.x;
279 let yp = m.point2.y;
280
281 let r1 = [x, y, 1.0, 0.0, 0.0, 0.0];
283 let r2 = [0.0, 0.0, 0.0, x, y, 1.0];
285
286 for i in 0..6 {
287 for j in 0..6 {
288 ata[i * 6 + j] += r1[i] * r1[j] + r2[i] * r2[j];
289 }
290 atb[i] += r1[i] * xp + r2[i] * yp;
291 }
292 }
293
294 let solution = solve_6x6(&ata, &atb)?;
296
297 Ok(solution.to_vec())
298 }
299
300 fn fit_homography(&self, matches: &[MatchPair]) -> AlignResult<Vec<f64>> {
301 if matches.len() < 4 {
302 return Err(AlignError::InsufficientData(
303 "Need >= 4 points for homography".to_string(),
304 ));
305 }
306
307 let (norm1, t1) = normalize_points(matches, true);
313 let (norm2, t2) = normalize_points(matches, false);
314
315 let n = matches.len();
316 let mut ata = [0.0_f64; 81];
318
319 for i in 0..n {
320 let x = norm1[i].0;
321 let y = norm1[i].1;
322 let xp = norm2[i].0;
323 let yp = norm2[i].1;
324
325 let r1 = [-x, -y, -1.0, 0.0, 0.0, 0.0, xp * x, xp * y, xp];
326 let r2 = [0.0, 0.0, 0.0, -x, -y, -1.0, yp * x, yp * y, yp];
327
328 for a in 0..9 {
329 for b in 0..9 {
330 ata[a * 9 + b] += r1[a] * r1[b] + r2[a] * r2[b];
331 }
332 }
333 }
334
335 let h_norm = find_smallest_eigenvector_9x9(&ata)?;
337
338 let h = denormalize_homography(&h_norm, &t1, &t2);
340
341 if h[8].abs() < 1e-12 {
343 return Err(AlignError::NumericalError(
344 "Degenerate homography".to_string(),
345 ));
346 }
347
348 let scale = h[8];
349 Ok(h.iter().map(|&v| v / scale).collect())
350 }
351
352 fn count_inliers(&self, matches: &[MatchPair], params: &[f64]) -> (Vec<bool>, usize) {
355 let threshold_sq = self.config.inlier_threshold * self.config.inlier_threshold;
356 let mut mask = vec![false; matches.len()];
357 let mut count = 0usize;
358
359 for (i, m) in matches.iter().enumerate() {
360 let projected = self.project_point(m.point1.x, m.point1.y, params);
361 let dx = projected.0 - m.point2.x;
362 let dy = projected.1 - m.point2.y;
363 let err_sq = dx * dx + dy * dy;
364
365 if err_sq < threshold_sq {
366 mask[i] = true;
367 count += 1;
368 }
369 }
370
371 (mask, count)
372 }
373
374 fn project_point(&self, x: f64, y: f64, params: &[f64]) -> (f64, f64) {
375 match self.model_type {
376 ProsacModelType::Affine => {
377 if params.len() < 6 {
378 return (x, y);
379 }
380 let xp = params[0] * x + params[1] * y + params[2];
381 let yp = params[3] * x + params[4] * y + params[5];
382 (xp, yp)
383 }
384 ProsacModelType::Homography => {
385 if params.len() < 9 {
386 return (x, y);
387 }
388 let w = params[6] * x + params[7] * y + params[8];
389 if w.abs() < 1e-12 {
390 return (x, y);
391 }
392 let xp = (params[0] * x + params[1] * y + params[2]) / w;
393 let yp = (params[3] * x + params[4] * y + params[5]) / w;
394 (xp, yp)
395 }
396 }
397 }
398}
399
400fn adaptive_max_iterations(inlier_ratio: f64, min_samples: usize, confidence: f64) -> f64 {
403 if inlier_ratio <= 0.0 || inlier_ratio >= 1.0 {
404 return 1.0;
405 }
406
407 let num = (1.0 - confidence).ln();
408 let denom = (1.0 - inlier_ratio.powi(min_samples as i32)).ln();
409
410 if denom.abs() < 1e-15 {
411 return 1.0;
412 }
413
414 num / denom
415}
416
417fn lcg_next(state: &mut u64) -> u64 {
420 *state = state
421 .wrapping_mul(6364136223846793005)
422 .wrapping_add(1442695040888963407);
423 *state >> 33
424}
425
426fn solve_6x6(ata: &[f64; 36], atb: &[f64; 6]) -> AlignResult<[f64; 6]> {
430 let mut a = *ata;
431 let mut b = *atb;
432
433 for col in 0..6 {
435 let mut max_row = col;
437 let mut max_val = a[col * 6 + col].abs();
438 for row in (col + 1)..6 {
439 let val = a[row * 6 + col].abs();
440 if val > max_val {
441 max_val = val;
442 max_row = row;
443 }
444 }
445
446 if max_val < 1e-12 {
447 return Err(AlignError::NumericalError(
448 "Singular matrix in 6x6 solve".to_string(),
449 ));
450 }
451
452 if max_row != col {
454 for j in 0..6 {
455 a.swap(col * 6 + j, max_row * 6 + j);
456 }
457 b.swap(col, max_row);
458 }
459
460 let pivot = a[col * 6 + col];
462 for row in (col + 1)..6 {
463 let factor = a[row * 6 + col] / pivot;
464 for j in col..6 {
465 a[row * 6 + j] -= factor * a[col * 6 + j];
466 }
467 b[row] -= factor * b[col];
468 }
469 }
470
471 let mut x = [0.0_f64; 6];
473 for col in (0..6).rev() {
474 let mut sum = b[col];
475 for j in (col + 1)..6 {
476 sum -= a[col * 6 + j] * x[j];
477 }
478 x[col] = sum / a[col * 6 + col];
479 }
480
481 Ok(x)
482}
483
484fn normalize_points(matches: &[MatchPair], use_first: bool) -> (Vec<(f64, f64)>, [f64; 9]) {
486 let pts: Vec<(f64, f64)> = if use_first {
487 matches.iter().map(|m| (m.point1.x, m.point1.y)).collect()
488 } else {
489 matches.iter().map(|m| (m.point2.x, m.point2.y)).collect()
490 };
491
492 let n = pts.len() as f64;
493 let cx: f64 = pts.iter().map(|p| p.0).sum::<f64>() / n;
494 let cy: f64 = pts.iter().map(|p| p.1).sum::<f64>() / n;
495
496 let avg_dist: f64 = pts
497 .iter()
498 .map(|p| ((p.0 - cx).powi(2) + (p.1 - cy).powi(2)).sqrt())
499 .sum::<f64>()
500 / n;
501
502 let s = if avg_dist > 1e-10 {
503 std::f64::consts::SQRT_2 / avg_dist
504 } else {
505 1.0
506 };
507
508 let normalized: Vec<(f64, f64)> = pts
509 .iter()
510 .map(|p| ((p.0 - cx) * s, (p.1 - cy) * s))
511 .collect();
512
513 let t = [s, 0.0, -s * cx, 0.0, s, -s * cy, 0.0, 0.0, 1.0];
515
516 (normalized, t)
517}
518
519fn denormalize_homography(h_norm: &[f64; 9], t1: &[f64; 9], t2: &[f64; 9]) -> [f64; 9] {
521 let s2 = t2[0];
523 let tx2 = t2[2];
524 let ty2 = t2[5];
525
526 let t2_inv = if s2.abs() > 1e-15 {
527 let inv_s = 1.0 / s2;
528 [
529 inv_s,
530 0.0,
531 -tx2 * inv_s,
532 0.0,
533 inv_s,
534 -ty2 * inv_s,
535 0.0,
536 0.0,
537 1.0,
538 ]
539 } else {
540 [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
541 };
542
543 let tmp = mat3_mul(&t2_inv, h_norm);
544 mat3_mul(&tmp, t1)
545}
546
547fn mat3_mul(a: &[f64; 9], b: &[f64; 9]) -> [f64; 9] {
548 let mut c = [0.0_f64; 9];
549 for i in 0..3 {
550 for j in 0..3 {
551 for k in 0..3 {
552 c[i * 3 + j] += a[i * 3 + k] * b[k * 3 + j];
553 }
554 }
555 }
556 c
557}
558
559fn find_smallest_eigenvector_9x9(ata: &[f64; 81]) -> AlignResult<[f64; 9]> {
563 let shift = 1e-8;
565 let mut a_shifted = *ata;
566 for i in 0..9 {
567 a_shifted[i * 9 + i] += shift;
568 }
569
570 let mut v = [1.0_f64 / 3.0; 9];
572
573 for _iter in 0..50 {
574 let w = solve_9x9_gauss(&a_shifted, &v)?;
576
577 let norm: f64 = w.iter().map(|&x| x * x).sum::<f64>().sqrt();
579 if norm < 1e-15 {
580 return Err(AlignError::NumericalError(
581 "Eigenvector iteration diverged".to_string(),
582 ));
583 }
584 for i in 0..9 {
585 v[i] = w[i] / norm;
586 }
587 }
588
589 Ok(v)
590}
591
592fn solve_9x9_gauss(a: &[f64; 81], b: &[f64; 9]) -> AlignResult<[f64; 9]> {
594 let mut mat = *a;
595 let mut rhs = *b;
596
597 for col in 0..9 {
598 let mut max_row = col;
600 let mut max_val = mat[col * 9 + col].abs();
601 for row in (col + 1)..9 {
602 let val = mat[row * 9 + col].abs();
603 if val > max_val {
604 max_val = val;
605 max_row = row;
606 }
607 }
608
609 if max_val < 1e-14 {
610 return Err(AlignError::NumericalError(
611 "Singular matrix in 9x9 solve".to_string(),
612 ));
613 }
614
615 if max_row != col {
616 for j in 0..9 {
617 mat.swap(col * 9 + j, max_row * 9 + j);
618 }
619 rhs.swap(col, max_row);
620 }
621
622 let pivot = mat[col * 9 + col];
623 for row in (col + 1)..9 {
624 let factor = mat[row * 9 + col] / pivot;
625 for j in col..9 {
626 mat[row * 9 + j] -= factor * mat[col * 9 + j];
627 }
628 rhs[row] -= factor * rhs[col];
629 }
630 }
631
632 let mut x = [0.0_f64; 9];
633 for col in (0..9).rev() {
634 let mut sum = rhs[col];
635 for j in (col + 1)..9 {
636 sum -= mat[col * 9 + j] * x[j];
637 }
638 x[col] = sum / mat[col * 9 + col];
639 }
640
641 Ok(x)
642}
643
644#[cfg(test)]
645mod tests {
646 use super::*;
647 use crate::Point2D;
648
649 fn make_affine_matches(
650 a: f64,
651 b: f64,
652 tx: f64,
653 c: f64,
654 d: f64,
655 ty: f64,
656 n: usize,
657 ) -> Vec<MatchPair> {
658 (0..n)
659 .map(|i| {
660 let x = (i as f64 * 17.0) % 100.0;
661 let y = (i as f64 * 31.0) % 100.0;
662 let xp = a * x + b * y + tx;
663 let yp = c * x + d * y + ty;
664 MatchPair::new(i, i, i as u32, Point2D::new(x, y), Point2D::new(xp, yp))
665 })
666 .collect()
667 }
668
669 fn make_identity_matches(n: usize) -> Vec<MatchPair> {
670 make_affine_matches(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, n)
671 }
672
673 #[test]
676 fn test_prosac_config_default() {
677 let config = ProsacConfig::default();
678 assert_eq!(config.max_iterations, 2000);
679 assert_eq!(config.min_inliers, 8);
680 }
681
682 #[test]
683 fn test_model_type_min_samples() {
684 assert_eq!(ProsacModelType::Affine.min_samples(), 3);
685 assert_eq!(ProsacModelType::Homography.min_samples(), 4);
686 }
687
688 #[test]
691 fn test_prosac_affine_identity() {
692 let matches = make_identity_matches(20);
693
694 let estimator = ProsacEstimator::new(
695 ProsacConfig {
696 min_inliers: 5,
697 ..ProsacConfig::default()
698 },
699 ProsacModelType::Affine,
700 );
701
702 let result = estimator.estimate(&matches).expect("should succeed");
703 assert!(result.num_inliers >= 5);
704 assert_eq!(result.params.len(), 6);
705
706 assert!(
708 (result.params[0] - 1.0).abs() < 0.1,
709 "a={}",
710 result.params[0]
711 );
712 assert!((result.params[1]).abs() < 0.1, "b={}", result.params[1]);
713 assert!((result.params[2]).abs() < 1.0, "tx={}", result.params[2]);
714 assert!((result.params[3]).abs() < 0.1, "c={}", result.params[3]);
715 assert!(
716 (result.params[4] - 1.0).abs() < 0.1,
717 "d={}",
718 result.params[4]
719 );
720 assert!((result.params[5]).abs() < 1.0, "ty={}", result.params[5]);
721 }
722
723 #[test]
724 fn test_prosac_affine_translation() {
725 let matches = make_affine_matches(1.0, 0.0, 10.0, 0.0, 1.0, -5.0, 20);
726
727 let estimator = ProsacEstimator::new(
728 ProsacConfig {
729 min_inliers: 5,
730 ..ProsacConfig::default()
731 },
732 ProsacModelType::Affine,
733 );
734
735 let result = estimator.estimate(&matches).expect("should succeed");
736 assert!(
737 (result.params[2] - 10.0).abs() < 1.0,
738 "tx={}",
739 result.params[2]
740 );
741 assert!(
742 (result.params[5] + 5.0).abs() < 1.0,
743 "ty={}",
744 result.params[5]
745 );
746 }
747
748 #[test]
749 fn test_prosac_affine_with_outliers() {
750 let mut matches = make_affine_matches(1.0, 0.0, 5.0, 0.0, 1.0, 3.0, 30);
751
752 for i in 0..5 {
754 matches.push(MatchPair::new(
755 30 + i,
756 30 + i,
757 100,
758 Point2D::new(i as f64 * 10.0, i as f64 * 10.0),
759 Point2D::new(999.0, 999.0),
760 ));
761 }
762
763 let estimator = ProsacEstimator::new(
764 ProsacConfig {
765 min_inliers: 5,
766 ..ProsacConfig::default()
767 },
768 ProsacModelType::Affine,
769 );
770
771 let result = estimator.estimate(&matches).expect("should succeed");
772 assert!(result.num_inliers >= 20);
774 }
775
776 #[test]
779 fn test_prosac_homography_identity() {
780 let matches = make_identity_matches(20);
781
782 let estimator = ProsacEstimator::new(
783 ProsacConfig {
784 min_inliers: 5,
785 ..ProsacConfig::default()
786 },
787 ProsacModelType::Homography,
788 );
789
790 let result = estimator.estimate(&matches).expect("should succeed");
791 assert!(result.num_inliers >= 5);
792 assert_eq!(result.params.len(), 9);
793
794 assert!(
796 (result.params[0] - 1.0).abs() < 0.2,
797 "h00={}",
798 result.params[0]
799 );
800 assert!(
801 (result.params[4] - 1.0).abs() < 0.2,
802 "h11={}",
803 result.params[4]
804 );
805 assert!(
806 (result.params[8] - 1.0).abs() < 0.2,
807 "h22={}",
808 result.params[8]
809 );
810 }
811
812 #[test]
813 fn test_prosac_insufficient_matches() {
814 let matches = vec![MatchPair::new(
815 0,
816 0,
817 0,
818 Point2D::new(0.0, 0.0),
819 Point2D::new(1.0, 1.0),
820 )];
821
822 let estimator = ProsacEstimator::new(ProsacConfig::default(), ProsacModelType::Homography);
823 let result = estimator.estimate(&matches);
824 assert!(result.is_err());
825 }
826
827 #[test]
830 fn test_adaptive_max_iterations() {
831 let iters = adaptive_max_iterations(0.5, 4, 0.99);
832 assert!(iters > 0.0 && iters < 10_000.0, "iters={iters}");
833 }
834
835 #[test]
836 fn test_adaptive_max_iterations_high_inlier_ratio() {
837 let iters = adaptive_max_iterations(0.9, 4, 0.99);
838 assert!(iters < 100.0, "iters={iters}");
840 }
841
842 #[test]
843 fn test_adaptive_max_iterations_edge_cases() {
844 assert_eq!(adaptive_max_iterations(0.0, 4, 0.99), 1.0);
845 assert_eq!(adaptive_max_iterations(1.0, 4, 0.99), 1.0);
846 }
847
848 #[test]
851 fn test_lcg_deterministic() {
852 let mut state1 = 42u64;
853 let mut state2 = 42u64;
854 assert_eq!(lcg_next(&mut state1), lcg_next(&mut state2));
855 }
856
857 #[test]
858 fn test_lcg_different_seeds() {
859 let mut s1 = 1u64;
860 let mut s2 = 2u64;
861 assert_ne!(lcg_next(&mut s1), lcg_next(&mut s2));
862 }
863
864 #[test]
867 fn test_solve_6x6_identity_system() {
868 let mut ata = [0.0_f64; 36];
870 for i in 0..6 {
871 ata[i * 6 + i] = 1.0;
872 }
873 let atb = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
874 let x = solve_6x6(&ata, &atb).expect("should succeed");
875 for i in 0..6 {
876 assert!((x[i] - atb[i]).abs() < 1e-10);
877 }
878 }
879
880 #[test]
881 fn test_solve_6x6_singular() {
882 let ata = [0.0_f64; 36]; let atb = [1.0; 6];
884 assert!(solve_6x6(&ata, &atb).is_err());
885 }
886
887 #[test]
888 fn test_mat3_mul_identity() {
889 let id = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
890 let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
891 let c = mat3_mul(&id, &a);
892 for i in 0..9 {
893 assert!((c[i] - a[i]).abs() < 1e-10);
894 }
895 }
896
897 #[test]
900 fn test_normalize_points_centered() {
901 let matches = vec![
902 MatchPair::new(0, 0, 0, Point2D::new(-1.0, -1.0), Point2D::new(0.0, 0.0)),
903 MatchPair::new(1, 1, 0, Point2D::new(1.0, 1.0), Point2D::new(0.0, 0.0)),
904 ];
905 let (norm, _t) = normalize_points(&matches, true);
906 let cx = norm.iter().map(|p| p.0).sum::<f64>() / norm.len() as f64;
908 let cy = norm.iter().map(|p| p.1).sum::<f64>() / norm.len() as f64;
909 assert!(cx.abs() < 1e-10);
910 assert!(cy.abs() < 1e-10);
911 }
912}