linfa/composing/
multi_class_model.rs1use crate::dataset::{Pr, Records};
4use crate::traits::PredictInplace;
5use ndarray::{Array1, ArrayBase, Data, Ix2};
6use std::iter::FromIterator;
7
8type MultiClassVec<R, L> = Vec<(L, Box<dyn PredictInplace<R, Array1<Pr>>>)>;
9
10pub struct MultiClassModel<R: Records, L> {
12 models: MultiClassVec<R, L>,
13}
14
15impl<R: Records, L> MultiClassModel<R, L> {
16 pub fn new(models: MultiClassVec<R, L>) -> Self {
17 MultiClassModel { models }
18 }
19}
20
21impl<L: Clone + Default, F, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<L>>
22 for MultiClassModel<ArrayBase<D, Ix2>, L>
23{
24 fn predict_inplace(&self, arr: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
25 assert_eq!(
26 arr.nrows(),
27 y.len(),
28 "The number of data points must match the number of output targets."
29 );
30
31 let mut res = Vec::new();
32
33 for pairs in self.models.iter().map(|(elm, model)| {
34 let mut targets = Array1::default(arr.nrows());
35 model.predict_inplace(arr, &mut targets);
36
37 targets.into_iter().map(|x| (elm.clone(), *x)).collect()
38 }) {
39 if res.is_empty() {
41 res = pairs;
42 continue;
43 }
44
45 res = res
48 .into_iter()
49 .zip(pairs.into_iter())
50 .map(|(c, d)| if d.1 > c.1 { d } else { c })
51 .collect();
52 }
53
54 for (r, target) in res.into_iter().zip(y.iter_mut()) {
56 *target = r.0;
57 }
58 }
59
60 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
61 Array1::default(x.nrows())
62 }
63}
64
65impl<F, D: Data<Elem = F>, L, P: PredictInplace<ArrayBase<D, Ix2>, Array1<Pr>> + 'static>
66 FromIterator<(L, P)> for MultiClassModel<ArrayBase<D, Ix2>, L>
67{
68 fn from_iter<I: IntoIterator<Item = (L, P)>>(iter: I) -> Self {
69 let models = iter
70 .into_iter()
71 .map(|(l, x)| {
72 (
73 l,
74 Box::new(x) as Box<dyn PredictInplace<ArrayBase<D, Ix2>, Array1<Pr>>>,
75 )
76 })
77 .collect();
78
79 MultiClassModel { models }
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use crate::{
86 dataset::Pr,
87 traits::{Predict, PredictInplace},
88 MultiClassModel,
89 };
90 use ndarray::{array, Array1, Array2};
91
92 struct DummyModel {
94 on_even: bool,
95 }
96
97 impl PredictInplace<Array2<f32>, Array1<Pr>> for DummyModel {
98 fn predict_inplace(&self, arr: &Array2<f32>, targets: &mut Array1<Pr>) {
99 assert_eq!(
100 arr.nrows(),
101 targets.len(),
102 "The number of data points must match the number of output targets."
103 );
104
105 if !self.on_even {
106 *targets = Array1::from_shape_fn(arr.nrows(), |x| {
107 if x % 2 == 1 {
108 Pr::new(1.0)
109 } else {
110 Pr::new(0.0)
111 }
112 });
113 } else {
114 *targets = Array1::from_shape_fn(arr.nrows(), |x| {
115 if x % 2 == 1 {
116 Pr::new(0.0)
117 } else {
118 Pr::new(1.0)
119 }
120 });
121 }
122 }
123
124 fn default_target(&self, x: &Array2<f32>) -> Array1<Pr> {
125 Array1::default(x.nrows())
126 }
127 }
128
129 #[test]
130 fn correct_dummies() {
131 let model1 = DummyModel { on_even: false };
132 let model2 = DummyModel { on_even: true };
133
134 let data = Array2::zeros((4, 2));
135 assert_eq!(
136 model1.predict(&data),
137 array![0.0, 1.0, 0.0, 1.0].mapv(Pr::new)
138 );
139 assert_eq!(
140 model2.predict(&data),
141 array![1.0, 0.0, 1.0, 0.0].mapv(Pr::new)
142 );
143 }
144
145 #[test]
146 fn choose_correct() {
147 let model = vec![
148 (0, DummyModel { on_even: false }),
149 (1, DummyModel { on_even: true }),
150 ]
151 .into_iter()
152 .collect::<MultiClassModel<_, usize>>();
153
154 let data = Array2::zeros((4, 2));
155 assert_eq!(model.predict(&data), array![1, 0, 1, 0]);
156 }
157}