1use anofox_ml_core::{Fit, Float, Predict, PredictProba, Result, RustMlError};
10use ndarray::{Array1, Array2};
11
12#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum StackMethod {
22 Predict,
23 PredictProba,
24}
25
26trait FitPredBox<F: Float>: Send + Sync {
27 fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>>;
28}
29
30trait FitProbaBox<F: Float>: Send + Sync {
31 fn fit_proba_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn ProbaBox<F>>>;
32}
33
34trait PredBox<F: Float>: Send + Sync {
35 fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>>;
36}
37
38trait ProbaBox<F: Float>: Send + Sync {
39 fn predict_proba_box(&self, x: &Array2<F>) -> Result<Array2<F>>;
40}
41
42impl<F, T> FitPredBox<F> for T
43where
44 F: Float,
45 T: Fit<F> + Send + Sync,
46 T::Fitted: Predict<F> + Send + Sync + 'static,
47{
48 fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>> {
49 let fitted = Fit::fit(self, x, y)?;
50 Ok(Box::new(fitted))
51 }
52}
53
54impl<F, T> PredBox<F> for T
55where
56 F: Float,
57 T: Predict<F> + Send + Sync,
58{
59 fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>> {
60 self.predict(x)
61 }
62}
63
64struct ProbaWrap<T>(T);
66
67impl<F, T> FitProbaBox<F> for ProbaWrap<T>
68where
69 F: Float,
70 T: Fit<F> + Send + Sync,
71 T::Fitted: Predict<F> + PredictProba<F> + Send + Sync + 'static,
72{
73 fn fit_proba_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn ProbaBox<F>>> {
74 let fitted = Fit::fit(&self.0, x, y)?;
75 Ok(Box::new(fitted))
76 }
77}
78
79impl<F, T> ProbaBox<F> for T
80where
81 F: Float,
82 T: PredictProba<F> + Send + Sync,
83{
84 fn predict_proba_box(&self, x: &Array2<F>) -> Result<Array2<F>> {
85 self.predict_proba(x)
86 }
87}
88
89pub struct StackingClassifier<F: Float> {
95 base_estimators: Vec<(String, BaseEstimator<F>)>,
96 meta_estimator: Box<dyn FitPredBox<F>>,
97 cv_folds: usize,
98}
99
100enum BaseEstimator<F: Float> {
101 Predict(Box<dyn FitPredBox<F>>),
102 PredictProba(Box<dyn FitProbaBox<F>>),
103}
104
105impl<F: Float> StackingClassifier<F> {
106 pub fn new<M>(meta_estimator: M) -> Self
107 where
108 M: Fit<F> + Send + Sync + 'static,
109 M::Fitted: Predict<F> + Send + Sync + 'static,
110 {
111 Self {
112 base_estimators: Vec::new(),
113 meta_estimator: Box::new(meta_estimator),
114 cv_folds: 5,
115 }
116 }
117
118 pub fn push<T>(mut self, name: impl Into<String>, estimator: T) -> Self
120 where
121 T: Fit<F> + Send + Sync + 'static,
122 T::Fitted: Predict<F> + Send + Sync + 'static,
123 {
124 self.base_estimators
125 .push((name.into(), BaseEstimator::Predict(Box::new(estimator))));
126 self
127 }
128
129 pub fn push_proba<T>(mut self, name: impl Into<String>, estimator: T) -> Self
132 where
133 T: Fit<F> + Send + Sync + 'static,
134 T::Fitted: Predict<F> + PredictProba<F> + Send + Sync + 'static,
135 {
136 self.base_estimators.push((
137 name.into(),
138 BaseEstimator::PredictProba(Box::new(ProbaWrap(estimator))),
139 ));
140 self
141 }
142
143 pub fn with_cv_folds(mut self, k: usize) -> Self {
144 self.cv_folds = k;
145 self
146 }
147}
148
149pub struct FittedStackingClassifier<F: Float> {
150 fitted_base: Vec<(String, FittedBase<F>)>,
151 fitted_meta: Box<dyn PredBox<F>>,
152 n_features: usize,
153}
154
155enum FittedBase<F: Float> {
156 Predict(Box<dyn PredBox<F>>),
157 PredictProba(Box<dyn ProbaBox<F>>),
158}
159
160impl<F: Float> FittedStackingClassifier<F> {
161 pub fn estimator_names(&self) -> Vec<&str> {
162 self.fitted_base.iter().map(|(n, _)| n.as_str()).collect()
163 }
164}
165
166impl<F: Float + 'static> Fit<F> for StackingClassifier<F> {
167 type Fitted = FittedStackingClassifier<F>;
168
169 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
170 if self.base_estimators.is_empty() {
171 return Err(RustMlError::InvalidParameter(
172 "StackingClassifier needs at least one base estimator".into(),
173 ));
174 }
175 if x.nrows() != y.len() {
176 return Err(RustMlError::ShapeMismatch(format!(
177 "X has {} rows but y has {} elements",
178 x.nrows(),
179 y.len()
180 )));
181 }
182 let n = x.nrows();
183 if n < 2 {
184 return Err(RustMlError::EmptyInput("need at least 2 samples".into()));
185 }
186
187 let k = self.cv_folds.min(n);
188 let folds = simple_k_fold(n, k);
189
190 let mut meta_cols: Vec<Array1<F>> = Vec::new();
199 for (_name, est) in self.base_estimators.iter() {
200 match est {
201 BaseEstimator::Predict(b) => {
202 let mut col = Array1::<F>::zeros(n);
203 for (train_idx, test_idx) in &folds {
204 let x_train = select_rows(x, train_idx);
205 let y_train = select_elements(y, train_idx);
206 let x_test = select_rows(x, test_idx);
207 let fitted = b.fit_box(&x_train, &y_train)?;
208 let preds = fitted.predict_box(&x_test)?;
209 for (li, &gi) in test_idx.iter().enumerate() {
210 col[gi] = preds[li];
211 }
212 }
213 meta_cols.push(col);
214 }
215 BaseEstimator::PredictProba(b) => {
216 let mut buf: Option<Array2<F>> = None;
219 for (train_idx, test_idx) in &folds {
220 let x_train = select_rows(x, train_idx);
221 let y_train = select_elements(y, train_idx);
222 let x_test = select_rows(x, test_idx);
223 let fitted = b.fit_proba_box(&x_train, &y_train)?;
224 let probs = fitted.predict_proba_box(&x_test)?;
225 let nc = probs.ncols();
226 let bufm = buf.get_or_insert_with(|| Array2::<F>::zeros((n, nc)));
227 for (li, &gi) in test_idx.iter().enumerate() {
228 for c in 0..nc {
229 bufm[[gi, c]] = probs[[li, c]];
230 }
231 }
232 }
233 if let Some(bufm) = buf {
234 for c in 0..bufm.ncols() {
235 meta_cols.push(bufm.column(c).to_owned());
236 }
237 }
238 }
239 }
240 }
241
242 let n_meta = meta_cols.len();
243 let mut meta_features = Array2::<F>::zeros((n, n_meta));
244 for (c, col) in meta_cols.iter().enumerate() {
245 for i in 0..n {
246 meta_features[[i, c]] = col[i];
247 }
248 }
249
250 let fitted_meta = self.meta_estimator.fit_box(&meta_features, y)?;
251
252 let mut fitted_base = Vec::with_capacity(self.base_estimators.len());
253 for (name, est) in &self.base_estimators {
254 let f = match est {
255 BaseEstimator::Predict(b) => FittedBase::Predict(b.fit_box(x, y)?),
256 BaseEstimator::PredictProba(b) => FittedBase::PredictProba(b.fit_proba_box(x, y)?),
257 };
258 fitted_base.push((name.clone(), f));
259 }
260
261 Ok(FittedStackingClassifier {
262 fitted_base,
263 fitted_meta,
264 n_features: x.ncols(),
265 })
266 }
267}
268
269impl<F: Float> Predict<F> for FittedStackingClassifier<F> {
270 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
271 if x.ncols() != self.n_features {
272 return Err(RustMlError::ShapeMismatch(format!(
273 "expected {} features, got {}",
274 self.n_features,
275 x.ncols()
276 )));
277 }
278
279 let n = x.nrows();
280 let mut meta_cols: Vec<Array1<F>> = Vec::new();
281 for (_name, m) in &self.fitted_base {
282 match m {
283 FittedBase::Predict(p) => {
284 meta_cols.push(p.predict_box(x)?);
285 }
286 FittedBase::PredictProba(p) => {
287 let probs = p.predict_proba_box(x)?;
288 for c in 0..probs.ncols() {
289 meta_cols.push(probs.column(c).to_owned());
290 }
291 }
292 }
293 }
294 let mut meta_features = Array2::<F>::zeros((n, meta_cols.len()));
295 for (c, col) in meta_cols.iter().enumerate() {
296 for i in 0..n {
297 meta_features[[i, c]] = col[i];
298 }
299 }
300 self.fitted_meta.predict_box(&meta_features)
301 }
302}
303
304fn simple_k_fold(n: usize, k: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
305 let fold_size = n / k;
306 let remainder = n % k;
307 let mut folds = Vec::with_capacity(k);
308 let mut start = 0;
309 for f in 0..k {
310 let end = start + fold_size + if f < remainder { 1 } else { 0 };
311 let test: Vec<usize> = (start..end).collect();
312 let train: Vec<usize> = (0..start).chain(end..n).collect();
313 folds.push((train, test));
314 start = end;
315 }
316 folds
317}
318
319fn select_rows<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
320 let ncols = x.ncols();
321 let mut data = Vec::with_capacity(indices.len() * ncols);
322 for &i in indices {
323 for j in 0..ncols {
324 data.push(x[[i, j]]);
325 }
326 }
327 Array2::from_shape_vec((indices.len(), ncols), data).unwrap()
328}
329
330fn select_elements<F: Float>(y: &Array1<F>, indices: &[usize]) -> Array1<F> {
331 Array1::from_vec(indices.iter().map(|&i| y[i]).collect())
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use anofox_ml_trees::DecisionTreeClassifier;
338 use ndarray::array;
339
340 #[test]
341 fn test_stacking_classifier_basic() {
342 let x = array![
345 [0.0, 0.0],
346 [5.0, 5.0],
347 [0.1, 0.1],
348 [5.1, 5.0],
349 [0.2, -0.1],
350 [4.9, 5.1],
351 [-0.1, 0.2],
352 [5.2, 4.8],
353 ];
354 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
355
356 let sc = StackingClassifier::new(DecisionTreeClassifier::default())
357 .push(
358 "t1",
359 DecisionTreeClassifier {
360 max_depth: Some(2),
361 ..Default::default()
362 },
363 )
364 .push(
365 "t2",
366 DecisionTreeClassifier {
367 max_depth: Some(3),
368 ..Default::default()
369 },
370 )
371 .with_cv_folds(2);
372
373 let fitted: FittedStackingClassifier<f64> = sc.fit(&x, &y).unwrap();
374 let preds = fitted.predict(&x).unwrap();
375 for (p, t) in preds.iter().zip(y.iter()) {
376 assert_eq!(*p, *t, "p={p}, t={t}");
377 }
378 }
379
380 #[test]
381 fn test_stacking_classifier_proba_path() {
382 let x = array![
384 [0.0, 0.0],
385 [5.0, 5.0],
386 [0.1, 0.1],
387 [5.1, 5.0],
388 [0.2, -0.1],
389 [4.9, 5.1],
390 [-0.1, 0.2],
391 [5.2, 4.8],
392 ];
393 let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
394
395 let sc = StackingClassifier::new(DecisionTreeClassifier::default())
396 .push_proba(
397 "t1",
398 DecisionTreeClassifier {
399 max_depth: Some(2),
400 ..Default::default()
401 },
402 )
403 .push_proba(
404 "t2",
405 DecisionTreeClassifier {
406 max_depth: Some(3),
407 ..Default::default()
408 },
409 )
410 .with_cv_folds(2);
411
412 let fitted: FittedStackingClassifier<f64> = sc.fit(&x, &y).unwrap();
413 let preds = fitted.predict(&x).unwrap();
414 for (p, t) in preds.iter().zip(y.iter()) {
415 assert_eq!(*p, *t, "p={p}, t={t}");
416 }
417 }
418
419 #[test]
420 fn test_stacking_classifier_empty_base_error() {
421 let x = array![[1.0], [2.0]];
422 let y = array![0.0, 1.0];
423
424 let sc = StackingClassifier::<f64>::new(DecisionTreeClassifier::default());
425 assert!(sc.fit(&x, &y).is_err());
426 }
427}