1use anofox_ml_core::{Fit, Float, Predict, Result, RustMlError};
5use ndarray::{Array1, Array2};
6
7trait FitPredBox<F: Float>: Send + Sync {
9 fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>>;
10}
11
12trait PredBox<F: Float>: Send + Sync {
13 fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>>;
14}
15
16impl<F, T> FitPredBox<F> for T
17where
18 F: Float,
19 T: Fit<F> + Send + Sync,
20 T::Fitted: Predict<F> + Send + Sync + 'static,
21{
22 fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>> {
23 let fitted = Fit::fit(self, x, y)?;
24 Ok(Box::new(fitted))
25 }
26}
27
28impl<F, T> PredBox<F> for T
29where
30 F: Float,
31 T: Predict<F> + Send + Sync,
32{
33 fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>> {
34 self.predict(x)
35 }
36}
37
38pub struct StackingRegressor<F: Float> {
44 base_estimators: Vec<(String, Box<dyn FitPredBox<F>>)>,
45 meta_estimator: Box<dyn FitPredBox<F>>,
46 cv_folds: usize,
47}
48
49impl<F: Float> StackingRegressor<F> {
50 pub fn new<M>(meta_estimator: M) -> Self
52 where
53 M: Fit<F> + Send + Sync + 'static,
54 M::Fitted: Predict<F> + Send + Sync + 'static,
55 {
56 Self {
57 base_estimators: Vec::new(),
58 meta_estimator: Box::new(meta_estimator),
59 cv_folds: 5,
60 }
61 }
62
63 pub fn push<T>(mut self, name: impl Into<String>, estimator: T) -> Self
65 where
66 T: Fit<F> + Send + Sync + 'static,
67 T::Fitted: Predict<F> + Send + Sync + 'static,
68 {
69 self.base_estimators
70 .push((name.into(), Box::new(estimator)));
71 self
72 }
73
74 pub fn with_cv_folds(mut self, cv_folds: usize) -> Self {
76 self.cv_folds = cv_folds;
77 self
78 }
79}
80
81pub struct FittedStackingRegressor<F: Float> {
83 fitted_base: Vec<(String, Box<dyn PredBox<F>>)>,
84 fitted_meta: Box<dyn PredBox<F>>,
85 n_features: usize,
86}
87
88impl<F: Float> FittedStackingRegressor<F> {
89 pub fn estimator_names(&self) -> Vec<&str> {
90 self.fitted_base.iter().map(|(n, _)| n.as_str()).collect()
91 }
92}
93
94impl<F: Float + 'static> Fit<F> for StackingRegressor<F> {
95 type Fitted = FittedStackingRegressor<F>;
96
97 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Self::Fitted> {
98 if self.base_estimators.is_empty() {
99 return Err(RustMlError::InvalidParameter(
100 "StackingRegressor needs at least one base estimator".into(),
101 ));
102 }
103 if x.nrows() != y.len() {
104 return Err(RustMlError::ShapeMismatch(format!(
105 "X has {} rows but y has {} elements",
106 x.nrows(),
107 y.len()
108 )));
109 }
110 let n = x.nrows();
111 if n < 2 {
112 return Err(RustMlError::EmptyInput("need at least 2 samples".into()));
113 }
114
115 let n_base = self.base_estimators.len();
116 let k = self.cv_folds.min(n);
117
118 let folds = simple_k_fold(n, k);
120 let mut meta_features = Array2::zeros((n, n_base));
121
122 for (bi, (_, est)) in self.base_estimators.iter().enumerate() {
123 for (train_idx, test_idx) in &folds {
124 let x_train = select_rows(x, train_idx);
125 let y_train = select_elements(y, train_idx);
126 let x_test = select_rows(x, test_idx);
127
128 let fitted = est.fit_box(&x_train, &y_train)?;
129 let preds = fitted.predict_box(&x_test)?;
130
131 for (li, &gi) in test_idx.iter().enumerate() {
132 meta_features[[gi, bi]] = preds[li];
133 }
134 }
135 }
136
137 let fitted_meta = self.meta_estimator.fit_box(&meta_features, y)?;
139
140 let mut fitted_base = Vec::with_capacity(n_base);
142 for (name, est) in &self.base_estimators {
143 let fitted = est.fit_box(x, y)?;
144 fitted_base.push((name.clone(), fitted));
145 }
146
147 Ok(FittedStackingRegressor {
148 fitted_base,
149 fitted_meta,
150 n_features: x.ncols(),
151 })
152 }
153}
154
155impl<F: Float> Predict<F> for FittedStackingRegressor<F> {
156 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>> {
157 if x.ncols() != self.n_features {
158 return Err(RustMlError::ShapeMismatch(format!(
159 "expected {} features, got {}",
160 self.n_features,
161 x.ncols()
162 )));
163 }
164
165 let n = x.nrows();
166 let n_base = self.fitted_base.len();
167 let mut meta_features = Array2::zeros((n, n_base));
168
169 for (bi, (_, model)) in self.fitted_base.iter().enumerate() {
170 let preds = model.predict_box(x)?;
171 for i in 0..n {
172 meta_features[[i, bi]] = preds[i];
173 }
174 }
175
176 self.fitted_meta.predict_box(&meta_features)
177 }
178}
179
180fn simple_k_fold(n: usize, k: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
182 let fold_size = n / k;
183 let remainder = n % k;
184 let mut folds = Vec::with_capacity(k);
185 let mut start = 0;
186
187 for f in 0..k {
188 let end = start + fold_size + if f < remainder { 1 } else { 0 };
189 let test: Vec<usize> = (start..end).collect();
190 let train: Vec<usize> = (0..start).chain(end..n).collect();
191 folds.push((train, test));
192 start = end;
193 }
194 folds
195}
196
197fn select_rows<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
198 let ncols = x.ncols();
199 let mut data = Vec::with_capacity(indices.len() * ncols);
200 for &i in indices {
201 for j in 0..ncols {
202 data.push(x[[i, j]]);
203 }
204 }
205 Array2::from_shape_vec((indices.len(), ncols), data).unwrap()
206}
207
208fn select_elements<F: Float>(y: &Array1<F>, indices: &[usize]) -> Array1<F> {
209 Array1::from_vec(indices.iter().map(|&i| y[i]).collect())
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use anofox_ml_trees::DecisionTreeRegressor;
216 use ndarray::array;
217
218 #[test]
219 fn test_stacking_regressor_basic() {
220 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]];
221 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0];
222
223 let sr = StackingRegressor::new(DecisionTreeRegressor::default())
224 .push(
225 "t1",
226 DecisionTreeRegressor {
227 max_depth: Some(2),
228 ..Default::default()
229 },
230 )
231 .push(
232 "t2",
233 DecisionTreeRegressor {
234 max_depth: Some(3),
235 ..Default::default()
236 },
237 )
238 .with_cv_folds(2);
239
240 let fitted: FittedStackingRegressor<f64> = sr.fit(&x, &y).unwrap();
241 let preds = fitted.predict(&x).unwrap();
242 assert_eq!(preds.len(), 8);
243
244 for &p in preds.iter() {
245 assert!(p.is_finite());
246 }
247 }
248
249 #[test]
250 fn test_stacking_regressor_names() {
251 let x = array![[1.0], [2.0], [3.0], [4.0]];
252 let y = array![1.0, 2.0, 3.0, 4.0];
253
254 let sr = StackingRegressor::new(DecisionTreeRegressor::default())
255 .push("a", DecisionTreeRegressor::default())
256 .push("b", DecisionTreeRegressor::default())
257 .with_cv_folds(2);
258
259 let fitted: FittedStackingRegressor<f64> = sr.fit(&x, &y).unwrap();
260 assert_eq!(fitted.estimator_names(), vec!["a", "b"]);
261 }
262
263 #[test]
264 fn test_stacking_regressor_empty_base_error() {
265 let x = array![[1.0], [2.0]];
266 let y = array![1.0, 2.0];
267
268 let sr = StackingRegressor::<f64>::new(DecisionTreeRegressor::default());
269 assert!(sr.fit(&x, &y).is_err());
270 }
271
272 #[test]
273 fn test_stacking_regressor_predict_shape_mismatch() {
274 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
275 let y = array![1.0, 2.0, 3.0, 4.0];
276
277 let sr = StackingRegressor::new(DecisionTreeRegressor::default())
278 .push("t1", DecisionTreeRegressor::default())
279 .with_cv_folds(2);
280
281 let fitted: FittedStackingRegressor<f64> = sr.fit(&x, &y).unwrap();
282 let x_bad = array![[1.0]];
283 assert!(fitted.predict(&x_bad).is_err());
284 }
285}