Skip to main content

ferrolearn_linear/
ransac.rs

1//! RANSAC (RANdom SAmple Consensus) robust regression.
2//!
3//! This module provides [`RANSACRegressor`], a meta-estimator that fits a
4//! base regressor to inlier data, automatically detecting and excluding
5//! outliers.
6//!
7//! # Algorithm
8//!
9//! 1. Randomly sample `min_samples` points.
10//! 2. Fit the base estimator on the sample.
11//! 3. Compute residuals for all points, identify inliers (residual below
12//!    `residual_threshold`).
13//! 4. If enough inliers, refit on all inliers.
14//! 5. Keep the model with the most inliers (ties broken by lowest residual).
15//!
16//! # Examples
17//!
18//! ```
19//! use ferrolearn_linear::ransac::RANSACRegressor;
20//! use ferrolearn_linear::LinearRegression;
21//! use ferrolearn_core::{Fit, Predict};
22//! use ndarray::{array, Array1, Array2};
23//!
24//! // Data with an outlier at index 4.
25//! let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
26//! let y = array![2.0, 4.0, 6.0, 8.0, 100.0]; // last point is outlier
27//!
28//! let base = LinearRegression::<f64>::new();
29//! let model = RANSACRegressor::new(base);
30//! let fitted = model.fit(&x, &y).unwrap();
31//!
32//! // The outlier should be detected.
33//! let mask = fitted.inlier_mask();
34//! assert!(!mask[4], "outlier at index 4 should be detected");
35//! ```
36
37use ferrolearn_core::error::FerroError;
38use ferrolearn_core::traits::{Fit, Predict};
39use ndarray::{Array1, Array2, ScalarOperand};
40use num_traits::Float;
41use rand::Rng;
42use rand::SeedableRng;
43
44// ---------------------------------------------------------------------------
45// RANSACRegressor (unfitted)
46// ---------------------------------------------------------------------------
47
48/// RANSAC robust regression meta-estimator.
49///
50/// Wraps a base regressor (e.g., [`LinearRegression`](crate::LinearRegression))
51/// and repeatedly fits it on random subsets to find a model robust to
52/// outliers.
53///
54/// # Type Parameters
55///
56/// - `F`: The floating-point type (`f32` or `f64`).
57/// - `E`: The base estimator type.
58#[derive(Debug, Clone)]
59pub struct RANSACRegressor<F, E> {
60    /// The base estimator.
61    pub estimator: E,
62    /// Minimum number of samples for fitting.
63    pub min_samples: Option<usize>,
64    /// Residual threshold: points with absolute residual below this are
65    /// considered inliers. If `None`, uses the MAD of the target.
66    pub residual_threshold: Option<F>,
67    /// Maximum number of random trials.
68    pub max_trials: usize,
69    /// Optional random seed for reproducibility.
70    pub random_state: Option<u64>,
71}
72
73impl<F: Float, E> RANSACRegressor<F, E> {
74    /// Create a new `RANSACRegressor` with the given base estimator.
75    ///
76    /// Defaults: `min_samples = None` (auto: n_features + 1),
77    /// `residual_threshold = None` (auto: MAD), `max_trials = 100`,
78    /// `random_state = None`.
79    #[must_use]
80    pub fn new(estimator: E) -> Self {
81        Self {
82            estimator,
83            min_samples: None,
84            residual_threshold: None,
85            max_trials: 100,
86            random_state: None,
87        }
88    }
89
90    /// Set the minimum number of samples for fitting.
91    #[must_use]
92    pub fn with_min_samples(mut self, min_samples: usize) -> Self {
93        self.min_samples = Some(min_samples);
94        self
95    }
96
97    /// Set the residual threshold for inlier detection.
98    #[must_use]
99    pub fn with_residual_threshold(mut self, threshold: F) -> Self {
100        self.residual_threshold = Some(threshold);
101        self
102    }
103
104    /// Set the maximum number of random trials.
105    #[must_use]
106    pub fn with_max_trials(mut self, max_trials: usize) -> Self {
107        self.max_trials = max_trials;
108        self
109    }
110
111    /// Set the random seed for reproducibility.
112    #[must_use]
113    pub fn with_random_state(mut self, seed: u64) -> Self {
114        self.random_state = Some(seed);
115        self
116    }
117}
118
119// ---------------------------------------------------------------------------
120// FittedRANSACRegressor
121// ---------------------------------------------------------------------------
122
123/// Fitted RANSAC robust regression model.
124///
125/// Stores the best estimator fitted on inlier data, and the inlier mask.
126#[derive(Debug, Clone)]
127pub struct FittedRANSACRegressor<Fitted> {
128    /// The fitted base estimator (fitted on inliers).
129    fitted_estimator: Fitted,
130    /// Boolean mask: true if the sample was classified as an inlier.
131    inlier_mask: Vec<bool>,
132}
133
134impl<Fitted> FittedRANSACRegressor<Fitted> {
135    /// Returns the inlier mask. `true` indicates the sample was an inlier.
136    #[must_use]
137    pub fn inlier_mask(&self) -> &[bool] {
138        &self.inlier_mask
139    }
140}
141
142// ---------------------------------------------------------------------------
143// Helper: Median Absolute Deviation
144// ---------------------------------------------------------------------------
145
146/// Compute the median of a slice of floats.
147fn median<F: Float>(values: &[F]) -> F {
148    let mut sorted: Vec<F> = values.to_vec();
149    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
150    let n = sorted.len();
151    if n == 0 {
152        return F::zero();
153    }
154    if n % 2 == 0 {
155        (sorted[n / 2 - 1] + sorted[n / 2]) / (F::one() + F::one())
156    } else {
157        sorted[n / 2]
158    }
159}
160
161/// Compute the Median Absolute Deviation (MAD) of a slice.
162fn mad<F: Float>(values: &[F]) -> F {
163    let med = median(values);
164    let abs_devs: Vec<F> = values.iter().map(|&v| (v - med).abs()).collect();
165    median(&abs_devs)
166}
167
168// ---------------------------------------------------------------------------
169// Random subset sampling
170// ---------------------------------------------------------------------------
171
172/// Sample `k` distinct indices from `0..n` using Fisher-Yates.
173fn sample_indices<R: Rng>(rng: &mut R, n: usize, k: usize) -> Vec<usize> {
174    let mut indices: Vec<usize> = (0..n).collect();
175    for i in 0..k {
176        let j = rng.random_range(i..n);
177        indices.swap(i, j);
178    }
179    indices.truncate(k);
180    indices
181}
182
183/// Extract a subset of rows from a 2D array and a 1D array.
184fn subset<F: Float>(x: &Array2<F>, y: &Array1<F>, indices: &[usize]) -> (Array2<F>, Array1<F>) {
185    let n_features = x.ncols();
186    let n = indices.len();
187    let mut x_sub = Array2::<F>::zeros((n, n_features));
188    let mut y_sub = Array1::<F>::zeros(n);
189    for (row, &idx) in indices.iter().enumerate() {
190        for col in 0..n_features {
191            x_sub[[row, col]] = x[[idx, col]];
192        }
193        y_sub[row] = y[idx];
194    }
195    (x_sub, y_sub)
196}
197
198// ---------------------------------------------------------------------------
199// Fit and Predict
200// ---------------------------------------------------------------------------
201
202impl<F, E, Ef> Fit<Array2<F>, Array1<F>> for RANSACRegressor<F, E>
203where
204    F: Float + Send + Sync + ScalarOperand + num_traits::FromPrimitive + 'static,
205    E: Fit<Array2<F>, Array1<F>, Fitted = Ef, Error = FerroError> + Clone,
206    Ef: Predict<Array2<F>, Output = Array1<F>, Error = FerroError> + Clone,
207{
208    type Fitted = FittedRANSACRegressor<Ef>;
209    type Error = FerroError;
210
211    /// Fit the RANSAC model by repeatedly sampling and fitting.
212    ///
213    /// # Errors
214    ///
215    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have different
216    /// sample counts.
217    /// Returns [`FerroError::ConvergenceFailure`] if no valid model is found
218    /// after `max_trials` iterations.
219    fn fit(
220        &self,
221        x: &Array2<F>,
222        y: &Array1<F>,
223    ) -> Result<FittedRANSACRegressor<E::Fitted>, FerroError> {
224        let (n_samples, n_features) = x.dim();
225
226        if n_samples != y.len() {
227            return Err(FerroError::ShapeMismatch {
228                expected: vec![n_samples],
229                actual: vec![y.len()],
230                context: "y length must match number of samples in X".into(),
231            });
232        }
233
234        let min_samples = self.min_samples.unwrap_or(n_features + 1).max(1);
235
236        if n_samples < min_samples {
237            return Err(FerroError::InsufficientSamples {
238                required: min_samples,
239                actual: n_samples,
240                context: "RANSAC requires at least min_samples samples".into(),
241            });
242        }
243
244        // Compute residual threshold if not provided.
245        let threshold = if let Some(t) = self.residual_threshold {
246            t
247        } else {
248            let y_mad = mad(&y.to_vec());
249            if y_mad <= F::epsilon() {
250                // If MAD is zero (constant target), use a small default.
251                F::from(1e-6).unwrap()
252            } else {
253                y_mad
254            }
255        };
256
257        let mut rng = match self.random_state {
258            Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
259            None => rand::rngs::StdRng::seed_from_u64(42),
260        };
261
262        let mut best_fitted: Option<E::Fitted> = None;
263        let mut best_inlier_mask: Option<Vec<bool>> = None;
264        let mut best_n_inliers = 0usize;
265        let mut best_residual_sum = F::infinity();
266
267        for _ in 0..self.max_trials {
268            // Sample random subset.
269            let indices = sample_indices(&mut rng, n_samples, min_samples);
270            let (x_sub, y_sub) = subset(x, y, &indices);
271
272            // Fit base estimator on the subset.
273            let fitted = match self.estimator.fit(&x_sub, &y_sub) {
274                Ok(f) => f,
275                Err(_) => continue, // Skip failed fits.
276            };
277
278            // Compute residuals for all points.
279            let preds = match fitted.predict(x) {
280                Ok(p) => p,
281                Err(_) => continue,
282            };
283
284            let mut inlier_mask = vec![false; n_samples];
285            let mut n_inliers = 0usize;
286            let mut residual_sum = F::zero();
287
288            for i in 0..n_samples {
289                let residual = (preds[i] - y[i]).abs();
290                if residual <= threshold {
291                    inlier_mask[i] = true;
292                    n_inliers += 1;
293                    residual_sum = residual_sum + residual;
294                }
295            }
296
297            // Check if this is better than the current best.
298            let is_better = n_inliers > best_n_inliers
299                || (n_inliers == best_n_inliers && residual_sum < best_residual_sum);
300
301            if is_better && n_inliers >= min_samples {
302                // Refit on all inliers.
303                let inlier_indices: Vec<usize> = inlier_mask
304                    .iter()
305                    .enumerate()
306                    .filter(|&(_, &is_inlier)| is_inlier)
307                    .map(|(i, _)| i)
308                    .collect();
309                let (x_inlier, y_inlier) = subset(x, y, &inlier_indices);
310
311                if let Ok(refit) = self.estimator.fit(&x_inlier, &y_inlier) {
312                    // Recompute inlier mask with the refitted model.
313                    if let Ok(new_preds) = refit.predict(x) {
314                        let mut new_mask = vec![false; n_samples];
315                        let mut new_n_inliers = 0;
316                        let mut new_residual_sum = F::zero();
317                        for i in 0..n_samples {
318                            let r = (new_preds[i] - y[i]).abs();
319                            if r <= threshold {
320                                new_mask[i] = true;
321                                new_n_inliers += 1;
322                                new_residual_sum = new_residual_sum + r;
323                            }
324                        }
325                        best_fitted = Some(refit);
326                        best_inlier_mask = Some(new_mask);
327                        best_n_inliers = new_n_inliers;
328                        best_residual_sum = new_residual_sum;
329                    }
330                } else {
331                    // Keep the original fit if refit fails.
332                    best_fitted = Some(fitted);
333                    best_inlier_mask = Some(inlier_mask);
334                    best_n_inliers = n_inliers;
335                    best_residual_sum = residual_sum;
336                }
337            }
338        }
339
340        match (best_fitted, best_inlier_mask) {
341            (Some(fitted), Some(mask)) => Ok(FittedRANSACRegressor {
342                fitted_estimator: fitted,
343                inlier_mask: mask,
344            }),
345            _ => Err(FerroError::ConvergenceFailure {
346                iterations: self.max_trials,
347                message: "RANSAC could not find a valid model after max_trials iterations".into(),
348            }),
349        }
350    }
351}
352
353impl<F, Fitted> Predict<Array2<F>> for FittedRANSACRegressor<Fitted>
354where
355    F: Float + Send + Sync + 'static,
356    Fitted: Predict<Array2<F>, Output = Array1<F>, Error = FerroError>,
357{
358    type Output = Array1<F>;
359    type Error = FerroError;
360
361    /// Predict target values using the base estimator fitted on inliers.
362    ///
363    /// # Errors
364    ///
365    /// Returns any error from the base estimator's predict method.
366    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
367        self.fitted_estimator.predict(x)
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::LinearRegression;
375    use approx::assert_relative_eq;
376    use ndarray::array;
377
378    #[test]
379    fn test_ransac_no_outliers() {
380        // Perfect linear data, no outliers.
381        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
382        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
383
384        let base = LinearRegression::<f64>::new();
385        let model = RANSACRegressor::new(base)
386            .with_random_state(42)
387            .with_residual_threshold(1.0);
388        let fitted = model.fit(&x, &y).unwrap();
389
390        // All should be inliers.
391        let mask = fitted.inlier_mask();
392        assert!(mask.iter().all(|&v| v), "All should be inliers");
393
394        // Predictions should be accurate.
395        let preds = fitted.predict(&x).unwrap();
396        for (p, &actual) in preds.iter().zip(y.iter()) {
397            assert_relative_eq!(*p, actual, epsilon = 0.5);
398        }
399    }
400
401    #[test]
402    fn test_ransac_with_outlier() {
403        // y = 2x, but one outlier.
404        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
405        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0]; // outlier at idx 5
406
407        let base = LinearRegression::<f64>::new();
408        let model = RANSACRegressor::new(base)
409            .with_random_state(42)
410            .with_max_trials(200)
411            .with_residual_threshold(2.0);
412        let fitted = model.fit(&x, &y).unwrap();
413
414        let mask = fitted.inlier_mask();
415        // The outlier at index 5 should be detected.
416        assert!(!mask[5], "Outlier at index 5 should not be an inlier");
417
418        // Most other points should be inliers.
419        let n_inliers: usize = mask.iter().filter(|&&v| v).count();
420        assert!(
421            n_inliers >= 4,
422            "Expected at least 4 inliers, got {n_inliers}"
423        );
424
425        // The prediction at x=3 should be close to 6.
426        let x_test = Array2::from_shape_vec((1, 1), vec![3.0]).unwrap();
427        let pred = fitted.predict(&x_test).unwrap();
428        assert!(
429            (pred[0] - 6.0).abs() < 3.0,
430            "Prediction at x=3 should be near 6.0, got {}",
431            pred[0]
432        );
433    }
434
435    #[test]
436    fn test_ransac_multiple_outliers() {
437        // y = x + 1, with two outliers.
438        let x =
439            Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
440        let y = array![2.0, 3.0, 50.0, 5.0, 6.0, -40.0, 8.0, 9.0]; // outliers at 2 and 5
441
442        let base = LinearRegression::<f64>::new();
443        let model = RANSACRegressor::new(base)
444            .with_random_state(123)
445            .with_max_trials(500)
446            .with_residual_threshold(2.0);
447        let fitted = model.fit(&x, &y).unwrap();
448
449        let mask = fitted.inlier_mask();
450        // Outliers at index 2 and 5 should be detected.
451        assert!(!mask[2], "Outlier at index 2 should not be an inlier");
452        assert!(!mask[5], "Outlier at index 5 should not be an inlier");
453    }
454
455    #[test]
456    fn test_ransac_shape_mismatch() {
457        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
458        let y = array![1.0, 2.0];
459
460        let base = LinearRegression::<f64>::new();
461        let model = RANSACRegressor::new(base);
462        assert!(model.fit(&x, &y).is_err());
463    }
464
465    #[test]
466    fn test_ransac_insufficient_samples() {
467        let x = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
468        let y = array![1.0];
469
470        let base = LinearRegression::<f64>::new();
471        let model = RANSACRegressor::new(base).with_min_samples(3);
472        assert!(model.fit(&x, &y).is_err());
473    }
474
475    #[test]
476    fn test_ransac_reproducible_with_seed() {
477        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
478        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0];
479
480        let base1 = LinearRegression::<f64>::new();
481        let model1 = RANSACRegressor::new(base1)
482            .with_random_state(42)
483            .with_residual_threshold(2.0);
484        let fitted1 = model1.fit(&x, &y).unwrap();
485
486        let base2 = LinearRegression::<f64>::new();
487        let model2 = RANSACRegressor::new(base2)
488            .with_random_state(42)
489            .with_residual_threshold(2.0);
490        let fitted2 = model2.fit(&x, &y).unwrap();
491
492        // Same seed should produce same inlier mask.
493        assert_eq!(fitted1.inlier_mask(), fitted2.inlier_mask());
494    }
495
496    #[test]
497    fn test_ransac_auto_threshold() {
498        // No explicit threshold — should use MAD.
499        let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
500        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0];
501
502        let base = LinearRegression::<f64>::new();
503        let model = RANSACRegressor::new(base)
504            .with_random_state(42)
505            .with_max_trials(200);
506        let fitted = model.fit(&x, &y).unwrap();
507
508        let mask = fitted.inlier_mask();
509        // At least some points should be inliers.
510        let n_inliers: usize = mask.iter().filter(|&&v| v).count();
511        assert!(
512            n_inliers >= 3,
513            "Expected at least 3 inliers, got {n_inliers}"
514        );
515    }
516}