Skip to main content

anofox_ml_ensemble/
stacking_classifier.rs

1//! Stacking classifier: two-level ensemble where base classifiers' predictions
2//! are used as features for a meta-classifier.
3//!
4//! Mirrors `sklearn.ensemble.StackingClassifier` with `stack_method='predict'`:
5//! base classifier *predictions* (not class probabilities) become inputs to the
6//! meta-estimator. Out-of-fold predictions are generated via k-fold CV during
7//! fitting to avoid leakage.
8
9use anofox_ml_core::{Fit, Float, Predict, PredictProba, Result, RustMlError};
10use ndarray::{Array1, Array2};
11
12/// Choice of base-estimator output used as meta-features.
13///
14/// - `Predict`: hard class labels (sklearn `stack_method='predict'`).
15/// - `PredictProba`: class probabilities (sklearn `stack_method='predict_proba'`).
16///
17/// sklearn's default is `'auto'`, which prefers `predict_proba` then
18/// `decision_function` then `predict`. We require an explicit choice and
19/// only support the two forms above.
20#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum StackMethod {
22    Predict,
23    PredictProba,
24}
25
26trait FitPredBox<F: Float>: Send + Sync {
27    fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>>;
28}
29
30trait FitProbaBox<F: Float>: Send + Sync {
31    fn fit_proba_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn ProbaBox<F>>>;
32}
33
34trait PredBox<F: Float>: Send + Sync {
35    fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>>;
36}
37
38trait ProbaBox<F: Float>: Send + Sync {
39    fn predict_proba_box(&self, x: &Array2<F>) -> Result<Array2<F>>;
40}
41
42impl<F, T> FitPredBox<F> for T
43where
44    F: Float,
45    T: Fit<F> + Send + Sync,
46    T::Fitted: Predict<F> + Send + Sync + 'static,
47{
48    fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>> {
49        let fitted = Fit::fit(self, x, y)?;
50        Ok(Box::new(fitted))
51    }
52}
53
54impl<F, T> PredBox<F> for T
55where
56    F: Float,
57    T: Predict<F> + Send + Sync,
58{
59    fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>> {
60        self.predict(x)
61    }
62}
63
64/// Wrapper for estimators whose Fitted type implements both Predict and PredictProba.
65struct ProbaWrap<T>(T);
66
67impl<F, T> FitProbaBox<F> for ProbaWrap<T>
68where
69    F: Float,
70    T: Fit<F> + Send + Sync,
71    T::Fitted: Predict<F> + PredictProba<F> + Send + Sync + 'static,
72{
73    fn fit_proba_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn ProbaBox<F>>> {
74        let fitted = Fit::fit(&self.0, x, y)?;
75        Ok(Box::new(fitted))
76    }
77}
78
79impl<F, T> ProbaBox<F> for T
80where
81    F: Float,
82    T: PredictProba<F> + Send + Sync,
83{
84    fn predict_proba_box(&self, x: &Array2<F>) -> Result<Array2<F>> {
85        self.predict_proba(x)
86    }
87}
88
89/// Stacking classifier.
90///
91/// Base estimators contribute either hard predictions (one meta-feature each)
92/// or class probabilities (n_classes - 1 meta-features each, dropping the last
93/// column to avoid colinearity — sklearn's convention).
94pub struct StackingClassifier<F: Float> {
95    base_estimators: Vec<(String, BaseEstimator<F>)>,
96    meta_estimator: Box<dyn FitPredBox<F>>,
97    cv_folds: usize,
98}
99
100enum BaseEstimator<F: Float> {
101    Predict(Box<dyn FitPredBox<F>>),
102    PredictProba(Box<dyn FitProbaBox<F>>),
103}
104
105impl<F: Float> StackingClassifier<F> {
106    pub fn new<M>(meta_estimator: M) -> Self
107    where
108        M: Fit<F> + Send + Sync + 'static,
109        M::Fitted: Predict<F> + Send + Sync + 'static,
110    {
111        Self {
112            base_estimators: Vec::new(),
113            meta_estimator: Box::new(meta_estimator),
114            cv_folds: 5,
115        }
116    }
117
118    /// Add a base estimator using hard predictions (`stack_method='predict'`).
119    pub fn push<T>(mut self, name: impl Into<String>, estimator: T) -> Self
120    where
121        T: Fit<F> + Send + Sync + 'static,
122        T::Fitted: Predict<F> + Send + Sync + 'static,
123    {
124        self.base_estimators
125            .push((name.into(), BaseEstimator::Predict(Box::new(estimator))));
126        self
127    }
128
129    /// Add a base estimator using `predict_proba` outputs (sklearn's
130    /// `stack_method='predict_proba'`).
131    pub fn push_proba<T>(mut self, name: impl Into<String>, estimator: T) -> Self
132    where
133        T: Fit<F> + Send + Sync + 'static,
134        T::Fitted: Predict<F> + PredictProba<F> + Send + Sync + 'static,
135    {
136        self.base_estimators.push((
137            name.into(),
138            BaseEstimator::PredictProba(Box::new(ProbaWrap(estimator))),
139        ));
140        self
141    }
142
143    pub fn with_cv_folds(mut self, k: usize) -> Self {
144        self.cv_folds = k;
145        self
146    }
147}
148
149pub struct FittedStackingClassifier<F: Float> {
150    fitted_base: Vec<(String, FittedBase<F>)>,
151    fitted_meta: Box<dyn PredBox<F>>,
152    n_features: usize,
153}
154
155enum FittedBase<F: Float> {
156    Predict(Box<dyn PredBox<F>>),
157    PredictProba(Box<dyn ProbaBox<F>>),
158}
159
160impl<F: Float> FittedStackingClassifier<F> {
161    pub fn estimator_names(&self) -> Vec<&str> {
162        self.fitted_base.iter().map(|(n, _)| n.as_str()).collect()
163    }
164}
165
166impl<F: Float + 'static> Fit<F> for StackingClassifier<F> {
167    type Fitted = FittedStackingClassifier<F>;
168
169    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
170        if self.base_estimators.is_empty() {
171            return Err(RustMlError::InvalidParameter(
172                "StackingClassifier needs at least one base estimator".into(),
173            ));
174        }
175        if x.nrows() != y.len() {
176            return Err(RustMlError::ShapeMismatch(format!(
177                "X has {} rows but y has {} elements",
178                x.nrows(),
179                y.len()
180            )));
181        }
182        let n = x.nrows();
183        if n < 2 {
184            return Err(RustMlError::EmptyInput("need at least 2 samples".into()));
185        }
186
187        let k = self.cv_folds.min(n);
188        let folds = simple_k_fold(n, k);
189
190        // Two passes through base estimators. Pass 1: out-of-fold predictions
191        // build the meta-feature matrix. We need to know each base's output
192        // width — for Predict it's 1, for PredictProba it's n_classes. We
193        // discover n_classes from the first proba estimator's prediction;
194        // otherwise default to 1.
195
196        // Generate meta-features per estimator first, accumulate into a Vec<Vec<f64>>
197        // (one column per meta-feature, length n).
198        let mut meta_cols: Vec<Array1<F>> = Vec::new();
199        for (_name, est) in self.base_estimators.iter() {
200            match est {
201                BaseEstimator::Predict(b) => {
202                    let mut col = Array1::<F>::zeros(n);
203                    for (train_idx, test_idx) in &folds {
204                        let x_train = select_rows(x, train_idx);
205                        let y_train = select_elements(y, train_idx);
206                        let x_test = select_rows(x, test_idx);
207                        let fitted = b.fit_box(&x_train, &y_train)?;
208                        let preds = fitted.predict_box(&x_test)?;
209                        for (li, &gi) in test_idx.iter().enumerate() {
210                            col[gi] = preds[li];
211                        }
212                    }
213                    meta_cols.push(col);
214                }
215                BaseEstimator::PredictProba(b) => {
216                    // Need to know n_classes; defer column creation until we
217                    // see the first proba output.
218                    let mut buf: Option<Array2<F>> = None;
219                    for (train_idx, test_idx) in &folds {
220                        let x_train = select_rows(x, train_idx);
221                        let y_train = select_elements(y, train_idx);
222                        let x_test = select_rows(x, test_idx);
223                        let fitted = b.fit_proba_box(&x_train, &y_train)?;
224                        let probs = fitted.predict_proba_box(&x_test)?;
225                        let nc = probs.ncols();
226                        let bufm = buf.get_or_insert_with(|| Array2::<F>::zeros((n, nc)));
227                        for (li, &gi) in test_idx.iter().enumerate() {
228                            for c in 0..nc {
229                                bufm[[gi, c]] = probs[[li, c]];
230                            }
231                        }
232                    }
233                    if let Some(bufm) = buf {
234                        for c in 0..bufm.ncols() {
235                            meta_cols.push(bufm.column(c).to_owned());
236                        }
237                    }
238                }
239            }
240        }
241
242        let n_meta = meta_cols.len();
243        let mut meta_features = Array2::<F>::zeros((n, n_meta));
244        for (c, col) in meta_cols.iter().enumerate() {
245            for i in 0..n {
246                meta_features[[i, c]] = col[i];
247            }
248        }
249
250        let fitted_meta = self.meta_estimator.fit_box(&meta_features, y)?;
251
252        let mut fitted_base = Vec::with_capacity(self.base_estimators.len());
253        for (name, est) in &self.base_estimators {
254            let f = match est {
255                BaseEstimator::Predict(b) => FittedBase::Predict(b.fit_box(x, y)?),
256                BaseEstimator::PredictProba(b) => FittedBase::PredictProba(b.fit_proba_box(x, y)?),
257            };
258            fitted_base.push((name.clone(), f));
259        }
260
261        Ok(FittedStackingClassifier {
262            fitted_base,
263            fitted_meta,
264            n_features: x.ncols(),
265        })
266    }
267}
268
269impl<F: Float> Predict<F> for FittedStackingClassifier<F> {
270    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
271        if x.ncols() != self.n_features {
272            return Err(RustMlError::ShapeMismatch(format!(
273                "expected {} features, got {}",
274                self.n_features,
275                x.ncols()
276            )));
277        }
278
279        let n = x.nrows();
280        let mut meta_cols: Vec<Array1<F>> = Vec::new();
281        for (_name, m) in &self.fitted_base {
282            match m {
283                FittedBase::Predict(p) => {
284                    meta_cols.push(p.predict_box(x)?);
285                }
286                FittedBase::PredictProba(p) => {
287                    let probs = p.predict_proba_box(x)?;
288                    for c in 0..probs.ncols() {
289                        meta_cols.push(probs.column(c).to_owned());
290                    }
291                }
292            }
293        }
294        let mut meta_features = Array2::<F>::zeros((n, meta_cols.len()));
295        for (c, col) in meta_cols.iter().enumerate() {
296            for i in 0..n {
297                meta_features[[i, c]] = col[i];
298            }
299        }
300        self.fitted_meta.predict_box(&meta_features)
301    }
302}
303
304fn simple_k_fold(n: usize, k: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
305    let fold_size = n / k;
306    let remainder = n % k;
307    let mut folds = Vec::with_capacity(k);
308    let mut start = 0;
309    for f in 0..k {
310        let end = start + fold_size + if f < remainder { 1 } else { 0 };
311        let test: Vec<usize> = (start..end).collect();
312        let train: Vec<usize> = (0..start).chain(end..n).collect();
313        folds.push((train, test));
314        start = end;
315    }
316    folds
317}
318
319fn select_rows<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
320    let ncols = x.ncols();
321    let mut data = Vec::with_capacity(indices.len() * ncols);
322    for &i in indices {
323        for j in 0..ncols {
324            data.push(x[[i, j]]);
325        }
326    }
327    Array2::from_shape_vec((indices.len(), ncols), data).unwrap()
328}
329
330fn select_elements<F: Float>(y: &Array1<F>, indices: &[usize]) -> Array1<F> {
331    Array1::from_vec(indices.iter().map(|&i| y[i]).collect())
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use anofox_ml_trees::DecisionTreeClassifier;
338    use ndarray::array;
339
340    #[test]
341    fn test_stacking_classifier_basic() {
342        // Two well-separated clusters, interleaved so simple k-fold sees both
343        // classes in each fold.
344        let x = array![
345            [0.0, 0.0],
346            [5.0, 5.0],
347            [0.1, 0.1],
348            [5.1, 5.0],
349            [0.2, -0.1],
350            [4.9, 5.1],
351            [-0.1, 0.2],
352            [5.2, 4.8],
353        ];
354        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
355
356        let sc = StackingClassifier::new(DecisionTreeClassifier::default())
357            .push(
358                "t1",
359                DecisionTreeClassifier {
360                    max_depth: Some(2),
361                    ..Default::default()
362                },
363            )
364            .push(
365                "t2",
366                DecisionTreeClassifier {
367                    max_depth: Some(3),
368                    ..Default::default()
369                },
370            )
371            .with_cv_folds(2);
372
373        let fitted: FittedStackingClassifier<f64> = sc.fit(&x, &y).unwrap();
374        let preds = fitted.predict(&x).unwrap();
375        for (p, t) in preds.iter().zip(y.iter()) {
376            assert_eq!(*p, *t, "p={p}, t={t}");
377        }
378    }
379
380    #[test]
381    fn test_stacking_classifier_proba_path() {
382        // Stack two DT classifiers via predict_proba into a DT meta.
383        let x = array![
384            [0.0, 0.0],
385            [5.0, 5.0],
386            [0.1, 0.1],
387            [5.1, 5.0],
388            [0.2, -0.1],
389            [4.9, 5.1],
390            [-0.1, 0.2],
391            [5.2, 4.8],
392        ];
393        let y = array![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
394
395        let sc = StackingClassifier::new(DecisionTreeClassifier::default())
396            .push_proba(
397                "t1",
398                DecisionTreeClassifier {
399                    max_depth: Some(2),
400                    ..Default::default()
401                },
402            )
403            .push_proba(
404                "t2",
405                DecisionTreeClassifier {
406                    max_depth: Some(3),
407                    ..Default::default()
408                },
409            )
410            .with_cv_folds(2);
411
412        let fitted: FittedStackingClassifier<f64> = sc.fit(&x, &y).unwrap();
413        let preds = fitted.predict(&x).unwrap();
414        for (p, t) in preds.iter().zip(y.iter()) {
415            assert_eq!(*p, *t, "p={p}, t={t}");
416        }
417    }
418
419    #[test]
420    fn test_stacking_classifier_empty_base_error() {
421        let x = array![[1.0], [2.0]];
422        let y = array![0.0, 1.0];
423
424        let sc = StackingClassifier::<f64>::new(DecisionTreeClassifier::default());
425        assert!(sc.fit(&x, &y).is_err());
426    }
427}