linfa/composing/
multi_target_model.rs

1//! Merge models with single target to multi-target models
2//!
3//! Many models assume that the target variables are uncorrelated and support therefore only a
4//! single target variable. This wrapper allows the user to merge multiple models with only a
5//! single-target variable into a multi-target model.
6//!
7//!
8use crate::dataset::Records;
9use crate::traits::PredictInplace;
10use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
11use std::iter::FromIterator;
12
13/// Merge models with single target to multi-target models
14///
15/// Many models assume that the target variables are uncorrelated and support therefore only a
16/// single target variable. This wrapper allows the user to merge multiple models with only a
17/// single-target variable into a multi-target model.
18pub struct MultiTargetModel<R: Records, L> {
19    models: Vec<Box<dyn PredictInplace<R, Array1<L>>>>,
20}
21
22impl<R: Records, L> MultiTargetModel<R, L> {
23    /// Create a wrapper model from a list of single-target models
24    ///
25    /// The type parameter of the single-target models are only constraint to implement the
26    /// prediction trait and can otherwise contain any object. This allows the mixture of different
27    /// models into the same wrapper. If you want to use the same model for all predictions, just
28    /// use the `FromIterator` implementation.
29    pub fn new(models: Vec<Box<dyn PredictInplace<R, Array1<L>>>>) -> Self {
30        MultiTargetModel { models }
31    }
32}
33
34impl<L: Default, F, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array2<L>>
35    for MultiTargetModel<ArrayBase<D, Ix2>, L>
36{
37    fn predict_inplace(&self, arr: &ArrayBase<D, Ix2>, targets: &mut Array2<L>) {
38        assert_eq!(
39            targets.shape(),
40            &[arr.nrows(), self.models.len()],
41            "The number of data points must match the number of output targets."
42        );
43        assert!(
44            targets.is_standard_layout(),
45            "targets not in row-major layout"
46        );
47        *targets = self
48            .models
49            .iter()
50            .flat_map(|model| {
51                let mut targets = Array1::default(arr.nrows());
52                model.predict_inplace(arr, &mut targets);
53                let (v, _) = targets.into_raw_vec_and_offset();
54                v
55            })
56            .collect::<Array1<L>>()
57            .into_shape_with_order((self.models.len(), arr.len_of(Axis(0))))
58            .unwrap()
59            .reversed_axes();
60    }
61
62    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array2<L> {
63        Array2::default((x.nrows(), self.models.len()))
64    }
65}
66
67impl<F, D: Data<Elem = F>, L, P: PredictInplace<ArrayBase<D, Ix2>, Array1<L>> + 'static>
68    FromIterator<P> for MultiTargetModel<ArrayBase<D, Ix2>, L>
69{
70    fn from_iter<I: IntoIterator<Item = P>>(iter: I) -> Self {
71        let models = iter
72            .into_iter()
73            .map(|x| Box::new(x) as Box<dyn PredictInplace<ArrayBase<D, Ix2>, Array1<L>>>)
74            .collect();
75
76        MultiTargetModel { models }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use crate::{
83        traits::{Predict, PredictInplace},
84        MultiTargetModel,
85    };
86    use approx::assert_abs_diff_eq;
87    use ndarray::{array, Array1, Array2, Axis};
88
89    /// First dummy model, returns a constant value
90    struct DummyModel {
91        val: f32,
92    }
93
94    impl PredictInplace<Array2<f32>, Array1<f32>> for DummyModel {
95        fn predict_inplace(&self, arr: &Array2<f32>, targets: &mut Array1<f32>) {
96            assert_eq!(
97                arr.nrows(),
98                targets.len(),
99                "The number of data points must match the number of output targets."
100            );
101            targets.fill(self.val);
102        }
103
104        fn default_target(&self, x: &Array2<f32>) -> Array1<f32> {
105            Array1::zeros(x.nrows())
106        }
107    }
108
109    /// Second dummy model, counts up from a start value to the number of samples
110    struct DummyModel2 {
111        val: f32,
112    }
113
114    impl PredictInplace<Array2<f32>, Array1<f32>> for DummyModel2 {
115        fn predict_inplace(&self, arr: &Array2<f32>, targets: &mut Array1<f32>) {
116            assert_eq!(
117                arr.nrows(),
118                targets.len(),
119                "The number of data points must match the number of output targets."
120            );
121            *targets = Array1::linspace(
122                self.val,
123                self.val + arr.len_of(Axis(0)) as f32 - 1.0,
124                arr.len_of(Axis(0)),
125            );
126        }
127
128        fn default_target(&self, x: &Array2<f32>) -> Array1<f32> {
129            Array1::zeros(x.nrows())
130        }
131    }
132
133    #[test]
134    fn dummy_constant() {
135        // construct models which predicts a constant all time
136        // and merge them into a `MultiTargetModel`
137        let model = (0..4)
138            .map(|val| val as f32)
139            .map(|val| DummyModel { val })
140            .collect::<MultiTargetModel<_, _>>();
141
142        // test capability to predict constants
143        let targets = model.predict(&Array2::zeros((5, 2)));
144        assert_abs_diff_eq!(
145            targets,
146            array![
147                [0., 1., 2., 3.],
148                [0., 1., 2., 3.],
149                [0., 1., 2., 3.],
150                [0., 1., 2., 3.],
151                [0., 1., 2., 3.],
152            ]
153        );
154    }
155
156    #[test]
157    fn different_dummys() {
158        // create two different models, the first predicts a constant 42 and the second counts up
159        // from 42 to the number of samples
160        let model_a = DummyModel { val: 42.0 };
161        let model_b = DummyModel2 { val: 42.0 };
162
163        let model = MultiTargetModel::new(vec![Box::new(model_a), Box::new(model_b)]);
164
165        let targets = model.predict(&Array2::zeros((5, 2)));
166        assert_abs_diff_eq!(
167            targets,
168            array![[42., 42.], [42., 43.], [42., 44.], [42., 45.], [42., 46.]]
169        );
170    }
171}