1use crate::dataset::Records;
9use crate::traits::PredictInplace;
10use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
11use std::iter::FromIterator;
12
13pub struct MultiTargetModel<R: Records, L> {
19 models: Vec<Box<dyn PredictInplace<R, Array1<L>>>>,
20}
21
22impl<R: Records, L> MultiTargetModel<R, L> {
23 pub fn new(models: Vec<Box<dyn PredictInplace<R, Array1<L>>>>) -> Self {
30 MultiTargetModel { models }
31 }
32}
33
34impl<L: Default, F, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array2<L>>
35 for MultiTargetModel<ArrayBase<D, Ix2>, L>
36{
37 fn predict_inplace(&self, arr: &ArrayBase<D, Ix2>, targets: &mut Array2<L>) {
38 assert_eq!(
39 targets.shape(),
40 &[arr.nrows(), self.models.len()],
41 "The number of data points must match the number of output targets."
42 );
43 assert!(
44 targets.is_standard_layout(),
45 "targets not in row-major layout"
46 );
47 *targets = self
48 .models
49 .iter()
50 .flat_map(|model| {
51 let mut targets = Array1::default(arr.nrows());
52 model.predict_inplace(arr, &mut targets);
53 let (v, _) = targets.into_raw_vec_and_offset();
54 v
55 })
56 .collect::<Array1<L>>()
57 .into_shape_with_order((self.models.len(), arr.len_of(Axis(0))))
58 .unwrap()
59 .reversed_axes();
60 }
61
62 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array2<L> {
63 Array2::default((x.nrows(), self.models.len()))
64 }
65}
66
67impl<F, D: Data<Elem = F>, L, P: PredictInplace<ArrayBase<D, Ix2>, Array1<L>> + 'static>
68 FromIterator<P> for MultiTargetModel<ArrayBase<D, Ix2>, L>
69{
70 fn from_iter<I: IntoIterator<Item = P>>(iter: I) -> Self {
71 let models = iter
72 .into_iter()
73 .map(|x| Box::new(x) as Box<dyn PredictInplace<ArrayBase<D, Ix2>, Array1<L>>>)
74 .collect();
75
76 MultiTargetModel { models }
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use crate::{
83 traits::{Predict, PredictInplace},
84 MultiTargetModel,
85 };
86 use approx::assert_abs_diff_eq;
87 use ndarray::{array, Array1, Array2, Axis};
88
89 struct DummyModel {
91 val: f32,
92 }
93
94 impl PredictInplace<Array2<f32>, Array1<f32>> for DummyModel {
95 fn predict_inplace(&self, arr: &Array2<f32>, targets: &mut Array1<f32>) {
96 assert_eq!(
97 arr.nrows(),
98 targets.len(),
99 "The number of data points must match the number of output targets."
100 );
101 targets.fill(self.val);
102 }
103
104 fn default_target(&self, x: &Array2<f32>) -> Array1<f32> {
105 Array1::zeros(x.nrows())
106 }
107 }
108
109 struct DummyModel2 {
111 val: f32,
112 }
113
114 impl PredictInplace<Array2<f32>, Array1<f32>> for DummyModel2 {
115 fn predict_inplace(&self, arr: &Array2<f32>, targets: &mut Array1<f32>) {
116 assert_eq!(
117 arr.nrows(),
118 targets.len(),
119 "The number of data points must match the number of output targets."
120 );
121 *targets = Array1::linspace(
122 self.val,
123 self.val + arr.len_of(Axis(0)) as f32 - 1.0,
124 arr.len_of(Axis(0)),
125 );
126 }
127
128 fn default_target(&self, x: &Array2<f32>) -> Array1<f32> {
129 Array1::zeros(x.nrows())
130 }
131 }
132
133 #[test]
134 fn dummy_constant() {
135 let model = (0..4)
138 .map(|val| val as f32)
139 .map(|val| DummyModel { val })
140 .collect::<MultiTargetModel<_, _>>();
141
142 let targets = model.predict(&Array2::zeros((5, 2)));
144 assert_abs_diff_eq!(
145 targets,
146 array![
147 [0., 1., 2., 3.],
148 [0., 1., 2., 3.],
149 [0., 1., 2., 3.],
150 [0., 1., 2., 3.],
151 [0., 1., 2., 3.],
152 ]
153 );
154 }
155
156 #[test]
157 fn different_dummys() {
158 let model_a = DummyModel { val: 42.0 };
161 let model_b = DummyModel2 { val: 42.0 };
162
163 let model = MultiTargetModel::new(vec![Box::new(model_a), Box::new(model_b)]);
164
165 let targets = model.predict(&Array2::zeros((5, 2)));
166 assert_abs_diff_eq!(
167 targets,
168 array![[42., 42.], [42., 43.], [42., 44.], [42., 45.], [42., 46.]]
169 );
170 }
171}