1use ndarray::{Array1, Array2, Axis};
8
9use crate::error::{Result, RustMlError};
10use crate::float::Float;
11use crate::traits::{Fit, Predict};
12
13trait MultiFitTemplate<F: Float>: Send + Sync {
16 fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>>;
17}
18
19trait PredBox<F: Float>: Send + Sync {
20 fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>>;
21}
22
23struct Template<T>(T);
24
25impl<F, T> MultiFitTemplate<F> for Template<T>
26where
27 F: Float,
28 T: Fit<F> + Send + Sync + Clone,
29 T::Fitted: Predict<F> + Send + Sync + 'static,
30{
31 fn fit_box(&self, x: &Array2<F>, y: &Array1<F>) -> Result<Box<dyn PredBox<F>>> {
32 let est = self.0.clone();
33 let fitted = Fit::fit(&est, x, y)?;
34 Ok(Box::new(PredHolder(fitted)))
35 }
36}
37
38struct PredHolder<P>(P);
39impl<F, P> PredBox<F> for PredHolder<P>
40where
41 F: Float,
42 P: Predict<F> + Send + Sync,
43{
44 fn predict_box(&self, x: &Array2<F>) -> Result<Array1<F>> {
45 self.0.predict(x)
46 }
47}
48
49pub struct MultiOutputRegressor<F: Float> {
51 template: Box<dyn MultiFitTemplate<F>>,
52}
53
54impl<F: Float> MultiOutputRegressor<F> {
55 pub fn new<T>(estimator: T) -> Self
56 where
57 T: Fit<F> + Send + Sync + Clone + 'static,
58 T::Fitted: Predict<F> + Send + Sync + 'static,
59 {
60 Self {
61 template: Box::new(Template(estimator)),
62 }
63 }
64
65 pub fn fit_2d(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedMultiOutputRegressor<F>> {
66 if x.nrows() != y.nrows() {
67 return Err(RustMlError::ShapeMismatch(format!(
68 "X has {} rows but y has {}",
69 x.nrows(),
70 y.nrows()
71 )));
72 }
73 if y.is_empty() {
74 return Err(RustMlError::EmptyInput("y is empty".into()));
75 }
76 let n_outputs = y.ncols();
77 let mut fitted = Vec::with_capacity(n_outputs);
78 for k in 0..n_outputs {
79 let yk = y.index_axis(Axis(1), k).to_owned();
80 let m = self.template.fit_box(x, &yk)?;
81 fitted.push(m);
82 }
83 Ok(FittedMultiOutputRegressor {
84 models: fitted,
85 n_features: x.ncols(),
86 })
87 }
88}
89
90pub struct FittedMultiOutputRegressor<F: Float> {
91 models: Vec<Box<dyn PredBox<F>>>,
92 n_features: usize,
93}
94
95impl<F: Float> FittedMultiOutputRegressor<F> {
96 pub fn predict_2d(&self, x: &Array2<F>) -> Result<Array2<F>> {
97 if x.ncols() != self.n_features {
98 return Err(RustMlError::ShapeMismatch(format!(
99 "expected {} features, got {}",
100 self.n_features,
101 x.ncols()
102 )));
103 }
104 let n = x.nrows();
105 let mut out = Array2::<F>::zeros((n, self.models.len()));
106 for (k, m) in self.models.iter().enumerate() {
107 let yk = m.predict_box(x)?;
108 for i in 0..n {
109 out[[i, k]] = yk[i];
110 }
111 }
112 Ok(out)
113 }
114
115 pub fn n_features(&self) -> usize {
116 self.n_features
117 }
118
119 pub fn n_outputs(&self) -> usize {
120 self.models.len()
121 }
122}
123
124pub type MultiOutputClassifier<F> = MultiOutputRegressor<F>;
135pub type FittedMultiOutputClassifier<F> = FittedMultiOutputRegressor<F>;
136
137pub struct RegressorChain<F: Float> {
146 template: Box<dyn MultiFitTemplate<F>>,
147 order: Option<Vec<usize>>,
148}
149
150impl<F: Float> RegressorChain<F> {
151 pub fn new<T>(estimator: T) -> Self
152 where
153 T: Fit<F> + Send + Sync + Clone + 'static,
154 T::Fitted: Predict<F> + Send + Sync + 'static,
155 {
156 Self {
157 template: Box::new(Template(estimator)),
158 order: None,
159 }
160 }
161
162 pub fn with_order(mut self, order: Vec<usize>) -> Self {
164 self.order = Some(order);
165 self
166 }
167
168 pub fn fit_2d(&self, x: &Array2<F>, y: &Array2<F>) -> Result<FittedRegressorChain<F>> {
169 if x.nrows() != y.nrows() {
170 return Err(RustMlError::ShapeMismatch(format!(
171 "X has {} rows but y has {}",
172 x.nrows(),
173 y.nrows()
174 )));
175 }
176 if y.is_empty() {
177 return Err(RustMlError::EmptyInput("y is empty".into()));
178 }
179 let n = x.nrows();
180 let d = x.ncols();
181 let n_out = y.ncols();
182 let order = self.order.clone().unwrap_or_else(|| (0..n_out).collect());
183 if order.len() != n_out {
184 return Err(RustMlError::InvalidParameter(format!(
185 "order length {} != n_outputs {}",
186 order.len(),
187 n_out
188 )));
189 }
190 let mut models: Vec<Box<dyn PredBox<F>>> = Vec::with_capacity(n_out);
191 let mut x_ext = Array2::<F>::zeros((n, d + n_out));
193 for i in 0..n {
195 for j in 0..d {
196 x_ext[[i, j]] = x[[i, j]];
197 }
198 }
199 for (step, &out_idx) in order.iter().enumerate() {
200 let cur_cols = d + step;
202 let xs = sub_x(&x_ext, n, cur_cols);
203 let yk = y.index_axis(Axis(1), out_idx).to_owned();
204 let m = self.template.fit_box(&xs, &yk)?;
205 for i in 0..n {
209 x_ext[[i, d + step]] = y[[i, out_idx]];
210 }
211 models.push(m);
212 }
213 Ok(FittedRegressorChain {
214 models,
215 order,
216 n_features: d,
217 n_outputs: n_out,
218 })
219 }
220}
221
222pub struct FittedRegressorChain<F: Float> {
223 models: Vec<Box<dyn PredBox<F>>>,
224 order: Vec<usize>,
225 n_features: usize,
226 n_outputs: usize,
227}
228
229impl<F: Float> FittedRegressorChain<F> {
230 pub fn predict_2d(&self, x: &Array2<F>) -> Result<Array2<F>> {
231 if x.ncols() != self.n_features {
232 return Err(RustMlError::ShapeMismatch(format!(
233 "expected {} features, got {}",
234 self.n_features,
235 x.ncols()
236 )));
237 }
238 let n = x.nrows();
239 let d = self.n_features;
240 let mut x_ext = Array2::<F>::zeros((n, d + self.n_outputs));
241 for i in 0..n {
242 for j in 0..d {
243 x_ext[[i, j]] = x[[i, j]];
244 }
245 }
246 let mut out = Array2::<F>::zeros((n, self.n_outputs));
247 for (step, &out_idx) in self.order.iter().enumerate() {
248 let xs = sub_x(&x_ext, n, d + step);
249 let pred = self.models[step].predict_box(&xs)?;
250 for i in 0..n {
251 out[[i, out_idx]] = pred[i];
252 x_ext[[i, d + step]] = pred[i];
253 }
254 }
255 Ok(out)
256 }
257}
258
259fn sub_x<F: Float>(x_ext: &Array2<F>, n: usize, cols: usize) -> Array2<F> {
260 let mut out = Array2::<F>::zeros((n, cols));
261 for i in 0..n {
262 for j in 0..cols {
263 out[[i, j]] = x_ext[[i, j]];
264 }
265 }
266 out
267}
268
269pub type ClassifierChain<F> = RegressorChain<F>;
272pub type FittedClassifierChain<F> = FittedRegressorChain<F>;
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use ndarray::array;
278
279 #[derive(Clone)]
280 struct MeanReg;
281 struct FittedMeanReg(f64);
282
283 impl Fit<f64> for MeanReg {
284 type Fitted = FittedMeanReg;
285 fn fit(&self, _x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
286 let m = y.iter().sum::<f64>() / y.len() as f64;
287 Ok(FittedMeanReg(m))
288 }
289 }
290 impl Predict<f64> for FittedMeanReg {
291 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
292 Ok(Array1::from_elem(x.nrows(), self.0))
293 }
294 }
295
296 #[test]
297 fn test_multi_output_predicts_per_column_mean() {
298 let x = array![[1.0], [2.0], [3.0], [4.0]];
299 let y = array![[1.0, 10.0], [3.0, 20.0], [5.0, 30.0], [7.0, 40.0]];
300
301 let model = MultiOutputRegressor::<f64>::new(MeanReg);
302 let fitted = model.fit_2d(&x, &y).unwrap();
303 let p = fitted.predict_2d(&x).unwrap();
304 assert_eq!(p.shape(), &[4, 2]);
305 for i in 0..4 {
306 assert!((p[[i, 0]] - 4.0).abs() < 1e-9);
307 assert!((p[[i, 1]] - 25.0).abs() < 1e-9);
308 }
309 }
310
311 #[test]
312 fn test_regressor_chain_predicts_2d() {
313 let x = array![[1.0], [2.0], [3.0], [4.0]];
318 let y = array![[1.0, 10.0], [3.0, 20.0], [5.0, 30.0], [7.0, 40.0]];
319
320 let chain = RegressorChain::<f64>::new(MeanReg);
321 let fitted = chain.fit_2d(&x, &y).unwrap();
322 let p = fitted.predict_2d(&x).unwrap();
323 assert_eq!(p.shape(), &[4, 2]);
324 for i in 0..4 {
325 assert!(p[[i, 0]].is_finite());
326 assert!(p[[i, 1]].is_finite());
327 }
328 }
329}