linfa/dataset/
impl_targets.rs

1use std::collections::HashMap;
2
3use super::{
4    AsMultiTargets, AsMultiTargetsMut, AsProbabilities, AsSingleTargets, AsSingleTargetsMut,
5    AsTargets, AsTargetsMut, CountedTargets, DatasetBase, FromTargetArray, FromTargetArrayOwned,
6    Label, Labels, Pr, TargetDim,
7};
8use ndarray::{
9    Array, Array1, Array2, ArrayBase, ArrayView, ArrayViewMut, Axis, CowArray, Data, DataMut,
10    Dimension, Ix1, Ix2, Ix3, OwnedRepr, ViewRepr,
11};
12
13impl TargetDim for Ix1 {}
14impl TargetDim for Ix2 {}
15
16impl<L, S: Data<Elem = L>, I: TargetDim> AsTargets for ArrayBase<S, I> {
17    type Elem = L;
18    type Ix = I;
19
20    fn as_targets(&self) -> ArrayView<'_, L, I> {
21        self.view()
22    }
23}
24
25impl<T: AsTargets<Ix = Ix1>> AsSingleTargets for T {}
26impl<T: AsTargets<Ix = Ix2>> AsMultiTargets for T {}
27
28impl<L: Clone, S: Data<Elem = L>, I: TargetDim> FromTargetArrayOwned for ArrayBase<S, I> {
29    type Owned = ArrayBase<OwnedRepr<L>, I>;
30
31    /// Returns an owned representation of the target array
32    fn new_targets(targets: Array<L, I>) -> Self::Owned {
33        targets
34    }
35}
36
37impl<'a, L: Clone + 'a, S: Data<Elem = L>, I: TargetDim> FromTargetArray<'a> for ArrayBase<S, I> {
38    type View = ArrayBase<ViewRepr<&'a L>, I>;
39
40    /// Returns a reference to the target array
41    fn new_targets_view(targets: ArrayView<'a, L, I>) -> Self::View {
42        targets
43    }
44}
45
46impl<L, S: DataMut<Elem = L>, I: TargetDim> AsTargetsMut for ArrayBase<S, I> {
47    type Elem = L;
48    type Ix = I;
49
50    fn as_targets_mut(&mut self) -> ArrayViewMut<'_, Self::Elem, I> {
51        self.view_mut()
52    }
53}
54
55impl<T: AsTargetsMut<Ix = Ix1>> AsSingleTargetsMut for T {}
56impl<T: AsTargetsMut<Ix = Ix2>> AsMultiTargetsMut for T {}
57
58impl<T: AsTargets> AsTargets for &T {
59    type Elem = T::Elem;
60    type Ix = T::Ix;
61
62    fn as_targets(&self) -> ArrayView<'_, Self::Elem, Self::Ix> {
63        (*self).as_targets()
64    }
65}
66
67impl<L: Label, T: AsTargets<Elem = L>> AsTargets for CountedTargets<L, T> {
68    type Elem = L;
69    type Ix = T::Ix;
70
71    fn as_targets(&self) -> ArrayView<'_, Self::Elem, Self::Ix> {
72        self.targets.as_targets()
73    }
74}
75
76impl<L: Label, T: AsTargetsMut<Elem = L>> AsTargetsMut for CountedTargets<L, T> {
77    type Elem = L;
78    type Ix = T::Ix;
79
80    fn as_targets_mut(&mut self) -> ArrayViewMut<'_, Self::Elem, Self::Ix> {
81        self.targets.as_targets_mut()
82    }
83}
84
85impl<L: Label, T> FromTargetArrayOwned for CountedTargets<L, T>
86where
87    T: FromTargetArrayOwned<Elem = L>,
88    T::Owned: Labels<Elem = L>,
89{
90    type Owned = CountedTargets<L, T::Owned>;
91
92    fn new_targets(targets: Array<L, T::Ix>) -> Self::Owned {
93        let targets = T::new_targets(targets);
94        CountedTargets {
95            labels: targets.label_count(),
96            targets,
97        }
98    }
99}
100
101impl<'a, L: Label + 'a, T> FromTargetArray<'a> for CountedTargets<L, T>
102where
103    T: FromTargetArray<'a, Elem = L>,
104    T::View: Labels<Elem = L>,
105{
106    type View = CountedTargets<L, T::View>;
107
108    fn new_targets_view(targets: ArrayView<'a, L, T::Ix>) -> Self::View {
109        let targets = T::new_targets_view(targets);
110
111        CountedTargets {
112            labels: targets.label_count(),
113            targets,
114        }
115    }
116}
117/*
118impl<L: Label, S: Data<Elem = Pr>> AsTargets for TargetsWithLabels<L, ArrayBase<S, Ix3>> {
119    type Elem = L;
120
121    fn as_multi_targets(&self) -> CowArray<L, Ix2> {
122        /*let init_vals = (..self.labels.len()).map(|i| (i, f32::INFINITY)).collect();
123        let res = self.targets.fold_axis(Axis(2), init_vals, |a, b| {
124            if a.1 > b.1 {
125                return b;
126            } else {
127                return a;
128            }
129        });*/
130
131        //let labels = self.labels.into_iter().collect::<Vec<_>>();
132        //res.map_axis(Axis(1), |a| {});
133        panic!("")
134    }
135}*/
136
137impl<S: Data<Elem = Pr>> AsProbabilities for ArrayBase<S, Ix3> {
138    fn as_multi_target_probabilities(&self) -> CowArray<'_, Pr, Ix3> {
139        CowArray::from(self.view())
140    }
141}
142
143/// A NdArray with discrete labels can act as labels
144impl<L: Label, S: Data<Elem = L>, I: Dimension> Labels for ArrayBase<S, I> {
145    type Elem = L;
146
147    fn label_count(&self) -> Vec<HashMap<L, usize>> {
148        self.columns()
149            .into_iter()
150            .map(|x| {
151                let mut map = HashMap::new();
152
153                for i in x {
154                    *map.entry(i.clone()).or_insert(0) += 1;
155                }
156
157                map
158            })
159            .collect()
160    }
161}
162
163/// Counted labels can act as labels
164impl<L: Label, T: AsTargets<Elem = L>> Labels for CountedTargets<L, T> {
165    type Elem = L;
166
167    fn label_count(&self) -> Vec<HashMap<L, usize>> {
168        self.labels.clone()
169    }
170}
171
172impl<F: Copy, L: Copy + Label, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
173where
174    D: Data<Elem = F>,
175    T: AsTargets<Elem = L>,
176{
177    /// Transforms the input dataset by keeping only those samples whose label appears in `labels`.
178    ///
179    /// In the multi-target case a sample is kept if *any* of its targets appears in `labels`.
180    ///
181    /// Sample weights and feature names are preserved by this transformation.
182    #[allow(clippy::type_complexity)]
183    pub fn with_labels(
184        &self,
185        labels: &[L],
186    ) -> DatasetBase<Array2<F>, CountedTargets<L, Array<L, T::Ix>>> {
187        let targets = self.targets.as_targets();
188        let old_weights = self.weights();
189
190        let mut records_arr = Vec::new();
191        let mut targets_arr = Vec::new();
192        let mut weights = Vec::new();
193
194        let mut map = vec![HashMap::new(); self.ntargets()];
195
196        for (i, (r, t)) in self
197            .records()
198            .rows()
199            .into_iter()
200            .zip(targets.axis_iter(Axis(0)))
201            .enumerate()
202        {
203            let any_exists = t.iter().any(|a| labels.contains(a));
204
205            if any_exists {
206                for (map, val) in map.iter_mut().zip(t.iter()) {
207                    *map.entry(*val).or_insert(0) += 1;
208                }
209
210                records_arr.push(r);
211                targets_arr.push(t);
212
213                if let Some(weight) = old_weights {
214                    weights.push(weight[i]);
215                }
216            }
217        }
218
219        let nsamples = records_arr.len();
220
221        let records_arr = records_arr.into_iter().flatten().copied().collect();
222        let targets_arr = targets_arr.into_iter().flatten().copied().collect();
223
224        let records =
225            Array2::from_shape_vec(self.records.raw_dim().nsamples(nsamples), records_arr).unwrap();
226        let targets = Array::from_shape_vec(
227            self.targets.as_targets().raw_dim().nsamples(nsamples),
228            targets_arr,
229        )
230        .unwrap();
231
232        let targets = CountedTargets {
233            targets,
234            labels: map,
235        };
236
237        DatasetBase {
238            records,
239            weights: Array1::from(weights),
240            targets,
241            feature_names: self.feature_names.clone(),
242            target_names: self.target_names.clone(),
243        }
244    }
245}