1use super::{
2 super::traits::{Predict, PredictInplace},
3 iter::{ChunksIter, DatasetIter, Iter},
4 AsSingleTargets, AsTargets, AsTargetsMut, CountedTargets, Dataset, DatasetBase, DatasetView,
5 Float, FromTargetArray, FromTargetArrayOwned, Label, Labels, Records, Result, TargetDim,
6};
7use crate::traits::Fit;
8use ndarray::{concatenate, prelude::*, Data, DataMut, Dimension};
9use rand::{seq::SliceRandom, Rng};
10use std::collections::HashMap;
11use std::ops::AddAssign;
12
13impl<R: Records, S> DatasetBase<R, S> {
18 pub fn new(records: R, targets: S) -> DatasetBase<R, S> {
26 let targets = targets;
27
28 DatasetBase {
29 records,
30 targets,
31 weights: Array1::zeros(0),
32 feature_names: Vec::new(),
33 target_names: Vec::new(),
34 }
35 }
36
37 pub fn targets(&self) -> &S {
39 &self.targets
40 }
41
42 pub fn weights(&self) -> Option<&[f32]> {
44 if !self.weights.is_empty() {
45 Some(self.weights.as_slice().unwrap())
46 } else {
47 None
48 }
49 }
50
51 pub fn weight_for(&self, idx: usize) -> f32 {
56 self.weights.get(idx).copied().unwrap_or(1.0)
57 }
58
59 pub fn feature_names(&self) -> &[String] {
65 &self.feature_names
66 }
67
68 pub fn records(&self) -> &R {
73 &self.records
74 }
75
76 pub fn with_records<T: Records>(self, records: T) -> DatasetBase<T, S> {
81 DatasetBase {
82 records,
83 targets: self.targets,
84 weights: Array1::zeros(0),
85 feature_names: Vec::new(),
86 target_names: Vec::new(),
87 }
88 }
89
90 pub fn with_targets<T>(self, targets: T) -> DatasetBase<R, T> {
94 DatasetBase {
95 records: self.records,
96 targets,
97 weights: self.weights,
98 feature_names: self.feature_names,
99 target_names: self.target_names,
100 }
101 }
102
103 pub fn with_weights(mut self, weights: Array1<f32>) -> DatasetBase<R, S> {
105 self.weights = weights;
106
107 self
108 }
109
110 pub fn with_feature_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, S> {
114 assert!(
115 names.is_empty() || names.len() == self.nfeatures(),
116 "Wrong number of feature names"
117 );
118 self.feature_names = names.into_iter().map(|x| x.into()).collect();
119 self
120 }
121}
122
123impl<X, Y> Dataset<X, Y> {
124 pub fn into_single_target(self) -> Dataset<X, Y, Ix1> {
126 let nsamples = self.records.nsamples();
127 let targets = self.targets.into_shape_with_order(nsamples).unwrap();
128 let features = self.records;
129 Dataset::new(features, targets)
130 }
131}
132
133impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
134 pub fn with_target_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, T> {
138 assert!(
139 names.is_empty() || names.len() == self.ntargets(),
140 "Wrong number of target names"
141 );
142 self.target_names = names.into_iter().map(|x| x.into()).collect();
143 self
144 }
145
146 pub fn map_targets<S, G: FnMut(&L) -> S>(self, fnc: G) -> DatasetBase<R, Array<S, T::Ix>> {
163 let DatasetBase {
164 records,
165 targets,
166 weights,
167 feature_names,
168 target_names,
169 ..
170 } = self;
171
172 let targets = targets.as_targets();
173
174 DatasetBase {
175 records,
176 targets: targets.map(fnc),
177 weights,
178 feature_names,
179 target_names,
180 }
181 }
182
183 pub fn target_names(&self) -> &[String] {
187 &self.target_names
188 }
189
190 pub fn ntargets(&self) -> usize {
201 if T::Ix::NDIM.unwrap() == 1 {
202 1
203 } else {
204 self.targets.as_targets().len_of(Axis(1))
205 }
206 }
207}
208
209impl<'a, F, L, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
210where
211 D: Data<Elem = F>,
212 T: AsTargets<Elem = L>,
213{
214 pub fn sample_iter(&'a self) -> Iter<'a, 'a, F, T::Elem, T::Ix> {
233 Iter::new(self.records.view(), self.targets.as_targets())
234 }
235}
236
237impl<'a, F: 'a, L: 'a, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
238where
239 D: Data<Elem = F>,
240 T: AsTargets<Elem = L> + FromTargetArray<'a>,
241 T::View: AsTargets<Elem = L>,
242{
243 pub fn view(&'a self) -> DatasetBase<ArrayView2<'a, F>, T::View> {
245 let records = self.records().view();
246 let targets = T::new_targets_view(self.as_targets());
247
248 DatasetBase::new(records, targets)
249 .with_feature_names(self.feature_names.clone())
250 .with_weights(self.weights.clone())
251 .with_target_names(self.target_names.clone())
252 }
253
254 pub fn feature_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase<D, Ix2>, T> {
259 DatasetIter::new(self, true)
260 }
261
262 pub fn target_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase<D, Ix2>, T> {
268 DatasetIter::new(self, false)
269 }
270}
271
272impl<L, R: Records, T: AsTargets<Elem = L>> AsTargets for DatasetBase<R, T> {
273 type Elem = L;
274 type Ix = T::Ix;
275
276 fn as_targets(&self) -> ArrayView<'_, Self::Elem, Self::Ix> {
277 self.targets.as_targets()
278 }
279}
280
281impl<L, R: Records, T: AsTargetsMut<Elem = L>> AsTargetsMut for DatasetBase<R, T> {
282 type Elem = L;
283 type Ix = T::Ix;
284
285 fn as_targets_mut(&mut self) -> ArrayViewMut<'_, Self::Elem, Self::Ix> {
286 self.targets.as_targets_mut()
287 }
288}
289
290#[allow(clippy::type_complexity)]
291impl<'a, L: 'a, F, T> DatasetBase<ArrayView2<'a, F>, T>
292where
293 T: AsTargets<Elem = L> + FromTargetArray<'a>,
294 T::View: AsTargets<Elem = L>,
295{
296 pub fn split_with_ratio(
303 &'a self,
304 ratio: f32,
305 ) -> (
306 DatasetBase<ArrayView2<'a, F>, T::View>,
307 DatasetBase<ArrayView2<'a, F>, T::View>,
308 ) {
309 let n = (self.nsamples() as f32 * ratio).ceil() as usize;
310 let (records_first, records_second) = self.records.view().split_at(Axis(0), n);
311 let (targets_first, targets_second) = self.targets.as_targets().split_at(Axis(0), n);
312
313 let targets_first = T::new_targets_view(targets_first);
314 let targets_second = T::new_targets_view(targets_second);
315
316 let (first_weights, second_weights) = if self.weights.len() == self.nsamples() {
317 let a = self.weights.slice(s![..n]).to_vec();
318 let b = self.weights.slice(s![n..]).to_vec();
319
320 (Array1::from(a), Array1::from(b))
321 } else {
322 (Array1::zeros(0), Array1::zeros(0))
323 };
324 let dataset1 = DatasetBase::new(records_first, targets_first)
325 .with_weights(first_weights)
326 .with_feature_names(self.feature_names.clone())
327 .with_target_names(self.target_names.clone());
328
329 let dataset2 = DatasetBase::new(records_second, targets_second)
330 .with_weights(second_weights)
331 .with_feature_names(self.feature_names.clone())
332 .with_target_names(self.target_names.clone());
333
334 (dataset1, dataset2)
335 }
336}
337
338impl<L: Label, T: Labels<Elem = L>, R: Records> Labels for DatasetBase<R, T> {
339 type Elem = L;
340
341 fn label_count(&self) -> Vec<HashMap<L, usize>> {
342 self.targets().label_count()
343 }
344}
345
346#[allow(clippy::type_complexity)]
347impl<F, L: Label, T, D> DatasetBase<ArrayBase<D, Ix2>, T>
348where
349 D: Data<Elem = F>,
350 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
351{
352 pub fn one_vs_all(
357 &self,
358 ) -> Result<
359 Vec<(
360 L,
361 DatasetBase<ArrayView2<'_, F>, CountedTargets<bool, Array1<bool>>>,
362 )>,
363 > {
364 let targets = self.targets().as_single_targets();
365
366 Ok(self
367 .labels()
368 .into_iter()
369 .map(|label| {
370 let targets = targets.iter().map(|x| x == &label).collect::<Array1<_>>();
371
372 let targets = CountedTargets::new(targets);
373
374 (
375 label,
376 DatasetBase::new(self.records().view(), targets)
377 .with_feature_names(self.feature_names.clone())
378 .with_weights(self.weights.clone())
379 .with_target_names(self.target_names.clone()),
380 )
381 })
382 .collect())
383 }
384}
385
386impl<L: Label, R: Records, S: AsTargets<Elem = L>> DatasetBase<R, S> {
387 pub fn label_frequencies_with_mask(&self, mask: &[bool]) -> HashMap<L, f32> {
397 let mut freqs = HashMap::new();
398
399 for (elms, val) in self
400 .targets
401 .as_targets()
402 .axis_iter(Axis(0))
403 .enumerate()
404 .filter(|(i, _)| *mask.get(*i).unwrap_or(&true))
405 .map(|(i, x)| (x, self.weight_for(i)))
406 {
407 for elm in elms {
408 if !freqs.contains_key(elm) {
409 freqs.insert(elm.clone(), 0.0);
410 }
411
412 *freqs.get_mut(elm).unwrap() += val;
413 }
414 }
415
416 freqs
417 }
418
419 pub fn label_frequencies(&self) -> HashMap<L, f32> {
421 self.label_frequencies_with_mask(&[])
422 }
423}
424
425impl<F, D: Data<Elem = F>, I: Dimension> From<ArrayBase<D, I>>
426 for DatasetBase<ArrayBase<D, I>, Array1<()>>
427{
428 fn from(records: ArrayBase<D, I>) -> Self {
429 let empty_targets = Array1::default(records.len_of(Axis(0)));
430 DatasetBase {
431 records,
432 targets: empty_targets,
433 weights: Array1::zeros(0),
434 feature_names: Vec::new(),
435 target_names: Vec::new(),
436 }
437 }
438}
439
440impl<F, E, D, S, I: TargetDim> From<(ArrayBase<D, Ix2>, ArrayBase<S, I>)>
441 for DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, I>>
442where
443 D: Data<Elem = F>,
444 S: Data<Elem = E>,
445{
446 fn from(rec_tar: (ArrayBase<D, Ix2>, ArrayBase<S, I>)) -> Self {
447 DatasetBase {
448 records: rec_tar.0,
449 targets: rec_tar.1,
450 weights: Array1::zeros(0),
451 feature_names: Vec::new(),
452 target_names: Vec::new(),
453 }
454 }
455}
456
457impl<'b, F: Clone, E: Copy + 'b, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
458where
459 D: Data<Elem = F>,
460 T: FromTargetArrayOwned<Elem = E>,
461 T::Owned: AsTargets,
462{
463 pub fn bootstrap<R: Rng>(
480 &'b self,
481 sample_feature_size: (usize, usize),
482 rng: &'b mut R,
483 ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
484 std::iter::repeat(()).map(move |_| {
485 let indices = (0..sample_feature_size.0)
487 .map(|_| rng.gen_range(0..self.nsamples()))
488 .collect::<Vec<_>>();
489
490 let records = self.records().select(Axis(0), &indices);
491 let targets = T::new_targets(self.as_targets().select(Axis(0), &indices));
492
493 let indices = (0..sample_feature_size.1)
494 .map(|_| rng.gen_range(0..self.nfeatures()))
495 .collect::<Vec<_>>();
496
497 let records = records.select(Axis(1), &indices);
498
499 DatasetBase::new(records, targets)
500 })
501 }
502
503 pub fn bootstrap_samples<R: Rng>(
520 &'b self,
521 num_samples: usize,
522 rng: &'b mut R,
523 ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
524 std::iter::repeat(()).map(move |_| {
525 let indices = (0..num_samples)
527 .map(|_| rng.gen_range(0..self.nsamples()))
528 .collect::<Vec<_>>();
529
530 let records = self.records().select(Axis(0), &indices);
531 let targets = T::new_targets(self.as_targets().select(Axis(0), &indices));
532
533 DatasetBase::new(records, targets)
534 })
535 }
536
537 pub fn bootstrap_features<R: Rng>(
554 &'b self,
555 num_features: usize,
556 rng: &'b mut R,
557 ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
558 std::iter::repeat(()).map(move |_| {
559 let targets = T::new_targets(self.as_targets().to_owned());
560
561 let indices = (0..num_features)
562 .map(|_| rng.gen_range(0..self.nfeatures()))
563 .collect::<Vec<_>>();
564
565 let records = self.records.select(Axis(1), &indices);
566
567 DatasetBase::new(records, targets)
568 })
569 }
570
571 pub fn shuffle<R: Rng>(&self, rng: &mut R) -> DatasetBase<Array2<F>, T::Owned> {
582 let mut indices = (0..self.nsamples()).collect::<Vec<_>>();
583 indices.shuffle(rng);
584
585 let records = self.records().select(Axis(0), &indices);
586 let targets = self.as_targets().select(Axis(0), &indices);
587 let targets = T::new_targets(targets);
588
589 DatasetBase::new(records, targets)
590 .with_feature_names(self.feature_names().to_vec())
591 .with_target_names(self.target_names().to_vec())
592 }
593
594 #[allow(clippy::type_complexity)]
595 pub fn fold(
630 &self,
631 k: usize,
632 ) -> Vec<(
633 DatasetBase<Array2<F>, T::Owned>,
634 DatasetBase<Array2<F>, T::Owned>,
635 )> {
636 let targets = self.as_targets();
637 let fold_size = targets.len() / k;
638
639 let mut records_chunks: Vec<_> =
641 self.records.axis_chunks_iter(Axis(0), fold_size).collect();
642 let mut targets_chunks: Vec<_> = targets.axis_chunks_iter(Axis(0), fold_size).collect();
643
644 let mut res = Vec::with_capacity(k);
645 for i in 0..k {
649 let remaining_records = concatenate(Axis(0), &records_chunks.as_slice()[1..]).unwrap();
650 let remaining_targets = concatenate(Axis(0), &targets_chunks.as_slice()[1..]).unwrap();
651
652 res.push((
653 DatasetBase::new(remaining_records, T::new_targets(remaining_targets)),
655 DatasetBase::new(
657 records_chunks[0].into_owned(),
658 T::new_targets(targets_chunks[0].clone().into_owned()),
659 ),
660 ));
661
662 if i < k - 1 {
664 records_chunks.swap(0, i + 1);
665 targets_chunks.swap(0, i + 1);
666 }
667 }
668 res
669 }
670
671 pub fn sample_chunks<'a: 'b>(&'b self, chunk_size: usize) -> ChunksIter<'b, 'a, F, T> {
672 ChunksIter::new(self.records().view(), &self.targets, chunk_size, Axis(0))
673 }
674
675 pub fn to_owned(&self) -> DatasetBase<Array2<F>, T::Owned> {
676 DatasetBase::new(
677 self.records().to_owned(),
678 T::new_targets(self.as_targets().to_owned()),
679 )
680 }
681}
682
683macro_rules! assist_swap_array2 {
684 ($slice: expr, $index: expr, $fold_size: expr, $features: expr) => {
685 if $index != 0 {
686 let adj_fold_size = $fold_size * $features;
687 let start = adj_fold_size * $index;
688 let (first_s, second_s) = $slice.split_at_mut(start);
689 let (mut fold, _) = second_s.split_at_mut(adj_fold_size);
690 first_s[..$fold_size * $features].swap_with_slice(&mut fold);
691 }
692 };
693}
694
695impl<'a, F: 'a + Clone, E: Copy + 'a, D, S, I: TargetDim>
696 DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, I>>
697where
698 D: DataMut<Elem = F>,
699 S: DataMut<Elem = E>,
700{
701 pub fn iter_fold<O, C: Fn(&DatasetView<F, E, I>) -> O>(
764 &'a mut self,
765 k: usize,
766 fit_closure: C,
767 ) -> impl Iterator<Item = (O, DatasetBase<ArrayView2<'a, F>, ArrayView<'a, E, I>>)> {
768 assert!(k > 0);
769 assert!(k <= self.nsamples());
770 let samples_count = self.nsamples();
771 let fold_size = samples_count / k;
772
773 let features = self.nfeatures();
774 let targets = self.ntargets();
775 let tshape = self.targets.raw_dim();
776
777 let mut objs: Vec<O> = Vec::with_capacity(k);
778
779 {
780 let records_sl = self.records.as_slice_mut().unwrap();
781 let mut targets_sl2 = self.targets.as_targets_mut();
782 let targets_sl = targets_sl2.as_slice_mut().unwrap();
783
784 for i in 0..k {
785 assist_swap_array2!(records_sl, i, fold_size, features);
786 assist_swap_array2!(targets_sl, i, fold_size, targets);
787
788 {
789 let train = DatasetBase::new(
790 ArrayView2::from_shape(
791 (samples_count - fold_size, features),
792 records_sl.split_at(fold_size * features).1,
793 )
794 .unwrap(),
795 ArrayView::from_shape(
796 tshape.clone().nsamples(samples_count - fold_size),
797 targets_sl.split_at(fold_size * targets).1,
798 )
799 .unwrap(),
800 );
801
802 let obj = fit_closure(&train);
803 objs.push(obj);
804 }
805
806 assist_swap_array2!(records_sl, i, fold_size, features);
807 assist_swap_array2!(targets_sl, i, fold_size, targets);
808 }
809 }
810
811 objs.into_iter().zip(self.sample_chunks(fold_size))
812 }
813
814 pub fn cross_validate<O, ER, M, FACC, C>(
879 &'a mut self,
880 k: usize,
881 parameters: &[M],
882 eval: C,
883 ) -> std::result::Result<Array<FACC, I>, ER>
884 where
885 ER: std::error::Error + std::convert::From<crate::error::Error>,
886 M: for<'c> Fit<ArrayView2<'c, F>, ArrayView<'c, E, I>, ER, Object = O>,
887 O: for<'d> PredictInplace<ArrayView2<'a, F>, Array<E, I>>,
888 FACC: Float,
889 C: Fn(
890 &Array<E, I>,
891 &ArrayView<E, I>,
892 ) -> std::result::Result<Array<FACC, I::Smaller>, crate::error::Error>,
893 {
894 let mut evaluations = Array::from_elem(
895 self.targets.raw_dim().nsamples(parameters.len()),
896 FACC::zero(),
897 );
898 let folds_evaluations: std::result::Result<Vec<_>, ER> = self
899 .iter_fold(k, |train| {
900 let fit_result: std::result::Result<Vec<_>, ER> =
901 parameters.iter().map(|p| p.fit(train)).collect();
902 fit_result
903 })
904 .map(|(models, valid)| {
905 let targets = valid.targets();
906 let models = models?;
907 let mut eval_predictions =
909 Array::from_elem(targets.raw_dim().nsamples(models.len()), FACC::zero());
910 for (i, model) in models.iter().enumerate() {
911 let predicted = model.predict(valid.records());
912 let eval_pred = match eval(&predicted, targets) {
913 Err(e) => Err(ER::from(e)),
914 Ok(res) => Ok(res),
915 }?;
916 eval_predictions
917 .index_axis_mut(Axis(0), i)
918 .add_assign(&eval_pred);
919 }
920 Ok(eval_predictions)
921 })
922 .collect();
923
924 for fold_evaluation in folds_evaluations? {
925 evaluations.add_assign(&fold_evaluation)
926 }
927 Ok(evaluations / FACC::from(k).unwrap())
928 }
929}
930
931impl<'a, F: 'a + Clone, E: Copy + 'a, D, S> DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, Ix1>>
932where
933 D: DataMut<Elem = F>,
934 S: DataMut<Elem = E>,
935{
936 pub fn cross_validate_single<O, ER, M, FACC, C>(
940 &'a mut self,
941 k: usize,
942 parameters: &[M],
943 eval: C,
944 ) -> std::result::Result<Array1<FACC>, ER>
945 where
946 ER: std::error::Error + std::convert::From<crate::error::Error>,
947 M: for<'c> Fit<ArrayView2<'c, F>, ArrayView1<'c, E>, ER, Object = O>,
948 O: for<'d> PredictInplace<ArrayView2<'a, F>, Array1<E>>,
949 FACC: Float,
950 C: Fn(&Array1<E>, &ArrayView1<E>) -> std::result::Result<FACC, crate::error::Error>,
951 {
952 self.cross_validate(k, parameters, |a, b| eval(a, b).map(arr0))
953 }
954}
955
956impl<F, E, I: TargetDim> Dataset<F, E, I> {
957 pub fn split_with_ratio(mut self, ratio: f32) -> (Self, Self) {
978 assert!(
979 self.records.is_standard_layout(),
980 "records not in row-major layout"
981 );
982 assert!(
983 self.targets.is_standard_layout(),
984 "targets not in row-major layout"
985 );
986
987 let nfeatures = self.nfeatures();
988
989 let n1 = (self.nsamples() as f32 * ratio).ceil() as usize;
990 let n2 = self.nsamples() - n1;
991
992 let feature_names = self.feature_names().to_vec();
993 let target_names = self.target_names().to_vec();
994
995 let (mut array_buf, _) = self.records.into_raw_vec_and_offset();
997 let second_array_buf = array_buf.split_off(n1 * nfeatures);
998
999 let first = Array2::from_shape_vec((n1, nfeatures), array_buf).unwrap();
1000 let second = Array2::from_shape_vec((n2, nfeatures), second_array_buf).unwrap();
1001
1002 let dim1 = self.targets.raw_dim().nsamples(n1);
1004 let dim2 = self.targets.raw_dim().nsamples(n2);
1005 let (mut array_buf, _) = self.targets.into_raw_vec_and_offset();
1006 let second_array_buf = array_buf.split_off(dim1.size());
1007
1008 let first_targets = Array::from_shape_vec(dim1, array_buf).unwrap();
1009 let second_targets = Array::from_shape_vec(dim2, second_array_buf).unwrap();
1010
1011 let second_weights = if self.weights.len() == n1 + n2 {
1013 let (mut weights, _) = self.weights.into_raw_vec_and_offset();
1014
1015 let weights2 = weights.split_off(n1);
1016 self.weights = Array1::from(weights);
1017
1018 Array1::from(weights2)
1019 } else {
1020 Array1::zeros(0)
1021 };
1022
1023 let dataset1 = Dataset::new(first, first_targets)
1025 .with_weights(self.weights)
1026 .with_feature_names(feature_names.clone())
1027 .with_target_names(target_names.clone());
1028 let dataset2 = Dataset::new(second, second_targets)
1029 .with_weights(second_weights)
1030 .with_feature_names(feature_names.clone())
1031 .with_target_names(target_names.clone());
1032
1033 (dataset1, dataset2)
1034 }
1035}
1036
1037impl<F, D, E, T, O> Predict<ArrayBase<D, Ix2>, DatasetBase<ArrayBase<D, Ix2>, T>> for O
1038where
1039 D: Data<Elem = F>,
1040 T: AsTargets<Elem = E>,
1041 O: PredictInplace<ArrayBase<D, Ix2>, T>,
1042{
1043 fn predict(&self, records: ArrayBase<D, Ix2>) -> DatasetBase<ArrayBase<D, Ix2>, T> {
1044 let mut targets = self.default_target(&records);
1045 self.predict_inplace(&records, &mut targets);
1046 DatasetBase::new(records, targets)
1047 }
1048}
1049
1050impl<F, R, T, E, S, O> Predict<DatasetBase<R, T>, DatasetBase<R, S>> for O
1051where
1052 R: Records<Elem = F>,
1053 S: AsTargets<Elem = E>,
1054 O: PredictInplace<R, S>,
1055{
1056 fn predict(&self, ds: DatasetBase<R, T>) -> DatasetBase<R, S> {
1057 let mut targets = self.default_target(&ds.records);
1058 self.predict_inplace(&ds.records, &mut targets);
1059 DatasetBase::new(ds.records, targets)
1060 }
1061}
1062
1063impl<'a, F, R, T, S, O> Predict<&'a DatasetBase<R, T>, S> for O
1064where
1065 R: Records<Elem = F>,
1066 O: PredictInplace<R, S>,
1067{
1068 fn predict(&self, ds: &'a DatasetBase<R, T>) -> S {
1069 let mut targets = self.default_target(&ds.records);
1070 self.predict_inplace(&ds.records, &mut targets);
1071 targets
1072 }
1073}
1074
1075impl<'a, F, D, DM, T, O> Predict<&'a ArrayBase<D, DM>, T> for O
1076where
1077 D: Data<Elem = F>,
1078 DM: Dimension,
1079 O: PredictInplace<ArrayBase<D, DM>, T>,
1080{
1081 fn predict(&self, records: &'a ArrayBase<D, DM>) -> T {
1082 let mut targets = self.default_target(records);
1083 self.predict_inplace(records, &mut targets);
1084 targets
1085 }
1086}
1087
1088impl<L: Label, S: Labels<Elem = L>> CountedTargets<L, S> {
1089 pub fn new(targets: S) -> Self {
1090 let labels = targets.label_count();
1091
1092 CountedTargets { targets, labels }
1093 }
1094}