1use anofox_ml_core::{Fit, Float, Predict, Result, RustMlError};
7use ndarray::{Array1, Array2};
8use std::collections::HashMap;
9
10struct NamedEstimator<F: Float> {
12 name: String,
13 estimator: Box<dyn FitPredictClone<F>>,
14}
15
16trait FitPredictClone<F: Float>: Send + Sync {
18 fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredictBox<F>>>;
19}
20
21trait PredictBox<F: Float>: Send + Sync {
22 fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>>;
23}
24
25impl<F, T> FitPredictClone<F> for T
27where
28 F: Float,
29 T: Fit<F> + Send + Sync,
30 T::Fitted: Predict<F> + Send + Sync + 'static,
31{
32 fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredictBox<F>>> {
33 let fitted = Fit::fit(self, x, y)?;
34 Ok(Box::new(fitted))
35 }
36}
37
38impl<F, T> PredictBox<F> for T
39where
40 F: Float,
41 T: Predict<F> + Send + Sync,
42{
43 fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>> {
44 self.predict(x)
45 }
46}
47
48pub struct VotingClassifier<F: Float> {
50 estimators: Vec<NamedEstimator<F>>,
51}
52
53impl<F: Float> VotingClassifier<F> {
54 pub fn new() -> Self {
56 Self {
57 estimators: Vec::new(),
58 }
59 }
60
61 pub fn push<T>(mut self, name: impl Into<String>, estimator: T) -> Self
63 where
64 T: Fit<F> + Send + Sync + 'static,
65 T::Fitted: Predict<F> + Send + Sync + 'static,
66 {
67 self.estimators.push(NamedEstimator {
68 name: name.into(),
69 estimator: Box::new(estimator),
70 });
71 self
72 }
73}
74
75pub struct FittedVotingClassifier<F: Float> {
77 fitted_models: Vec<(String, Box<dyn PredictBox<F>>)>,
78 n_features: usize,
79}
80
81impl<F: Float> FittedVotingClassifier<F> {
82 pub fn estimator_names(&self) -> Vec<&str> {
84 self.fitted_models.iter().map(|(n, _)| n.as_str()).collect()
85 }
86
87 pub fn score(&self, x: &Array2<F>, y: &Array1<F>) -> Result<f64> {
89 let preds = self.predict(x)?;
90 let n = y.len();
91 let correct = preds
92 .iter()
93 .zip(y.iter())
94 .filter(|(&p, &t)| (p - t).abs() < F::from_f64(1e-9).unwrap())
95 .count();
96 Ok(correct as f64 / n as f64)
97 }
98}
99
100impl<F: Float + 'static> Fit<F> for VotingClassifier<F> {
101 type Fitted = FittedVotingClassifier<F>;
102
103 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
104 if self.estimators.is_empty() {
105 return Err(RustMlError::InvalidParameter(
106 "VotingClassifier needs at least one estimator".into(),
107 ));
108 }
109 if x.nrows() != y.len() {
110 return Err(RustMlError::ShapeMismatch(format!(
111 "X has {} rows but y has {} elements",
112 x.nrows(),
113 y.len()
114 )));
115 }
116 if x.is_empty() {
117 return Err(RustMlError::EmptyInput("training data is empty".into()));
118 }
119
120 let mut fitted_models = Vec::with_capacity(self.estimators.len());
121 for est in &self.estimators {
122 let fitted = est.estimator.fit_box(x, y)?;
123 fitted_models.push((est.name.clone(), fitted));
124 }
125
126 Ok(FittedVotingClassifier {
127 fitted_models,
128 n_features: x.ncols(),
129 })
130 }
131}
132
133impl<F: Float> Predict<F> for FittedVotingClassifier<F> {
134 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
135 if x.ncols() != self.n_features {
136 return Err(RustMlError::ShapeMismatch(format!(
137 "expected {} features, got {}",
138 self.n_features,
139 x.ncols()
140 )));
141 }
142
143 let n = x.nrows();
144 let all_preds: Vec<Array1<F>> = self
145 .fitted_models
146 .iter()
147 .map(|(_, model)| model.predict_box(x))
148 .collect::<Result<Vec<_>>>()?;
149
150 let mut result = Array1::zeros(n);
151 for i in 0..n {
152 let mut votes: HashMap<u64, (F, usize)> = HashMap::new();
153 for preds in &all_preds {
154 let key = preds[i].to_f64().unwrap().to_bits();
155 votes
156 .entry(key)
157 .and_modify(|e| e.1 += 1)
158 .or_insert((preds[i], 1));
159 }
160 result[i] = votes
161 .into_values()
162 .max_by_key(|&(_, count)| count)
163 .unwrap()
164 .0;
165 }
166
167 Ok(result)
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use anofox_ml_trees::DecisionTreeClassifier;
175 use ndarray::array;
176
177 #[test]
178 fn test_voting_classifier_basic() {
179 let x = array![
180 [1.0, 0.0],
181 [2.0, 0.0],
182 [3.0, 0.0],
183 [10.0, 1.0],
184 [11.0, 1.0],
185 [12.0, 1.0]
186 ];
187 let y = array![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
188
189 let vc = VotingClassifier::new()
190 .push(
191 "tree1",
192 DecisionTreeClassifier {
193 max_depth: Some(3),
194 ..Default::default()
195 },
196 )
197 .push(
198 "tree2",
199 DecisionTreeClassifier {
200 max_depth: Some(2),
201 ..Default::default()
202 },
203 )
204 .push(
205 "tree3",
206 DecisionTreeClassifier {
207 max_depth: Some(5),
208 ..Default::default()
209 },
210 );
211
212 let fitted: FittedVotingClassifier<f64> = vc.fit(&x, &y).unwrap();
213 let preds = fitted.predict(&x).unwrap();
214
215 for (p, t) in preds.iter().zip(y.iter()) {
216 assert!((p - t).abs() < 1e-10);
217 }
218 }
219
220 #[test]
221 fn test_voting_classifier_names() {
222 let x = array![[1.0], [2.0], [3.0], [4.0]];
223 let y = array![0.0, 0.0, 1.0, 1.0];
224
225 let vc = VotingClassifier::new()
226 .push("a", DecisionTreeClassifier::default())
227 .push("b", DecisionTreeClassifier::default());
228
229 let fitted: FittedVotingClassifier<f64> = vc.fit(&x, &y).unwrap();
230 assert_eq!(fitted.estimator_names(), vec!["a", "b"]);
231 }
232
233 #[test]
234 fn test_voting_classifier_score() {
235 let x = array![[1.0, 0.0], [2.0, 0.0], [10.0, 1.0], [11.0, 1.0]];
236 let y = array![0.0, 0.0, 1.0, 1.0];
237
238 let vc = VotingClassifier::new()
239 .push("t1", DecisionTreeClassifier::default())
240 .push("t2", DecisionTreeClassifier::default());
241
242 let fitted: FittedVotingClassifier<f64> = vc.fit(&x, &y).unwrap();
243 let acc = fitted.score(&x, &y).unwrap();
244 assert!(acc >= 0.5);
245 }
246
247 #[test]
248 fn test_voting_classifier_empty_error() {
249 let x = array![[1.0], [2.0]];
250 let y = array![0.0, 1.0];
251 let vc = VotingClassifier::<f64>::new();
252 assert!(vc.fit(&x, &y).is_err());
253 }
254
255 #[test]
256 fn test_voting_classifier_shape_mismatch() {
257 let x = array![[1.0], [2.0]];
258 let y = array![0.0, 1.0, 2.0];
259 let vc = VotingClassifier::new().push("t", DecisionTreeClassifier::default());
260 assert!(Fit::<f64>::fit(&vc, &x, &y).is_err());
261 }
262}