Skip to main content

sklears_preprocessing/
cross_validation.rs

1//! Cross-Validation Utilities for Preprocessing
2//!
3//! Provides cross-validation support for preprocessing parameter tuning,
4//! including grid search and random search for optimal preprocessing parameters.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::random::essentials::Uniform;
8use scirs2_core::random::{seeded_rng, Distribution};
9use sklears_core::prelude::SklearsError;
10use std::collections::HashMap;
11
12/// K-Fold cross-validation splitter
13#[derive(Debug, Clone)]
14pub struct KFold {
15    /// Number of folds
16    pub n_splits: usize,
17    /// Whether to shuffle data before splitting
18    pub shuffle: bool,
19    /// Random seed for shuffling
20    pub random_state: Option<u64>,
21}
22
23impl KFold {
24    /// Create a new K-Fold splitter
25    pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> Self {
26        Self {
27            n_splits,
28            shuffle,
29            random_state,
30        }
31    }
32
33    /// Generate train/test splits
34    pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>, SklearsError> {
35        if n_samples < self.n_splits {
36            return Err(SklearsError::InvalidInput(format!(
37                "Cannot split {} samples into {} folds",
38                n_samples, self.n_splits
39            )));
40        }
41
42        let mut indices: Vec<usize> = (0..n_samples).collect();
43
44        if self.shuffle {
45            use std::time::{SystemTime, UNIX_EPOCH};
46
47            let seed = self.random_state.unwrap_or_else(|| {
48                SystemTime::now()
49                    .duration_since(UNIX_EPOCH)
50                    .expect("operation should succeed")
51                    .as_secs()
52            });
53
54            let mut rng = seeded_rng(seed);
55
56            // Fisher-Yates shuffle
57            for i in (1..indices.len()).rev() {
58                let uniform = Uniform::new(0, i + 1).expect("operation should succeed");
59                let j = uniform.sample(&mut rng);
60                indices.swap(i, j);
61            }
62        }
63
64        let fold_size = n_samples / self.n_splits;
65        let mut splits = Vec::new();
66
67        for fold_idx in 0..self.n_splits {
68            let test_start = fold_idx * fold_size;
69            let test_end = if fold_idx == self.n_splits - 1 {
70                n_samples
71            } else {
72                (fold_idx + 1) * fold_size
73            };
74
75            let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
76            let train_indices: Vec<usize> = indices[..test_start]
77                .iter()
78                .chain(&indices[test_end..])
79                .copied()
80                .collect();
81
82            splits.push((train_indices, test_indices));
83        }
84
85        Ok(splits)
86    }
87}
88
89/// Stratified K-Fold cross-validation splitter
90#[derive(Debug, Clone)]
91pub struct StratifiedKFold {
92    /// Number of folds
93    pub n_splits: usize,
94    /// Whether to shuffle data before splitting
95    pub shuffle: bool,
96    /// Random seed for shuffling
97    pub random_state: Option<u64>,
98}
99
100impl StratifiedKFold {
101    /// Create a new Stratified K-Fold splitter
102    pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> Self {
103        Self {
104            n_splits,
105            shuffle,
106            random_state,
107        }
108    }
109
110    /// Generate stratified train/test splits
111    pub fn split(&self, y: &Array1<i32>) -> Result<Vec<(Vec<usize>, Vec<usize>)>, SklearsError> {
112        let n_samples = y.len();
113
114        if n_samples < self.n_splits {
115            return Err(SklearsError::InvalidInput(format!(
116                "Cannot split {} samples into {} folds",
117                n_samples, self.n_splits
118            )));
119        }
120
121        // Group indices by class
122        let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
123        for (idx, &label) in y.iter().enumerate() {
124            class_indices.entry(label).or_default().push(idx);
125        }
126
127        // Shuffle within each class
128        if self.shuffle {
129            use std::time::{SystemTime, UNIX_EPOCH};
130
131            let seed = self.random_state.unwrap_or_else(|| {
132                SystemTime::now()
133                    .duration_since(UNIX_EPOCH)
134                    .expect("operation should succeed")
135                    .as_secs()
136            });
137
138            let mut rng = seeded_rng(seed);
139
140            for indices in class_indices.values_mut() {
141                for i in (1..indices.len()).rev() {
142                    let uniform = Uniform::new(0, i + 1).expect("operation should succeed");
143                    let j = uniform.sample(&mut rng);
144                    indices.swap(i, j);
145                }
146            }
147        }
148
149        // Create splits maintaining class distribution
150        let mut splits: Vec<(Vec<usize>, Vec<usize>)> = vec![];
151
152        for fold_idx in 0..self.n_splits {
153            let mut train_indices = Vec::new();
154            let mut test_indices = Vec::new();
155
156            for indices in class_indices.values() {
157                let fold_size = indices.len() / self.n_splits;
158                let test_start = fold_idx * fold_size;
159                let test_end = if fold_idx == self.n_splits - 1 {
160                    indices.len()
161                } else {
162                    (fold_idx + 1) * fold_size
163                };
164
165                test_indices.extend(&indices[test_start..test_end]);
166                train_indices.extend(&indices[..test_start]);
167                train_indices.extend(&indices[test_end..]);
168            }
169
170            splits.push((train_indices, test_indices));
171        }
172
173        Ok(splits)
174    }
175}
176
177/// Cross-validation score result
178#[derive(Debug, Clone)]
179pub struct CVScore {
180    /// Mean score across folds
181    pub mean: f64,
182    /// Standard deviation of scores
183    pub std: f64,
184    /// Individual fold scores
185    pub scores: Vec<f64>,
186}
187
188/// Grid search parameter specification
189#[derive(Debug, Clone)]
190pub struct ParameterGrid {
191    parameters: HashMap<String, Vec<f64>>,
192}
193
194impl ParameterGrid {
195    /// Create a new parameter grid
196    pub fn new() -> Self {
197        Self {
198            parameters: HashMap::new(),
199        }
200    }
201
202    /// Add a parameter with possible values
203    pub fn add_parameter(mut self, name: String, values: Vec<f64>) -> Self {
204        self.parameters.insert(name, values);
205        self
206    }
207
208    /// Generate all parameter combinations
209    pub fn combinations(&self) -> Vec<HashMap<String, f64>> {
210        if self.parameters.is_empty() {
211            return vec![HashMap::new()];
212        }
213
214        let mut result = vec![HashMap::new()];
215
216        for (param_name, param_values) in &self.parameters {
217            let mut new_result = Vec::new();
218
219            for combination in &result {
220                for &value in param_values {
221                    let mut new_combination = combination.clone();
222                    new_combination.insert(param_name.clone(), value);
223                    new_result.push(new_combination);
224                }
225            }
226
227            result = new_result;
228        }
229
230        result
231    }
232
233    /// Get total number of combinations
234    pub fn n_combinations(&self) -> usize {
235        if self.parameters.is_empty() {
236            return 0;
237        }
238
239        self.parameters.values().map(|v| v.len()).product()
240    }
241}
242
243impl Default for ParameterGrid {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249/// Random search parameter specification
250#[derive(Debug, Clone)]
251pub struct ParameterDistribution {
252    parameters: HashMap<String, (f64, f64)>, // (min, max) for uniform distribution
253}
254
255impl ParameterDistribution {
256    /// Create a new parameter distribution
257    pub fn new() -> Self {
258        Self {
259            parameters: HashMap::new(),
260        }
261    }
262
263    /// Add a parameter with range
264    pub fn add_parameter(mut self, name: String, min: f64, max: f64) -> Self {
265        self.parameters.insert(name, (min, max));
266        self
267    }
268
269    /// Sample random parameters
270    pub fn sample(&self, n_iter: usize, random_state: Option<u64>) -> Vec<HashMap<String, f64>> {
271        use std::time::{SystemTime, UNIX_EPOCH};
272
273        let seed = random_state.unwrap_or_else(|| {
274            SystemTime::now()
275                .duration_since(UNIX_EPOCH)
276                .expect("operation should succeed")
277                .as_secs()
278        });
279
280        let mut rng = seeded_rng(seed);
281
282        (0..n_iter)
283            .map(|_| {
284                self.parameters
285                    .iter()
286                    .map(|(name, &(min, max))| {
287                        let uniform =
288                            Uniform::new_inclusive(min, max).expect("operation should succeed");
289                        (name.clone(), uniform.sample(&mut rng))
290                    })
291                    .collect()
292            })
293            .collect()
294    }
295}
296
297impl Default for ParameterDistribution {
298    fn default() -> Self {
299        Self::new()
300    }
301}
302
303/// Evaluation metric for preprocessing quality
304pub trait PreprocessingMetric {
305    /// Evaluate preprocessing quality
306    fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64;
307}
308
309/// Variance preservation metric
310pub struct VariancePreservationMetric;
311
312impl PreprocessingMetric for VariancePreservationMetric {
313    fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64 {
314        let mut total_variance_ratio = 0.0;
315
316        for j in 0..x_original.ncols() {
317            let original_col = x_original.column(j);
318            let transformed_col = x_transformed.column(j);
319
320            let original_var = Self::compute_variance(original_col);
321            let transformed_var = Self::compute_variance(transformed_col);
322
323            if original_var > 1e-10 {
324                total_variance_ratio += transformed_var / original_var;
325            }
326        }
327
328        total_variance_ratio / x_original.ncols() as f64
329    }
330}
331
332impl VariancePreservationMetric {
333    fn compute_variance<'a, I>(values: I) -> f64
334    where
335        I: IntoIterator<Item = &'a f64>,
336    {
337        let vals: Vec<f64> = values
338            .into_iter()
339            .copied()
340            .filter(|v| !v.is_nan())
341            .collect();
342
343        if vals.is_empty() {
344            return 0.0;
345        }
346
347        let mean = vals.iter().sum::<f64>() / vals.len() as f64;
348        vals.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / vals.len() as f64
349    }
350}
351
352/// Information preservation metric (measures mutual information preservation)
353pub struct InformationPreservationMetric;
354
355impl PreprocessingMetric for InformationPreservationMetric {
356    fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64 {
357        // Simplified: Use correlation as proxy for information preservation
358        let mut total_correlation = 0.0;
359        let mut count = 0;
360
361        for j in 0..x_original.ncols().min(x_transformed.ncols()) {
362            let corr = Self::compute_correlation(x_original, x_transformed, j);
363            if !corr.is_nan() {
364                total_correlation += corr.abs();
365                count += 1;
366            }
367        }
368
369        if count > 0 {
370            total_correlation / count as f64
371        } else {
372            0.0
373        }
374    }
375}
376
377impl InformationPreservationMetric {
378    fn compute_correlation(x1: &Array2<f64>, x2: &Array2<f64>, col_idx: usize) -> f64 {
379        let col1 = x1.column(col_idx);
380        let col2 = x2.column(col_idx);
381
382        let pairs: Vec<(f64, f64)> = col1
383            .iter()
384            .zip(col2.iter())
385            .filter(|(a, b)| !a.is_nan() && !b.is_nan())
386            .map(|(&a, &b)| (a, b))
387            .collect();
388
389        if pairs.len() < 2 {
390            return 0.0;
391        }
392
393        let mean1 = pairs.iter().map(|(a, _)| a).sum::<f64>() / pairs.len() as f64;
394        let mean2 = pairs.iter().map(|(_, b)| b).sum::<f64>() / pairs.len() as f64;
395
396        let mut cov = 0.0;
397        let mut var1 = 0.0;
398        let mut var2 = 0.0;
399
400        for (a, b) in &pairs {
401            let d1 = a - mean1;
402            let d2 = b - mean2;
403            cov += d1 * d2;
404            var1 += d1 * d1;
405            var2 += d2 * d2;
406        }
407
408        if var1 < 1e-10 || var2 < 1e-10 {
409            return 0.0;
410        }
411
412        cov / (var1 * var2).sqrt()
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use scirs2_core::random::essentials::Normal;
420    use scirs2_core::random::{seeded_rng, Distribution};
421
422    fn generate_test_data(nrows: usize, ncols: usize, seed: u64) -> Array2<f64> {
423        let mut rng = seeded_rng(seed);
424        let normal = Normal::new(0.0, 1.0).expect("operation should succeed");
425
426        let data: Vec<f64> = (0..nrows * ncols)
427            .map(|_| normal.sample(&mut rng))
428            .collect();
429
430        Array2::from_shape_vec((nrows, ncols), data).expect("shape and data length should match")
431    }
432
433    #[test]
434    fn test_kfold_split() {
435        let kfold = KFold::new(5, false, Some(42));
436        let splits = kfold.split(100).expect("operation should succeed");
437
438        assert_eq!(splits.len(), 5);
439
440        for (train, test) in &splits {
441            assert!(train.len() > 0);
442            assert!(test.len() > 0);
443            assert_eq!(train.len() + test.len(), 100);
444        }
445    }
446
447    #[test]
448    fn test_kfold_shuffle() {
449        let kfold1 = KFold::new(3, true, Some(42));
450        let splits1 = kfold1.split(30).expect("operation should succeed");
451
452        let kfold2 = KFold::new(3, false, None);
453        let splits2 = kfold2.split(30).expect("operation should succeed");
454
455        // Shuffled and non-shuffled should be different
456        let different = splits1[0].0 != splits2[0].0;
457        assert!(different);
458    }
459
460    #[test]
461    fn test_stratified_kfold() {
462        let y = Array1::from_vec(vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2]);
463
464        let stratified = StratifiedKFold::new(3, false, Some(42));
465        let splits = stratified.split(&y).expect("operation should succeed");
466
467        assert_eq!(splits.len(), 3);
468
469        // Check that each split maintains class distribution
470        for (train_indices, test_indices) in &splits {
471            let _train_classes: Vec<i32> = train_indices.iter().map(|&i| y[i]).collect();
472            let test_classes: Vec<i32> = test_indices.iter().map(|&i| y[i]).collect();
473
474            // Count classes in test set
475            let test_0 = test_classes.iter().filter(|&&c| c == 0).count();
476            let test_1 = test_classes.iter().filter(|&&c| c == 1).count();
477            let test_2 = test_classes.iter().filter(|&&c| c == 2).count();
478
479            // Each class should appear roughly equally
480            assert!(test_0 > 0);
481            assert!(test_1 > 0);
482            assert!(test_2 > 0);
483        }
484    }
485
486    #[test]
487    fn test_parameter_grid() {
488        let grid = ParameterGrid::new()
489            .add_parameter("alpha".to_string(), vec![0.1, 1.0, 10.0])
490            .add_parameter("beta".to_string(), vec![0.5, 1.5]);
491
492        let combinations = grid.combinations();
493
494        assert_eq!(combinations.len(), 6); // 3 * 2 = 6
495        assert_eq!(grid.n_combinations(), 6);
496
497        // Check that all combinations are present
498        let has_alpha_0_1 = combinations.iter().any(|c| c.get("alpha") == Some(&0.1));
499        assert!(has_alpha_0_1);
500    }
501
502    #[test]
503    fn test_parameter_distribution() {
504        let dist = ParameterDistribution::new()
505            .add_parameter("alpha".to_string(), 0.0, 1.0)
506            .add_parameter("beta".to_string(), 0.0, 10.0);
507
508        let samples = dist.sample(10, Some(42));
509
510        assert_eq!(samples.len(), 10);
511
512        for sample in &samples {
513            let alpha = sample.get("alpha").expect("sampling should succeed");
514            let beta = sample.get("beta").expect("sampling should succeed");
515
516            assert!(*alpha >= 0.0 && *alpha <= 1.0);
517            assert!(*beta >= 0.0 && *beta <= 10.0);
518        }
519    }
520
521    #[test]
522    fn test_variance_preservation_metric() {
523        let x_original = generate_test_data(100, 5, 42);
524        let x_transformed = x_original.clone();
525
526        let metric = VariancePreservationMetric;
527        let score = metric.evaluate(&x_original, &x_transformed);
528
529        // Same data should have score close to 1.0
530        assert!((score - 1.0).abs() < 0.1);
531    }
532
533    #[test]
534    fn test_information_preservation_metric() {
535        let x_original = generate_test_data(100, 5, 123);
536        let x_transformed = x_original.clone();
537
538        let metric = InformationPreservationMetric;
539        let score = metric.evaluate(&x_original, &x_transformed);
540
541        // Same data should have high correlation
542        assert!(score > 0.9);
543    }
544
545    #[test]
546    fn test_kfold_edge_case_small_dataset() {
547        let kfold = KFold::new(5, false, Some(42));
548        let result = kfold.split(3);
549
550        assert!(result.is_err());
551    }
552
553    #[test]
554    fn test_empty_parameter_grid() {
555        let grid = ParameterGrid::new();
556        let combinations = grid.combinations();
557
558        assert_eq!(combinations.len(), 1);
559        assert!(combinations[0].is_empty());
560        assert_eq!(grid.n_combinations(), 0);
561    }
562}