1use ferrolearn_core::error::FerroError;
38use ferrolearn_core::traits::{Fit, Predict};
39use ndarray::{Array1, Array2, ScalarOperand};
40use num_traits::Float;
41use rand::Rng;
42use rand::SeedableRng;
43
44#[derive(Debug, Clone)]
59pub struct RANSACRegressor<F, E> {
60 pub estimator: E,
62 pub min_samples: Option<usize>,
64 pub residual_threshold: Option<F>,
67 pub max_trials: usize,
69 pub random_state: Option<u64>,
71}
72
73impl<F: Float, E> RANSACRegressor<F, E> {
74 #[must_use]
80 pub fn new(estimator: E) -> Self {
81 Self {
82 estimator,
83 min_samples: None,
84 residual_threshold: None,
85 max_trials: 100,
86 random_state: None,
87 }
88 }
89
90 #[must_use]
92 pub fn with_min_samples(mut self, min_samples: usize) -> Self {
93 self.min_samples = Some(min_samples);
94 self
95 }
96
97 #[must_use]
99 pub fn with_residual_threshold(mut self, threshold: F) -> Self {
100 self.residual_threshold = Some(threshold);
101 self
102 }
103
104 #[must_use]
106 pub fn with_max_trials(mut self, max_trials: usize) -> Self {
107 self.max_trials = max_trials;
108 self
109 }
110
111 #[must_use]
113 pub fn with_random_state(mut self, seed: u64) -> Self {
114 self.random_state = Some(seed);
115 self
116 }
117}
118
119#[derive(Debug, Clone)]
127pub struct FittedRANSACRegressor<Fitted> {
128 fitted_estimator: Fitted,
130 inlier_mask: Vec<bool>,
132}
133
134impl<Fitted> FittedRANSACRegressor<Fitted> {
135 #[must_use]
137 pub fn inlier_mask(&self) -> &[bool] {
138 &self.inlier_mask
139 }
140}
141
142fn median<F: Float>(values: &[F]) -> F {
148 let mut sorted: Vec<F> = values.to_vec();
149 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
150 let n = sorted.len();
151 if n == 0 {
152 return F::zero();
153 }
154 if n % 2 == 0 {
155 (sorted[n / 2 - 1] + sorted[n / 2]) / (F::one() + F::one())
156 } else {
157 sorted[n / 2]
158 }
159}
160
161fn mad<F: Float>(values: &[F]) -> F {
163 let med = median(values);
164 let abs_devs: Vec<F> = values.iter().map(|&v| (v - med).abs()).collect();
165 median(&abs_devs)
166}
167
168fn sample_indices<R: Rng>(rng: &mut R, n: usize, k: usize) -> Vec<usize> {
174 let mut indices: Vec<usize> = (0..n).collect();
175 for i in 0..k {
176 let j = rng.random_range(i..n);
177 indices.swap(i, j);
178 }
179 indices.truncate(k);
180 indices
181}
182
183fn subset<F: Float>(x: &Array2<F>, y: &Array1<F>, indices: &[usize]) -> (Array2<F>, Array1<F>) {
185 let n_features = x.ncols();
186 let n = indices.len();
187 let mut x_sub = Array2::<F>::zeros((n, n_features));
188 let mut y_sub = Array1::<F>::zeros(n);
189 for (row, &idx) in indices.iter().enumerate() {
190 for col in 0..n_features {
191 x_sub[[row, col]] = x[[idx, col]];
192 }
193 y_sub[row] = y[idx];
194 }
195 (x_sub, y_sub)
196}
197
198impl<F, E, Ef> Fit<Array2<F>, Array1<F>> for RANSACRegressor<F, E>
203where
204 F: Float + Send + Sync + ScalarOperand + num_traits::FromPrimitive + 'static,
205 E: Fit<Array2<F>, Array1<F>, Fitted = Ef, Error = FerroError> + Clone,
206 Ef: Predict<Array2<F>, Output = Array1<F>, Error = FerroError> + Clone,
207{
208 type Fitted = FittedRANSACRegressor<Ef>;
209 type Error = FerroError;
210
211 fn fit(
220 &self,
221 x: &Array2<F>,
222 y: &Array1<F>,
223 ) -> Result<FittedRANSACRegressor<E::Fitted>, FerroError> {
224 let (n_samples, n_features) = x.dim();
225
226 if n_samples != y.len() {
227 return Err(FerroError::ShapeMismatch {
228 expected: vec![n_samples],
229 actual: vec![y.len()],
230 context: "y length must match number of samples in X".into(),
231 });
232 }
233
234 let min_samples = self.min_samples.unwrap_or(n_features + 1).max(1);
235
236 if n_samples < min_samples {
237 return Err(FerroError::InsufficientSamples {
238 required: min_samples,
239 actual: n_samples,
240 context: "RANSAC requires at least min_samples samples".into(),
241 });
242 }
243
244 let threshold = match self.residual_threshold {
246 Some(t) => t,
247 None => {
248 let y_mad = mad(&y.to_vec());
249 if y_mad <= F::epsilon() {
250 F::from(1e-6).unwrap()
252 } else {
253 y_mad
254 }
255 }
256 };
257
258 let mut rng = match self.random_state {
259 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
260 None => rand::rngs::StdRng::seed_from_u64(42),
261 };
262
263 let mut best_fitted: Option<E::Fitted> = None;
264 let mut best_inlier_mask: Option<Vec<bool>> = None;
265 let mut best_n_inliers = 0usize;
266 let mut best_residual_sum = F::infinity();
267
268 for _ in 0..self.max_trials {
269 let indices = sample_indices(&mut rng, n_samples, min_samples);
271 let (x_sub, y_sub) = subset(x, y, &indices);
272
273 let fitted = match self.estimator.fit(&x_sub, &y_sub) {
275 Ok(f) => f,
276 Err(_) => continue, };
278
279 let preds = match fitted.predict(x) {
281 Ok(p) => p,
282 Err(_) => continue,
283 };
284
285 let mut inlier_mask = vec![false; n_samples];
286 let mut n_inliers = 0usize;
287 let mut residual_sum = F::zero();
288
289 for i in 0..n_samples {
290 let residual = (preds[i] - y[i]).abs();
291 if residual <= threshold {
292 inlier_mask[i] = true;
293 n_inliers += 1;
294 residual_sum = residual_sum + residual;
295 }
296 }
297
298 let is_better = n_inliers > best_n_inliers
300 || (n_inliers == best_n_inliers && residual_sum < best_residual_sum);
301
302 if is_better && n_inliers >= min_samples {
303 let inlier_indices: Vec<usize> = inlier_mask
305 .iter()
306 .enumerate()
307 .filter(|&(_, &is_inlier)| is_inlier)
308 .map(|(i, _)| i)
309 .collect();
310 let (x_inlier, y_inlier) = subset(x, y, &inlier_indices);
311
312 match self.estimator.fit(&x_inlier, &y_inlier) {
313 Ok(refit) => {
314 if let Ok(new_preds) = refit.predict(x) {
316 let mut new_mask = vec![false; n_samples];
317 let mut new_n_inliers = 0;
318 let mut new_residual_sum = F::zero();
319 for i in 0..n_samples {
320 let r = (new_preds[i] - y[i]).abs();
321 if r <= threshold {
322 new_mask[i] = true;
323 new_n_inliers += 1;
324 new_residual_sum = new_residual_sum + r;
325 }
326 }
327 best_fitted = Some(refit);
328 best_inlier_mask = Some(new_mask);
329 best_n_inliers = new_n_inliers;
330 best_residual_sum = new_residual_sum;
331 }
332 }
333 Err(_) => {
334 best_fitted = Some(fitted);
336 best_inlier_mask = Some(inlier_mask);
337 best_n_inliers = n_inliers;
338 best_residual_sum = residual_sum;
339 }
340 }
341 }
342 }
343
344 match (best_fitted, best_inlier_mask) {
345 (Some(fitted), Some(mask)) => Ok(FittedRANSACRegressor {
346 fitted_estimator: fitted,
347 inlier_mask: mask,
348 }),
349 _ => Err(FerroError::ConvergenceFailure {
350 iterations: self.max_trials,
351 message: "RANSAC could not find a valid model after max_trials iterations".into(),
352 }),
353 }
354 }
355}
356
357impl<F, Fitted> Predict<Array2<F>> for FittedRANSACRegressor<Fitted>
358where
359 F: Float + Send + Sync + 'static,
360 Fitted: Predict<Array2<F>, Output = Array1<F>, Error = FerroError>,
361{
362 type Output = Array1<F>;
363 type Error = FerroError;
364
365 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
371 self.fitted_estimator.predict(x)
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use crate::LinearRegression;
379 use approx::assert_relative_eq;
380 use ndarray::array;
381
382 #[test]
383 fn test_ransac_no_outliers() {
384 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
386 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
387
388 let base = LinearRegression::<f64>::new();
389 let model = RANSACRegressor::new(base)
390 .with_random_state(42)
391 .with_residual_threshold(1.0);
392 let fitted = model.fit(&x, &y).unwrap();
393
394 let mask = fitted.inlier_mask();
396 assert!(mask.iter().all(|&v| v), "All should be inliers");
397
398 let preds = fitted.predict(&x).unwrap();
400 for (p, &actual) in preds.iter().zip(y.iter()) {
401 assert_relative_eq!(*p, actual, epsilon = 0.5);
402 }
403 }
404
405 #[test]
406 fn test_ransac_with_outlier() {
407 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
409 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0]; let base = LinearRegression::<f64>::new();
412 let model = RANSACRegressor::new(base)
413 .with_random_state(42)
414 .with_max_trials(200)
415 .with_residual_threshold(2.0);
416 let fitted = model.fit(&x, &y).unwrap();
417
418 let mask = fitted.inlier_mask();
419 assert!(!mask[5], "Outlier at index 5 should not be an inlier");
421
422 let n_inliers: usize = mask.iter().filter(|&&v| v).count();
424 assert!(
425 n_inliers >= 4,
426 "Expected at least 4 inliers, got {n_inliers}"
427 );
428
429 let x_test = Array2::from_shape_vec((1, 1), vec![3.0]).unwrap();
431 let pred = fitted.predict(&x_test).unwrap();
432 assert!(
433 (pred[0] - 6.0).abs() < 3.0,
434 "Prediction at x=3 should be near 6.0, got {}",
435 pred[0]
436 );
437 }
438
439 #[test]
440 fn test_ransac_multiple_outliers() {
441 let x =
443 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
444 let y = array![2.0, 3.0, 50.0, 5.0, 6.0, -40.0, 8.0, 9.0]; let base = LinearRegression::<f64>::new();
447 let model = RANSACRegressor::new(base)
448 .with_random_state(123)
449 .with_max_trials(500)
450 .with_residual_threshold(2.0);
451 let fitted = model.fit(&x, &y).unwrap();
452
453 let mask = fitted.inlier_mask();
454 assert!(!mask[2], "Outlier at index 2 should not be an inlier");
456 assert!(!mask[5], "Outlier at index 5 should not be an inlier");
457 }
458
459 #[test]
460 fn test_ransac_shape_mismatch() {
461 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
462 let y = array![1.0, 2.0];
463
464 let base = LinearRegression::<f64>::new();
465 let model = RANSACRegressor::new(base);
466 assert!(model.fit(&x, &y).is_err());
467 }
468
469 #[test]
470 fn test_ransac_insufficient_samples() {
471 let x = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
472 let y = array![1.0];
473
474 let base = LinearRegression::<f64>::new();
475 let model = RANSACRegressor::new(base).with_min_samples(3);
476 assert!(model.fit(&x, &y).is_err());
477 }
478
479 #[test]
480 fn test_ransac_reproducible_with_seed() {
481 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
482 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0];
483
484 let base1 = LinearRegression::<f64>::new();
485 let model1 = RANSACRegressor::new(base1)
486 .with_random_state(42)
487 .with_residual_threshold(2.0);
488 let fitted1 = model1.fit(&x, &y).unwrap();
489
490 let base2 = LinearRegression::<f64>::new();
491 let model2 = RANSACRegressor::new(base2)
492 .with_random_state(42)
493 .with_residual_threshold(2.0);
494 let fitted2 = model2.fit(&x, &y).unwrap();
495
496 assert_eq!(fitted1.inlier_mask(), fitted2.inlier_mask());
498 }
499
500 #[test]
501 fn test_ransac_auto_threshold() {
502 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
504 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0];
505
506 let base = LinearRegression::<f64>::new();
507 let model = RANSACRegressor::new(base)
508 .with_random_state(42)
509 .with_max_trials(200);
510 let fitted = model.fit(&x, &y).unwrap();
511
512 let mask = fitted.inlier_mask();
513 let n_inliers: usize = mask.iter().filter(|&&v| v).count();
515 assert!(
516 n_inliers >= 3,
517 "Expected at least 3 inliers, got {n_inliers}"
518 );
519 }
520}