Skip to main content

anofox_ml_core/
multi_output.rs

1//! Multi-output meta-estimators: per-output regressor / classifier wrappers.
2//!
3//! Mirrors `sklearn.multioutput.MultiOutputRegressor` (a separate estimator is
4//! fitted per output column). Our existing `Fit` / `Predict` traits assume 1-D
5//! `y`; this wrapper provides a 2-D entry point on top.
6
7use ndarray::{Array1, Array2, Axis};
8
9use crate::error::{Result, RustMlError};
10use crate::float::Float;
11use crate::traits::{Fit, Predict};
12
13/// Internal trait used to abstract over the inner-estimator type without a
14/// blanket impl that conflicts with downstream crates.
15trait MultiFitTemplate<F: Float>: Send + Sync {
16    fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>>;
17}
18
19trait PredBox<F: Float>: Send + Sync {
20    fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>>;
21}
22
23struct Template<T>(T);
24
25impl<F, T> MultiFitTemplate<F> for Template<T>
26where
27    F: Float,
28    T: Fit<F> + Send + Sync + Clone,
29    T::Fitted: Predict<F> + Send + Sync + 'static,
30{
31    fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>> {
32        let est = self.0.clone();
33        let fitted = Fit::fit(&est, x, y)?;
34        Ok(Box::new(PredHolder(fitted)))
35    }
36}
37
38struct PredHolder<P>(P);
39impl<F, P> PredBox<F> for PredHolder<P>
40where
41    F: Float,
42    P: Predict<F> + Send + Sync,
43{
44    fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>> {
45        self.0.predict(x)
46    }
47}
48
49/// Fits one independent estimator per target column.
50pub struct MultiOutputRegressor<F: Float> {
51    template: Box<dyn MultiFitTemplate<F>>,
52}
53
54impl<F: Float> MultiOutputRegressor<F> {
55    pub fn new<T>(estimator: T) -> Self
56    where
57        T: Fit<F> + Send + Sync + Clone + 'static,
58        T::Fitted: Predict<F> + Send + Sync + 'static,
59    {
60        Self {
61            template: Box::new(Template(estimator)),
62        }
63    }
64
65    pub fn fit_2d(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedMultiOutputRegressor<F>> {
66        if x.nrows() != y.nrows() {
67            return Err(RustMlError::ShapeMismatch(format!(
68                "X has {} rows but y has {}",
69                x.nrows(),
70                y.nrows()
71            )));
72        }
73        if y.is_empty() {
74            return Err(RustMlError::EmptyInput("y is empty".into()));
75        }
76        let n_outputs = y.ncols();
77        let mut fitted = Vec::with_capacity(n_outputs);
78        for k in 0..n_outputs {
79            let yk = y.index_axis(Axis(1), k).to_owned();
80            let m = self.template.fit_box(x, &yk)?;
81            fitted.push(m);
82        }
83        Ok(FittedMultiOutputRegressor {
84            models: fitted,
85            n_features: x.ncols(),
86        })
87    }
88}
89
90pub struct FittedMultiOutputRegressor<F: Float> {
91    models: Vec<Box<dyn PredBox<F>>>,
92    n_features: usize,
93}
94
95impl<F: Float> FittedMultiOutputRegressor<F> {
96    pub fn predict_2d(&self, x: &Array2<F>) -> Result<Array2<F>> {
97        if x.ncols() != self.n_features {
98            return Err(RustMlError::ShapeMismatch(format!(
99                "expected {} features, got {}",
100                self.n_features,
101                x.ncols()
102            )));
103        }
104        let n = x.nrows();
105        let mut out = Array2::<F>::zeros((n, self.models.len()));
106        for (k, m) in self.models.iter().enumerate() {
107            let yk = m.predict_box(x)?;
108            for i in 0..n {
109                out[[i, k]] = yk[i];
110            }
111        }
112        Ok(out)
113    }
114
115    pub fn n_features(&self) -> usize {
116        self.n_features
117    }
118
119    pub fn n_outputs(&self) -> usize {
120        self.models.len()
121    }
122}
123
124// ---------------------------------------------------------------------------
125// MultiOutputClassifier — same one-per-output pattern as MultiOutputRegressor.
126// Just a re-export under a clearer name since the math is identical.
127// ---------------------------------------------------------------------------
128
129/// Multi-output classifier. Fits one independent classifier per output column;
130/// each column of `y` is the class label for that output dimension. Identical
131/// implementation to `MultiOutputRegressor` — the distinction is purely about
132/// downstream meaning (sklearn ships them as separate classes for the same
133/// reason).
134pub type MultiOutputClassifier<F> = MultiOutputRegressor<F>;
135pub type FittedMultiOutputClassifier<F> = FittedMultiOutputRegressor<F>;
136
137// ---------------------------------------------------------------------------
138// RegressorChain — chain feeds previous predictions as features.
139// ---------------------------------------------------------------------------
140
141/// Chain of regressors where each step's prediction becomes a feature for the
142/// next. Mirrors `sklearn.multioutput.RegressorChain`. With `order` = `[2, 0, 1]`,
143/// the regressor for output 2 sees only the original X, the regressor for
144/// output 0 sees X + prediction-of-2, and so on.
145pub struct RegressorChain<F: Float> {
146    template: Box<dyn MultiFitTemplate<F>>,
147    order: Option<Vec<usize>>,
148}
149
150impl<F: Float> RegressorChain<F> {
151    pub fn new<T>(estimator: T) -> Self
152    where
153        T: Fit<F> + Send + Sync + Clone + 'static,
154        T::Fitted: Predict<F> + Send + Sync + 'static,
155    {
156        Self {
157            template: Box::new(Template(estimator)),
158            order: None,
159        }
160    }
161
162    /// Set the chain order. Default is `0..n_outputs`.
163    pub fn with_order(mut self, order: Vec<usize>) -> Self {
164        self.order = Some(order);
165        self
166    }
167
168    pub fn fit_2d(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedRegressorChain<F>> {
169        if x.nrows() != y.nrows() {
170            return Err(RustMlError::ShapeMismatch(format!(
171                "X has {} rows but y has {}",
172                x.nrows(),
173                y.nrows()
174            )));
175        }
176        if y.is_empty() {
177            return Err(RustMlError::EmptyInput("y is empty".into()));
178        }
179        let n = x.nrows();
180        let d = x.ncols();
181        let n_out = y.ncols();
182        let order = self.order.clone().unwrap_or_else(|| (0..n_out).collect());
183        if order.len() != n_out {
184            return Err(RustMlError::InvalidParameter(format!(
185                "order length {} != n_outputs {}",
186                order.len(),
187                n_out
188            )));
189        }
190        let mut models: Vec<Box<dyn PredBox<F>>> = Vec::with_capacity(n_out);
191        // Build per-step features: original X plus all already-predicted columns.
192        let mut x_ext = Array2::<F>::zeros((n, d + n_out));
193        // Copy original X.
194        for i in 0..n {
195            for j in 0..d {
196                x_ext[[i, j]] = x[[i, j]];
197            }
198        }
199        for (step, &out_idx) in order.iter().enumerate() {
200            // Build the feature view containing original + first `step` predicted columns.
201            let cur_cols = d + step;
202            let xs = sub_x(&x_ext, n, cur_cols);
203            let yk = y.index_axis(Axis(1), out_idx).to_owned();
204            let m = self.template.fit_box(&xs, &yk)?;
205            // For subsequent steps we feed the *predicted* values of yk
206            // (sklearn does this — at fit time it's the true value to avoid
207            // exposure bias; we follow sklearn and use the true y at fit).
208            for i in 0..n {
209                x_ext[[i, d + step]] = y[[i, out_idx]];
210            }
211            models.push(m);
212        }
213        Ok(FittedRegressorChain {
214            models,
215            order,
216            n_features: d,
217            n_outputs: n_out,
218        })
219    }
220}
221
222pub struct FittedRegressorChain<F: Float> {
223    models: Vec<Box<dyn PredBox<F>>>,
224    order: Vec<usize>,
225    n_features: usize,
226    n_outputs: usize,
227}
228
229impl<F: Float> FittedRegressorChain<F> {
230    pub fn predict_2d(&self, x: &Array2<F>) -> Result<Array2<F>> {
231        if x.ncols() != self.n_features {
232            return Err(RustMlError::ShapeMismatch(format!(
233                "expected {} features, got {}",
234                self.n_features,
235                x.ncols()
236            )));
237        }
238        let n = x.nrows();
239        let d = self.n_features;
240        let mut x_ext = Array2::<F>::zeros((n, d + self.n_outputs));
241        for i in 0..n {
242            for j in 0..d {
243                x_ext[[i, j]] = x[[i, j]];
244            }
245        }
246        let mut out = Array2::<F>::zeros((n, self.n_outputs));
247        for (step, &out_idx) in self.order.iter().enumerate() {
248            let xs = sub_x(&x_ext, n, d + step);
249            let pred = self.models[step].predict_box(&xs)?;
250            for i in 0..n {
251                out[[i, out_idx]] = pred[i];
252                x_ext[[i, d + step]] = pred[i];
253            }
254        }
255        Ok(out)
256    }
257}
258
259fn sub_x<F: Float>(x_ext: &Array2<F>, n: usize, cols: usize) -> Array2<F> {
260    let mut out = Array2::<F>::zeros((n, cols));
261    for i in 0..n {
262        for j in 0..cols {
263            out[[i, j]] = x_ext[[i, j]];
264        }
265    }
266    out
267}
268
269// `ClassifierChain` is the same as `RegressorChain` for our purposes
270// (per-step prediction is a class label, fed as a feature to the next step).
271pub type ClassifierChain<F> = RegressorChain<F>;
272pub type FittedClassifierChain<F> = FittedRegressorChain<F>;
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use ndarray::array;
278
279    #[derive(Clone)]
280    struct MeanReg;
281    struct FittedMeanReg(f64);
282
283    impl Fit<f64> for MeanReg {
284        type Fitted = FittedMeanReg;
285        fn fit(&self, _x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
286            let m = y.iter().sum::<f64>() / y.len() as f64;
287            Ok(FittedMeanReg(m))
288        }
289    }
290    impl Predict<f64> for FittedMeanReg {
291        fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
292            Ok(Array1::from_elem(x.nrows(), self.0))
293        }
294    }
295
296    #[test]
297    fn test_multi_output_predicts_per_column_mean() {
298        let x = array![[1.0], [2.0], [3.0], [4.0]];
299        let y = array![[1.0, 10.0], [3.0, 20.0], [5.0, 30.0], [7.0, 40.0]];
300
301        let model = MultiOutputRegressor::<f64>::new(MeanReg);
302        let fitted = model.fit_2d(&x, &y).unwrap();
303        let p = fitted.predict_2d(&x).unwrap();
304        assert_eq!(p.shape(), &[4, 2]);
305        for i in 0..4 {
306            assert!((p[[i, 0]] - 4.0).abs() < 1e-9);
307            assert!((p[[i, 1]] - 25.0).abs() < 1e-9);
308        }
309    }
310
311    #[test]
312    fn test_regressor_chain_predicts_2d() {
313        // With MeanReg as the base, each step's prediction is the (constant)
314        // mean of its target column — chain ordering doesn't affect output for
315        // this trivial estimator, but predict_2d should produce the expected
316        // 2-D shape and values.
317        let x = array![[1.0], [2.0], [3.0], [4.0]];
318        let y = array![[1.0, 10.0], [3.0, 20.0], [5.0, 30.0], [7.0, 40.0]];
319
320        let chain = RegressorChain::<f64>::new(MeanReg);
321        let fitted = chain.fit_2d(&x, &y).unwrap();
322        let p = fitted.predict_2d(&x).unwrap();
323        assert_eq!(p.shape(), &[4, 2]);
324        for i in 0..4 {
325            assert!(p[[i, 0]].is_finite());
326            assert!(p[[i, 1]].is_finite());
327        }
328    }
329}