sklears_utils/
ensemble.rs

1//! Ensemble utilities for machine learning
2//!
3//! This module provides utilities for ensemble methods including bootstrap sampling,
4//! bagging, and ensemble combination strategies.
5
6use crate::{UtilsError, UtilsResult};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::{Rng, SeedableRng};
10use std::collections::HashMap;
11
12/// Bootstrap sample generator
13///
14/// Creates bootstrap samples (sampling with replacement) for ensemble methods.
15#[derive(Clone, Debug)]
16pub struct Bootstrap {
17    n_samples: Option<usize>,
18    random_state: Option<u64>,
19}
20
21impl Bootstrap {
22    /// Create a new bootstrap sampler
23    ///
24    /// # Arguments
25    /// * `n_samples` - Number of samples to draw (None = same as input size)
26    /// * `random_state` - Random seed for reproducibility
27    pub fn new(n_samples: Option<usize>, random_state: Option<u64>) -> Self {
28        Self {
29            n_samples,
30            random_state,
31        }
32    }
33
34    /// Generate bootstrap sample indices
35    ///
36    /// # Arguments
37    /// * `n_population` - Size of the population to sample from
38    ///
39    /// # Returns
40    /// Tuple of (in-bag indices, out-of-bag indices)
41    pub fn sample(&self, n_population: usize) -> UtilsResult<(Vec<usize>, Vec<usize>)> {
42        if n_population == 0 {
43            return Err(UtilsError::InvalidParameter(
44                "Population size must be positive".to_string(),
45            ));
46        }
47
48        let n_samples = self.n_samples.unwrap_or(n_population);
49        let mut rng = self
50            .random_state
51            .map(StdRng::seed_from_u64)
52            .unwrap_or_else(|| StdRng::seed_from_u64(42));
53
54        // Generate in-bag samples
55        let mut in_bag = Vec::with_capacity(n_samples);
56        let mut in_bag_set = vec![false; n_population];
57
58        for _ in 0..n_samples {
59            let idx = rng.gen_range(0..n_population);
60            in_bag.push(idx);
61            in_bag_set[idx] = true;
62        }
63
64        // Collect out-of-bag samples
65        let out_of_bag: Vec<usize> = (0..n_population).filter(|&i| !in_bag_set[i]).collect();
66
67        Ok((in_bag, out_of_bag))
68    }
69
70    /// Generate multiple bootstrap samples
71    ///
72    /// # Arguments
73    /// * `n_population` - Size of the population
74    /// * `n_bootstraps` - Number of bootstrap samples to generate
75    ///
76    /// # Returns
77    /// Vector of (in-bag, out-of-bag) index pairs
78    pub fn sample_multiple(
79        &self,
80        n_population: usize,
81        n_bootstraps: usize,
82    ) -> UtilsResult<Vec<(Vec<usize>, Vec<usize>)>> {
83        let mut samples = Vec::with_capacity(n_bootstraps);
84
85        for i in 0..n_bootstraps {
86            // Create new sampler with different seed for each bootstrap
87            let seed = self.random_state.map(|s| s + i as u64);
88            let sampler = Bootstrap::new(self.n_samples, seed);
89            samples.push(sampler.sample(n_population)?);
90        }
91
92        Ok(samples)
93    }
94}
95
96impl Default for Bootstrap {
97    fn default() -> Self {
98        Self::new(None, Some(42))
99    }
100}
101
102/// Bagging utility for creating bagged ensemble predictions
103#[derive(Clone, Debug)]
104pub struct BaggingPredictor {
105    aggregation: AggregationStrategy,
106}
107
108/// Strategy for aggregating predictions from ensemble members
109#[derive(Clone, Debug, PartialEq)]
110pub enum AggregationStrategy {
111    /// Average predictions (for regression)
112    Mean,
113    /// Median of predictions (robust to outliers)
114    Median,
115    /// Majority voting (for classification)
116    MajorityVote,
117    /// Weighted average with given weights
118    WeightedMean,
119}
120
121impl BaggingPredictor {
122    /// Create a new bagging predictor
123    pub fn new(aggregation: AggregationStrategy) -> Self {
124        Self { aggregation }
125    }
126
127    /// Aggregate regression predictions
128    ///
129    /// # Arguments
130    /// * `predictions` - Matrix of predictions (n_samples × n_estimators)
131    /// * `weights` - Optional weights for weighted averaging
132    ///
133    /// # Returns
134    /// Aggregated predictions for each sample
135    pub fn aggregate_regression(
136        &self,
137        predictions: &Array2<f64>,
138        weights: Option<&Array1<f64>>,
139    ) -> UtilsResult<Array1<f64>> {
140        if predictions.nrows() == 0 || predictions.ncols() == 0 {
141            return Err(UtilsError::InvalidParameter(
142                "Predictions array cannot be empty".to_string(),
143            ));
144        }
145
146        match &self.aggregation {
147            AggregationStrategy::Mean => Ok(predictions
148                .mean_axis(scirs2_core::ndarray::Axis(1))
149                .unwrap()),
150            AggregationStrategy::Median => {
151                let mut result = Array1::zeros(predictions.nrows());
152                for (i, row) in predictions.outer_iter().enumerate() {
153                    let mut sorted: Vec<f64> = row.to_vec();
154                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
155                    let mid = sorted.len() / 2;
156                    result[i] = if sorted.len() % 2 == 0 {
157                        (sorted[mid - 1] + sorted[mid]) / 2.0
158                    } else {
159                        sorted[mid]
160                    };
161                }
162                Ok(result)
163            }
164            AggregationStrategy::WeightedMean => {
165                let weights = weights.ok_or_else(|| {
166                    UtilsError::InvalidParameter("Weights required for weighted mean".to_string())
167                })?;
168
169                if weights.len() != predictions.ncols() {
170                    return Err(UtilsError::InvalidParameter(
171                        "Number of weights must match number of estimators".to_string(),
172                    ));
173                }
174
175                let weight_sum: f64 = weights.sum();
176                if weight_sum <= 0.0 {
177                    return Err(UtilsError::InvalidParameter(
178                        "Weight sum must be positive".to_string(),
179                    ));
180                }
181
182                let normalized_weights = weights / weight_sum;
183                Ok(predictions.dot(&normalized_weights))
184            }
185            AggregationStrategy::MajorityVote => Err(UtilsError::InvalidParameter(
186                "Use aggregate_classification for majority voting".to_string(),
187            )),
188        }
189    }
190
191    /// Aggregate classification predictions (voting)
192    ///
193    /// # Arguments
194    /// * `predictions` - Matrix of class predictions (n_samples × n_estimators)
195    ///
196    /// # Returns
197    /// Predicted class for each sample (majority vote)
198    pub fn aggregate_classification(
199        &self,
200        predictions: &Array2<usize>,
201    ) -> UtilsResult<Array1<usize>> {
202        if predictions.nrows() == 0 || predictions.ncols() == 0 {
203            return Err(UtilsError::InvalidParameter(
204                "Predictions array cannot be empty".to_string(),
205            ));
206        }
207
208        let mut result = Array1::zeros(predictions.nrows());
209
210        for (i, row) in predictions.outer_iter().enumerate() {
211            // Count votes for each class
212            let mut vote_counts: HashMap<usize, usize> = HashMap::new();
213            for &pred in row.iter() {
214                *vote_counts.entry(pred).or_insert(0) += 1;
215            }
216
217            // Find class with maximum votes
218            let (predicted_class, _) = vote_counts
219                .iter()
220                .max_by_key(|(_, &count)| count)
221                .ok_or_else(|| UtilsError::InvalidParameter("No votes found".to_string()))?;
222
223            result[i] = *predicted_class;
224        }
225
226        Ok(result)
227    }
228
229    /// Aggregate classification probabilities (soft voting)
230    ///
231    /// # Arguments
232    /// * `probabilities` - Array of probability matrices from each estimator
233    ///   `Vec<Array2>` where each Array2 is (n_samples × n_classes)
234    ///
235    /// # Returns
236    /// Averaged probability matrix (n_samples × n_classes)
237    pub fn aggregate_probabilities(
238        &self,
239        probabilities: &[Array2<f64>],
240    ) -> UtilsResult<Array2<f64>> {
241        if probabilities.is_empty() {
242            return Err(UtilsError::InvalidParameter(
243                "Probabilities array cannot be empty".to_string(),
244            ));
245        }
246
247        let (n_samples, n_classes) = probabilities[0].dim();
248
249        // Validate all arrays have same shape
250        for probs in probabilities.iter() {
251            if probs.dim() != (n_samples, n_classes) {
252                return Err(UtilsError::InvalidParameter(
253                    "All probability matrices must have the same shape".to_string(),
254                ));
255            }
256        }
257
258        // Average probabilities
259        let mut result = Array2::zeros((n_samples, n_classes));
260        for probs in probabilities {
261            result += probs;
262        }
263        result /= probabilities.len() as f64;
264
265        Ok(result)
266    }
267}
268
269impl Default for BaggingPredictor {
270    fn default() -> Self {
271        Self::new(AggregationStrategy::Mean)
272    }
273}
274
275/// Out-of-bag score estimator
276///
277/// Estimates model performance using out-of-bag samples from bootstrap sampling.
278#[derive(Clone, Debug)]
279pub struct OOBScoreEstimator;
280
281impl OOBScoreEstimator {
282    /// Compute OOB score for regression
283    ///
284    /// # Arguments
285    /// * `y_true` - True target values
286    /// * `oob_predictions` - OOB predictions for each sample
287    ///
288    /// # Returns
289    /// R² score on OOB samples
290    pub fn oob_score_regression(
291        y_true: &Array1<f64>,
292        oob_predictions: &Array1<f64>,
293    ) -> UtilsResult<f64> {
294        if y_true.len() != oob_predictions.len() {
295            return Err(UtilsError::InvalidParameter(
296                "y_true and predictions must have same length".to_string(),
297            ));
298        }
299
300        if y_true.is_empty() {
301            return Err(UtilsError::InvalidParameter(
302                "Cannot compute score on empty array".to_string(),
303            ));
304        }
305
306        // Compute R² score
307        let y_mean = y_true.mean().unwrap();
308        let ss_tot: f64 = y_true.iter().map(|&y| (y - y_mean).powi(2)).sum();
309        let ss_res: f64 = y_true
310            .iter()
311            .zip(oob_predictions.iter())
312            .map(|(&y, &pred)| (y - pred).powi(2))
313            .sum();
314
315        if ss_tot <= 0.0 {
316            Ok(0.0)
317        } else {
318            Ok(1.0 - ss_res / ss_tot)
319        }
320    }
321
322    /// Compute OOB accuracy for classification
323    ///
324    /// # Arguments
325    /// * `y_true` - True class labels
326    /// * `oob_predictions` - OOB predicted class labels
327    ///
328    /// # Returns
329    /// Accuracy score on OOB samples
330    pub fn oob_accuracy(
331        y_true: &Array1<usize>,
332        oob_predictions: &Array1<usize>,
333    ) -> UtilsResult<f64> {
334        if y_true.len() != oob_predictions.len() {
335            return Err(UtilsError::InvalidParameter(
336                "y_true and predictions must have same length".to_string(),
337            ));
338        }
339
340        if y_true.is_empty() {
341            return Err(UtilsError::InvalidParameter(
342                "Cannot compute score on empty array".to_string(),
343            ));
344        }
345
346        let correct: usize = y_true
347            .iter()
348            .zip(oob_predictions.iter())
349            .filter(|(&y, &pred)| y == pred)
350            .count();
351
352        Ok(correct as f64 / y_true.len() as f64)
353    }
354}
355
356/// Stacking ensemble utilities
357#[derive(Clone, Debug)]
358pub struct StackingHelper;
359
360impl StackingHelper {
361    /// Generate cross-validated predictions for stacking
362    ///
363    /// # Arguments
364    /// * `n_samples` - Number of samples
365    /// * `n_folds` - Number of cross-validation folds
366    /// * `random_state` - Random seed
367    ///
368    /// # Returns
369    /// Vector of (train_indices, test_indices) for each fold
370    pub fn generate_cv_folds(
371        n_samples: usize,
372        n_folds: usize,
373        random_state: Option<u64>,
374    ) -> UtilsResult<Vec<(Vec<usize>, Vec<usize>)>> {
375        if n_folds < 2 {
376            return Err(UtilsError::InvalidParameter(
377                "n_folds must be at least 2".to_string(),
378            ));
379        }
380
381        if n_samples < n_folds {
382            return Err(UtilsError::InvalidParameter(
383                "n_samples must be >= n_folds".to_string(),
384            ));
385        }
386
387        // Create shuffled indices
388        let mut indices: Vec<usize> = (0..n_samples).collect();
389        let mut rng = random_state
390            .map(StdRng::seed_from_u64)
391            .unwrap_or_else(|| StdRng::seed_from_u64(42));
392
393        // Fisher-Yates shuffle
394        for i in (1..indices.len()).rev() {
395            let j = rng.gen_range(0..=i);
396            indices.swap(i, j);
397        }
398
399        // Distribute indices into folds
400        let fold_sizes = Self::compute_fold_sizes(n_samples, n_folds);
401        let mut folds = Vec::with_capacity(n_folds);
402        let mut start = 0;
403
404        for size in fold_sizes {
405            let test_indices = indices[start..start + size].to_vec();
406            let train_indices: Vec<usize> = indices
407                .iter()
408                .enumerate()
409                .filter(|(i, _)| *i < start || *i >= start + size)
410                .map(|(_, &idx)| idx)
411                .collect();
412
413            folds.push((train_indices, test_indices));
414            start += size;
415        }
416
417        Ok(folds)
418    }
419
420    fn compute_fold_sizes(n_samples: usize, n_folds: usize) -> Vec<usize> {
421        let base_size = n_samples / n_folds;
422        let remainder = n_samples % n_folds;
423
424        (0..n_folds)
425            .map(|i| {
426                if i < remainder {
427                    base_size + 1
428                } else {
429                    base_size
430                }
431            })
432            .collect()
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use approx::assert_abs_diff_eq;
440    use scirs2_core::ndarray::array;
441
442    #[test]
443    fn test_bootstrap_sample() {
444        let bootstrap = Bootstrap::new(Some(10), Some(42));
445        let (in_bag, out_of_bag) = bootstrap.sample(10).unwrap();
446
447        assert_eq!(in_bag.len(), 10);
448        assert!(out_of_bag.len() > 0); // Typically ~37% are OOB
449        assert!(out_of_bag.len() < 10);
450
451        // Check all indices are valid
452        for &idx in &in_bag {
453            assert!(idx < 10);
454        }
455        for &idx in &out_of_bag {
456            assert!(idx < 10);
457        }
458    }
459
460    #[test]
461    fn test_bootstrap_multiple() {
462        let bootstrap = Bootstrap::new(None, Some(42));
463        let samples = bootstrap.sample_multiple(10, 5).unwrap();
464
465        assert_eq!(samples.len(), 5);
466
467        // Check each sample is valid
468        for (in_bag, out_of_bag) in &samples {
469            assert_eq!(in_bag.len(), 10);
470            assert!(out_of_bag.len() <= 10);
471        }
472    }
473
474    #[test]
475    fn test_bagging_mean_aggregation() {
476        let predictor = BaggingPredictor::new(AggregationStrategy::Mean);
477
478        // 3 samples, 4 estimators
479        let predictions = array![
480            [1.0, 2.0, 3.0, 4.0],
481            [2.0, 2.0, 2.0, 2.0],
482            [1.0, 3.0, 2.0, 4.0]
483        ];
484
485        let result = predictor.aggregate_regression(&predictions, None).unwrap();
486
487        assert_abs_diff_eq!(result[0], 2.5, epsilon = 1e-10);
488        assert_abs_diff_eq!(result[1], 2.0, epsilon = 1e-10);
489        assert_abs_diff_eq!(result[2], 2.5, epsilon = 1e-10);
490    }
491
492    #[test]
493    fn test_bagging_median_aggregation() {
494        let predictor = BaggingPredictor::new(AggregationStrategy::Median);
495
496        // Test with 4 estimators per sample
497        let predictions = array![
498            [1.0, 2.0, 100.0, 3.0], // Median = 2.5 (robust to outlier 100)
499            [1.0, 2.0, 3.0, 4.0]    // Median = 2.5
500        ];
501
502        let result = predictor.aggregate_regression(&predictions, None).unwrap();
503
504        assert_abs_diff_eq!(result[0], 2.5, epsilon = 1e-10);
505        assert_abs_diff_eq!(result[1], 2.5, epsilon = 1e-10);
506
507        // Test with 5 estimators (odd number)
508        let predictions2 = array![
509            [1.0, 2.0, 3.0, 4.0, 5.0],    // Median = 3.0
510            [10.0, 1.0, 2.0, 3.0, 100.0]  // Median = 3.0 (robust to outliers)
511        ];
512
513        let result2 = predictor.aggregate_regression(&predictions2, None).unwrap();
514        assert_abs_diff_eq!(result2[0], 3.0, epsilon = 1e-10);
515        assert_abs_diff_eq!(result2[1], 3.0, epsilon = 1e-10);
516    }
517
518    #[test]
519    fn test_bagging_weighted_mean() {
520        let predictor = BaggingPredictor::new(AggregationStrategy::WeightedMean);
521
522        let predictions = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
523        let weights = array![0.5, 0.3, 0.2]; // Sum = 1.0
524
525        let result = predictor
526            .aggregate_regression(&predictions, Some(&weights))
527            .unwrap();
528
529        // Sample 0: 1.0*0.5 + 2.0*0.3 + 3.0*0.2 = 0.5 + 0.6 + 0.6 = 1.7
530        assert_abs_diff_eq!(result[0], 1.7, epsilon = 1e-10);
531        // Sample 1: 4.0*0.5 + 5.0*0.3 + 6.0*0.2 = 2.0 + 1.5 + 1.2 = 4.7
532        assert_abs_diff_eq!(result[1], 4.7, epsilon = 1e-10);
533    }
534
535    #[test]
536    fn test_majority_vote() {
537        let predictor = BaggingPredictor::new(AggregationStrategy::MajorityVote);
538
539        let predictions = array![
540            [0, 0, 1, 0, 0], // Majority: 0 (4 votes)
541            [1, 1, 0, 1, 1], // Majority: 1 (4 votes)
542            [2, 2, 2, 0, 1]  // Majority: 2 (3 votes)
543        ];
544
545        let result = predictor.aggregate_classification(&predictions).unwrap();
546
547        assert_eq!(result[0], 0);
548        assert_eq!(result[1], 1);
549        assert_eq!(result[2], 2);
550    }
551
552    #[test]
553    fn test_aggregate_probabilities() {
554        let predictor = BaggingPredictor::default();
555
556        let probs1 = array![[0.8, 0.2], [0.3, 0.7]];
557        let probs2 = array![[0.6, 0.4], [0.4, 0.6]];
558
559        let result = predictor
560            .aggregate_probabilities(&[probs1, probs2])
561            .unwrap();
562
563        assert_abs_diff_eq!(result[[0, 0]], 0.7, epsilon = 1e-10);
564        assert_abs_diff_eq!(result[[0, 1]], 0.3, epsilon = 1e-10);
565        assert_abs_diff_eq!(result[[1, 0]], 0.35, epsilon = 1e-10);
566        assert_abs_diff_eq!(result[[1, 1]], 0.65, epsilon = 1e-10);
567    }
568
569    #[test]
570    fn test_oob_score_regression() {
571        let y_true = array![1.0, 2.0, 3.0, 4.0, 5.0];
572        let y_pred = array![1.1, 1.9, 3.1, 3.9, 5.1];
573
574        let score = OOBScoreEstimator::oob_score_regression(&y_true, &y_pred).unwrap();
575
576        // Should be close to 1.0 (perfect predictions)
577        assert!(score > 0.95);
578    }
579
580    #[test]
581    fn test_oob_accuracy() {
582        let y_true = array![0, 1, 2, 0, 1];
583        let y_pred = array![0, 1, 2, 0, 2]; // 4/5 correct
584
585        let accuracy = OOBScoreEstimator::oob_accuracy(&y_true, &y_pred).unwrap();
586
587        assert_abs_diff_eq!(accuracy, 0.8, epsilon = 1e-10);
588    }
589
590    #[test]
591    fn test_stacking_cv_folds() {
592        let folds = StackingHelper::generate_cv_folds(10, 3, Some(42)).unwrap();
593
594        assert_eq!(folds.len(), 3);
595
596        // Check all samples are covered exactly once in test sets
597        let mut all_test_indices: Vec<usize> = Vec::new();
598        for (train, test) in &folds {
599            assert!(train.len() > 0);
600            assert!(test.len() > 0);
601            assert_eq!(train.len() + test.len(), 10);
602            all_test_indices.extend(test);
603        }
604
605        all_test_indices.sort_unstable();
606        assert_eq!(all_test_indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
607    }
608}