Skip to main content

anofox_ml_ensemble/
voting_classifier.rs

1//! Voting classifier: combines predictions from multiple heterogeneous models.
2//!
3//! Supports hard voting (majority vote) and soft voting (average probabilities,
4//! requires predict_proba on fitted models — not yet implemented).
5
6use anofox_ml_core::{Fit, Float, Predict, Result, RustMlError};
7use ndarray::{Array1, Array2};
8use std::collections::HashMap;
9
10/// A named estimator for the voting ensemble.
11struct NamedEstimator<F: Float> {
12    name: String,
13    estimator: Box<dyn FitPredictClone<F>>,
14}
15
16/// Internal trait combining Fit + Send + Sync for trait objects.
17trait FitPredictClone<F: Float>: Send + Sync {
18    fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredictBox<F>>>;
19}
20
21trait PredictBox<F: Float>: Send + Sync {
22    fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>>;
23}
24
25/// Blanket impl: any Fit+Predict type can be a voting member.
26impl<F, T> FitPredictClone<F> for T
27where
28    F: Float,
29    T: Fit<F> + Send + Sync,
30    T::Fitted: Predict<F> + Send + Sync + 'static,
31{
32    fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredictBox<F>>> {
33        let fitted = Fit::fit(self, x, y)?;
34        Ok(Box::new(fitted))
35    }
36}
37
38impl<F, T> PredictBox<F> for T
39where
40    F: Float,
41    T: Predict<F> + Send + Sync,
42{
43    fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>> {
44        self.predict(x)
45    }
46}
47
48/// Voting classifier that combines multiple models via majority vote.
49pub struct VotingClassifier<F: Float> {
50    estimators: Vec<NamedEstimator<F>>,
51}
52
53impl<F: Float> VotingClassifier<F> {
54    /// Create a new empty VotingClassifier.
55    pub fn new() -> Self {
56        Self {
57            estimators: Vec::new(),
58        }
59    }
60
61    /// Add a named estimator to the ensemble.
62    pub fn push<T>(mut self, name: impl Into<String>, estimator: T) -> Self
63    where
64        T: Fit<F> + Send + Sync + 'static,
65        T::Fitted: Predict<F> + Send + Sync + 'static,
66    {
67        self.estimators.push(NamedEstimator {
68            name: name.into(),
69            estimator: Box::new(estimator),
70        });
71        self
72    }
73}
74
75/// Fitted voting classifier.
76pub struct FittedVotingClassifier<F: Float> {
77    fitted_models: Vec<(String, Box<dyn PredictBox<F>>)>,
78    n_features: usize,
79}
80
81impl<F: Float> FittedVotingClassifier<F> {
82    /// Return the names of the constituent estimators.
83    pub fn estimator_names(&self) -> Vec<&str> {
84        self.fitted_models.iter().map(|(n, _)| n.as_str()).collect()
85    }
86
87    /// Compute classification accuracy on the given data.
88    pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<f64> {
89        let preds = self.predict(x)?;
90        let n = y.len();
91        let correct = preds
92            .iter()
93            .zip(y.iter())
94            .filter(|(&p, &t)| (p - t).abs() < F::from_f64(1e-9).unwrap())
95            .count();
96        Ok(correct as f64 / n as f64)
97    }
98}
99
100impl<F: Float + 'static> Fit<F> for VotingClassifier<F> {
101    type Fitted = FittedVotingClassifier<F>;
102
103    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
104        if self.estimators.is_empty() {
105            return Err(RustMlError::InvalidParameter(
106                "VotingClassifier needs at least one estimator".into(),
107            ));
108        }
109        if x.nrows() != y.len() {
110            return Err(RustMlError::ShapeMismatch(format!(
111                "X has {} rows but y has {} elements",
112                x.nrows(),
113                y.len()
114            )));
115        }
116        if x.is_empty() {
117            return Err(RustMlError::EmptyInput("training data is empty".into()));
118        }
119
120        let mut fitted_models = Vec::with_capacity(self.estimators.len());
121        for est in &self.estimators {
122            let fitted = est.estimator.fit_box(x, y)?;
123            fitted_models.push((est.name.clone(), fitted));
124        }
125
126        Ok(FittedVotingClassifier {
127            fitted_models,
128            n_features: x.ncols(),
129        })
130    }
131}
132
133impl<F: Float> Predict<F> for FittedVotingClassifier<F> {
134    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
135        if x.ncols() != self.n_features {
136            return Err(RustMlError::ShapeMismatch(format!(
137                "expected {} features, got {}",
138                self.n_features,
139                x.ncols()
140            )));
141        }
142
143        let n = x.nrows();
144        let all_preds: Vec<Array1<F>> = self
145            .fitted_models
146            .iter()
147            .map(|(_, model)| model.predict_box(x))
148            .collect::<Result<Vec<_>>>()?;
149
150        let mut result = Array1::zeros(n);
151        for i in 0..n {
152            let mut votes: HashMap<u64, (F, usize)> = HashMap::new();
153            for preds in &all_preds {
154                let key = preds[i].to_f64().unwrap().to_bits();
155                votes
156                    .entry(key)
157                    .and_modify(|e| e.1 += 1)
158                    .or_insert((preds[i], 1));
159            }
160            result[i] = votes
161                .into_values()
162                .max_by_key(|&(_, count)| count)
163                .unwrap()
164                .0;
165        }
166
167        Ok(result)
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use anofox_ml_trees::DecisionTreeClassifier;
175    use ndarray::array;
176
177    #[test]
178    fn test_voting_classifier_basic() {
179        let x = array![
180            [1.0, 0.0],
181            [2.0, 0.0],
182            [3.0, 0.0],
183            [10.0, 1.0],
184            [11.0, 1.0],
185            [12.0, 1.0]
186        ];
187        let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
188
189        let vc = VotingClassifier::new()
190            .push(
191                "tree1",
192                DecisionTreeClassifier {
193                    max_depth: Some(3),
194                    ..Default::default()
195                },
196            )
197            .push(
198                "tree2",
199                DecisionTreeClassifier {
200                    max_depth: Some(2),
201                    ..Default::default()
202                },
203            )
204            .push(
205                "tree3",
206                DecisionTreeClassifier {
207                    max_depth: Some(5),
208                    ..Default::default()
209                },
210            );
211
212        let fitted: FittedVotingClassifier<f64> = vc.fit(&x, &y).unwrap();
213        let preds = fitted.predict(&x).unwrap();
214
215        for (p, t) in preds.iter().zip(y.iter()) {
216            assert!((p - t).abs() < 1e-10);
217        }
218    }
219
220    #[test]
221    fn test_voting_classifier_names() {
222        let x = array![[1.0], [2.0], [3.0], [4.0]];
223        let y = array![0.0, 0.0, 1.0, 1.0];
224
225        let vc = VotingClassifier::new()
226            .push("a", DecisionTreeClassifier::default())
227            .push("b", DecisionTreeClassifier::default());
228
229        let fitted: FittedVotingClassifier<f64> = vc.fit(&x, &y).unwrap();
230        assert_eq!(fitted.estimator_names(), vec!["a", "b"]);
231    }
232
233    #[test]
234    fn test_voting_classifier_score() {
235        let x = array![[1.0, 0.0], [2.0, 0.0], [10.0, 1.0], [11.0, 1.0]];
236        let y = array![0.0, 0.0, 1.0, 1.0];
237
238        let vc = VotingClassifier::new()
239            .push("t1", DecisionTreeClassifier::default())
240            .push("t2", DecisionTreeClassifier::default());
241
242        let fitted: FittedVotingClassifier<f64> = vc.fit(&x, &y).unwrap();
243        let acc = fitted.score(&x, &y).unwrap();
244        assert!(acc >= 0.5);
245    }
246
247    #[test]
248    fn test_voting_classifier_empty_error() {
249        let x = array![[1.0], [2.0]];
250        let y = array![0.0, 1.0];
251        let vc = VotingClassifier::<f64>::new();
252        assert!(vc.fit(&x, &y).is_err());
253    }
254
255    #[test]
256    fn test_voting_classifier_shape_mismatch() {
257        let x = array![[1.0], [2.0]];
258        let y = array![0.0, 1.0, 2.0];
259        let vc = VotingClassifier::new().push("t", DecisionTreeClassifier::default());
260        assert!(Fit::<f64>::fit(&vc, &x, &y).is_err());
261    }
262}