1use std::collections::HashMap;
2
3use super::{
4 AsMultiTargets, AsMultiTargetsMut, AsProbabilities, AsSingleTargets, AsSingleTargetsMut,
5 AsTargets, AsTargetsMut, CountedTargets, DatasetBase, FromTargetArray, FromTargetArrayOwned,
6 Label, Labels, Pr, TargetDim,
7};
8use ndarray::{
9 Array, Array1, Array2, ArrayBase, ArrayView, ArrayViewMut, Axis, CowArray, Data, DataMut,
10 Dimension, Ix1, Ix2, Ix3, OwnedRepr, ViewRepr,
11};
12
13impl TargetDim for Ix1 {}
14impl TargetDim for Ix2 {}
15
16impl<L, S: Data<Elem = L>, I: TargetDim> AsTargets for ArrayBase<S, I> {
17 type Elem = L;
18 type Ix = I;
19
20 fn as_targets(&self) -> ArrayView<'_, L, I> {
21 self.view()
22 }
23}
24
25impl<T: AsTargets<Ix = Ix1>> AsSingleTargets for T {}
26impl<T: AsTargets<Ix = Ix2>> AsMultiTargets for T {}
27
28impl<L: Clone, S: Data<Elem = L>, I: TargetDim> FromTargetArrayOwned for ArrayBase<S, I> {
29 type Owned = ArrayBase<OwnedRepr<L>, I>;
30
31 fn new_targets(targets: Array<L, I>) -> Self::Owned {
33 targets
34 }
35}
36
37impl<'a, L: Clone + 'a, S: Data<Elem = L>, I: TargetDim> FromTargetArray<'a> for ArrayBase<S, I> {
38 type View = ArrayBase<ViewRepr<&'a L>, I>;
39
40 fn new_targets_view(targets: ArrayView<'a, L, I>) -> Self::View {
42 targets
43 }
44}
45
46impl<L, S: DataMut<Elem = L>, I: TargetDim> AsTargetsMut for ArrayBase<S, I> {
47 type Elem = L;
48 type Ix = I;
49
50 fn as_targets_mut(&mut self) -> ArrayViewMut<'_, Self::Elem, I> {
51 self.view_mut()
52 }
53}
54
55impl<T: AsTargetsMut<Ix = Ix1>> AsSingleTargetsMut for T {}
56impl<T: AsTargetsMut<Ix = Ix2>> AsMultiTargetsMut for T {}
57
58impl<T: AsTargets> AsTargets for &T {
59 type Elem = T::Elem;
60 type Ix = T::Ix;
61
62 fn as_targets(&self) -> ArrayView<'_, Self::Elem, Self::Ix> {
63 (*self).as_targets()
64 }
65}
66
67impl<L: Label, T: AsTargets<Elem = L>> AsTargets for CountedTargets<L, T> {
68 type Elem = L;
69 type Ix = T::Ix;
70
71 fn as_targets(&self) -> ArrayView<'_, Self::Elem, Self::Ix> {
72 self.targets.as_targets()
73 }
74}
75
76impl<L: Label, T: AsTargetsMut<Elem = L>> AsTargetsMut for CountedTargets<L, T> {
77 type Elem = L;
78 type Ix = T::Ix;
79
80 fn as_targets_mut(&mut self) -> ArrayViewMut<'_, Self::Elem, Self::Ix> {
81 self.targets.as_targets_mut()
82 }
83}
84
85impl<L: Label, T> FromTargetArrayOwned for CountedTargets<L, T>
86where
87 T: FromTargetArrayOwned<Elem = L>,
88 T::Owned: Labels<Elem = L>,
89{
90 type Owned = CountedTargets<L, T::Owned>;
91
92 fn new_targets(targets: Array<L, T::Ix>) -> Self::Owned {
93 let targets = T::new_targets(targets);
94 CountedTargets {
95 labels: targets.label_count(),
96 targets,
97 }
98 }
99}
100
101impl<'a, L: Label + 'a, T> FromTargetArray<'a> for CountedTargets<L, T>
102where
103 T: FromTargetArray<'a, Elem = L>,
104 T::View: Labels<Elem = L>,
105{
106 type View = CountedTargets<L, T::View>;
107
108 fn new_targets_view(targets: ArrayView<'a, L, T::Ix>) -> Self::View {
109 let targets = T::new_targets_view(targets);
110
111 CountedTargets {
112 labels: targets.label_count(),
113 targets,
114 }
115 }
116}
117impl<S: Data<Elem = Pr>> AsProbabilities for ArrayBase<S, Ix3> {
138 fn as_multi_target_probabilities(&self) -> CowArray<'_, Pr, Ix3> {
139 CowArray::from(self.view())
140 }
141}
142
143impl<L: Label, S: Data<Elem = L>, I: Dimension> Labels for ArrayBase<S, I> {
145 type Elem = L;
146
147 fn label_count(&self) -> Vec<HashMap<L, usize>> {
148 self.columns()
149 .into_iter()
150 .map(|x| {
151 let mut map = HashMap::new();
152
153 for i in x {
154 *map.entry(i.clone()).or_insert(0) += 1;
155 }
156
157 map
158 })
159 .collect()
160 }
161}
162
163impl<L: Label, T: AsTargets<Elem = L>> Labels for CountedTargets<L, T> {
165 type Elem = L;
166
167 fn label_count(&self) -> Vec<HashMap<L, usize>> {
168 self.labels.clone()
169 }
170}
171
172impl<F: Copy, L: Copy + Label, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
173where
174 D: Data<Elem = F>,
175 T: AsTargets<Elem = L>,
176{
177 #[allow(clippy::type_complexity)]
183 pub fn with_labels(
184 &self,
185 labels: &[L],
186 ) -> DatasetBase<Array2<F>, CountedTargets<L, Array<L, T::Ix>>> {
187 let targets = self.targets.as_targets();
188 let old_weights = self.weights();
189
190 let mut records_arr = Vec::new();
191 let mut targets_arr = Vec::new();
192 let mut weights = Vec::new();
193
194 let mut map = vec![HashMap::new(); self.ntargets()];
195
196 for (i, (r, t)) in self
197 .records()
198 .rows()
199 .into_iter()
200 .zip(targets.axis_iter(Axis(0)))
201 .enumerate()
202 {
203 let any_exists = t.iter().any(|a| labels.contains(a));
204
205 if any_exists {
206 for (map, val) in map.iter_mut().zip(t.iter()) {
207 *map.entry(*val).or_insert(0) += 1;
208 }
209
210 records_arr.push(r);
211 targets_arr.push(t);
212
213 if let Some(weight) = old_weights {
214 weights.push(weight[i]);
215 }
216 }
217 }
218
219 let nsamples = records_arr.len();
220
221 let records_arr = records_arr.into_iter().flatten().copied().collect();
222 let targets_arr = targets_arr.into_iter().flatten().copied().collect();
223
224 let records =
225 Array2::from_shape_vec(self.records.raw_dim().nsamples(nsamples), records_arr).unwrap();
226 let targets = Array::from_shape_vec(
227 self.targets.as_targets().raw_dim().nsamples(nsamples),
228 targets_arr,
229 )
230 .unwrap();
231
232 let targets = CountedTargets {
233 targets,
234 labels: map,
235 };
236
237 DatasetBase {
238 records,
239 weights: Array1::from(weights),
240 targets,
241 feature_names: self.feature_names.clone(),
242 target_names: self.target_names.clone(),
243 }
244 }
245}