1use crate::imputation::OutlierAwareImputer;
18use crate::outlier_detection::{OutlierDetectionMethod, OutlierDetector};
19use crate::outlier_transformation::{OutlierTransformationMethod, OutlierTransformer};
20use crate::scaling::RobustScaler;
21use scirs2_core::ndarray::Array2;
22use sklears_core::{
23 error::{Result, SklearsError},
24 traits::{Fit, Trained, Transform, Untrained},
25 types::Float,
26};
27use std::marker::PhantomData;
28
29#[derive(Debug, Clone, Copy)]
31pub enum RobustStrategy {
32 Conservative,
34 Moderate,
36 Aggressive,
38 Custom,
40}
41
42#[derive(Debug, Clone)]
44pub struct RobustPreprocessorConfig {
45 pub strategy: RobustStrategy,
47 pub enable_outlier_detection: bool,
49 pub enable_outlier_transformation: bool,
51 pub enable_outlier_imputation: bool,
53 pub enable_robust_scaling: bool,
55 pub outlier_threshold: Option<Float>,
57 pub detection_method: OutlierDetectionMethod,
59 pub transformation_method: OutlierTransformationMethod,
61 pub contamination_rate: Float,
63 pub adaptive_thresholds: bool,
65 pub quantile_range: (Float, Float),
67 pub with_centering: bool,
69 pub with_scaling: bool,
71 pub parallel: bool,
73}
74
75impl Default for RobustPreprocessorConfig {
76 fn default() -> Self {
77 Self {
78 strategy: RobustStrategy::Moderate,
79 enable_outlier_detection: true,
80 enable_outlier_transformation: true,
81 enable_outlier_imputation: true,
82 enable_robust_scaling: true,
83 outlier_threshold: None, detection_method: OutlierDetectionMethod::MahalanobisDistance,
85 transformation_method: OutlierTransformationMethod::Log1p,
86 contamination_rate: 0.1,
87 adaptive_thresholds: true,
88 quantile_range: (25.0, 75.0),
89 with_centering: true,
90 with_scaling: true,
91 parallel: true,
92 }
93 }
94}
95
96impl RobustPreprocessorConfig {
97 pub fn conservative() -> Self {
99 Self {
100 strategy: RobustStrategy::Conservative,
101 outlier_threshold: Some(3.0),
102 contamination_rate: 0.05,
103 adaptive_thresholds: false,
104 enable_outlier_transformation: false,
105 transformation_method: OutlierTransformationMethod::RobustScale,
106 ..Self::default()
107 }
108 }
109
110 pub fn moderate() -> Self {
112 Self {
113 strategy: RobustStrategy::Moderate,
114 outlier_threshold: Some(2.5),
115 contamination_rate: 0.1,
116 adaptive_thresholds: true,
117 transformation_method: OutlierTransformationMethod::Log1p,
118 ..Self::default()
119 }
120 }
121
122 pub fn aggressive() -> Self {
124 Self {
125 strategy: RobustStrategy::Aggressive,
126 outlier_threshold: Some(2.0),
127 contamination_rate: 0.15,
128 adaptive_thresholds: true,
129 transformation_method: OutlierTransformationMethod::BoxCox,
130 ..Self::default()
131 }
132 }
133
134 pub fn custom() -> Self {
136 Self {
137 strategy: RobustStrategy::Custom,
138 adaptive_thresholds: true,
139 ..Self::default()
140 }
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct RobustPreprocessor<State = Untrained> {
147 config: RobustPreprocessorConfig,
148 state: PhantomData<State>,
149 outlier_detector_: Option<OutlierDetector<Trained>>,
151 outlier_transformer_: Option<OutlierTransformer<Trained>>,
152 outlier_imputer_: Option<OutlierAwareImputer>,
153 robust_scaler_: Option<RobustScaler>,
154 preprocessing_stats_: Option<RobustPreprocessingStats>,
156 n_features_in_: Option<usize>,
157}
158
159#[derive(Debug, Clone)]
161pub struct RobustPreprocessingStats {
162 pub outliers_per_feature: Vec<usize>,
164 pub outlier_percentages: Vec<Float>,
166 pub adaptive_thresholds: Vec<Float>,
168 pub robustness_score: Float,
170 pub missing_stats: MissingValueStats,
172 pub transformation_stats: TransformationStats,
174 pub quality_improvement: Float,
176}
177
178#[derive(Debug, Clone)]
180pub struct MissingValueStats {
181 pub missing_before: usize,
182 pub missing_after: usize,
183 pub imputation_success_rate: Float,
184}
185
186#[derive(Debug, Clone)]
188pub struct TransformationStats {
189 pub skewness_reduction: Vec<Float>,
191 pub kurtosis_reduction: Vec<Float>,
193 pub normality_improvement: Vec<Float>,
195}
196
197impl RobustPreprocessor<Untrained> {
198 pub fn new() -> Self {
200 Self {
201 config: RobustPreprocessorConfig::default(),
202 state: PhantomData,
203 outlier_detector_: None,
204 outlier_transformer_: None,
205 outlier_imputer_: None,
206 robust_scaler_: None,
207 preprocessing_stats_: None,
208 n_features_in_: None,
209 }
210 }
211
212 pub fn conservative() -> Self {
214 Self::new().config(RobustPreprocessorConfig::conservative())
215 }
216
217 pub fn moderate() -> Self {
219 Self::new().config(RobustPreprocessorConfig::moderate())
220 }
221
222 pub fn aggressive() -> Self {
224 Self::new().config(RobustPreprocessorConfig::aggressive())
225 }
226
227 pub fn custom() -> Self {
229 Self::new().config(RobustPreprocessorConfig::custom())
230 }
231
232 pub fn config(mut self, config: RobustPreprocessorConfig) -> Self {
234 self.config = config;
235 self
236 }
237
238 pub fn outlier_detection(mut self, enable: bool) -> Self {
240 self.config.enable_outlier_detection = enable;
241 self
242 }
243
244 pub fn outlier_transformation(mut self, enable: bool) -> Self {
246 self.config.enable_outlier_transformation = enable;
247 self
248 }
249
250 pub fn outlier_imputation(mut self, enable: bool) -> Self {
252 self.config.enable_outlier_imputation = enable;
253 self
254 }
255
256 pub fn robust_scaling(mut self, enable: bool) -> Self {
258 self.config.enable_robust_scaling = enable;
259 self
260 }
261
262 pub fn detection_method(mut self, method: OutlierDetectionMethod) -> Self {
264 self.config.detection_method = method;
265 self
266 }
267
268 pub fn transformation_method(mut self, method: OutlierTransformationMethod) -> Self {
270 self.config.transformation_method = method;
271 self
272 }
273
274 pub fn outlier_threshold(mut self, threshold: Float) -> Self {
276 self.config.outlier_threshold = Some(threshold);
277 self.config.adaptive_thresholds = false;
278 self
279 }
280
281 pub fn adaptive_thresholds(mut self, enable: bool) -> Self {
283 self.config.adaptive_thresholds = enable;
284 if enable {
285 self.config.outlier_threshold = None;
286 }
287 self
288 }
289
290 pub fn contamination_rate(mut self, rate: Float) -> Self {
292 self.config.contamination_rate = rate;
293 self
294 }
295
296 pub fn quantile_range(mut self, range: (Float, Float)) -> Self {
298 self.config.quantile_range = range;
299 self
300 }
301
302 pub fn with_centering(mut self, center: bool) -> Self {
304 self.config.with_centering = center;
305 self
306 }
307
308 pub fn with_scaling(mut self, scale: bool) -> Self {
310 self.config.with_scaling = scale;
311 self
312 }
313
314 pub fn parallel(mut self, enable: bool) -> Self {
316 self.config.parallel = enable;
317 self
318 }
319}
320
321impl Fit<Array2<Float>, ()> for RobustPreprocessor<Untrained> {
322 type Fitted = RobustPreprocessor<Trained>;
323
324 fn fit(mut self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
325 let (n_samples, n_features) = x.dim();
326
327 if n_samples == 0 || n_features == 0 {
328 return Err(SklearsError::InvalidInput(
329 "Input array is empty".to_string(),
330 ));
331 }
332
333 self.n_features_in_ = Some(n_features);
334
335 let mut stats = RobustPreprocessingStats {
337 outliers_per_feature: vec![0; n_features],
338 outlier_percentages: vec![0.0; n_features],
339 adaptive_thresholds: vec![0.0; n_features],
340 robustness_score: 0.0,
341 missing_stats: MissingValueStats {
342 missing_before: 0,
343 missing_after: 0,
344 imputation_success_rate: 0.0,
345 },
346 transformation_stats: TransformationStats {
347 skewness_reduction: vec![0.0; n_features],
348 kurtosis_reduction: vec![0.0; n_features],
349 normality_improvement: vec![0.0; n_features],
350 },
351 quality_improvement: 0.0,
352 };
353
354 stats.missing_stats.missing_before = x.iter().filter(|x| x.is_nan()).count();
356
357 let mut current_data = x.clone();
358
359 if self.config.enable_outlier_imputation {
361 let threshold = self.get_adaptive_threshold(¤t_data, 0.5)?;
362
363 let _imputer = OutlierAwareImputer::exclude_outliers(threshold, "mad")?
364 .base_strategy(crate::imputation::ImputationStrategy::Median);
365
366 for j in 0..current_data.ncols() {
369 let mut column: Vec<Float> = current_data.column(j).to_vec();
370 column.retain(|x| !x.is_nan()); if !column.is_empty() {
372 column.sort_by(|a, b| a.partial_cmp(b).unwrap());
373 let median = column[column.len() / 2];
374
375 for i in 0..current_data.nrows() {
377 if current_data[[i, j]].is_nan() {
378 current_data[[i, j]] = median;
379 }
380 }
381 }
382 }
383
384 stats.missing_stats.missing_after = current_data.iter().filter(|x| x.is_nan()).count();
388 stats.missing_stats.imputation_success_rate = 1.0
389 - (stats.missing_stats.missing_after as Float
390 / stats.missing_stats.missing_before.max(1) as Float);
391 }
392
393 if self.config.enable_outlier_detection {
395 let threshold = if self.config.adaptive_thresholds {
396 self.get_adaptive_threshold(¤t_data, self.config.contamination_rate)?
397 } else {
398 self.config.outlier_threshold.unwrap_or(2.5)
399 };
400
401 let detector = OutlierDetector::new()
402 .method(self.config.detection_method)
403 .threshold(threshold);
404
405 let fitted_detector = detector.fit(¤t_data, &())?;
406
407 let outlier_result = fitted_detector.detect_outliers(¤t_data)?;
409 stats.outliers_per_feature = vec![outlier_result.summary.n_outliers; n_features]; stats.outlier_percentages = vec![outlier_result.summary.outlier_fraction; n_features]; stats.adaptive_thresholds = vec![threshold; n_features];
413
414 self.outlier_detector_ = Some(fitted_detector);
415 }
416
417 if self.config.enable_outlier_transformation {
419 let transformer = OutlierTransformer::new()
420 .method(self.config.transformation_method)
421 .handle_negatives(true)
422 .feature_wise(true);
423
424 let fitted_transformer = transformer.fit(¤t_data, &())?;
425
426 let original_stats = self.compute_distribution_stats(¤t_data);
428
429 current_data = fitted_transformer.transform(¤t_data)?;
430
431 let transformed_stats = self.compute_distribution_stats(¤t_data);
433 stats.transformation_stats.skewness_reduction = original_stats
434 .iter()
435 .zip(transformed_stats.iter())
436 .map(|((orig_skew, _), (trans_skew, _))| {
437 (orig_skew.abs() - trans_skew.abs()).max(0.0)
438 })
439 .collect();
440
441 stats.transformation_stats.kurtosis_reduction = original_stats
442 .iter()
443 .zip(transformed_stats.iter())
444 .map(|((_, orig_kurt), (_, trans_kurt))| {
445 (orig_kurt.abs() - trans_kurt.abs()).max(0.0)
446 })
447 .collect();
448
449 self.outlier_transformer_ = Some(fitted_transformer);
450 }
451
452 if self.config.enable_robust_scaling {
454 let _scaler = RobustScaler::new();
455 }
462
463 stats.robustness_score = self.compute_robustness_score(&stats);
465
466 stats.quality_improvement = self.compute_quality_improvement(&stats);
468
469 self.preprocessing_stats_ = Some(stats);
470
471 Ok(RobustPreprocessor {
472 config: self.config,
473 state: PhantomData,
474 outlier_detector_: self.outlier_detector_,
475 outlier_transformer_: self.outlier_transformer_,
476 outlier_imputer_: self.outlier_imputer_,
477 robust_scaler_: self.robust_scaler_,
478 preprocessing_stats_: self.preprocessing_stats_,
479 n_features_in_: self.n_features_in_,
480 })
481 }
482}
483
484impl RobustPreprocessor<Untrained> {
485 fn get_adaptive_threshold(
487 &self,
488 data: &Array2<Float>,
489 contamination_rate: Float,
490 ) -> Result<Float> {
491 let valid_values: Vec<Float> = data.iter().filter(|x| x.is_finite()).copied().collect();
492
493 if valid_values.is_empty() {
494 return Ok(2.5); }
496
497 let mut sorted_values = valid_values.clone();
499 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
500
501 let median = if sorted_values.len() % 2 == 0 {
502 let mid = sorted_values.len() / 2;
503 (sorted_values[mid - 1] + sorted_values[mid]) / 2.0
504 } else {
505 sorted_values[sorted_values.len() / 2]
506 };
507
508 let deviations: Vec<Float> = valid_values.iter().map(|x| (x - median).abs()).collect();
510 let mut sorted_deviations = deviations;
511 sorted_deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
512
513 let _mad = if sorted_deviations.len() % 2 == 0 {
514 let mid = sorted_deviations.len() / 2;
515 (sorted_deviations[mid - 1] + sorted_deviations[mid]) / 2.0
516 } else {
517 sorted_deviations[sorted_deviations.len() / 2]
518 };
519
520 let base_threshold = 2.5;
523 let adaptation_factor = 1.0 - contamination_rate;
524 let threshold = base_threshold * adaptation_factor + 1.5 * contamination_rate;
525
526 Ok(threshold.clamp(1.5, 4.0)) }
528
529 fn compute_distribution_stats(&self, data: &Array2<Float>) -> Vec<(Float, Float)> {
531 (0..data.ncols())
532 .map(|j| {
533 let column = data.column(j);
534 let valid_values: Vec<Float> =
535 column.iter().filter(|x| x.is_finite()).copied().collect();
536
537 if valid_values.len() < 3 {
538 return (0.0, 0.0);
539 }
540
541 let mean = valid_values.iter().sum::<Float>() / valid_values.len() as Float;
542 let variance = valid_values
543 .iter()
544 .map(|x| (x - mean).powi(2))
545 .sum::<Float>()
546 / valid_values.len() as Float;
547 let std = variance.sqrt();
548
549 if std == 0.0 {
550 return (0.0, 0.0);
551 }
552
553 let skewness = valid_values
555 .iter()
556 .map(|x| ((x - mean) / std).powi(3))
557 .sum::<Float>()
558 / valid_values.len() as Float;
559
560 let kurtosis = valid_values
562 .iter()
563 .map(|x| ((x - mean) / std).powi(4))
564 .sum::<Float>()
565 / valid_values.len() as Float
566 - 3.0; (skewness, kurtosis)
569 })
570 .collect()
571 }
572
573 fn compute_robustness_score(&self, stats: &RobustPreprocessingStats) -> Float {
575 let mut score = 1.0;
576
577 let avg_outlier_rate = stats.outlier_percentages.iter().sum::<Float>()
579 / stats.outlier_percentages.len() as Float;
580 score *= (1.0 - avg_outlier_rate / 100.0).max(0.1);
581
582 score *= stats.missing_stats.imputation_success_rate;
584
585 let avg_skewness_reduction = stats
587 .transformation_stats
588 .skewness_reduction
589 .iter()
590 .sum::<Float>()
591 / stats.transformation_stats.skewness_reduction.len() as Float;
592 score *= (1.0 + avg_skewness_reduction / 10.0).min(1.5);
593
594 score.clamp(0.0, 1.0)
595 }
596
597 fn compute_quality_improvement(&self, stats: &RobustPreprocessingStats) -> Float {
599 let imputation_improvement = stats.missing_stats.imputation_success_rate * 0.3;
600 let outlier_improvement = (1.0
601 - stats.outlier_percentages.iter().sum::<Float>()
602 / (stats.outlier_percentages.len() as Float * 100.0))
603 * 0.4;
604 let transformation_improvement = (stats
605 .transformation_stats
606 .skewness_reduction
607 .iter()
608 .sum::<Float>()
609 / stats.transformation_stats.skewness_reduction.len() as Float)
610 * 0.3;
611
612 (imputation_improvement + outlier_improvement + transformation_improvement).clamp(0.0, 1.0)
613 }
614}
615
616impl Transform<Array2<Float>, Array2<Float>> for RobustPreprocessor<Trained> {
617 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
618 let (_n_samples, n_features) = x.dim();
619
620 if n_features != self.n_features_in().unwrap() {
621 return Err(SklearsError::FeatureMismatch {
622 expected: self.n_features_in().unwrap(),
623 actual: n_features,
624 });
625 }
626
627 let mut result = x.clone();
628
629 if let Some(ref imputer) = self.outlier_imputer_ {
633 result = imputer.transform(&result)?;
634 }
635
636 if let Some(ref transformer) = self.outlier_transformer_ {
638 result = transformer.transform(&result)?;
639 }
640
641 if let Some(ref scaler) = self.robust_scaler_ {
643 result = scaler.transform(&result)?;
644 }
645
646 Ok(result)
647 }
648}
649
650impl RobustPreprocessor<Trained> {
651 pub fn n_features_in(&self) -> Option<usize> {
653 self.n_features_in_
654 }
655
656 pub fn preprocessing_stats(&self) -> Option<&RobustPreprocessingStats> {
658 self.preprocessing_stats_.as_ref()
659 }
660
661 pub fn outlier_detector(&self) -> Option<&OutlierDetector<Trained>> {
663 self.outlier_detector_.as_ref()
664 }
665
666 pub fn outlier_transformer(&self) -> Option<&OutlierTransformer<Trained>> {
668 self.outlier_transformer_.as_ref()
669 }
670
671 pub fn outlier_imputer(&self) -> Option<&OutlierAwareImputer> {
673 self.outlier_imputer_.as_ref()
674 }
675
676 pub fn robust_scaler(&self) -> Option<&RobustScaler> {
678 self.robust_scaler_.as_ref()
679 }
680
681 pub fn preprocessing_report(&self) -> Result<String> {
683 let stats = self.preprocessing_stats_.as_ref().ok_or_else(|| {
684 SklearsError::InvalidInput("No preprocessing statistics available".to_string())
685 })?;
686
687 let mut report = String::new();
688
689 report.push_str("=== Robust Preprocessing Report ===\n\n");
690
691 report.push_str(&format!(
693 "Robustness Score: {:.3}\n",
694 stats.robustness_score
695 ));
696 report.push_str(&format!(
697 "Quality Improvement: {:.3}\n",
698 stats.quality_improvement
699 ));
700 report.push('\n');
701
702 report.push_str("=== Missing Value Handling ===\n");
704 report.push_str(&format!(
705 "Missing values before: {}\n",
706 stats.missing_stats.missing_before
707 ));
708 report.push_str(&format!(
709 "Missing values after: {}\n",
710 stats.missing_stats.missing_after
711 ));
712 report.push_str(&format!(
713 "Imputation success rate: {:.1}%\n",
714 stats.missing_stats.imputation_success_rate * 100.0
715 ));
716 report.push('\n');
717
718 if !stats.outliers_per_feature.is_empty() {
720 report.push_str("=== Outlier Detection ===\n");
721 for (i, (&count, &percentage)) in stats
722 .outliers_per_feature
723 .iter()
724 .zip(stats.outlier_percentages.iter())
725 .enumerate()
726 {
727 report.push_str(&format!(
728 "Feature {}: {} outliers ({:.1}%)\n",
729 i, count, percentage
730 ));
731 }
732 report.push('\n');
733 }
734
735 if !stats.transformation_stats.skewness_reduction.is_empty() {
737 report.push_str("=== Transformation Effectiveness ===\n");
738 for (i, (&skew_red, &kurt_red)) in stats
739 .transformation_stats
740 .skewness_reduction
741 .iter()
742 .zip(stats.transformation_stats.kurtosis_reduction.iter())
743 .enumerate()
744 {
745 report.push_str(&format!(
746 "Feature {}: Skewness reduction: {:.3}, Kurtosis reduction: {:.3}\n",
747 i, skew_red, kurt_red
748 ));
749 }
750 report.push('\n');
751 }
752
753 report.push_str("=== Configuration ===\n");
755 report.push_str(&format!("Strategy: {:?}\n", self.config.strategy));
756 report.push_str(&format!(
757 "Outlier detection: {}\n",
758 self.config.enable_outlier_detection
759 ));
760 report.push_str(&format!(
761 "Outlier transformation: {}\n",
762 self.config.enable_outlier_transformation
763 ));
764 report.push_str(&format!(
765 "Outlier imputation: {}\n",
766 self.config.enable_outlier_imputation
767 ));
768 report.push_str(&format!(
769 "Robust scaling: {}\n",
770 self.config.enable_robust_scaling
771 ));
772 report.push_str(&format!(
773 "Adaptive thresholds: {}\n",
774 self.config.adaptive_thresholds
775 ));
776
777 Ok(report)
778 }
779
780 pub fn is_effective(&self) -> bool {
782 if let Some(stats) = &self.preprocessing_stats_ {
783 stats.robustness_score > 0.7 && stats.quality_improvement > 0.5
784 } else {
785 false
786 }
787 }
788
789 pub fn get_recommendations(&self) -> Vec<String> {
791 let mut recommendations = Vec::new();
792
793 if let Some(stats) = &self.preprocessing_stats_ {
794 if stats.robustness_score < 0.5 {
795 recommendations
796 .push("Consider using a more aggressive robust strategy".to_string());
797 }
798
799 let avg_outlier_rate = stats.outlier_percentages.iter().sum::<Float>()
800 / stats.outlier_percentages.len() as Float;
801 if avg_outlier_rate > 20.0 {
802 recommendations.push(
803 "High outlier rate detected - consider additional data cleaning".to_string(),
804 );
805 }
806
807 if stats.missing_stats.imputation_success_rate < 0.8 {
808 recommendations.push(
809 "Low imputation success rate - consider alternative imputation strategies"
810 .to_string(),
811 );
812 }
813
814 let avg_skewness_reduction = stats
815 .transformation_stats
816 .skewness_reduction
817 .iter()
818 .sum::<Float>()
819 / stats.transformation_stats.skewness_reduction.len() as Float;
820 if avg_skewness_reduction < 0.1 {
821 recommendations.push("Low transformation effectiveness - consider alternative transformation methods".to_string());
822 }
823
824 if stats.quality_improvement < 0.3 {
825 recommendations.push(
826 "Low overall quality improvement - consider reviewing preprocessing pipeline"
827 .to_string(),
828 );
829 }
830 }
831
832 if recommendations.is_empty() {
833 recommendations
834 .push("Preprocessing appears effective - no specific recommendations".to_string());
835 }
836
837 recommendations
838 }
839}
840
841impl Default for RobustPreprocessor<Untrained> {
842 fn default() -> Self {
843 Self::new()
844 }
845}
846
847#[allow(non_snake_case)]
848#[cfg(test)]
849mod tests {
850 use super::*;
851 use scirs2_core::ndarray::Array2;
852
853 #[test]
854 fn test_robust_preprocessor_creation() {
855 let preprocessor = RobustPreprocessor::new();
856 assert_eq!(
857 preprocessor.config.strategy as u8,
858 RobustStrategy::Moderate as u8
859 );
860 assert!(preprocessor.config.enable_outlier_detection);
861 assert!(preprocessor.config.enable_robust_scaling);
862 }
863
864 #[test]
865 fn test_robust_preprocessor_conservative() {
866 let preprocessor = RobustPreprocessor::conservative();
867 assert_eq!(
868 preprocessor.config.strategy as u8,
869 RobustStrategy::Conservative as u8
870 );
871 assert_eq!(preprocessor.config.contamination_rate, 0.05);
872 assert!(!preprocessor.config.adaptive_thresholds);
873 }
874
875 #[test]
876 fn test_robust_preprocessor_aggressive() {
877 let preprocessor = RobustPreprocessor::aggressive();
878 assert_eq!(
879 preprocessor.config.strategy as u8,
880 RobustStrategy::Aggressive as u8
881 );
882 assert_eq!(preprocessor.config.contamination_rate, 0.15);
883 assert_eq!(preprocessor.config.outlier_threshold, Some(2.0));
884 }
885
886 #[test]
887 fn test_robust_preprocessor_fit_transform() {
888 let data = Array2::from_shape_vec(
889 (10, 2),
890 vec![
891 1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 6.0, 60.0, 7.0, 70.0, 8.0, 80.0, 100.0,
893 1000.0, 9.0, 90.0,
895 ],
896 )
897 .unwrap();
898
899 let preprocessor = RobustPreprocessor::moderate();
900 let fitted = preprocessor.fit(&data, &()).unwrap();
901 let result = fitted.transform(&data).unwrap();
902
903 assert_eq!(result.dim(), data.dim());
904
905 assert!(
907 fitted.is_effective() || fitted.preprocessing_stats().unwrap().robustness_score > 0.3
908 );
909 }
910
911 #[test]
912 fn test_robust_preprocessor_with_missing_values() {
913 let data = Array2::from_shape_vec(
914 (8, 2),
915 vec![
916 1.0,
917 10.0,
918 2.0,
919 Float::NAN, 3.0,
921 30.0,
922 Float::NAN,
923 40.0, 5.0,
925 50.0,
926 100.0,
927 1000.0, 7.0,
929 70.0,
930 8.0,
931 80.0,
932 ],
933 )
934 .unwrap();
935
936 let preprocessor = RobustPreprocessor::moderate()
937 .outlier_imputation(false) .outlier_transformation(false); let fitted = preprocessor.fit(&data, &()).unwrap();
941 let result = fitted.transform(&data).unwrap();
942
943 assert_eq!(result.dim(), data.dim());
944
945 let missing_before = data.iter().filter(|x| x.is_nan()).count();
947 let missing_after = result.iter().filter(|x| x.is_nan()).count();
948 assert_eq!(missing_after, missing_before); let stats = fitted.preprocessing_stats().unwrap();
951 assert!(stats.robustness_score >= 0.0);
954 }
955
956 #[test]
957 fn test_robust_preprocessor_configuration() {
958 let preprocessor = RobustPreprocessor::new()
959 .outlier_detection(false)
960 .robust_scaling(true)
961 .outlier_threshold(2.0)
962 .contamination_rate(0.05);
963
964 assert!(!preprocessor.config.enable_outlier_detection);
965 assert!(preprocessor.config.enable_robust_scaling);
966 assert_eq!(preprocessor.config.outlier_threshold, Some(2.0));
967 assert_eq!(preprocessor.config.contamination_rate, 0.05);
968 }
969
970 #[test]
971 fn test_adaptive_threshold_computation() {
972 let data = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]).unwrap();
973
974 let preprocessor = RobustPreprocessor::new();
975 let threshold = preprocessor.get_adaptive_threshold(&data, 0.1).unwrap();
976
977 assert!(threshold >= 1.5 && threshold <= 4.0);
978 }
979
980 #[test]
981 fn test_preprocessing_report() {
982 let data = Array2::from_shape_vec(
983 (6, 2),
984 vec![
985 1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0,
986 1000.0, ],
988 )
989 .unwrap();
990
991 let preprocessor = RobustPreprocessor::moderate();
992 let fitted = preprocessor.fit(&data, &()).unwrap();
993
994 let report = fitted.preprocessing_report().unwrap();
995 assert!(report.contains("Robust Preprocessing Report"));
996 assert!(report.contains("Robustness Score"));
997 assert!(report.contains("Quality Improvement"));
998 }
999
1000 #[test]
1001 fn test_recommendations() {
1002 let data = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1003
1004 let preprocessor = RobustPreprocessor::conservative();
1005 let fitted = preprocessor.fit(&data, &()).unwrap();
1006
1007 let recommendations = fitted.get_recommendations();
1008 assert!(!recommendations.is_empty());
1009 }
1010
1011 #[test]
1012 fn test_robust_preprocessor_error_handling() {
1013 let preprocessor = RobustPreprocessor::new();
1014
1015 let empty_data = Array2::from_shape_vec((0, 0), vec![]).unwrap();
1017 assert!(preprocessor.fit(&empty_data, &()).is_err());
1018 }
1019
1020 #[test]
1021 fn test_feature_mismatch() {
1022 let data =
1023 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1024 let wrong_data = Array2::from_shape_vec(
1025 (4, 3),
1026 vec![
1027 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1028 ],
1029 )
1030 .unwrap();
1031
1032 let preprocessor = RobustPreprocessor::moderate();
1033 let fitted = preprocessor.fit(&data, &()).unwrap();
1034
1035 assert!(fitted.transform(&wrong_data).is_err());
1036 }
1037}