rustkernel_ml/
ensemble.rs

1//! Ensemble method kernels.
2//!
3//! This module provides ensemble methods:
4//! - Weighted majority voting
5//! - Soft voting (probability averaging)
6
7use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
8use std::collections::HashMap;
9
10// ============================================================================
11// Ensemble Voting Kernel
12// ============================================================================
13
14/// Voting strategy for ensemble.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum VotingStrategy {
17    /// Hard voting: majority class wins.
18    #[default]
19    Hard,
20    /// Soft voting: average probabilities.
21    Soft,
22}
23
24/// Ensemble voting kernel.
25///
26/// Combines predictions from multiple classifiers using
27/// majority voting (hard) or probability averaging (soft).
28#[derive(Debug, Clone)]
29pub struct EnsembleVoting {
30    metadata: KernelMetadata,
31}
32
33impl Default for EnsembleVoting {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl EnsembleVoting {
40    /// Create a new ensemble voting kernel.
41    #[must_use]
42    pub fn new() -> Self {
43        Self {
44            metadata: KernelMetadata::batch("ml/ensemble-voting", Domain::StatisticalML)
45                .with_description("Weighted majority voting ensemble")
46                .with_throughput(100_000)
47                .with_latency_us(10.0),
48        }
49    }
50
51    /// Compute hard voting (majority vote) for classification.
52    ///
53    /// # Arguments
54    /// * `predictions` - Matrix of predictions (n_classifiers x n_samples)
55    /// * `weights` - Optional classifier weights (defaults to equal)
56    pub fn hard_vote(predictions: &[Vec<i32>], weights: Option<&[f64]>) -> Vec<i32> {
57        if predictions.is_empty() || predictions[0].is_empty() {
58            return Vec::new();
59        }
60
61        let n_classifiers = predictions.len();
62        let n_samples = predictions[0].len();
63
64        // Default to equal weights
65        let default_weights: Vec<f64> = vec![1.0 / n_classifiers as f64; n_classifiers];
66        let weights = weights.unwrap_or(&default_weights);
67
68        (0..n_samples)
69            .map(|i| {
70                let mut class_weights: HashMap<i32, f64> = HashMap::new();
71
72                for (j, pred) in predictions.iter().enumerate() {
73                    let class = pred[i];
74                    *class_weights.entry(class).or_insert(0.0) += weights[j];
75                }
76
77                *class_weights
78                    .iter()
79                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
80                    .map(|(class, _)| class)
81                    .unwrap_or(&0)
82            })
83            .collect()
84    }
85
86    /// Compute soft voting (probability averaging) for classification.
87    ///
88    /// # Arguments
89    /// * `probabilities` - 3D matrix: (n_classifiers, n_samples, n_classes)
90    ///   Outer vec: classifiers, middle vec: samples, inner vec: class probabilities
91    /// * `weights` - Optional classifier weights (defaults to equal)
92    pub fn soft_vote(probabilities: &[Vec<Vec<f64>>], weights: Option<&[f64]>) -> Vec<usize> {
93        if probabilities.is_empty() || probabilities[0].is_empty() {
94            return Vec::new();
95        }
96
97        let n_classifiers = probabilities.len();
98        let n_samples = probabilities[0].len();
99        let n_classes = probabilities[0][0].len();
100
101        // Default to equal weights
102        let default_weights: Vec<f64> = vec![1.0 / n_classifiers as f64; n_classifiers];
103        let weights = weights.unwrap_or(&default_weights);
104
105        (0..n_samples)
106            .map(|sample_idx| {
107                // Average probabilities across classifiers
108                let mut avg_probs = vec![0.0f64; n_classes];
109
110                for (classifier_idx, probs) in probabilities.iter().enumerate() {
111                    for (class_idx, &prob) in probs[sample_idx].iter().enumerate() {
112                        avg_probs[class_idx] += weights[classifier_idx] * prob;
113                    }
114                }
115
116                // Return class with highest average probability
117                avg_probs
118                    .iter()
119                    .enumerate()
120                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
121                    .map(|(idx, _)| idx)
122                    .unwrap_or(0)
123            })
124            .collect()
125    }
126
127    /// Compute weighted average for regression ensemble.
128    ///
129    /// # Arguments
130    /// * `predictions` - Matrix of predictions (n_regressors x n_samples)
131    /// * `weights` - Optional regressor weights (defaults to equal)
132    pub fn weighted_average(predictions: &[Vec<f64>], weights: Option<&[f64]>) -> Vec<f64> {
133        if predictions.is_empty() || predictions[0].is_empty() {
134            return Vec::new();
135        }
136
137        let n_regressors = predictions.len();
138        let n_samples = predictions[0].len();
139
140        // Default to equal weights
141        let default_weights: Vec<f64> = vec![1.0; n_regressors];
142        let weights = weights.unwrap_or(&default_weights);
143        let weight_sum: f64 = weights.iter().sum();
144
145        (0..n_samples)
146            .map(|i| {
147                let weighted_sum: f64 = predictions
148                    .iter()
149                    .zip(weights.iter())
150                    .map(|(preds, &w)| preds[i] * w)
151                    .sum();
152                weighted_sum / weight_sum
153            })
154            .collect()
155    }
156
157    /// Compute median for robust regression ensemble.
158    ///
159    /// # Arguments
160    /// * `predictions` - Matrix of predictions (n_regressors x n_samples)
161    pub fn median_prediction(predictions: &[Vec<f64>]) -> Vec<f64> {
162        if predictions.is_empty() || predictions[0].is_empty() {
163            return Vec::new();
164        }
165
166        let n_samples = predictions[0].len();
167
168        (0..n_samples)
169            .map(|i| {
170                let mut values: Vec<f64> = predictions.iter().map(|p| p[i]).collect();
171                values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
172
173                let n = values.len();
174                if n % 2 == 0 {
175                    (values[n / 2 - 1] + values[n / 2]) / 2.0
176                } else {
177                    values[n / 2]
178                }
179            })
180            .collect()
181    }
182}
183
184impl GpuKernel for EnsembleVoting {
185    fn metadata(&self) -> &KernelMetadata {
186        &self.metadata
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_ensemble_voting_metadata() {
196        let kernel = EnsembleVoting::new();
197        assert_eq!(kernel.metadata().id, "ml/ensemble-voting");
198        assert_eq!(kernel.metadata().domain, Domain::StatisticalML);
199    }
200
201    #[test]
202    fn test_hard_vote() {
203        // 3 classifiers, 5 samples
204        let predictions = vec![
205            vec![0, 1, 0, 1, 0], // Classifier 1
206            vec![0, 0, 0, 1, 1], // Classifier 2
207            vec![1, 1, 0, 1, 0], // Classifier 3
208        ];
209
210        let result = EnsembleVoting::hard_vote(&predictions, None);
211
212        // Majority votes: 0, 1, 0, 1, 0
213        assert_eq!(result[0], 0); // 2 votes for 0, 1 vote for 1
214        assert_eq!(result[1], 1); // 1 vote for 0, 2 votes for 1
215        assert_eq!(result[2], 0); // 3 votes for 0
216        assert_eq!(result[3], 1); // 3 votes for 1
217        // result[4] is a tie (0: 2, 1: 1), so 0 wins
218    }
219
220    #[test]
221    fn test_hard_vote_weighted() {
222        let predictions = vec![vec![0, 0, 0], vec![1, 1, 1]];
223
224        // Give second classifier higher weight
225        let weights = vec![0.3, 0.7];
226        let result = EnsembleVoting::hard_vote(&predictions, Some(&weights));
227
228        // Class 1 should win due to higher weight
229        assert_eq!(result, vec![1, 1, 1]);
230    }
231
232    #[test]
233    fn test_soft_vote() {
234        // 2 classifiers, 3 samples, 2 classes
235        let probabilities = vec![
236            // Classifier 1
237            vec![
238                vec![0.9, 0.1], // Sample 1: strongly class 0
239                vec![0.4, 0.6], // Sample 2: slightly class 1
240                vec![0.5, 0.5], // Sample 3: tied
241            ],
242            // Classifier 2
243            vec![
244                vec![0.8, 0.2], // Sample 1: strongly class 0
245                vec![0.3, 0.7], // Sample 2: class 1
246                vec![0.2, 0.8], // Sample 3: class 1
247            ],
248        ];
249
250        let result = EnsembleVoting::soft_vote(&probabilities, None);
251
252        assert_eq!(result[0], 0); // Average: [0.85, 0.15] -> class 0
253        assert_eq!(result[1], 1); // Average: [0.35, 0.65] -> class 1
254        assert_eq!(result[2], 1); // Average: [0.35, 0.65] -> class 1
255    }
256
257    #[test]
258    fn test_weighted_average() {
259        let predictions = vec![
260            vec![1.0, 2.0, 3.0],
261            vec![2.0, 3.0, 4.0],
262            vec![3.0, 4.0, 5.0],
263        ];
264
265        let result = EnsembleVoting::weighted_average(&predictions, None);
266
267        // Equal weights: average = [2.0, 3.0, 4.0]
268        assert!((result[0] - 2.0).abs() < 0.01);
269        assert!((result[1] - 3.0).abs() < 0.01);
270        assert!((result[2] - 4.0).abs() < 0.01);
271    }
272
273    #[test]
274    fn test_median_prediction() {
275        let predictions = vec![
276            vec![1.0, 100.0, 3.0],
277            vec![2.0, 2.0, 4.0],
278            vec![3.0, 3.0, 5.0],
279        ];
280
281        let result = EnsembleVoting::median_prediction(&predictions);
282
283        // Median is robust to outliers
284        assert!((result[0] - 2.0).abs() < 0.01);
285        assert!((result[1] - 3.0).abs() < 0.01); // 100 is outlier
286        assert!((result[2] - 4.0).abs() < 0.01);
287    }
288
289    #[test]
290    fn test_empty_predictions() {
291        let empty: Vec<Vec<i32>> = vec![];
292        assert!(EnsembleVoting::hard_vote(&empty, None).is_empty());
293
294        let empty_probs: Vec<Vec<Vec<f64>>> = vec![];
295        assert!(EnsembleVoting::soft_vote(&empty_probs, None).is_empty());
296
297        let empty_reg: Vec<Vec<f64>> = vec![];
298        assert!(EnsembleVoting::weighted_average(&empty_reg, None).is_empty());
299    }
300}