1use ferrolearn_core::error::FerroError;
38use ferrolearn_core::traits::{Fit, Predict};
39use ndarray::{Array1, Array2, ScalarOperand};
40use num_traits::Float;
41use rand::Rng;
42use rand::SeedableRng;
43
44#[derive(Debug, Clone)]
59pub struct RANSACRegressor<F, E> {
60 pub estimator: E,
62 pub min_samples: Option<usize>,
64 pub residual_threshold: Option<F>,
67 pub max_trials: usize,
69 pub random_state: Option<u64>,
71}
72
73impl<F: Float, E> RANSACRegressor<F, E> {
74 #[must_use]
80 pub fn new(estimator: E) -> Self {
81 Self {
82 estimator,
83 min_samples: None,
84 residual_threshold: None,
85 max_trials: 100,
86 random_state: None,
87 }
88 }
89
90 #[must_use]
92 pub fn with_min_samples(mut self, min_samples: usize) -> Self {
93 self.min_samples = Some(min_samples);
94 self
95 }
96
97 #[must_use]
99 pub fn with_residual_threshold(mut self, threshold: F) -> Self {
100 self.residual_threshold = Some(threshold);
101 self
102 }
103
104 #[must_use]
106 pub fn with_max_trials(mut self, max_trials: usize) -> Self {
107 self.max_trials = max_trials;
108 self
109 }
110
111 #[must_use]
113 pub fn with_random_state(mut self, seed: u64) -> Self {
114 self.random_state = Some(seed);
115 self
116 }
117}
118
119#[derive(Debug, Clone)]
127pub struct FittedRANSACRegressor<Fitted> {
128 fitted_estimator: Fitted,
130 inlier_mask: Vec<bool>,
132}
133
134impl<Fitted> FittedRANSACRegressor<Fitted> {
135 #[must_use]
137 pub fn inlier_mask(&self) -> &[bool] {
138 &self.inlier_mask
139 }
140}
141
142fn median<F: Float>(values: &[F]) -> F {
148 let mut sorted: Vec<F> = values.to_vec();
149 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
150 let n = sorted.len();
151 if n == 0 {
152 return F::zero();
153 }
154 if n % 2 == 0 {
155 (sorted[n / 2 - 1] + sorted[n / 2]) / (F::one() + F::one())
156 } else {
157 sorted[n / 2]
158 }
159}
160
161fn mad<F: Float>(values: &[F]) -> F {
163 let med = median(values);
164 let abs_devs: Vec<F> = values.iter().map(|&v| (v - med).abs()).collect();
165 median(&abs_devs)
166}
167
168fn sample_indices<R: Rng>(rng: &mut R, n: usize, k: usize) -> Vec<usize> {
174 let mut indices: Vec<usize> = (0..n).collect();
175 for i in 0..k {
176 let j = rng.random_range(i..n);
177 indices.swap(i, j);
178 }
179 indices.truncate(k);
180 indices
181}
182
183fn subset<F: Float>(x: &Array2<F>, y: &Array1<F>, indices: &[usize]) -> (Array2<F>, Array1<F>) {
185 let n_features = x.ncols();
186 let n = indices.len();
187 let mut x_sub = Array2::<F>::zeros((n, n_features));
188 let mut y_sub = Array1::<F>::zeros(n);
189 for (row, &idx) in indices.iter().enumerate() {
190 for col in 0..n_features {
191 x_sub[[row, col]] = x[[idx, col]];
192 }
193 y_sub[row] = y[idx];
194 }
195 (x_sub, y_sub)
196}
197
198impl<F, E, Ef> Fit<Array2<F>, Array1<F>> for RANSACRegressor<F, E>
203where
204 F: Float + Send + Sync + ScalarOperand + num_traits::FromPrimitive + 'static,
205 E: Fit<Array2<F>, Array1<F>, Fitted = Ef, Error = FerroError> + Clone,
206 Ef: Predict<Array2<F>, Output = Array1<F>, Error = FerroError> + Clone,
207{
208 type Fitted = FittedRANSACRegressor<Ef>;
209 type Error = FerroError;
210
211 fn fit(
220 &self,
221 x: &Array2<F>,
222 y: &Array1<F>,
223 ) -> Result<FittedRANSACRegressor<E::Fitted>, FerroError> {
224 let (n_samples, n_features) = x.dim();
225
226 if n_samples != y.len() {
227 return Err(FerroError::ShapeMismatch {
228 expected: vec![n_samples],
229 actual: vec![y.len()],
230 context: "y length must match number of samples in X".into(),
231 });
232 }
233
234 let min_samples = self.min_samples.unwrap_or(n_features + 1).max(1);
235
236 if n_samples < min_samples {
237 return Err(FerroError::InsufficientSamples {
238 required: min_samples,
239 actual: n_samples,
240 context: "RANSAC requires at least min_samples samples".into(),
241 });
242 }
243
244 let threshold = if let Some(t) = self.residual_threshold {
246 t
247 } else {
248 let y_mad = mad(&y.to_vec());
249 if y_mad <= F::epsilon() {
250 F::from(1e-6).unwrap()
252 } else {
253 y_mad
254 }
255 };
256
257 let mut rng = match self.random_state {
258 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
259 None => rand::rngs::StdRng::seed_from_u64(42),
260 };
261
262 let mut best_fitted: Option<E::Fitted> = None;
263 let mut best_inlier_mask: Option<Vec<bool>> = None;
264 let mut best_n_inliers = 0usize;
265 let mut best_residual_sum = F::infinity();
266
267 for _ in 0..self.max_trials {
268 let indices = sample_indices(&mut rng, n_samples, min_samples);
270 let (x_sub, y_sub) = subset(x, y, &indices);
271
272 let fitted = match self.estimator.fit(&x_sub, &y_sub) {
274 Ok(f) => f,
275 Err(_) => continue, };
277
278 let preds = match fitted.predict(x) {
280 Ok(p) => p,
281 Err(_) => continue,
282 };
283
284 let mut inlier_mask = vec![false; n_samples];
285 let mut n_inliers = 0usize;
286 let mut residual_sum = F::zero();
287
288 for i in 0..n_samples {
289 let residual = (preds[i] - y[i]).abs();
290 if residual <= threshold {
291 inlier_mask[i] = true;
292 n_inliers += 1;
293 residual_sum = residual_sum + residual;
294 }
295 }
296
297 let is_better = n_inliers > best_n_inliers
299 || (n_inliers == best_n_inliers && residual_sum < best_residual_sum);
300
301 if is_better && n_inliers >= min_samples {
302 let inlier_indices: Vec<usize> = inlier_mask
304 .iter()
305 .enumerate()
306 .filter(|&(_, &is_inlier)| is_inlier)
307 .map(|(i, _)| i)
308 .collect();
309 let (x_inlier, y_inlier) = subset(x, y, &inlier_indices);
310
311 if let Ok(refit) = self.estimator.fit(&x_inlier, &y_inlier) {
312 if let Ok(new_preds) = refit.predict(x) {
314 let mut new_mask = vec![false; n_samples];
315 let mut new_n_inliers = 0;
316 let mut new_residual_sum = F::zero();
317 for i in 0..n_samples {
318 let r = (new_preds[i] - y[i]).abs();
319 if r <= threshold {
320 new_mask[i] = true;
321 new_n_inliers += 1;
322 new_residual_sum = new_residual_sum + r;
323 }
324 }
325 best_fitted = Some(refit);
326 best_inlier_mask = Some(new_mask);
327 best_n_inliers = new_n_inliers;
328 best_residual_sum = new_residual_sum;
329 }
330 } else {
331 best_fitted = Some(fitted);
333 best_inlier_mask = Some(inlier_mask);
334 best_n_inliers = n_inliers;
335 best_residual_sum = residual_sum;
336 }
337 }
338 }
339
340 match (best_fitted, best_inlier_mask) {
341 (Some(fitted), Some(mask)) => Ok(FittedRANSACRegressor {
342 fitted_estimator: fitted,
343 inlier_mask: mask,
344 }),
345 _ => Err(FerroError::ConvergenceFailure {
346 iterations: self.max_trials,
347 message: "RANSAC could not find a valid model after max_trials iterations".into(),
348 }),
349 }
350 }
351}
352
353impl<F, Fitted> Predict<Array2<F>> for FittedRANSACRegressor<Fitted>
354where
355 F: Float + Send + Sync + 'static,
356 Fitted: Predict<Array2<F>, Output = Array1<F>, Error = FerroError>,
357{
358 type Output = Array1<F>;
359 type Error = FerroError;
360
361 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
367 self.fitted_estimator.predict(x)
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::LinearRegression;
375 use approx::assert_relative_eq;
376 use ndarray::array;
377
378 #[test]
379 fn test_ransac_no_outliers() {
380 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
382 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
383
384 let base = LinearRegression::<f64>::new();
385 let model = RANSACRegressor::new(base)
386 .with_random_state(42)
387 .with_residual_threshold(1.0);
388 let fitted = model.fit(&x, &y).unwrap();
389
390 let mask = fitted.inlier_mask();
392 assert!(mask.iter().all(|&v| v), "All should be inliers");
393
394 let preds = fitted.predict(&x).unwrap();
396 for (p, &actual) in preds.iter().zip(y.iter()) {
397 assert_relative_eq!(*p, actual, epsilon = 0.5);
398 }
399 }
400
401 #[test]
402 fn test_ransac_with_outlier() {
403 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
405 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0]; let base = LinearRegression::<f64>::new();
408 let model = RANSACRegressor::new(base)
409 .with_random_state(42)
410 .with_max_trials(200)
411 .with_residual_threshold(2.0);
412 let fitted = model.fit(&x, &y).unwrap();
413
414 let mask = fitted.inlier_mask();
415 assert!(!mask[5], "Outlier at index 5 should not be an inlier");
417
418 let n_inliers: usize = mask.iter().filter(|&&v| v).count();
420 assert!(
421 n_inliers >= 4,
422 "Expected at least 4 inliers, got {n_inliers}"
423 );
424
425 let x_test = Array2::from_shape_vec((1, 1), vec![3.0]).unwrap();
427 let pred = fitted.predict(&x_test).unwrap();
428 assert!(
429 (pred[0] - 6.0).abs() < 3.0,
430 "Prediction at x=3 should be near 6.0, got {}",
431 pred[0]
432 );
433 }
434
435 #[test]
436 fn test_ransac_multiple_outliers() {
437 let x =
439 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
440 let y = array![2.0, 3.0, 50.0, 5.0, 6.0, -40.0, 8.0, 9.0]; let base = LinearRegression::<f64>::new();
443 let model = RANSACRegressor::new(base)
444 .with_random_state(123)
445 .with_max_trials(500)
446 .with_residual_threshold(2.0);
447 let fitted = model.fit(&x, &y).unwrap();
448
449 let mask = fitted.inlier_mask();
450 assert!(!mask[2], "Outlier at index 2 should not be an inlier");
452 assert!(!mask[5], "Outlier at index 5 should not be an inlier");
453 }
454
455 #[test]
456 fn test_ransac_shape_mismatch() {
457 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
458 let y = array![1.0, 2.0];
459
460 let base = LinearRegression::<f64>::new();
461 let model = RANSACRegressor::new(base);
462 assert!(model.fit(&x, &y).is_err());
463 }
464
465 #[test]
466 fn test_ransac_insufficient_samples() {
467 let x = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
468 let y = array![1.0];
469
470 let base = LinearRegression::<f64>::new();
471 let model = RANSACRegressor::new(base).with_min_samples(3);
472 assert!(model.fit(&x, &y).is_err());
473 }
474
475 #[test]
476 fn test_ransac_reproducible_with_seed() {
477 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
478 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0];
479
480 let base1 = LinearRegression::<f64>::new();
481 let model1 = RANSACRegressor::new(base1)
482 .with_random_state(42)
483 .with_residual_threshold(2.0);
484 let fitted1 = model1.fit(&x, &y).unwrap();
485
486 let base2 = LinearRegression::<f64>::new();
487 let model2 = RANSACRegressor::new(base2)
488 .with_random_state(42)
489 .with_residual_threshold(2.0);
490 let fitted2 = model2.fit(&x, &y).unwrap();
491
492 assert_eq!(fitted1.inlier_mask(), fitted2.inlier_mask());
494 }
495
496 #[test]
497 fn test_ransac_auto_threshold() {
498 let x = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
500 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 100.0];
501
502 let base = LinearRegression::<f64>::new();
503 let model = RANSACRegressor::new(base)
504 .with_random_state(42)
505 .with_max_trials(200);
506 let fitted = model.fit(&x, &y).unwrap();
507
508 let mask = fitted.inlier_mask();
509 let n_inliers: usize = mask.iter().filter(|&&v| v).count();
511 assert!(
512 n_inliers >= 3,
513 "Expected at least 3 inliers, got {n_inliers}"
514 );
515 }
516}