sklears_compose/
validation.rs

1//! Comprehensive Pipeline Validation Framework
2//!
3//! Advanced validation system for pipeline structure, data compatibility,
4//! statistical properties, and performance validation.
5
6use scirs2_core::ndarray::{s, Array1, ArrayView1, ArrayView2};
7use scirs2_core::random::{thread_rng, Rng};
8use sklears_core::{error::Result as SklResult, types::Float};
9use std::collections::{HashMap, HashSet};
10use std::time::{Duration, Instant};
11
12use crate::Pipeline;
13
14/// Comprehensive pipeline validation framework
15pub struct ComprehensivePipelineValidator {
16    /// Data validation settings
17    pub data_validator: DataValidator,
18    /// Structure validation settings
19    pub structure_validator: StructureValidator,
20    /// Statistical validation settings
21    pub statistical_validator: StatisticalValidator,
22    /// Performance validation settings
23    pub performance_validator: PerformanceValidator,
24    /// Cross-validation settings
25    pub cross_validator: CrossValidator,
26    /// Robustness testing settings
27    pub robustness_tester: RobustnessTester,
28    /// Output detailed validation report
29    pub verbose: bool,
30}
31
32/// Data quality and compatibility validation
33pub struct DataValidator {
34    /// Check for missing values (NaN)
35    pub check_missing_values: bool,
36    /// Check for infinite values
37    pub check_infinite_values: bool,
38    /// Check data type consistency
39    pub check_data_types: bool,
40    /// Check feature scaling consistency
41    pub check_feature_scaling: bool,
42    /// Check data distribution properties
43    pub check_distributions: bool,
44    /// Maximum allowed proportion of missing values
45    pub max_missing_ratio: f64,
46    /// Check for duplicate samples
47    pub check_duplicates: bool,
48    /// Check for outliers using IQR method
49    pub check_outliers: bool,
50    /// IQR multiplier for outlier detection
51    pub outlier_iqr_multiplier: f64,
52}
53
54/// Pipeline structure and component validation
55pub struct StructureValidator {
56    /// Validate component compatibility
57    pub check_component_compatibility: bool,
58    /// Check data flow between components
59    pub check_data_flow: bool,
60    /// Validate parameter consistency
61    pub check_parameter_consistency: bool,
62    /// Check for circular dependencies
63    pub check_circular_dependencies: bool,
64    /// Validate resource requirements
65    pub check_resource_requirements: bool,
66    /// Maximum allowed pipeline depth
67    pub max_pipeline_depth: usize,
68    /// Maximum allowed number of components
69    pub max_components: usize,
70}
71
72/// Statistical validation and testing
73pub struct StatisticalValidator {
74    /// Perform statistical significance tests
75    pub statistical_tests: bool,
76    /// Test for data leakage
77    pub check_data_leakage: bool,
78    /// Validate feature importance
79    pub check_feature_importance: bool,
80    /// Test prediction consistency
81    pub check_prediction_consistency: bool,
82    /// Minimum sample size for statistical tests
83    pub min_sample_size: usize,
84    /// Alpha level for statistical tests
85    pub alpha: f64,
86    /// Test for concept drift
87    pub check_concept_drift: bool,
88}
89
90/// Performance validation and benchmarking
91pub struct PerformanceValidator {
92    /// Check training time limits
93    pub check_training_time: bool,
94    /// Check prediction time limits
95    pub check_prediction_time: bool,
96    /// Check memory usage limits
97    pub check_memory_usage: bool,
98    /// Maximum training time (seconds)
99    pub max_training_time: f64,
100    /// Maximum prediction time per sample (milliseconds)
101    pub max_prediction_time_per_sample: f64,
102    /// Maximum memory usage (MB)
103    pub max_memory_usage: f64,
104    /// Check scalability properties
105    pub check_scalability: bool,
106}
107
108/// Cross-validation framework
109pub struct CrossValidator {
110    /// Number of cross-validation folds
111    pub cv_folds: usize,
112    /// Stratified cross-validation for classification
113    pub stratified: bool,
114    /// Time series cross-validation
115    pub time_series_cv: bool,
116    /// Leave-one-out cross-validation
117    pub leave_one_out: bool,
118    /// Bootstrap validation
119    pub bootstrap: bool,
120    /// Number of bootstrap samples
121    pub n_bootstrap: usize,
122    /// Random state for reproducibility
123    pub random_state: Option<u64>,
124}
125
126/// Robustness testing framework
127pub struct RobustnessTester {
128    /// Test with noisy data
129    pub test_noise_robustness: bool,
130    /// Test with missing data
131    pub test_missing_data_robustness: bool,
132    /// Test with adversarial examples
133    pub test_adversarial_robustness: bool,
134    /// Test with distribution shift
135    pub test_distribution_shift: bool,
136    /// Noise levels to test
137    pub noise_levels: Vec<f64>,
138    /// Missing data ratios to test
139    pub missing_ratios: Vec<f64>,
140    /// Number of robustness test iterations
141    pub n_robustness_tests: usize,
142}
143
144/// Validation results summary
145#[derive(Debug, Clone)]
146pub struct ValidationReport {
147    /// Overall validation status
148    pub passed: bool,
149    /// Data validation results
150    pub data_validation: DataValidationResult,
151    /// Structure validation results
152    pub structure_validation: StructureValidationResult,
153    /// Statistical validation results
154    pub statistical_validation: StatisticalValidationResult,
155    /// Performance validation results
156    pub performance_validation: PerformanceValidationResult,
157    /// Cross-validation results
158    pub cross_validation: CrossValidationResult,
159    /// Robustness testing results
160    pub robustness_testing: RobustnessTestResult,
161    /// Detailed validation messages
162    pub messages: Vec<ValidationMessage>,
163    /// Total validation time
164    pub validation_time: Duration,
165}
166
167/// Data validation results
168#[derive(Debug, Clone)]
169pub struct DataValidationResult {
170    pub passed: bool,
171    pub missing_values_count: usize,
172    pub infinite_values_count: usize,
173    pub duplicate_samples_count: usize,
174    pub outliers_count: usize,
175    pub data_quality_score: f64,
176}
177
178/// Structure validation results
179#[derive(Debug, Clone)]
180pub struct StructureValidationResult {
181    pub passed: bool,
182    pub component_compatibility: bool,
183    pub data_flow_valid: bool,
184    pub circular_dependencies: bool,
185    pub pipeline_depth: usize,
186    pub component_count: usize,
187}
188
189/// Statistical validation results
190#[derive(Debug, Clone)]
191pub struct StatisticalValidationResult {
192    pub passed: bool,
193    pub statistical_significance: bool,
194    pub data_leakage_detected: bool,
195    pub prediction_consistency: f64,
196    pub concept_drift_detected: bool,
197    pub p_values: HashMap<String, f64>,
198}
199
200/// Performance validation results
201#[derive(Debug, Clone)]
202pub struct PerformanceValidationResult {
203    pub passed: bool,
204    pub training_time: f64,
205    pub prediction_time_per_sample: f64,
206    pub memory_usage: f64,
207    pub scalability_score: f64,
208}
209
210/// Cross-validation results
211#[derive(Debug, Clone)]
212pub struct CrossValidationResult {
213    pub passed: bool,
214    pub cv_scores: Vec<f64>,
215    pub mean_score: f64,
216    pub std_score: f64,
217    pub bootstrap_scores: Vec<f64>,
218    pub confidence_interval: (f64, f64),
219}
220
221/// Robustness testing results
222#[derive(Debug, Clone)]
223pub struct RobustnessTestResult {
224    pub passed: bool,
225    pub noise_robustness_scores: HashMap<String, f64>,
226    pub missing_data_robustness_scores: HashMap<String, f64>,
227    pub adversarial_robustness_score: f64,
228    pub distribution_shift_robustness: f64,
229}
230
231/// Validation message types
232#[derive(Debug, Clone)]
233pub struct ValidationMessage {
234    pub level: MessageLevel,
235    pub category: String,
236    pub message: String,
237    pub component: Option<String>,
238}
239
240/// Message severity levels
241#[derive(Debug, Clone)]
242pub enum MessageLevel {
243    /// Info
244    Info,
245    /// Warning
246    Warning,
247    /// Error
248    Error,
249    /// Critical
250    Critical,
251}
252
253impl Default for ComprehensivePipelineValidator {
254    fn default() -> Self {
255        Self::new()
256    }
257}
258
259impl ComprehensivePipelineValidator {
260    /// Create a new comprehensive pipeline validator with default settings
261    #[must_use]
262    pub fn new() -> Self {
263        Self {
264            data_validator: DataValidator::default(),
265            structure_validator: StructureValidator::default(),
266            statistical_validator: StatisticalValidator::default(),
267            performance_validator: PerformanceValidator::default(),
268            cross_validator: CrossValidator::default(),
269            robustness_tester: RobustnessTester::default(),
270            verbose: false,
271        }
272    }
273
274    /// Create a strict validator with all checks enabled
275    #[must_use]
276    pub fn strict() -> Self {
277        Self {
278            data_validator: DataValidator::strict(),
279            structure_validator: StructureValidator::strict(),
280            statistical_validator: StatisticalValidator::strict(),
281            performance_validator: PerformanceValidator::strict(),
282            cross_validator: CrossValidator::default(),
283            robustness_tester: RobustnessTester::comprehensive(),
284            verbose: true,
285        }
286    }
287
288    /// Create a fast validator with minimal checks for development
289    #[must_use]
290    pub fn fast() -> Self {
291        Self {
292            data_validator: DataValidator::basic(),
293            structure_validator: StructureValidator::basic(),
294            statistical_validator: StatisticalValidator::disabled(),
295            performance_validator: PerformanceValidator::basic(),
296            cross_validator: CrossValidator::fast(),
297            robustness_tester: RobustnessTester::disabled(),
298            verbose: false,
299        }
300    }
301
302    /// Run comprehensive validation on a pipeline
303    pub fn validate<S>(
304        &self,
305        pipeline: &Pipeline<S>,
306        x: &ArrayView2<'_, Float>,
307        y: Option<&ArrayView1<'_, Float>>,
308    ) -> SklResult<ValidationReport>
309    where
310        S: std::fmt::Debug,
311    {
312        let start_time = Instant::now();
313        let mut messages = Vec::new();
314        let mut overall_passed = true;
315
316        if self.verbose {
317            println!("Starting comprehensive pipeline validation...");
318        }
319
320        // Data validation
321        let data_validation = self.validate_data(x, y, &mut messages)?;
322        if !data_validation.passed {
323            overall_passed = false;
324        }
325
326        // Structure validation
327        let structure_validation = self.validate_structure(pipeline, &mut messages)?;
328        if !structure_validation.passed {
329            overall_passed = false;
330        }
331
332        // Statistical validation
333        let statistical_validation = self.validate_statistics(x, y, &mut messages)?;
334        if !statistical_validation.passed {
335            overall_passed = false;
336        }
337
338        // Performance validation
339        let performance_validation = self.validate_performance(pipeline, x, y, &mut messages)?;
340        if !performance_validation.passed {
341            overall_passed = false;
342        }
343
344        // Cross-validation
345        let cross_validation = self.run_cross_validation(pipeline, x, y, &mut messages)?;
346        if !cross_validation.passed {
347            overall_passed = false;
348        }
349
350        // Robustness testing
351        let robustness_testing = self.test_robustness(pipeline, x, y, &mut messages)?;
352        if !robustness_testing.passed {
353            overall_passed = false;
354        }
355
356        let validation_time = start_time.elapsed();
357
358        if self.verbose {
359            println!(
360                "Validation completed in {:.2}s. Status: {}",
361                validation_time.as_secs_f64(),
362                if overall_passed { "PASSED" } else { "FAILED" }
363            );
364        }
365
366        Ok(ValidationReport {
367            passed: overall_passed,
368            data_validation,
369            structure_validation,
370            statistical_validation,
371            performance_validation,
372            cross_validation,
373            robustness_testing,
374            messages,
375            validation_time,
376        })
377    }
378
379    fn validate_data(
380        &self,
381        x: &ArrayView2<'_, Float>,
382        y: Option<&ArrayView1<'_, Float>>,
383        messages: &mut Vec<ValidationMessage>,
384    ) -> SklResult<DataValidationResult> {
385        let mut passed = true;
386        let mut missing_count = 0;
387        let mut infinite_count = 0;
388        let mut duplicate_count = 0;
389        let mut outliers_count = 0;
390
391        if self.data_validator.check_missing_values {
392            missing_count = self.count_missing_values(x);
393            if missing_count > 0 {
394                let missing_ratio = missing_count as f64 / (x.nrows() * x.ncols()) as f64;
395                if missing_ratio > self.data_validator.max_missing_ratio {
396                    passed = false;
397                    messages.push(ValidationMessage {
398                        level: MessageLevel::Error,
399                        category: "Data Quality".to_string(),
400                        message: format!(
401                            "Missing values ratio ({:.3}) exceeds maximum allowed ({:.3})",
402                            missing_ratio, self.data_validator.max_missing_ratio
403                        ),
404                        component: None,
405                    });
406                }
407            }
408        }
409
410        if self.data_validator.check_infinite_values {
411            infinite_count = self.count_infinite_values(x);
412            if infinite_count > 0 {
413                passed = false;
414                messages.push(ValidationMessage {
415                    level: MessageLevel::Error,
416                    category: "Data Quality".to_string(),
417                    message: format!("Found {infinite_count} infinite values in input data"),
418                    component: None,
419                });
420            }
421        }
422
423        if self.data_validator.check_duplicates {
424            duplicate_count = self.count_duplicate_samples(x);
425            if duplicate_count > 0 {
426                messages.push(ValidationMessage {
427                    level: MessageLevel::Warning,
428                    category: "Data Quality".to_string(),
429                    message: format!("Found {duplicate_count} duplicate samples"),
430                    component: None,
431                });
432            }
433        }
434
435        if self.data_validator.check_outliers {
436            outliers_count = self.count_outliers(x, self.data_validator.outlier_iqr_multiplier);
437            if outliers_count > x.nrows() / 10 {
438                messages.push(ValidationMessage {
439                    level: MessageLevel::Warning,
440                    category: "Data Quality".to_string(),
441                    message: format!(
442                        "High number of outliers detected: {} ({}% of samples)",
443                        outliers_count,
444                        (outliers_count * 100) / x.nrows()
445                    ),
446                    component: None,
447                });
448            }
449        }
450
451        let data_quality_score = self.calculate_data_quality_score(
452            x.nrows() * x.ncols(),
453            missing_count,
454            infinite_count,
455            duplicate_count,
456            outliers_count,
457        );
458
459        Ok(DataValidationResult {
460            passed,
461            missing_values_count: missing_count,
462            infinite_values_count: infinite_count,
463            duplicate_samples_count: duplicate_count,
464            outliers_count,
465            data_quality_score,
466        })
467    }
468
469    fn validate_structure<S>(
470        &self,
471        pipeline: &Pipeline<S>,
472        messages: &mut Vec<ValidationMessage>,
473    ) -> SklResult<StructureValidationResult>
474    where
475        S: std::fmt::Debug,
476    {
477        let mut passed = true;
478        let component_compatibility = true;
479        let data_flow_valid = true;
480        let circular_dependencies = false; // Placeholder
481        let pipeline_depth = 1; // Placeholder - would need to analyze actual pipeline structure
482        let component_count = 1; // Placeholder
483
484        if self.structure_validator.check_component_compatibility {
485            // Placeholder for component compatibility checking
486            // Would analyze if transformer outputs match estimator inputs
487        }
488
489        if self.structure_validator.check_data_flow {
490            // Placeholder for data flow validation
491            // Would check if data shapes are compatible between pipeline steps
492        }
493
494        if pipeline_depth > self.structure_validator.max_pipeline_depth {
495            passed = false;
496            messages.push(ValidationMessage {
497                level: MessageLevel::Error,
498                category: "Structure".to_string(),
499                message: format!(
500                    "Pipeline depth ({}) exceeds maximum allowed ({})",
501                    pipeline_depth, self.structure_validator.max_pipeline_depth
502                ),
503                component: None,
504            });
505        }
506
507        Ok(StructureValidationResult {
508            passed,
509            component_compatibility,
510            data_flow_valid,
511            circular_dependencies,
512            pipeline_depth,
513            component_count,
514        })
515    }
516
517    fn validate_statistics(
518        &self,
519        x: &ArrayView2<'_, Float>,
520        y: Option<&ArrayView1<'_, Float>>,
521        messages: &mut Vec<ValidationMessage>,
522    ) -> SklResult<StatisticalValidationResult> {
523        let mut passed = true;
524        let mut p_values = HashMap::new();
525
526        if x.nrows() < self.statistical_validator.min_sample_size {
527            passed = false;
528            messages.push(ValidationMessage {
529                level: MessageLevel::Error,
530                category: "Statistics".to_string(),
531                message: format!(
532                    "Sample size ({}) below minimum required ({})",
533                    x.nrows(),
534                    self.statistical_validator.min_sample_size
535                ),
536                component: None,
537            });
538        }
539
540        // Perform comprehensive statistical tests
541        let statistical_significance = self.test_statistical_significance(x, y, &mut p_values)?;
542        let data_leakage_detected = if self.statistical_validator.check_data_leakage {
543            self.detect_data_leakage(x, y)?
544        } else {
545            false
546        };
547        let prediction_consistency = if self.statistical_validator.check_prediction_consistency {
548            self.calculate_prediction_consistency(x)?
549        } else {
550            1.0
551        };
552        let concept_drift_detected = if self.statistical_validator.check_concept_drift {
553            self.detect_concept_drift(x, y)?
554        } else {
555            false
556        };
557
558        // Update passed status based on actual test results
559        if !statistical_significance || data_leakage_detected || concept_drift_detected {
560            passed = false;
561        }
562        if prediction_consistency < 0.8 {
563            passed = false;
564        }
565
566        Ok(StatisticalValidationResult {
567            passed,
568            statistical_significance,
569            data_leakage_detected,
570            prediction_consistency,
571            concept_drift_detected,
572            p_values,
573        })
574    }
575
576    fn validate_performance<S>(
577        &self,
578        pipeline: &Pipeline<S>,
579        x: &ArrayView2<'_, Float>,
580        y: Option<&ArrayView1<'_, Float>>,
581        messages: &mut Vec<ValidationMessage>,
582    ) -> SklResult<PerformanceValidationResult>
583    where
584        S: std::fmt::Debug,
585    {
586        let mut passed = true;
587
588        // Measure training time (placeholder)
589        let training_time = 1.0; // Would measure actual training time
590        if self.performance_validator.check_training_time
591            && training_time > self.performance_validator.max_training_time
592        {
593            passed = false;
594            messages.push(ValidationMessage {
595                level: MessageLevel::Error,
596                category: "Performance".to_string(),
597                message: format!(
598                    "Training time ({:.2}s) exceeds maximum allowed ({:.2}s)",
599                    training_time, self.performance_validator.max_training_time
600                ),
601                component: None,
602            });
603        }
604
605        // Placeholder values - would measure actual performance
606        let prediction_time_per_sample = 0.1;
607        let memory_usage = 100.0;
608        let scalability_score = 0.8;
609
610        Ok(PerformanceValidationResult {
611            passed,
612            training_time,
613            prediction_time_per_sample,
614            memory_usage,
615            scalability_score,
616        })
617    }
618
619    fn run_cross_validation<S>(
620        &self,
621        pipeline: &Pipeline<S>,
622        x: &ArrayView2<'_, Float>,
623        y: Option<&ArrayView1<'_, Float>>,
624        messages: &mut Vec<ValidationMessage>,
625    ) -> SklResult<CrossValidationResult>
626    where
627        S: std::fmt::Debug,
628    {
629        if y.is_none() {
630            return Ok(CrossValidationResult {
631                passed: true,
632                cv_scores: vec![],
633                mean_score: 0.0,
634                std_score: 0.0,
635                bootstrap_scores: vec![],
636                confidence_interval: (0.0, 0.0),
637            });
638        }
639
640        let n_samples = x.nrows();
641        let fold_size = n_samples / self.cross_validator.cv_folds;
642        let mut cv_scores = Vec::new();
643
644        // Placeholder cross-validation implementation
645        for fold in 0..self.cross_validator.cv_folds {
646            let start_idx = fold * fold_size;
647            let end_idx = if fold == self.cross_validator.cv_folds - 1 {
648                n_samples
649            } else {
650                (fold + 1) * fold_size
651            };
652
653            // Would split data and evaluate pipeline
654            let score = 0.8 + thread_rng().gen::<f64>() * 0.2; // Placeholder
655            cv_scores.push(score);
656        }
657
658        let mean_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
659        let variance = cv_scores
660            .iter()
661            .map(|&x| (x - mean_score).powi(2))
662            .sum::<f64>()
663            / cv_scores.len() as f64;
664        let std_score = variance.sqrt();
665
666        let passed = std_score < 0.1; // Placeholder criterion
667
668        Ok(CrossValidationResult {
669            passed,
670            cv_scores,
671            mean_score,
672            std_score,
673            bootstrap_scores: vec![],
674            confidence_interval: (mean_score - 1.96 * std_score, mean_score + 1.96 * std_score),
675        })
676    }
677
678    fn test_robustness<S>(
679        &self,
680        pipeline: &Pipeline<S>,
681        x: &ArrayView2<'_, Float>,
682        y: Option<&ArrayView1<'_, Float>>,
683        messages: &mut Vec<ValidationMessage>,
684    ) -> SklResult<RobustnessTestResult>
685    where
686        S: std::fmt::Debug,
687    {
688        let mut noise_robustness_scores = HashMap::new();
689        let mut missing_data_robustness_scores = HashMap::new();
690
691        if self.robustness_tester.test_noise_robustness {
692            for &noise_level in &self.robustness_tester.noise_levels {
693                let score = self.test_noise_robustness(pipeline, x, y, noise_level)?;
694                noise_robustness_scores.insert(format!("noise_{noise_level}"), score);
695            }
696        }
697
698        if self.robustness_tester.test_missing_data_robustness {
699            for &missing_ratio in &self.robustness_tester.missing_ratios {
700                let score = self.test_missing_data_robustness(pipeline, x, y, missing_ratio)?;
701                missing_data_robustness_scores.insert(format!("missing_{missing_ratio}"), score);
702            }
703        }
704
705        let adversarial_robustness_score = 0.7; // Placeholder
706        let distribution_shift_robustness = 0.6; // Placeholder
707
708        let passed = noise_robustness_scores.values().all(|&score| score > 0.5)
709            && missing_data_robustness_scores
710                .values()
711                .all(|&score| score > 0.5);
712
713        Ok(RobustnessTestResult {
714            passed,
715            noise_robustness_scores,
716            missing_data_robustness_scores,
717            adversarial_robustness_score,
718            distribution_shift_robustness,
719        })
720    }
721
722    // Helper methods for data validation
723    fn count_missing_values(&self, x: &ArrayView2<'_, Float>) -> usize {
724        x.iter().filter(|&&val| val.is_nan()).count()
725    }
726
727    fn count_infinite_values(&self, x: &ArrayView2<'_, Float>) -> usize {
728        x.iter().filter(|&&val| val.is_infinite()).count()
729    }
730
731    fn count_duplicate_samples(&self, x: &ArrayView2<'_, Float>) -> usize {
732        // Simplified duplicate detection
733        let mut unique_rows = HashSet::new();
734        let mut duplicates = 0;
735
736        for row in x.rows() {
737            let row_vec: Vec<String> = row.iter().map(|&val| format!("{val:.6}")).collect();
738            let row_key = row_vec.join(",");
739
740            if !unique_rows.insert(row_key) {
741                duplicates += 1;
742            }
743        }
744
745        duplicates
746    }
747
748    fn count_outliers(&self, x: &ArrayView2<'_, Float>, iqr_multiplier: f64) -> usize {
749        let mut outliers = 0;
750
751        for col in x.columns() {
752            let mut sorted_col: Vec<Float> = col.to_vec();
753            // Filter out NaN values before sorting
754            sorted_col.retain(|x| !x.is_nan());
755            sorted_col.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
756
757            let n = sorted_col.len();
758            if n < 3 {
759                continue; // Need at least 3 points for meaningful outlier detection
760            }
761
762            // Use different quartile calculation for small datasets
763            let (q1, q3) = if n == 3 {
764                (sorted_col[0], sorted_col[2])
765            } else if n == 4 {
766                // For 4 points, use the middle two as Q1 and Q3 reference
767                (sorted_col[0], sorted_col[2]) // More conservative
768            } else {
769                // Standard quartile calculation for larger datasets
770                let q1_idx = (n - 1) / 4;
771                let q3_idx = 3 * (n - 1) / 4;
772                (sorted_col[q1_idx], sorted_col[q3_idx])
773            };
774
775            let iqr = q3 - q1;
776
777            // Avoid division by zero or very small IQR
778            if iqr <= 1e-10 {
779                continue;
780            }
781
782            let lower_bound = q1 - iqr_multiplier * iqr;
783            let upper_bound = q3 + iqr_multiplier * iqr;
784
785            for &val in col {
786                if val < lower_bound || val > upper_bound {
787                    outliers += 1;
788                }
789            }
790        }
791
792        outliers
793    }
794
795    fn calculate_data_quality_score(
796        &self,
797        total_values: usize,
798        missing: usize,
799        infinite: usize,
800        duplicates: usize,
801        outliers: usize,
802    ) -> f64 {
803        let quality_score = 1.0
804            - (missing as f64 / total_values as f64) * 0.4
805            - (infinite as f64 / total_values as f64) * 0.3
806            - (duplicates as f64 / total_values as f64) * 0.2
807            - (outliers as f64 / total_values as f64) * 0.1;
808
809        quality_score.max(0.0)
810    }
811
812    fn test_noise_robustness<S>(
813        &self,
814        _pipeline: &Pipeline<S>,
815        x: &ArrayView2<'_, Float>,
816        _y: Option<&ArrayView1<'_, Float>>,
817        noise_level: f64,
818    ) -> SklResult<f64>
819    where
820        S: std::fmt::Debug,
821    {
822        // Placeholder noise robustness test
823        // Would add noise to data and measure performance degradation
824        Ok(1.0 - noise_level * 0.5)
825    }
826
827    fn test_missing_data_robustness<S>(
828        &self,
829        _pipeline: &Pipeline<S>,
830        x: &ArrayView2<'_, Float>,
831        _y: Option<&ArrayView1<'_, Float>>,
832        missing_ratio: f64,
833    ) -> SklResult<f64>
834    where
835        S: std::fmt::Debug,
836    {
837        // Placeholder missing data robustness test
838        // Would introduce missing values and measure performance
839        Ok(1.0 - missing_ratio * 0.7)
840    }
841
842    /// Test statistical significance of data patterns
843    fn test_statistical_significance(
844        &self,
845        x: &ArrayView2<'_, Float>,
846        y: Option<&ArrayView1<'_, Float>>,
847        p_values: &mut HashMap<String, f64>,
848    ) -> SklResult<bool> {
849        if !self.statistical_validator.statistical_tests {
850            return Ok(true);
851        }
852
853        let mut all_significant = true;
854
855        // Test for normality using Shapiro-Wilk approximation
856        for (i, column) in x.columns().into_iter().enumerate() {
857            let normality_p = self.shapiro_wilk_test(&column.to_owned())?;
858            p_values.insert(format!("normality_feature_{i}"), normality_p);
859
860            if normality_p < self.statistical_validator.alpha {
861                all_significant = false;
862            }
863        }
864
865        // Test for independence between features (correlation test)
866        if x.ncols() > 1 {
867            let correlation_p = self.independence_test(x)?;
868            p_values.insert("feature_independence".to_string(), correlation_p);
869
870            if correlation_p < self.statistical_validator.alpha {
871                all_significant = false;
872            }
873        }
874
875        // Test for target distribution if available
876        if let Some(targets) = y {
877            let target_normality_p = self.shapiro_wilk_test(&targets.to_owned())?;
878            p_values.insert("target_normality".to_string(), target_normality_p);
879        }
880
881        Ok(all_significant)
882    }
883
884    /// Detect potential data leakage
885    fn detect_data_leakage(
886        &self,
887        x: &ArrayView2<'_, Float>,
888        y: Option<&ArrayView1<'_, Float>>,
889    ) -> SklResult<bool> {
890        // Test for perfect correlation between features and target
891        if let Some(targets) = y {
892            for (i, column) in x.columns().into_iter().enumerate() {
893                let correlation =
894                    self.calculate_correlation(&column.to_owned(), &targets.to_owned())?;
895
896                // Perfect or near-perfect correlation suggests potential leakage
897                if correlation.abs() > 0.99 {
898                    return Ok(true);
899                }
900            }
901        }
902
903        // Test for duplicate features (perfect multicollinearity)
904        for i in 0..x.ncols() {
905            for j in (i + 1)..x.ncols() {
906                let col_i = x.column(i);
907                let col_j = x.column(j);
908                let correlation =
909                    self.calculate_correlation(&col_i.to_owned(), &col_j.to_owned())?;
910
911                if correlation.abs() > 0.999 {
912                    return Ok(true); // Likely duplicate features
913                }
914            }
915        }
916
917        Ok(false)
918    }
919
920    /// Calculate prediction consistency across data subsets
921    fn calculate_prediction_consistency(&self, x: &ArrayView2<'_, Float>) -> SklResult<f64> {
922        if x.nrows() < 20 {
923            return Ok(1.0); // Not enough data to test consistency
924        }
925
926        // Split data into random subsets and check consistency of statistics
927        let mid = x.nrows() / 2;
928        let subset1 = x.slice(s![..mid, ..]);
929        let subset2 = x.slice(s![mid.., ..]);
930
931        let mut consistency_scores = Vec::new();
932
933        // Compare means across subsets
934        for i in 0..x.ncols() {
935            let mean1 = subset1.column(i).mean().unwrap_or(0.0);
936            let mean2 = subset2.column(i).mean().unwrap_or(0.0);
937
938            let consistency = if mean1.abs() + mean2.abs() > 1e-10 {
939                1.0 - (mean1 - mean2).abs() / (mean1.abs() + mean2.abs()).max(1.0)
940            } else {
941                1.0
942            };
943
944            consistency_scores.push(consistency);
945        }
946
947        let avg_consistency =
948            consistency_scores.iter().sum::<f64>() / consistency_scores.len() as f64;
949        Ok(avg_consistency)
950    }
951
952    /// Detect concept drift in the data
953    fn detect_concept_drift(
954        &self,
955        x: &ArrayView2<'_, Float>,
956        y: Option<&ArrayView1<'_, Float>>,
957    ) -> SklResult<bool> {
958        if x.nrows() < 100 {
959            return Ok(false); // Not enough data to detect drift
960        }
961
962        // Split data into early and late periods
963        let split_point = x.nrows() * 2 / 3;
964        let early_x = x.slice(s![..split_point, ..]);
965        let late_x = x.slice(s![split_point.., ..]);
966
967        // Test for distribution changes using two-sample tests
968        for i in 0..x.ncols() {
969            let early_col = early_x.column(i);
970            let late_col = late_x.column(i);
971
972            // Simple drift test: compare means and variances
973            let mean_diff =
974                (early_col.mean().unwrap_or(0.0) - late_col.mean().unwrap_or(0.0)).abs();
975            let var_early = self.calculate_variance(&early_col.to_owned())?;
976            let var_late = self.calculate_variance(&late_col.to_owned())?;
977            let var_ratio = if var_late > 1e-10 {
978                var_early / var_late
979            } else {
980                1.0
981            };
982
983            // Detect significant changes
984            if mean_diff > 2.0 || !(0.5..=2.0).contains(&var_ratio) {
985                return Ok(true);
986            }
987        }
988
989        // Test target drift if available
990        if let Some(targets) = y {
991            let early_y = targets.slice(s![..split_point]);
992            let late_y = targets.slice(s![split_point..]);
993
994            let mean_diff = (early_y.mean().unwrap_or(0.0) - late_y.mean().unwrap_or(0.0)).abs();
995            if mean_diff > 1.0 {
996                return Ok(true);
997            }
998        }
999
1000        Ok(false)
1001    }
1002
1003    /// Approximate Shapiro-Wilk normality test
1004    fn shapiro_wilk_test(&self, data: &Array1<f64>) -> SklResult<f64> {
1005        if data.len() < 3 {
1006            return Ok(1.0); // Not enough data for test
1007        }
1008
1009        let n = data.len();
1010        let mut sorted_data = data.to_vec();
1011        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
1012
1013        // Simplified normality test based on sample skewness and kurtosis
1014        let mean = data.mean().unwrap_or(0.0);
1015        let variance = self.calculate_variance(data)?;
1016
1017        if variance < 1e-10 {
1018            return Ok(0.0); // Constant data is not normal
1019        }
1020
1021        let std_dev = variance.sqrt();
1022
1023        // Calculate sample skewness
1024        let skewness = data
1025            .iter()
1026            .map(|&x| ((x - mean) / std_dev).powi(3))
1027            .sum::<f64>()
1028            / n as f64;
1029
1030        // Calculate sample kurtosis
1031        let kurtosis = data
1032            .iter()
1033            .map(|&x| ((x - mean) / std_dev).powi(4))
1034            .sum::<f64>()
1035            / n as f64;
1036
1037        // Approximate p-value based on skewness and kurtosis
1038        // Normal distribution has skewness = 0 and kurtosis = 3
1039        let skew_stat = skewness.abs();
1040        let kurt_stat = (kurtosis - 3.0).abs();
1041
1042        // Simple approximation: lower p-value for higher deviations from normality
1043        let p_value = (1.0 - (skew_stat + kurt_stat) / 4.0).max(0.0).min(1.0);
1044
1045        Ok(p_value)
1046    }
1047
1048    /// Test for feature independence using correlation
1049    fn independence_test(&self, x: &ArrayView2<'_, Float>) -> SklResult<f64> {
1050        let mut max_correlation: f64 = 0.0;
1051
1052        for i in 0..x.ncols() {
1053            for j in (i + 1)..x.ncols() {
1054                let col_i = x.column(i);
1055                let col_j = x.column(j);
1056                let correlation =
1057                    self.calculate_correlation(&col_i.to_owned(), &col_j.to_owned())?;
1058                max_correlation = max_correlation.max(correlation.abs());
1059            }
1060        }
1061
1062        // Convert correlation to approximate p-value
1063        // Higher correlation = lower p-value (less independence)
1064        let p_value = (1.0 - max_correlation).max(0.0);
1065        Ok(p_value)
1066    }
1067
1068    /// Calculate Pearson correlation coefficient
1069    fn calculate_correlation(&self, x: &Array1<f64>, y: &Array1<f64>) -> SklResult<f64> {
1070        if x.len() != y.len() || x.len() < 2 {
1071            return Ok(0.0);
1072        }
1073
1074        let mean_x = x.mean().unwrap_or(0.0);
1075        let mean_y = y.mean().unwrap_or(0.0);
1076
1077        let covariance = x
1078            .iter()
1079            .zip(y.iter())
1080            .map(|(&xi, &yi)| (xi - mean_x) * (yi - mean_y))
1081            .sum::<f64>()
1082            / (x.len() - 1) as f64;
1083
1084        let var_x = self.calculate_variance(x)?;
1085        let var_y = self.calculate_variance(y)?;
1086
1087        if var_x < 1e-10 || var_y < 1e-10 {
1088            return Ok(0.0); // No correlation if either variable is constant
1089        }
1090
1091        let correlation = covariance / (var_x.sqrt() * var_y.sqrt());
1092        Ok(correlation.max(-1.0).min(1.0)) // Clamp to [-1, 1]
1093    }
1094
1095    /// Calculate sample variance
1096    fn calculate_variance(&self, data: &Array1<f64>) -> SklResult<f64> {
1097        if data.len() < 2 {
1098            return Ok(0.0);
1099        }
1100
1101        let mean = data.mean().unwrap_or(0.0);
1102        let variance =
1103            data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (data.len() - 1) as f64;
1104
1105        Ok(variance)
1106    }
1107}
1108
1109// Default implementations for validator components
1110impl Default for DataValidator {
1111    fn default() -> Self {
1112        Self {
1113            check_missing_values: true,
1114            check_infinite_values: true,
1115            check_data_types: true,
1116            check_feature_scaling: false,
1117            check_distributions: false,
1118            max_missing_ratio: 0.05,
1119            check_duplicates: false,
1120            check_outliers: false,
1121            outlier_iqr_multiplier: 1.5,
1122        }
1123    }
1124}
1125
1126impl DataValidator {
1127    #[must_use]
1128    pub fn strict() -> Self {
1129        Self {
1130            check_missing_values: true,
1131            check_infinite_values: true,
1132            check_data_types: true,
1133            check_feature_scaling: true,
1134            check_distributions: true,
1135            max_missing_ratio: 0.01,
1136            check_duplicates: true,
1137            check_outliers: true,
1138            outlier_iqr_multiplier: 1.5,
1139        }
1140    }
1141
1142    #[must_use]
1143    pub fn basic() -> Self {
1144        Self {
1145            check_missing_values: true,
1146            check_infinite_values: true,
1147            check_data_types: false,
1148            check_feature_scaling: false,
1149            check_distributions: false,
1150            max_missing_ratio: 0.1,
1151            check_duplicates: false,
1152            check_outliers: false,
1153            outlier_iqr_multiplier: 2.0,
1154        }
1155    }
1156}
1157
1158impl Default for StructureValidator {
1159    fn default() -> Self {
1160        Self {
1161            check_component_compatibility: true,
1162            check_data_flow: true,
1163            check_parameter_consistency: false,
1164            check_circular_dependencies: true,
1165            check_resource_requirements: false,
1166            max_pipeline_depth: 10,
1167            max_components: 50,
1168        }
1169    }
1170}
1171
1172impl StructureValidator {
1173    #[must_use]
1174    pub fn strict() -> Self {
1175        Self {
1176            check_component_compatibility: true,
1177            check_data_flow: true,
1178            check_parameter_consistency: true,
1179            check_circular_dependencies: true,
1180            check_resource_requirements: true,
1181            max_pipeline_depth: 5,
1182            max_components: 20,
1183        }
1184    }
1185
1186    #[must_use]
1187    pub fn basic() -> Self {
1188        Self {
1189            check_component_compatibility: false,
1190            check_data_flow: false,
1191            check_parameter_consistency: false,
1192            check_circular_dependencies: false,
1193            check_resource_requirements: false,
1194            max_pipeline_depth: 20,
1195            max_components: 100,
1196        }
1197    }
1198}
1199
1200impl Default for StatisticalValidator {
1201    fn default() -> Self {
1202        Self {
1203            statistical_tests: false,
1204            check_data_leakage: false,
1205            check_feature_importance: false,
1206            check_prediction_consistency: false,
1207            min_sample_size: 30,
1208            alpha: 0.05,
1209            check_concept_drift: false,
1210        }
1211    }
1212}
1213
1214impl StatisticalValidator {
1215    #[must_use]
1216    pub fn strict() -> Self {
1217        Self {
1218            statistical_tests: true,
1219            check_data_leakage: true,
1220            check_feature_importance: true,
1221            check_prediction_consistency: true,
1222            min_sample_size: 100,
1223            alpha: 0.01,
1224            check_concept_drift: true,
1225        }
1226    }
1227
1228    #[must_use]
1229    pub fn disabled() -> Self {
1230        Self {
1231            statistical_tests: false,
1232            check_data_leakage: false,
1233            check_feature_importance: false,
1234            check_prediction_consistency: false,
1235            min_sample_size: 10,
1236            alpha: 0.1,
1237            check_concept_drift: false,
1238        }
1239    }
1240}
1241
1242impl Default for PerformanceValidator {
1243    fn default() -> Self {
1244        Self {
1245            check_training_time: false,
1246            check_prediction_time: false,
1247            check_memory_usage: false,
1248            max_training_time: 300.0,             // 5 minutes
1249            max_prediction_time_per_sample: 10.0, // 10ms
1250            max_memory_usage: 1000.0,             // 1GB
1251            check_scalability: false,
1252        }
1253    }
1254}
1255
1256impl PerformanceValidator {
1257    #[must_use]
1258    pub fn strict() -> Self {
1259        Self {
1260            check_training_time: true,
1261            check_prediction_time: true,
1262            check_memory_usage: true,
1263            max_training_time: 60.0,             // 1 minute
1264            max_prediction_time_per_sample: 1.0, // 1ms
1265            max_memory_usage: 500.0,             // 500MB
1266            check_scalability: true,
1267        }
1268    }
1269
1270    #[must_use]
1271    pub fn basic() -> Self {
1272        Self {
1273            check_training_time: false,
1274            check_prediction_time: false,
1275            check_memory_usage: false,
1276            max_training_time: 3600.0,             // 1 hour
1277            max_prediction_time_per_sample: 100.0, // 100ms
1278            max_memory_usage: 5000.0,              // 5GB
1279            check_scalability: false,
1280        }
1281    }
1282}
1283
1284impl Default for CrossValidator {
1285    fn default() -> Self {
1286        Self {
1287            cv_folds: 5,
1288            stratified: true,
1289            time_series_cv: false,
1290            leave_one_out: false,
1291            bootstrap: false,
1292            n_bootstrap: 100,
1293            random_state: Some(42),
1294        }
1295    }
1296}
1297
1298impl CrossValidator {
1299    #[must_use]
1300    pub fn fast() -> Self {
1301        Self {
1302            cv_folds: 3,
1303            stratified: false,
1304            time_series_cv: false,
1305            leave_one_out: false,
1306            bootstrap: false,
1307            n_bootstrap: 10,
1308            random_state: Some(42),
1309        }
1310    }
1311}
1312
1313impl Default for RobustnessTester {
1314    fn default() -> Self {
1315        Self {
1316            test_noise_robustness: false,
1317            test_missing_data_robustness: false,
1318            test_adversarial_robustness: false,
1319            test_distribution_shift: false,
1320            noise_levels: vec![0.01, 0.05, 0.1],
1321            missing_ratios: vec![0.01, 0.05, 0.1],
1322            n_robustness_tests: 10,
1323        }
1324    }
1325}
1326
1327impl RobustnessTester {
1328    #[must_use]
1329    pub fn comprehensive() -> Self {
1330        Self {
1331            test_noise_robustness: true,
1332            test_missing_data_robustness: true,
1333            test_adversarial_robustness: true,
1334            test_distribution_shift: true,
1335            noise_levels: vec![0.001, 0.01, 0.05, 0.1, 0.2],
1336            missing_ratios: vec![0.01, 0.05, 0.1, 0.2, 0.3],
1337            n_robustness_tests: 50,
1338        }
1339    }
1340
1341    #[must_use]
1342    pub fn disabled() -> Self {
1343        Self {
1344            test_noise_robustness: false,
1345            test_missing_data_robustness: false,
1346            test_adversarial_robustness: false,
1347            test_distribution_shift: false,
1348            noise_levels: vec![],
1349            missing_ratios: vec![],
1350            n_robustness_tests: 0,
1351        }
1352    }
1353}
1354
1355#[allow(non_snake_case)]
1356#[cfg(test)]
1357mod tests {
1358    use super::*;
1359    use scirs2_core::ndarray::{array, Array, ArrayView1, ArrayView2};
1360
1361    #[test]
1362    fn test_comprehensive_validator_creation() {
1363        let validator = ComprehensivePipelineValidator::new();
1364        assert!(!validator.verbose);
1365
1366        let strict_validator = ComprehensivePipelineValidator::strict();
1367        assert!(strict_validator.verbose);
1368
1369        let fast_validator = ComprehensivePipelineValidator::fast();
1370        assert!(!fast_validator.verbose);
1371    }
1372
1373    #[test]
1374    fn test_data_validation() {
1375        let validator = ComprehensivePipelineValidator::new();
1376        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1377        let y = array![1.0, 2.0, 3.0];
1378
1379        let mut messages = Vec::new();
1380        let result = validator
1381            .validate_data(&x.view(), Some(&y.view()), &mut messages)
1382            .unwrap();
1383
1384        assert!(result.passed);
1385        assert_eq!(result.missing_values_count, 0);
1386        assert_eq!(result.infinite_values_count, 0);
1387    }
1388
1389    #[test]
1390    fn test_data_validation_with_missing_values() {
1391        let validator = ComprehensivePipelineValidator::strict();
1392        let x = array![[1.0, f64::NAN], [3.0, 4.0], [5.0, 6.0]];
1393
1394        let mut messages = Vec::new();
1395        let result = validator
1396            .validate_data(&x.view(), None, &mut messages)
1397            .unwrap();
1398
1399        assert!(!result.passed);
1400        assert_eq!(result.missing_values_count, 1);
1401        assert!(!messages.is_empty());
1402    }
1403
1404    #[test]
1405    fn test_outlier_detection() {
1406        let validator = ComprehensivePipelineValidator::new();
1407        let outlier_count = validator.count_outliers(
1408            &array![[1.0, 2.0], [1.1, 2.1], [1.0, 2.0], [100.0, 200.0]].view(),
1409            1.5,
1410        );
1411
1412        assert!(outlier_count > 0);
1413    }
1414
1415    #[test]
1416    fn test_duplicate_detection() {
1417        let validator = ComprehensivePipelineValidator::new();
1418        let duplicate_count =
1419            validator.count_duplicate_samples(&array![[1.0, 2.0], [3.0, 4.0], [1.0, 2.0]].view());
1420
1421        assert_eq!(duplicate_count, 1);
1422    }
1423
1424    #[test]
1425    fn test_data_quality_score() {
1426        let validator = ComprehensivePipelineValidator::new();
1427
1428        let perfect_score = validator.calculate_data_quality_score(100, 0, 0, 0, 0);
1429        assert_eq!(perfect_score, 1.0);
1430
1431        let imperfect_score = validator.calculate_data_quality_score(100, 10, 5, 2, 1);
1432        assert!(imperfect_score < 1.0);
1433        assert!(imperfect_score > 0.0);
1434    }
1435}