linfa/composing/
multi_class_model.rs

1//! Merge models with binary to multi-class classification
2//!
3use crate::dataset::{Pr, Records};
4use crate::traits::PredictInplace;
5use ndarray::{Array1, ArrayBase, Data, Ix2};
6use std::iter::FromIterator;
7
8type MultiClassVec<R, L> = Vec<(L, Box<dyn PredictInplace<R, Array1<Pr>>>)>;
9
10/// Merge models with binary to multi-class classification
11pub struct MultiClassModel<R: Records, L> {
12    models: MultiClassVec<R, L>,
13}
14
15impl<R: Records, L> MultiClassModel<R, L> {
16    pub fn new(models: MultiClassVec<R, L>) -> Self {
17        MultiClassModel { models }
18    }
19}
20
21impl<L: Clone + Default, F, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<L>>
22    for MultiClassModel<ArrayBase<D, Ix2>, L>
23{
24    fn predict_inplace(&self, arr: &ArrayBase<D, Ix2>, y: &mut Array1<L>) {
25        assert_eq!(
26            arr.nrows(),
27            y.len(),
28            "The number of data points must match the number of output targets."
29        );
30
31        let mut res = Vec::new();
32
33        for pairs in self.models.iter().map(|(elm, model)| {
34            let mut targets = Array1::default(arr.nrows());
35            model.predict_inplace(arr, &mut targets);
36
37            targets.into_iter().map(|x| (elm.clone(), *x)).collect()
38        }) {
39            // initialize result with guess of first model
40            if res.is_empty() {
41                res = pairs;
42                continue;
43            }
44
45            // compare probability to each subsequent model and replace label
46            // if probability is higher
47            res = res
48                .into_iter()
49                .zip(pairs.into_iter())
50                .map(|(c, d)| if d.1 > c.1 { d } else { c })
51                .collect();
52        }
53
54        // remove probabilities from array and convert to `Array1`
55        for (r, target) in res.into_iter().zip(y.iter_mut()) {
56            *target = r.0;
57        }
58    }
59
60    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<L> {
61        Array1::default(x.nrows())
62    }
63}
64
65impl<F, D: Data<Elem = F>, L, P: PredictInplace<ArrayBase<D, Ix2>, Array1<Pr>> + 'static>
66    FromIterator<(L, P)> for MultiClassModel<ArrayBase<D, Ix2>, L>
67{
68    fn from_iter<I: IntoIterator<Item = (L, P)>>(iter: I) -> Self {
69        let models = iter
70            .into_iter()
71            .map(|(l, x)| {
72                (
73                    l,
74                    Box::new(x) as Box<dyn PredictInplace<ArrayBase<D, Ix2>, Array1<Pr>>>,
75                )
76            })
77            .collect();
78
79        MultiClassModel { models }
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use crate::{
86        dataset::Pr,
87        traits::{Predict, PredictInplace},
88        MultiClassModel,
89    };
90    use ndarray::{array, Array1, Array2};
91
92    /// First dummy model, returns probability 1 for odd items
93    struct DummyModel {
94        on_even: bool,
95    }
96
97    impl PredictInplace<Array2<f32>, Array1<Pr>> for DummyModel {
98        fn predict_inplace(&self, arr: &Array2<f32>, targets: &mut Array1<Pr>) {
99            assert_eq!(
100                arr.nrows(),
101                targets.len(),
102                "The number of data points must match the number of output targets."
103            );
104
105            if !self.on_even {
106                *targets = Array1::from_shape_fn(arr.nrows(), |x| {
107                    if x % 2 == 1 {
108                        Pr::new(1.0)
109                    } else {
110                        Pr::new(0.0)
111                    }
112                });
113            } else {
114                *targets = Array1::from_shape_fn(arr.nrows(), |x| {
115                    if x % 2 == 1 {
116                        Pr::new(0.0)
117                    } else {
118                        Pr::new(1.0)
119                    }
120                });
121            }
122        }
123
124        fn default_target(&self, x: &Array2<f32>) -> Array1<Pr> {
125            Array1::default(x.nrows())
126        }
127    }
128
129    #[test]
130    fn correct_dummies() {
131        let model1 = DummyModel { on_even: false };
132        let model2 = DummyModel { on_even: true };
133
134        let data = Array2::zeros((4, 2));
135        assert_eq!(
136            model1.predict(&data),
137            array![0.0, 1.0, 0.0, 1.0].mapv(Pr::new)
138        );
139        assert_eq!(
140            model2.predict(&data),
141            array![1.0, 0.0, 1.0, 0.0].mapv(Pr::new)
142        );
143    }
144
145    #[test]
146    fn choose_correct() {
147        let model = vec![
148            (0, DummyModel { on_even: false }),
149            (1, DummyModel { on_even: true }),
150        ]
151        .into_iter()
152        .collect::<MultiClassModel<_, usize>>();
153
154        let data = Array2::zeros((4, 2));
155        assert_eq!(model.predict(&data), array![1, 0, 1, 0]);
156    }
157}