Skip to main content

ferrolearn_linear/
ransac.rs

1//! RANSAC (RANdom SAmple Consensus) robust regression.
2//!
3//! This module provides [`RANSACRegressor`], a meta-estimator that fits a
4//! base regressor to inlier data, automatically detecting and excluding
5//! outliers.
6//!
7//! # Algorithm
8//!
9//! 1. Randomly sample `min_samples` points.
10//! 2. Fit the base estimator on the sample.
11//! 3. Compute residuals for all points, identify inliers (residual below
12//!    `residual_threshold`).
13//! 4. If enough inliers, refit on all inliers.
14//! 5. Keep the model with the most inliers (ties broken by lowest residual).
15//!
16//! # Examples
17//!
18//! ```
19//! use ferrolearn_linear::ransac::RANSACRegressor;
20//! use ferrolearn_linear::LinearRegression;
21//! use ferrolearn_core::{Fit, Predict};
22//! use ndarray::{array, Array1, Array2};
23//!
24//! // Data with an outlier at index 4.
25//! let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
26//! let y = array![2.0, 4.0, 6.0, 8.0, 100.0]; // last point is outlier
27//!
28//! let base = LinearRegression::<f64>::new();
29//! let model = RANSACRegressor::new(base);
30//! let fitted = model.fit(&x, &y).unwrap();
31//!
32//! // The outlier should be detected.
33//! let mask = fitted.inlier_mask();
34//! assert!(!mask[4], "outlier at index 4 should be detected");
35//! ```
36
37use ferrolearn_core::error::FerroError;
38use ferrolearn_core::traits::{Fit, Predict};
39use ndarray::{Array1, Array2, ScalarOperand};
40use num_traits::Float;
41use rand::Rng;
42use rand::SeedableRng;
43
44// ---------------------------------------------------------------------------
45// RANSACRegressor (unfitted)
46// ---------------------------------------------------------------------------
47
48/// RANSAC robust regression meta-estimator.
49///
50/// Wraps a base regressor (e.g., [`LinearRegression`](crate::LinearRegression))
51/// and repeatedly fits it on random subsets to find a model robust to
52/// outliers.
53///
54/// # Type Parameters
55///
56/// - `F`: The floating-point type (`f32` or `f64`).
57/// - `E`: The base estimator type.
58#[derive(Debug, Clone)]
59pub struct RANSACRegressor<F, E> {
60    /// The base estimator.
61    pub estimator: E,
62    /// Minimum number of samples for fitting.
63    pub min_samples: Option<usize>,
64    /// Residual threshold: points with absolute residual below this are
65    /// considered inliers. If `None`, uses the MAD of the target.
66    pub residual_threshold: Option<F>,
67    /// Maximum number of random trials.
68    pub max_trials: usize,
69    /// Optional random seed for reproducibility.
70    pub random_state: Option<u64>,
71}
72
73impl<F: Float, E> RANSACRegressor<F, E> {
74    /// Create a new `RANSACRegressor` with the given base estimator.
75    ///
76    /// Defaults: `min_samples = None` (auto: n_features + 1),
77    /// `residual_threshold = None` (auto: MAD), `max_trials = 100`,
78    /// `random_state = None`.
79    #[must_use]
80    pub fn new(estimator: E) -> Self {
81        Self {
82            estimator,
83            min_samples: None,
84            residual_threshold: None,
85            max_trials: 100,
86            random_state: None,
87        }
88    }
89
90    /// Set the minimum number of samples for fitting.
91    #[must_use]
92    pub fn with_min_samples(mut self, min_samples: usize) -> Self {
93        self.min_samples = Some(min_samples);
94        self
95    }
96
97    /// Set the residual threshold for inlier detection.
98    #[must_use]
99    pub fn with_residual_threshold(mut self, threshold: F) -> Self {
100        self.residual_threshold = Some(threshold);
101        self
102    }
103
104    /// Set the maximum number of random trials.
105    #[must_use]
106    pub fn with_max_trials(mut self, max_trials: usize) -> Self {
107        self.max_trials = max_trials;
108        self
109    }
110
111    /// Set the random seed for reproducibility.
112    #[must_use]
113    pub fn with_random_state(mut self, seed: u64) -> Self {
114        self.random_state = Some(seed);
115        self
116    }
117}
118
119// ---------------------------------------------------------------------------
120// FittedRANSACRegressor
121// ---------------------------------------------------------------------------
122
123/// Fitted RANSAC robust regression model.
124///
125/// Stores the best estimator fitted on inlier data, and the inlier mask.
126#[derive(Debug, Clone)]
127pub struct FittedRANSACRegressor<Fitted> {
128    /// The fitted base estimator (fitted on inliers).
129    fitted_estimator: Fitted,
130    /// Boolean mask: true if the sample was classified as an inlier.
131    inlier_mask: Vec<bool>,
132}
133
134impl<Fitted> FittedRANSACRegressor<Fitted> {
135    /// Returns the inlier mask. `true` indicates the sample was an inlier.
136    #[must_use]
137    pub fn inlier_mask(&self) -> &[bool] {
138        &self.inlier_mask
139    }
140}
141
142// ---------------------------------------------------------------------------
143// Helper: Median Absolute Deviation
144// ---------------------------------------------------------------------------
145
146/// Compute the median of a slice of floats.
147fn median<F: Float>(values: &[F]) -> F {
148    let mut sorted: Vec<F> = values.to_vec();
149    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
150    let n = sorted.len();
151    if n == 0 {
152        return F::zero();
153    }
154    if n % 2 == 0 {
155        (sorted[n / 2 - 1] + sorted[n / 2]) / (F::one() + F::one())
156    } else {
157        sorted[n / 2]
158    }
159}
160
161/// Compute the Median Absolute Deviation (MAD) of a slice.
162fn mad<F: Float>(values: &[F]) -> F {
163    let med = median(values);
164    let abs_devs: Vec<F> = values.iter().map(|&v| (v - med).abs()).collect();
165    median(&abs_devs)
166}
167
168// ---------------------------------------------------------------------------
169// Random subset sampling
170// ---------------------------------------------------------------------------
171
172/// Sample `k` distinct indices from `0..n` using Fisher-Yates.
173fn sample_indices<R: Rng>(rng: &mut R, n: usize, k: usize) -> Vec<usize> {
174    let mut indices: Vec<usize> = (0..n).collect();
175    for i in 0..k {
176        let j = rng.random_range(i..n);
177        indices.swap(i, j);
178    }
179    indices.truncate(k);
180    indices
181}
182
183/// Extract a subset of rows from a 2D array and a 1D array.
184fn subset<F: Float>(x: &Array2<F>, y: &Array1<F>, indices: &[usize]) -> (Array2<F>, Array1<F>) {
185    let n_features = x.ncols();
186    let n = indices.len();
187    let mut x_sub = Array2::<F>::zeros((n, n_features));
188    let mut y_sub = Array1::<F>::zeros(n);
189    for (row, &idx) in indices.iter().enumerate() {
190        for col in 0..n_features {
191            x_sub[[row, col]] = x[[idx, col]];
192        }
193        y_sub[row] = y[idx];
194    }
195    (x_sub, y_sub)
196}
197
198// ---------------------------------------------------------------------------
199// Fit and Predict
200// ---------------------------------------------------------------------------
201
202impl<F, E, Ef> Fit<Array2<F>, Array1<F>> for RANSACRegressor<F, E>
203where
204    F: Float + Send + Sync + ScalarOperand + num_traits::FromPrimitive + 'static,
205    E: Fit<Array2<F>, Array1<F>, Fitted = Ef, Error = FerroError> + Clone,
206    Ef: Predict<Array2<F>, Output = Array1<F>, Error = FerroError> + Clone,
207{
208    type Fitted = FittedRANSACRegressor<Ef>;
209    type Error = FerroError;
210
211    /// Fit the RANSAC model by repeatedly sampling and fitting.
212    ///
213    /// # Errors
214    ///
215    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
216    /// sample counts.
217    /// Returns [`FerroError::ConvergenceFailure`] if no valid model is found
218    /// after `max_trials` iterations.
219    fn fit(
220        &self,
221        x: &Array2<F>,
222        y: &Array1<F>,
223    ) -> Result<FittedRANSACRegressor<E::Fitted>, FerroError> {
224        let (n_samples, n_features) = x.dim();
225
226        if n_samples != y.len() {
227            return Err(FerroError::ShapeMismatch {
228                expected: vec![n_samples],
229                actual: vec![y.len()],
230                context: "y length must match number of samples in X".into(),
231            });
232        }
233
234        let min_samples = self.min_samples.unwrap_or(n_features + 1).max(1);
235
236        if n_samples < min_samples {
237            return Err(FerroError::InsufficientSamples {
238                required: min_samples,
239                actual: n_samples,
240                context: "RANSAC requires at least min_samples samples".into(),
241            });
242        }
243
244        // Compute residual threshold if not provided.
245        let threshold = match self.residual_threshold {
246            Some(t) => t,
247            None => {
248                let y_mad = mad(&y.to_vec());
249                if y_mad <= F::epsilon() {
250                    // If MAD is zero (constant target), use a small default.
251                    F::from(1e-6).unwrap()
252                } else {
253                    y_mad
254                }
255            }
256        };
257
258        let mut rng = match self.random_state {
259            Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
260            None => rand::rngs::StdRng::seed_from_u64(42),
261        };
262
263        let mut best_fitted: Option<E::Fitted> = None;
264        let mut best_inlier_mask: Option<Vec<bool>> = None;
265        let mut best_n_inliers = 0usize;
266        let mut best_residual_sum = F::infinity();
267
268        for _ in 0..self.max_trials {
269            // Sample random subset.
270            let indices = sample_indices(&mut rng, n_samples, min_samples);
271            let (x_sub, y_sub) = subset(x, y, &indices);
272
273            // Fit base estimator on the subset.
274            let fitted = match self.estimator.fit(&x_sub, &y_sub) {
275                Ok(f) => f,
276                Err(_) => continue, // Skip failed fits.
277            };
278
279            // Compute residuals for all points.
280            let preds = match fitted.predict(x) {
281                Ok(p) => p,
282                Err(_) => continue,
283            };
284
285            let mut inlier_mask = vec![false; n_samples];
286            let mut n_inliers = 0usize;
287            let mut residual_sum = F::zero();
288
289            for i in 0..n_samples {
290                let residual = (preds[i] - y[i]).abs();
291                if residual <= threshold {
292                    inlier_mask[i] = true;
293                    n_inliers += 1;
294                    residual_sum = residual_sum + residual;
295                }
296            }
297
298            // Check if this is better than the current best.
299            let is_better = n_inliers > best_n_inliers
300                || (n_inliers == best_n_inliers && residual_sum < best_residual_sum);
301
302            if is_better && n_inliers >= min_samples {
303                // Refit on all inliers.
304                let inlier_indices: Vec<usize> = inlier_mask
305                    .iter()
306                    .enumerate()
307                    .filter(|&(_, &is_inlier)| is_inlier)
308                    .map(|(i, _)| i)
309                    .collect();
310                let (x_inlier, y_inlier) = subset(x, y, &inlier_indices);
311
312                match self.estimator.fit(&x_inlier, &y_inlier) {
313                    Ok(refit) => {
314                        // Recompute inlier mask with the refitted model.
315                        if let Ok(new_preds) = refit.predict(x) {
316                            let mut new_mask = vec![false; n_samples];
317                            let mut new_n_inliers = 0;
318                            let mut new_residual_sum = F::zero();
319                            for i in 0..n_samples {
320                                let r = (new_preds[i] - y[i]).abs();
321                                if r <= threshold {
322                                    new_mask[i] = true;
323                                    new_n_inliers += 1;
324                                    new_residual_sum = new_residual_sum + r;
325                                }
326                            }
327                            best_fitted = Some(refit);
328                            best_inlier_mask = Some(new_mask);
329                            best_n_inliers = new_n_inliers;
330                            best_residual_sum = new_residual_sum;
331                        }
332                    }
333                    Err(_) => {
334                        // Keep the original fit if refit fails.
335                        best_fitted = Some(fitted);
336                        best_inlier_mask = Some(inlier_mask);
337                        best_n_inliers = n_inliers;
338                        best_residual_sum = residual_sum;
339                    }
340                }
341            }
342        }
343
344        match (best_fitted, best_inlier_mask) {
345            (Some(fitted), Some(mask)) => Ok(FittedRANSACRegressor {
346                fitted_estimator: fitted,
347                inlier_mask: mask,
348            }),
349            _ => Err(FerroError::ConvergenceFailure {
350                iterations: self.max_trials,
351                message: "RANSAC could not find a valid model after max_trials iterations".into(),
352            }),
353        }
354    }
355}
356
357impl<F, Fitted> Predict<Array2<F>> for FittedRANSACRegressor<Fitted>
358where
359    F: Float + Send + Sync + 'static,
360    Fitted: Predict<Array2<F>, Output = Array1<F>, Error = FerroError>,
361{
362    type Output = Array1<F>;
363    type Error = FerroError;
364
365    /// Predict target values using the base estimator fitted on inliers.
366    ///
367    /// # Errors
368    ///
369    /// Returns any error from the base estimator's predict method.
370    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
371        self.fitted_estimator.predict(x)
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use crate::LinearRegression;
379    use approx::assert_relative_eq;
380    use ndarray::array;
381
382    #[test]
383    fn test_ransac_no_outliers() {
384        // Perfect linear data, no outliers.
385        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
386        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
387
388        let base = LinearRegression::<f64>::new();
389        let model = RANSACRegressor::new(base)
390            .with_random_state(42)
391            .with_residual_threshold(1.0);
392        let fitted = model.fit(&x, &y).unwrap();
393
394        // All should be inliers.
395        let mask = fitted.inlier_mask();
396        assert!(mask.iter().all(|&v| v), "All should be inliers");
397
398        // Predictions should be accurate.
399        let preds = fitted.predict(&x).unwrap();
400        for (p, &actual) in preds.iter().zip(y.iter()) {
401            assert_relative_eq!(*p, actual, epsilon = 0.5);
402        }
403    }
404
405    #[test]
406    fn test_ransac_with_outlier() {
407        // y = 2x, but one outlier.
408        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
409        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0]; // outlier at idx 5
410
411        let base = LinearRegression::<f64>::new();
412        let model = RANSACRegressor::new(base)
413            .with_random_state(42)
414            .with_max_trials(200)
415            .with_residual_threshold(2.0);
416        let fitted = model.fit(&x, &y).unwrap();
417
418        let mask = fitted.inlier_mask();
419        // The outlier at index 5 should be detected.
420        assert!(!mask[5], "Outlier at index 5 should not be an inlier");
421
422        // Most other points should be inliers.
423        let n_inliers: usize = mask.iter().filter(|&&v| v).count();
424        assert!(
425            n_inliers >= 4,
426            "Expected at least 4 inliers, got {n_inliers}"
427        );
428
429        // The prediction at x=3 should be close to 6.
430        let x_test = Array2::from_shape_vec((1, 1), vec![3.0]).unwrap();
431        let pred = fitted.predict(&x_test).unwrap();
432        assert!(
433            (pred[0] - 6.0).abs() < 3.0,
434            "Prediction at x=3 should be near 6.0, got {}",
435            pred[0]
436        );
437    }
438
439    #[test]
440    fn test_ransac_multiple_outliers() {
441        // y = x + 1, with two outliers.
442        let x =
443            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
444        let y = array![2.0, 3.0, 50.0, 5.0, 6.0, -40.0, 8.0, 9.0]; // outliers at 2 and 5
445
446        let base = LinearRegression::<f64>::new();
447        let model = RANSACRegressor::new(base)
448            .with_random_state(123)
449            .with_max_trials(500)
450            .with_residual_threshold(2.0);
451        let fitted = model.fit(&x, &y).unwrap();
452
453        let mask = fitted.inlier_mask();
454        // Outliers at index 2 and 5 should be detected.
455        assert!(!mask[2], "Outlier at index 2 should not be an inlier");
456        assert!(!mask[5], "Outlier at index 5 should not be an inlier");
457    }
458
459    #[test]
460    fn test_ransac_shape_mismatch() {
461        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
462        let y = array![1.0, 2.0];
463
464        let base = LinearRegression::<f64>::new();
465        let model = RANSACRegressor::new(base);
466        assert!(model.fit(&x, &y).is_err());
467    }
468
469    #[test]
470    fn test_ransac_insufficient_samples() {
471        let x = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
472        let y = array![1.0];
473
474        let base = LinearRegression::<f64>::new();
475        let model = RANSACRegressor::new(base).with_min_samples(3);
476        assert!(model.fit(&x, &y).is_err());
477    }
478
479    #[test]
480    fn test_ransac_reproducible_with_seed() {
481        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
482        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0];
483
484        let base1 = LinearRegression::<f64>::new();
485        let model1 = RANSACRegressor::new(base1)
486            .with_random_state(42)
487            .with_residual_threshold(2.0);
488        let fitted1 = model1.fit(&x, &y).unwrap();
489
490        let base2 = LinearRegression::<f64>::new();
491        let model2 = RANSACRegressor::new(base2)
492            .with_random_state(42)
493            .with_residual_threshold(2.0);
494        let fitted2 = model2.fit(&x, &y).unwrap();
495
496        // Same seed should produce same inlier mask.
497        assert_eq!(fitted1.inlier_mask(), fitted2.inlier_mask());
498    }
499
500    #[test]
501    fn test_ransac_auto_threshold() {
502        // No explicit threshold — should use MAD.
503        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
504        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0];
505
506        let base = LinearRegression::<f64>::new();
507        let model = RANSACRegressor::new(base)
508            .with_random_state(42)
509            .with_max_trials(200);
510        let fitted = model.fit(&x, &y).unwrap();
511
512        let mask = fitted.inlier_mask();
513        // At least some points should be inliers.
514        let n_inliers: usize = mask.iter().filter(|&&v| v).count();
515        assert!(
516            n_inliers >= 3,
517            "Expected at least 3 inliers, got {n_inliers}"
518        );
519    }
520}