linear_ransac/
estimator.rs

1use rand::{SeedableRng};
2use rand::prelude::IndexedRandom; 
3use rand_chacha::ChaCha8Rng;
4use crate::types::{Point, LinearModel};
5use crate::error::RansacError;
6use crate::utils;
7
8#[derive(Debug, Clone)]
9pub struct RansacSolver {
10    seed: u64,
11    stop_probability: f64, 
12    max_trials_limit: usize,
13    min_trials_limit: usize,
14}
15
16impl Default for RansacSolver {
17    fn default() -> Self {
18        Self {
19            seed: 42,
20            stop_probability: 0.99,
21            max_trials_limit: 10_000, // Safety ceiling
22            min_trials_limit: 500,    // Ensure sufficient exploration
23        }
24    }
25}
26
27impl RansacSolver {
28    pub fn new() -> Self { Self::default() }
29
30    pub fn with_seed(mut self, seed: u64) -> Self {
31        self.seed = seed;
32        self
33    }
34
35    /// Set a custom minimum number of RANSAC trials.
36    /// This lower bound prevents premature stopping when an early
37    pub fn with_min_trials(mut self, min_trials: usize) -> Self {
38        self.min_trials_limit = min_trials.max(1);
39        self
40    }
41
42    pub fn fit(&self, data: &[Point]) -> Result<LinearModel, RansacError> {
43        let n_samples = data.len();
44        if n_samples < 2 {
45            return Err(RansacError::InsufficientData { needed: 2, count: n_samples });
46        }
47
48        // 1. Auto-calculate threshold using MAD
49        // If MAD fails (empty data), default to 1.0
50        let threshold = utils::estimate_threshold_via_mad(data).unwrap_or(1.0);
51
52        // 2. Setup Seeding and Loop
53        let mut rng = ChaCha8Rng::seed_from_u64(self.seed);
54        
55        let mut best_model: Option<LinearModel> = None;
56        let mut best_inlier_count = 0;
57        let mut best_error = f64::MAX;
58
59        // Start assuming we know nothing (high iterations)
60        let mut dynamic_max_trials = self.max_trials_limit;
61        let mut trials_performed = 0;
62
63        // 3. Sequential Loop
64        while trials_performed < dynamic_max_trials {
65            trials_performed += 1;
66
67            // A. Sample 2 points
68            let sample: Vec<&Point> = data.choose_multiple(&mut rng, 2).collect();
69            
70            // B. Fit temporary model
71            let model = match utils::fit_line_from_two_points(sample[0], sample[1]) {
72                Some(m) => m,
73                None => continue, 
74            };
75
76            // C. Count Inliers & Error (sum of absolute residuals
77            // used as a proxy score similar to sklearn's strategy
78            // of ranking models with more inliers and better score).
79            let mut current_inliers = 0;
80            let mut current_error = 0.0;
81
82            for p in data {
83                let err = (p.y - model.predict(p.x)).abs();
84                if err < threshold {
85                    current_inliers += 1;
86                    current_error += err;
87                }
88            }
89
90            // D. Is this the best model so far?
91            if current_inliers > best_inlier_count || 
92               (current_inliers == best_inlier_count && current_error < best_error) {
93                
94                best_inlier_count = current_inliers;
95                best_error = current_error;
96                best_model = Some(model);
97
98                // --- DYNAMIC STOPPING ---
99                // "I found a model that explains X% of data. 
100                //  Math says I only need Y trials to be 99% sure."
101                let ratio = best_inlier_count as f64 / n_samples as f64;
102                let estimated_k = utils::calculate_k_iterations(self.stop_probability, ratio);
103
104                // Never increase iterations, only reduce if we found something better,
105                // but always respect a user-configurable minimum to avoid premature stop.
106                dynamic_max_trials = dynamic_max_trials
107                    .min(estimated_k)
108                    .max(self.min_trials_limit);
109            }
110        }
111
112        // 4. Final Polish (OLS on Inliers)
113        if let Some(best) = best_model {
114            let final_inliers: Vec<&Point> = data.iter()
115                .filter(|p| (p.y - best.predict(p.x)).abs() < threshold)
116                .collect();
117            
118            if final_inliers.len() < 2 {
119                return Err(RansacError::ModelFittingFailed);
120            }
121
122            utils::fit_ols(&final_inliers).ok_or(RansacError::ModelFittingFailed)
123        } else {
124            Err(RansacError::NoConsensusFound(trials_performed))
125        }
126    }
127}