sklears_dummy/
robust.rs

1//! Robust dummy estimators resistant to outliers and anomalies
2//!
3//! This module provides dummy estimators that are robust to outliers and provide
4//! reliable baseline predictions even in the presence of data quality issues.
5
6use scirs2_core::ndarray::Array1;
7use scirs2_core::random::{
8    essentials::Normal, prelude::*, rngs::StdRng, Distribution, Rng, SeedableRng,
9};
10use sklears_core::error::Result;
11use sklears_core::traits::{Estimator, Fit, Predict};
12use sklears_core::types::{Features, Float};
13
14/// Strategy for robust predictions
15#[derive(Debug, Clone, PartialEq)]
16#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
17pub enum RobustStrategy {
18    /// Outlier-resistant methods using robust statistics
19    OutlierResistant {
20        /// Contamination rate (proportion of outliers expected)
21        contamination: Float,
22        /// Method for outlier detection
23        detection_method: OutlierDetectionMethod,
24    },
25    /// Trimmed mean baselines removing extreme values
26    TrimmedMean {
27        /// Proportion of values to trim from each tail
28        trim_proportion: Float,
29    },
30    /// Robust scale estimation using various estimators
31    RobustScale {
32        /// Scale estimator to use
33        scale_estimator: ScaleEstimator,
34        /// Location estimator to use
35        location_estimator: LocationEstimator,
36    },
37    /// Breakdown point analysis with different robustness levels
38    BreakdownPoint { breakdown_point: Float },
39    /// Influence-resistant methods using M-estimators
40    InfluenceResistant {
41        /// Huber parameter for loss function
42        huber_delta: Float,
43        /// Maximum number of iterations
44        max_iter: usize,
45        /// Convergence tolerance
46        tolerance: Float,
47    },
48}
49
50/// Methods for outlier detection
51#[derive(Debug, Clone, PartialEq)]
52#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
53pub enum OutlierDetectionMethod {
54    /// Interquartile Range (IQR) method
55    IQR { multiplier: Float },
56    /// Z-score method using median absolute deviation
57    ModifiedZScore { threshold: Float },
58    /// Isolation based on distance from median
59    MedianDistance { threshold: Float },
60}
61
62/// Scale estimation methods
63#[derive(Debug, Clone, PartialEq)]
64#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
65pub enum ScaleEstimator {
66    /// Median Absolute Deviation
67    MAD,
68    /// Qn estimator (more efficient than MAD)
69    Qn,
70    /// Inter-Quartile Range
71    IQR,
72    /// Rousseeuw-Croux Sn estimator
73    Sn,
74}
75
76/// Location estimation methods
77#[derive(Debug, Clone, PartialEq)]
78#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
79pub enum LocationEstimator {
80    /// Median (most robust)
81    Median,
82    /// Trimmed mean
83    TrimmedMean { trim_proportion: Float },
84    /// Huber M-estimator
85    Huber { delta: Float },
86    /// Biweight midvariance
87    Biweight,
88}
89
90/// Robust dummy regressor
91#[derive(Debug, Clone)]
92pub struct RobustDummyRegressor<State = sklears_core::traits::Untrained> {
93    /// Strategy for robust predictions
94    pub strategy: RobustStrategy,
95    /// Random state for reproducible output
96    pub random_state: Option<u64>,
97
98    // Fitted parameters
99    /// Robust location estimate
100    pub(crate) robust_location_: Option<Float>,
101    /// Robust scale estimate
102    pub(crate) robust_scale_: Option<Float>,
103    /// Outlier mask (true for outliers)
104    pub(crate) outlier_mask_: Option<Array1<bool>>,
105    /// Clean data after outlier removal
106    pub(crate) clean_data_: Option<Array1<Float>>,
107    /// Breakdown point achieved
108    pub(crate) breakdown_point_: Option<Float>,
109    /// M-estimator weights
110    pub(crate) m_weights_: Option<Array1<Float>>,
111
112    /// Phantom data for state
113    pub(crate) _state: std::marker::PhantomData<State>,
114}
115
116impl RobustDummyRegressor {
117    /// Create a new robust dummy regressor
118    pub fn new(strategy: RobustStrategy) -> Self {
119        Self {
120            strategy,
121            random_state: None,
122            robust_location_: None,
123            robust_scale_: None,
124            outlier_mask_: None,
125            clean_data_: None,
126            breakdown_point_: None,
127            m_weights_: None,
128            _state: std::marker::PhantomData,
129        }
130    }
131
132    /// Set the random state for reproducible output
133    pub fn with_random_state(mut self, random_state: u64) -> Self {
134        self.random_state = Some(random_state);
135        self
136    }
137
138    /// Get the breakdown point achieved by the fitted estimator
139    pub fn breakdown_point(&self) -> Option<Float> {
140        self.breakdown_point_
141    }
142
143    /// Get the outlier mask (available after fitting with outlier-resistant strategy)
144    pub fn outlier_mask(&self) -> Option<&Array1<bool>> {
145        self.outlier_mask_.as_ref()
146    }
147
148    /// Get the M-estimator weights (available after fitting with influence-resistant strategy)
149    pub fn m_weights(&self) -> Option<&Array1<Float>> {
150        self.m_weights_.as_ref()
151    }
152}
153
154impl Default for RobustDummyRegressor {
155    fn default() -> Self {
156        Self::new(RobustStrategy::TrimmedMean {
157            trim_proportion: 0.1,
158        })
159    }
160}
161
162impl Estimator for RobustDummyRegressor {
163    type Config = ();
164    type Error = sklears_core::error::SklearsError;
165    type Float = Float;
166
167    fn config(&self) -> &Self::Config {
168        &()
169    }
170}
171
172impl Fit<Features, Array1<Float>> for RobustDummyRegressor {
173    type Fitted = RobustDummyRegressor<sklears_core::traits::Trained>;
174
175    fn fit(self, x: &Features, y: &Array1<Float>) -> Result<Self::Fitted> {
176        if x.is_empty() || y.is_empty() {
177            return Err(sklears_core::error::SklearsError::InvalidInput(
178                "Input cannot be empty".to_string(),
179            ));
180        }
181
182        if x.nrows() != y.len() {
183            return Err(sklears_core::error::SklearsError::InvalidInput(
184                "Number of samples in X and y must be equal".to_string(),
185            ));
186        }
187
188        let mut fitted = RobustDummyRegressor {
189            strategy: self.strategy.clone(),
190            random_state: self.random_state,
191            robust_location_: None,
192            robust_scale_: None,
193            outlier_mask_: None,
194            clean_data_: None,
195            breakdown_point_: None,
196            m_weights_: None,
197            _state: std::marker::PhantomData,
198        };
199
200        match &self.strategy {
201            RobustStrategy::OutlierResistant {
202                contamination,
203                detection_method,
204            } => {
205                fitted.fit_outlier_resistant(y, *contamination, detection_method)?;
206            }
207            RobustStrategy::TrimmedMean { trim_proportion } => {
208                fitted.fit_trimmed_mean(y, *trim_proportion)?;
209            }
210            RobustStrategy::RobustScale {
211                scale_estimator,
212                location_estimator,
213            } => {
214                fitted.fit_robust_scale(y, scale_estimator, location_estimator)?;
215            }
216            RobustStrategy::BreakdownPoint { breakdown_point } => {
217                fitted.fit_breakdown_point(y, *breakdown_point)?;
218            }
219            RobustStrategy::InfluenceResistant {
220                huber_delta,
221                max_iter,
222                tolerance,
223            } => {
224                fitted.fit_influence_resistant(y, *huber_delta, *max_iter, *tolerance)?;
225            }
226        }
227
228        Ok(fitted)
229    }
230}
231
232impl RobustDummyRegressor<sklears_core::traits::Trained> {
233    /// Get the breakdown point achieved by the fitted estimator
234    pub fn breakdown_point(&self) -> Option<Float> {
235        self.breakdown_point_
236    }
237
238    /// Get the outlier mask (available after fitting with outlier-resistant strategy)
239    pub fn outlier_mask(&self) -> Option<&Array1<bool>> {
240        self.outlier_mask_.as_ref()
241    }
242
243    /// Get the M-estimator weights (available after fitting with influence-resistant strategy)
244    pub fn m_weights(&self) -> Option<&Array1<Float>> {
245        self.m_weights_.as_ref()
246    }
247    /// Fit outlier-resistant strategy
248    fn fit_outlier_resistant(
249        &mut self,
250        y: &Array1<Float>,
251        contamination: Float,
252        detection_method: &OutlierDetectionMethod,
253    ) -> Result<()> {
254        let outlier_mask = self.detect_outliers(y, detection_method, contamination)?;
255        let clean_data: Array1<Float> = y
256            .iter()
257            .zip(outlier_mask.iter())
258            .filter_map(|(&value, &is_outlier)| if !is_outlier { Some(value) } else { None })
259            .collect();
260
261        if clean_data.is_empty() {
262            return Err(sklears_core::error::SklearsError::InvalidInput(
263                "All data points detected as outliers".to_string(),
264            ));
265        }
266
267        let location = self.compute_median(&clean_data);
268        let scale = self.compute_mad(&clean_data, location);
269
270        // Calculate actual breakdown point achieved
271        let n_outliers = outlier_mask.iter().filter(|&&x| x).count();
272        let breakdown_point = n_outliers as Float / y.len() as Float;
273
274        self.robust_location_ = Some(location);
275        self.robust_scale_ = Some(scale);
276        self.outlier_mask_ = Some(outlier_mask);
277        self.clean_data_ = Some(clean_data);
278        self.breakdown_point_ = Some(breakdown_point);
279
280        Ok(())
281    }
282
283    /// Fit trimmed mean strategy
284    fn fit_trimmed_mean(&mut self, y: &Array1<Float>, trim_proportion: Float) -> Result<()> {
285        if !(0.0..0.5).contains(&trim_proportion) {
286            return Err(sklears_core::error::SklearsError::InvalidInput(
287                "Trim proportion must be between 0 and 0.5".to_string(),
288            ));
289        }
290
291        let mut sorted_y = y.to_vec();
292        sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
293
294        let n = sorted_y.len();
295        let trim_count = (n as Float * trim_proportion).floor() as usize;
296
297        if trim_count * 2 >= n {
298            return Err(sklears_core::error::SklearsError::InvalidInput(
299                "Too much trimming for dataset size".to_string(),
300            ));
301        }
302
303        let trimmed_data = &sorted_y[trim_count..(n - trim_count)];
304        let location = trimmed_data.iter().sum::<Float>() / trimmed_data.len() as Float;
305
306        // Robust scale using trimmed standard deviation
307        let mean = location;
308        let variance = trimmed_data
309            .iter()
310            .map(|&x| (x - mean).powi(2))
311            .sum::<Float>()
312            / (trimmed_data.len() - 1) as Float;
313        let scale = variance.sqrt();
314
315        let breakdown_point = trim_proportion;
316
317        self.robust_location_ = Some(location);
318        self.robust_scale_ = Some(scale);
319        self.breakdown_point_ = Some(breakdown_point);
320
321        Ok(())
322    }
323
324    /// Fit robust scale strategy
325    fn fit_robust_scale(
326        &mut self,
327        y: &Array1<Float>,
328        scale_estimator: &ScaleEstimator,
329        location_estimator: &LocationEstimator,
330    ) -> Result<()> {
331        let location = self.compute_robust_location(y, location_estimator)?;
332        let scale = self.compute_robust_scale(y, scale_estimator, location)?;
333
334        // Breakdown point depends on the estimators used
335        let breakdown_point = match (location_estimator, scale_estimator) {
336            (LocationEstimator::Median, ScaleEstimator::MAD) => 0.5,
337            (LocationEstimator::Median, ScaleEstimator::Qn) => 0.5,
338            (LocationEstimator::TrimmedMean { trim_proportion }, _) => *trim_proportion,
339            _ => 0.25, // Conservative estimate
340        };
341
342        self.robust_location_ = Some(location);
343        self.robust_scale_ = Some(scale);
344        self.breakdown_point_ = Some(breakdown_point);
345
346        Ok(())
347    }
348
349    /// Fit breakdown point strategy
350    fn fit_breakdown_point(&mut self, y: &Array1<Float>, target_breakdown: Float) -> Result<()> {
351        if target_breakdown <= 0.0 || target_breakdown >= 0.5 {
352            return Err(sklears_core::error::SklearsError::InvalidInput(
353                "Breakdown point must be between 0 and 0.5".to_string(),
354            ));
355        }
356
357        // Use trimmed mean with appropriate trimming to achieve target breakdown point
358        let trim_proportion = target_breakdown;
359        self.fit_trimmed_mean(y, trim_proportion)?;
360
361        Ok(())
362    }
363
364    /// Fit influence-resistant strategy using M-estimators
365    fn fit_influence_resistant(
366        &mut self,
367        y: &Array1<Float>,
368        huber_delta: Float,
369        max_iter: usize,
370        tolerance: Float,
371    ) -> Result<()> {
372        // Initialize with median
373        let mut location = self.compute_median(y);
374        let initial_scale = self.compute_mad(y, location);
375
376        let mut weights = Array1::ones(y.len());
377
378        // Iteratively reweighted least squares (IRLS) for Huber M-estimator
379        for _iter in 0..max_iter {
380            let old_location = location;
381
382            // Update weights based on Huber function
383            for i in 0..y.len() {
384                let residual = (y[i] - location).abs();
385                let scaled_residual = residual / initial_scale;
386
387                weights[i] = if scaled_residual <= huber_delta {
388                    1.0
389                } else {
390                    huber_delta / scaled_residual
391                };
392            }
393
394            // Update location estimate
395            let weighted_sum: Float = y.iter().zip(weights.iter()).map(|(&yi, &wi)| wi * yi).sum();
396            let weight_sum: Float = weights.sum();
397
398            if weight_sum > 0.0 {
399                location = weighted_sum / weight_sum;
400            }
401
402            // Check convergence
403            if (location - old_location).abs() < tolerance {
404                break;
405            }
406        }
407
408        // Compute robust scale with final weights
409        let weighted_variance: Float = y
410            .iter()
411            .zip(weights.iter())
412            .map(|(&yi, &wi)| wi * (yi - location).powi(2))
413            .sum();
414        let effective_sample_size: Float = weights.sum();
415
416        let scale = if effective_sample_size > 1.0 {
417            (weighted_variance / (effective_sample_size - 1.0)).sqrt()
418        } else {
419            initial_scale
420        };
421
422        // Breakdown point for Huber M-estimator is approximately 1/(2*delta + 1)
423        let breakdown_point = 1.0 / (2.0 * huber_delta + 1.0);
424
425        self.robust_location_ = Some(location);
426        self.robust_scale_ = Some(scale);
427        self.m_weights_ = Some(weights);
428        self.breakdown_point_ = Some(breakdown_point);
429
430        Ok(())
431    }
432
433    /// Detect outliers using various methods
434    fn detect_outliers(
435        &self,
436        y: &Array1<Float>,
437        method: &OutlierDetectionMethod,
438        _contamination: Float,
439    ) -> Result<Array1<bool>> {
440        let mut outlier_mask = Array1::from_elem(y.len(), false);
441
442        match method {
443            OutlierDetectionMethod::IQR { multiplier } => {
444                let mut sorted_y = y.to_vec();
445                sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
446
447                let n = sorted_y.len();
448                let q1_idx = n / 4;
449                let q3_idx = 3 * n / 4;
450
451                let q1 = sorted_y[q1_idx];
452                let q3 = sorted_y[q3_idx];
453                let iqr = q3 - q1;
454
455                let lower_bound = q1 - multiplier * iqr;
456                let upper_bound = q3 + multiplier * iqr;
457
458                for (i, &value) in y.iter().enumerate() {
459                    if value < lower_bound || value > upper_bound {
460                        outlier_mask[i] = true;
461                    }
462                }
463            }
464            OutlierDetectionMethod::ModifiedZScore { threshold } => {
465                let median = self.compute_median(y);
466                let mad = self.compute_mad(y, median);
467
468                if mad > 0.0 {
469                    for (i, &value) in y.iter().enumerate() {
470                        let modified_z_score = 0.6745 * (value - median).abs() / mad;
471                        if modified_z_score > *threshold {
472                            outlier_mask[i] = true;
473                        }
474                    }
475                }
476            }
477            OutlierDetectionMethod::MedianDistance { threshold } => {
478                let median = self.compute_median(y);
479                let distances: Array1<Float> =
480                    y.iter().map(|&value| (value - median).abs()).collect();
481                let distance_threshold = self.compute_median(&distances) * threshold;
482
483                for (i, &distance) in distances.iter().enumerate() {
484                    if distance > distance_threshold {
485                        outlier_mask[i] = true;
486                    }
487                }
488            }
489        }
490
491        Ok(outlier_mask)
492    }
493
494    /// Compute median
495    fn compute_median(&self, data: &Array1<Float>) -> Float {
496        let mut sorted_data = data.to_vec();
497        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
498
499        let n = sorted_data.len();
500        if n % 2 == 0 {
501            (sorted_data[n / 2 - 1] + sorted_data[n / 2]) / 2.0
502        } else {
503            sorted_data[n / 2]
504        }
505    }
506
507    /// Compute Median Absolute Deviation (MAD)
508    fn compute_mad(&self, data: &Array1<Float>, median: Float) -> Float {
509        let deviations: Array1<Float> = data.iter().map(|&x| (x - median).abs()).collect();
510        self.compute_median(&deviations) * 1.4826 // Consistency factor for normal distribution
511    }
512
513    /// Compute robust location estimate
514    fn compute_robust_location(
515        &self,
516        y: &Array1<Float>,
517        estimator: &LocationEstimator,
518    ) -> Result<Float> {
519        match estimator {
520            LocationEstimator::Median => Ok(self.compute_median(y)),
521            LocationEstimator::TrimmedMean { trim_proportion } => {
522                let mut sorted_y = y.to_vec();
523                sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
524
525                let n = sorted_y.len();
526                let trim_count = (n as Float * trim_proportion).floor() as usize;
527
528                if trim_count * 2 >= n {
529                    return Err(sklears_core::error::SklearsError::InvalidInput(
530                        "Too much trimming for dataset size".to_string(),
531                    ));
532                }
533
534                let trimmed_data = &sorted_y[trim_count..(n - trim_count)];
535                Ok(trimmed_data.iter().sum::<Float>() / trimmed_data.len() as Float)
536            }
537            LocationEstimator::Huber { delta } => {
538                // Simplified Huber M-estimator
539                let mut location = self.compute_median(y);
540                let scale = self.compute_mad(y, location);
541
542                for _iter in 0..10 {
543                    let old_location = location;
544                    let mut weighted_sum = 0.0;
545                    let mut weight_sum = 0.0;
546
547                    for &yi in y.iter() {
548                        let residual = (yi - location).abs();
549                        let weight = if residual <= delta * scale {
550                            1.0
551                        } else {
552                            delta * scale / residual
553                        };
554
555                        weighted_sum += weight * yi;
556                        weight_sum += weight;
557                    }
558
559                    if weight_sum > 0.0 {
560                        location = weighted_sum / weight_sum;
561                    }
562
563                    if (location - old_location).abs() < 1e-6 {
564                        break;
565                    }
566                }
567
568                Ok(location)
569            }
570            LocationEstimator::Biweight => {
571                // Biweight midvariance (simplified implementation)
572                let median = self.compute_median(y);
573                let mad = self.compute_mad(y, median);
574
575                if mad == 0.0 {
576                    return Ok(median);
577                }
578
579                let mut weighted_sum = 0.0;
580                let mut weight_sum = 0.0;
581
582                for &yi in y.iter() {
583                    let u = (yi - median) / (9.0 * mad);
584                    if u.abs() < 1.0 {
585                        let weight = (1.0 - u * u).powi(2);
586                        weighted_sum += weight * yi;
587                        weight_sum += weight;
588                    }
589                }
590
591                if weight_sum > 0.0 {
592                    Ok(weighted_sum / weight_sum)
593                } else {
594                    Ok(median)
595                }
596            }
597        }
598    }
599
600    /// Compute robust scale estimate
601    fn compute_robust_scale(
602        &self,
603        y: &Array1<Float>,
604        estimator: &ScaleEstimator,
605        location: Float,
606    ) -> Result<Float> {
607        match estimator {
608            ScaleEstimator::MAD => Ok(self.compute_mad(y, location)),
609            ScaleEstimator::IQR => {
610                let mut sorted_y = y.to_vec();
611                sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
612
613                let n = sorted_y.len();
614                let q1_idx = n / 4;
615                let q3_idx = 3 * n / 4;
616
617                let q1 = sorted_y[q1_idx];
618                let q3 = sorted_y[q3_idx];
619                Ok((q3 - q1) / 1.349) // Consistency factor for normal distribution
620            }
621            ScaleEstimator::Qn => {
622                // Simplified Qn estimator (first quartile of pairwise distances)
623                let mut pairwise_distances = Vec::new();
624                for i in 0..y.len() {
625                    for j in (i + 1)..y.len() {
626                        pairwise_distances.push((y[i] - y[j]).abs());
627                    }
628                }
629
630                if pairwise_distances.is_empty() {
631                    return Ok(0.0);
632                }
633
634                pairwise_distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
635                let q1_idx = pairwise_distances.len() / 4;
636                Ok(pairwise_distances[q1_idx] * 2.2219) // Consistency factor
637            }
638            ScaleEstimator::Sn => {
639                // Simplified Sn estimator (median of medians of distances)
640                let mut medians = Vec::new();
641
642                for i in 0..y.len() {
643                    let mut distances: Vec<Float> = y
644                        .iter()
645                        .enumerate()
646                        .filter(|(j, _)| *j != i)
647                        .map(|(_, &yj)| (y[i] - yj).abs())
648                        .collect();
649
650                    if !distances.is_empty() {
651                        distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
652                        let median_dist = if distances.len() % 2 == 0 {
653                            (distances[distances.len() / 2 - 1] + distances[distances.len() / 2])
654                                / 2.0
655                        } else {
656                            distances[distances.len() / 2]
657                        };
658                        medians.push(median_dist);
659                    }
660                }
661
662                if medians.is_empty() {
663                    return Ok(0.0);
664                }
665
666                medians.sort_by(|a, b| a.partial_cmp(b).unwrap());
667                let result = if medians.len() % 2 == 0 {
668                    (medians[medians.len() / 2 - 1] + medians[medians.len() / 2]) / 2.0
669                } else {
670                    medians[medians.len() / 2]
671                };
672
673                Ok(result * 1.1926) // Consistency factor
674            }
675        }
676    }
677}
678
679impl Predict<Features, Array1<Float>> for RobustDummyRegressor<sklears_core::traits::Trained> {
680    fn predict(&self, x: &Features) -> Result<Array1<Float>> {
681        if x.is_empty() {
682            return Err(sklears_core::error::SklearsError::InvalidInput(
683                "Input cannot be empty".to_string(),
684            ));
685        }
686
687        let n_samples = x.nrows();
688        let mut predictions = Array1::zeros(n_samples);
689        let location = self.robust_location_.unwrap_or(0.0);
690        let scale = self.robust_scale_.unwrap_or(1.0);
691
692        let mut rng = if let Some(seed) = self.random_state {
693            StdRng::seed_from_u64(seed)
694        } else {
695            StdRng::seed_from_u64(0)
696        };
697
698        match &self.strategy {
699            RobustStrategy::OutlierResistant { .. }
700            | RobustStrategy::TrimmedMean { .. }
701            | RobustStrategy::RobustScale { .. }
702            | RobustStrategy::BreakdownPoint { .. } => {
703                // For most strategies, predict the robust location estimate
704                predictions.fill(location);
705            }
706            RobustStrategy::InfluenceResistant { .. } => {
707                // For influence-resistant methods, we can add some controlled noise
708                // based on the robust scale estimate
709                if scale > 0.0 {
710                    let normal = Normal::new(location, scale * 0.1).unwrap();
711                    for i in 0..n_samples {
712                        predictions[i] = normal.sample(&mut rng);
713                    }
714                } else {
715                    predictions.fill(location);
716                }
717            }
718        }
719
720        Ok(predictions)
721    }
722}
723
724/// Robust dummy classifier
725#[derive(Debug, Clone)]
726pub struct RobustDummyClassifier<State = sklears_core::traits::Untrained> {
727    /// Strategy for robust predictions
728    pub strategy: RobustStrategy,
729    /// Random state for reproducible output
730    pub random_state: Option<u64>,
731
732    // Fitted parameters
733    pub(crate) robust_class_probs_: Option<Array1<Float>>,
734    pub(crate) classes_: Option<Array1<i32>>,
735    pub(crate) outlier_mask_: Option<Array1<bool>>,
736
737    /// Phantom data for state
738    pub(crate) _state: std::marker::PhantomData<State>,
739}
740
741impl RobustDummyClassifier {
742    /// Create a new robust dummy classifier
743    pub fn new(strategy: RobustStrategy) -> Self {
744        Self {
745            strategy,
746            random_state: None,
747            robust_class_probs_: None,
748            classes_: None,
749            outlier_mask_: None,
750            _state: std::marker::PhantomData,
751        }
752    }
753
754    /// Set the random state for reproducible output
755    pub fn with_random_state(mut self, random_state: u64) -> Self {
756        self.random_state = Some(random_state);
757        self
758    }
759
760    /// Get the outlier mask (available after fitting)
761    pub fn outlier_mask(&self) -> Option<&Array1<bool>> {
762        self.outlier_mask_.as_ref()
763    }
764}
765
766impl Default for RobustDummyClassifier {
767    fn default() -> Self {
768        Self::new(RobustStrategy::OutlierResistant {
769            contamination: 0.1,
770            detection_method: OutlierDetectionMethod::IQR { multiplier: 1.5 },
771        })
772    }
773}
774
775impl Estimator for RobustDummyClassifier {
776    type Config = ();
777    type Error = sklears_core::error::SklearsError;
778    type Float = Float;
779
780    fn config(&self) -> &Self::Config {
781        &()
782    }
783}
784
785impl Fit<Features, Array1<i32>> for RobustDummyClassifier {
786    type Fitted = RobustDummyClassifier<sklears_core::traits::Trained>;
787
788    fn fit(self, x: &Features, y: &Array1<i32>) -> Result<Self::Fitted> {
789        if x.is_empty() || y.is_empty() {
790            return Err(sklears_core::error::SklearsError::InvalidInput(
791                "Input cannot be empty".to_string(),
792            ));
793        }
794
795        if x.nrows() != y.len() {
796            return Err(sklears_core::error::SklearsError::InvalidInput(
797                "Number of samples in X and y must be equal".to_string(),
798            ));
799        }
800
801        // Get unique classes
802        let mut unique_classes = y.iter().cloned().collect::<Vec<_>>();
803        unique_classes.sort_unstable();
804        unique_classes.dedup();
805        let classes = Array1::from_vec(unique_classes);
806
807        // For classification, we'll use a simplified approach:
808        // Remove outliers and compute robust class probabilities
809        let mut fitted = RobustDummyClassifier {
810            strategy: self.strategy.clone(),
811            random_state: self.random_state,
812            robust_class_probs_: None,
813            classes_: Some(classes.clone()),
814            outlier_mask_: None,
815            _state: std::marker::PhantomData,
816        };
817
818        // Simple outlier detection based on class frequency
819        let mut class_counts = std::collections::HashMap::new();
820        for &class in y.iter() {
821            *class_counts.entry(class).or_insert(0) += 1;
822        }
823
824        // Identify outlier classes (those with very low frequency)
825        let total_samples = y.len() as f64;
826        let min_frequency = 0.05; // Classes with < 5% frequency are considered outliers
827        let mut outlier_mask = Array1::from_elem(y.len(), false);
828
829        for (i, &class) in y.iter().enumerate() {
830            let class_freq = *class_counts.get(&class).unwrap() as f64 / total_samples;
831            if class_freq < min_frequency {
832                outlier_mask[i] = true;
833            }
834        }
835
836        // Compute robust class probabilities excluding outliers
837        let mut robust_class_counts = std::collections::HashMap::new();
838        let mut total_clean_samples = 0;
839
840        for (i, &class) in y.iter().enumerate() {
841            if !outlier_mask[i] {
842                *robust_class_counts.entry(class).or_insert(0) += 1;
843                total_clean_samples += 1;
844            }
845        }
846
847        let mut class_probs = Array1::zeros(classes.len());
848        for (i, &class) in classes.iter().enumerate() {
849            let count = *robust_class_counts.get(&class).unwrap_or(&0);
850            class_probs[i] = if total_clean_samples > 0 {
851                count as Float / total_clean_samples as Float
852            } else {
853                1.0 / classes.len() as Float
854            };
855        }
856
857        fitted.robust_class_probs_ = Some(class_probs);
858        fitted.outlier_mask_ = Some(outlier_mask);
859
860        Ok(fitted)
861    }
862}
863
864impl RobustDummyClassifier<sklears_core::traits::Trained> {
865    /// Get the outlier mask (available after fitting)
866    pub fn outlier_mask(&self) -> Option<&Array1<bool>> {
867        self.outlier_mask_.as_ref()
868    }
869}
870
871impl Predict<Features, Array1<i32>> for RobustDummyClassifier<sklears_core::traits::Trained> {
872    fn predict(&self, x: &Features) -> Result<Array1<i32>> {
873        if x.is_empty() {
874            return Err(sklears_core::error::SklearsError::InvalidInput(
875                "Input cannot be empty".to_string(),
876            ));
877        }
878
879        let n_samples = x.nrows();
880        let mut predictions = Array1::zeros(n_samples);
881
882        let classes = self.classes_.as_ref().unwrap();
883        let class_probs = self.robust_class_probs_.as_ref().unwrap();
884
885        let mut rng = if let Some(seed) = self.random_state {
886            StdRng::seed_from_u64(seed)
887        } else {
888            StdRng::seed_from_u64(0)
889        };
890
891        // Sample from robust class distribution
892        for i in 0..n_samples {
893            let rand_val: Float = rng.gen();
894            let mut cumulative_prob = 0.0;
895            let mut selected_class = classes[0];
896
897            for (j, &class) in classes.iter().enumerate() {
898                cumulative_prob += class_probs[j];
899                if rand_val <= cumulative_prob {
900                    selected_class = class;
901                    break;
902                }
903            }
904            predictions[i] = selected_class;
905        }
906
907        Ok(predictions)
908    }
909}
910
911#[allow(non_snake_case)]
912#[cfg(test)]
913mod tests {
914    use super::*;
915    use approx::assert_abs_diff_eq;
916    use scirs2_core::ndarray::{array, Array2};
917
918    #[test]
919    fn test_outlier_resistant_regressor() {
920        let x = Array2::from_shape_vec(
921            (10, 2),
922            vec![
923                1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
924                100.0, 101.0, 102.0, 103.0, // Last two rows are outliers
925            ],
926        )
927        .unwrap();
928        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 100.0, 101.0]; // Last two are outliers
929
930        let regressor = RobustDummyRegressor::new(RobustStrategy::OutlierResistant {
931            contamination: 0.2,
932            detection_method: OutlierDetectionMethod::IQR { multiplier: 1.5 },
933        });
934
935        let fitted = regressor.fit(&x, &y).unwrap();
936        let predictions = fitted.predict(&x).unwrap();
937
938        assert_eq!(predictions.len(), 10);
939
940        // Check that some outliers were detected
941        let outlier_mask = fitted.outlier_mask().unwrap();
942        let n_outliers = outlier_mask.iter().filter(|&&x| x).count();
943        assert!(n_outliers > 0);
944    }
945
946    #[test]
947    fn test_trimmed_mean_regressor() {
948        let x = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
949        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 100.0, 101.0];
950
951        let regressor = RobustDummyRegressor::new(RobustStrategy::TrimmedMean {
952            trim_proportion: 0.2,
953        });
954
955        let fitted = regressor.fit(&x, &y).unwrap();
956        let predictions = fitted.predict(&x).unwrap();
957
958        assert_eq!(predictions.len(), 10);
959
960        // Trimmed mean should be more robust than regular mean
961        let robust_mean = predictions[0];
962        let regular_mean = y.mean().unwrap();
963        assert!(robust_mean < regular_mean); // Should be less affected by outliers
964    }
965
966    #[test]
967    fn test_robust_scale_regressor() {
968        let x = Array2::from_shape_vec((8, 2), (0..16).map(|x| x as f64).collect()).unwrap();
969        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 100.0]; // Last value is outlier
970
971        let regressor = RobustDummyRegressor::new(RobustStrategy::RobustScale {
972            scale_estimator: ScaleEstimator::MAD,
973            location_estimator: LocationEstimator::Median,
974        });
975
976        let fitted = regressor.fit(&x, &y).unwrap();
977        let predictions = fitted.predict(&x).unwrap();
978
979        assert_eq!(predictions.len(), 8);
980
981        // Should predict median value
982        let mut sorted_y = y.to_vec();
983        sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
984        let expected_median = (sorted_y[3] + sorted_y[4]) / 2.0; // 4.5
985        assert_abs_diff_eq!(predictions[0], expected_median, epsilon = 0.1);
986    }
987
988    #[test]
989    fn test_breakdown_point_regressor() {
990        let x = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
991        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
992
993        let regressor = RobustDummyRegressor::new(RobustStrategy::BreakdownPoint {
994            breakdown_point: 0.3,
995        });
996
997        let fitted = regressor.fit(&x, &y).unwrap();
998        let predictions = fitted.predict(&x).unwrap();
999
1000        assert_eq!(predictions.len(), 10);
1001        assert_eq!(fitted.breakdown_point().unwrap(), 0.3);
1002    }
1003
1004    #[test]
1005    fn test_influence_resistant_regressor() {
1006        let x = Array2::from_shape_vec((8, 2), (0..16).map(|x| x as f64).collect()).unwrap();
1007        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 100.0]; // Last value is outlier
1008
1009        let regressor = RobustDummyRegressor::new(RobustStrategy::InfluenceResistant {
1010            huber_delta: 1.345,
1011            max_iter: 50,
1012            tolerance: 1e-6,
1013        });
1014
1015        let fitted = regressor.fit(&x, &y).unwrap();
1016        let predictions = fitted.predict(&x).unwrap();
1017
1018        assert_eq!(predictions.len(), 8);
1019
1020        // Should have M-estimator weights
1021        let weights = fitted.m_weights().unwrap();
1022        assert_eq!(weights.len(), 8);
1023
1024        // Outlier should have lower weight
1025        assert!(weights[7] < weights[0]);
1026    }
1027
1028    #[test]
1029    fn test_robust_classifier() {
1030        let x = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
1031        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 3, 3]; // Class 3 is less frequent
1032
1033        let classifier = RobustDummyClassifier::new(RobustStrategy::OutlierResistant {
1034            contamination: 0.1,
1035            detection_method: OutlierDetectionMethod::IQR { multiplier: 1.5 },
1036        })
1037        .with_random_state(42);
1038
1039        let fitted = classifier.fit(&x, &y).unwrap();
1040        let predictions = fitted.predict(&x).unwrap();
1041
1042        assert_eq!(predictions.len(), 10);
1043
1044        // Check that predictions are valid classes
1045        let classes = fitted.classes_.as_ref().unwrap();
1046        for &pred in predictions.iter() {
1047            assert!(classes.iter().any(|&c| c == pred));
1048        }
1049    }
1050}