linfa/dataset/
impl_dataset.rs

1use super::{
2    super::traits::{Predict, PredictInplace},
3    iter::{ChunksIter, DatasetIter, Iter},
4    AsSingleTargets, AsTargets, AsTargetsMut, CountedTargets, Dataset, DatasetBase, DatasetView,
5    Float, FromTargetArray, FromTargetArrayOwned, Label, Labels, Records, Result, TargetDim,
6};
7use crate::traits::Fit;
8use ndarray::{concatenate, prelude::*, Data, DataMut, Dimension};
9use rand::{seq::SliceRandom, Rng};
10use std::collections::HashMap;
11use std::ops::AddAssign;
12
13/// Implementation without constraints on records and targets
14///
15/// This implementation block provides methods for the creation and mutation of datasets. This
16/// includes swapping the targets, return the records etc.
17impl<R: Records, S> DatasetBase<R, S> {
18    /// Create a new dataset from records and targets
19    ///
20    /// # Example
21    ///
22    /// ```ignore
23    /// let dataset = Dataset::new(records, targets);
24    /// ```
25    pub fn new(records: R, targets: S) -> DatasetBase<R, S> {
26        let targets = targets;
27
28        DatasetBase {
29            records,
30            targets,
31            weights: Array1::zeros(0),
32            feature_names: Vec::new(),
33            target_names: Vec::new(),
34        }
35    }
36
37    /// Returns reference to targets
38    pub fn targets(&self) -> &S {
39        &self.targets
40    }
41
42    /// Returns optionally weights
43    pub fn weights(&self) -> Option<&[f32]> {
44        if !self.weights.is_empty() {
45            Some(self.weights.as_slice().unwrap())
46        } else {
47            None
48        }
49    }
50
51    /// Return a single weight
52    ///
53    /// The weight of the `idx`th observation is returned. If no weight is specified, then all
54    /// observations are unweighted with default value `1.0`.
55    pub fn weight_for(&self, idx: usize) -> f32 {
56        self.weights.get(idx).copied().unwrap_or(1.0)
57    }
58
59    /// Returns feature names
60    ///
61    /// A feature name gives a human-readable string describing the purpose of a single feature.
62    /// This allow the reader to understand its purpose while analysing results, for example
63    /// correlation analysis or feature importance.
64    pub fn feature_names(&self) -> &[String] {
65        &self.feature_names
66    }
67
68    /// Return records of a dataset
69    ///
70    /// The records are data points from which predictions are made. This functions returns a
71    /// reference to the record field.
72    pub fn records(&self) -> &R {
73        &self.records
74    }
75
76    /// Updates the records of a dataset
77    ///
78    /// This function overwrites the records in a dataset. It also invalidates the weights and
79    /// feature/target names.
80    pub fn with_records<T: Records>(self, records: T) -> DatasetBase<T, S> {
81        DatasetBase {
82            records,
83            targets: self.targets,
84            weights: Array1::zeros(0),
85            feature_names: Vec::new(),
86            target_names: Vec::new(),
87        }
88    }
89
90    /// Updates the targets of a dataset
91    ///
92    /// This function overwrites the targets in a dataset.
93    pub fn with_targets<T>(self, targets: T) -> DatasetBase<R, T> {
94        DatasetBase {
95            records: self.records,
96            targets,
97            weights: self.weights,
98            feature_names: self.feature_names,
99            target_names: self.target_names,
100        }
101    }
102
103    /// Updates the weights of a dataset
104    pub fn with_weights(mut self, weights: Array1<f32>) -> DatasetBase<R, S> {
105        self.weights = weights;
106
107        self
108    }
109
110    /// Updates the feature names of a dataset
111    ///
112    /// **Panics** when given names not empty and length does not equal to the number of features
113    pub fn with_feature_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, S> {
114        assert!(
115            names.is_empty() || names.len() == self.nfeatures(),
116            "Wrong number of feature names"
117        );
118        self.feature_names = names.into_iter().map(|x| x.into()).collect();
119        self
120    }
121}
122
123impl<X, Y> Dataset<X, Y> {
124    // Convert 2D targets to 1D. Only works for targets with shape of form [X, 1], panics otherwise.
125    pub fn into_single_target(self) -> Dataset<X, Y, Ix1> {
126        let nsamples = self.records.nsamples();
127        let targets = self.targets.into_shape_with_order(nsamples).unwrap();
128        let features = self.records;
129        Dataset::new(features, targets)
130    }
131}
132
133impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
134    /// Updates the target names of a dataset
135    ///
136    /// **Panics**  when given names not empty and length does not equal to the number of targets
137    pub fn with_target_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, T> {
138        assert!(
139            names.is_empty() || names.len() == self.ntargets(),
140            "Wrong number of target names"
141        );
142        self.target_names = names.into_iter().map(|x| x.into()).collect();
143        self
144    }
145
146    /// Map targets with a function `f`
147    ///
148    /// # Example
149    ///
150    /// ```
151    /// let dataset = linfa_datasets::winequality()
152    ///     .map_targets(|x| *x > 6);
153    ///
154    /// // dataset has now boolean targets
155    /// println!("{:?}", dataset.targets());
156    /// ```
157    ///
158    /// # Returns
159    ///
160    /// A modified dataset with new target type.
161    ///
162    pub fn map_targets<S, G: FnMut(&L) -> S>(self, fnc: G) -> DatasetBase<R, Array<S, T::Ix>> {
163        let DatasetBase {
164            records,
165            targets,
166            weights,
167            feature_names,
168            target_names,
169            ..
170        } = self;
171
172        let targets = targets.as_targets();
173
174        DatasetBase {
175            records,
176            targets: targets.map(fnc),
177            weights,
178            feature_names,
179            target_names,
180        }
181    }
182
183    /// Returns target names
184    ///
185    /// A target name gives a human-readable string describing the purpose of a single target.
186    pub fn target_names(&self) -> &[String] {
187        &self.target_names
188    }
189
190    /// Return the number of targets in the dataset
191    ///
192    /// # Example
193    ///
194    /// ```
195    /// let dataset = linfa_datasets::winequality();
196    ///
197    /// println!("#targets {}", dataset.ntargets());
198    /// ```
199    ///
200    pub fn ntargets(&self) -> usize {
201        if T::Ix::NDIM.unwrap() == 1 {
202            1
203        } else {
204            self.targets.as_targets().len_of(Axis(1))
205        }
206    }
207}
208
209impl<'a, F, L, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
210where
211    D: Data<Elem = F>,
212    T: AsTargets<Elem = L>,
213{
214    /// Iterate over observations
215    ///
216    /// This function creates an iterator which produces tuples of data points and target value. The
217    /// iterator runs once for each data point and, while doing so, holds an reference to the owned
218    /// dataset.
219    ///
220    /// For multi-target datasets, the yielded target value is `ArrayView1` consisting of the
221    /// different targets. For single-target datasets, the target value is `ArrayView0` containing
222    /// the single target.
223    ///
224    /// # Example
225    /// ```
226    /// let dataset = linfa_datasets::iris();
227    ///
228    /// for (x, y) in dataset.sample_iter() {
229    ///     println!("{} => {}", x, y);
230    /// }
231    /// ```
232    pub fn sample_iter(&'a self) -> Iter<'a, 'a, F, T::Elem, T::Ix> {
233        Iter::new(self.records.view(), self.targets.as_targets())
234    }
235}
236
237impl<'a, F: 'a, L: 'a, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
238where
239    D: Data<Elem = F>,
240    T: AsTargets<Elem = L> + FromTargetArray<'a>,
241    T::View: AsTargets<Elem = L>,
242{
243    /// Creates a view of a dataset
244    pub fn view(&'a self) -> DatasetBase<ArrayView2<'a, F>, T::View> {
245        let records = self.records().view();
246        let targets = T::new_targets_view(self.as_targets());
247
248        DatasetBase::new(records, targets)
249            .with_feature_names(self.feature_names.clone())
250            .with_weights(self.weights.clone())
251            .with_target_names(self.target_names.clone())
252    }
253
254    /// Iterate over features
255    ///
256    /// This iterator produces dataset views with only a single feature, while the set of targets remain
257    /// complete. It can be useful to compare each feature individual to all targets.
258    pub fn feature_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase<D, Ix2>, T> {
259        DatasetIter::new(self, true)
260    }
261
262    /// Iterate over targets
263    ///
264    /// This functions creates an iterator which produces dataset views complete records, but only
265    /// a single target each. Useful to train multiple single target models for a multi-target
266    /// dataset.
267    pub fn target_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase<D, Ix2>, T> {
268        DatasetIter::new(self, false)
269    }
270}
271
272impl<L, R: Records, T: AsTargets<Elem = L>> AsTargets for DatasetBase<R, T> {
273    type Elem = L;
274    type Ix = T::Ix;
275
276    fn as_targets(&self) -> ArrayView<'_, Self::Elem, Self::Ix> {
277        self.targets.as_targets()
278    }
279}
280
281impl<L, R: Records, T: AsTargetsMut<Elem = L>> AsTargetsMut for DatasetBase<R, T> {
282    type Elem = L;
283    type Ix = T::Ix;
284
285    fn as_targets_mut(&mut self) -> ArrayViewMut<'_, Self::Elem, Self::Ix> {
286        self.targets.as_targets_mut()
287    }
288}
289
290#[allow(clippy::type_complexity)]
291impl<'a, L: 'a, F, T> DatasetBase<ArrayView2<'a, F>, T>
292where
293    T: AsTargets<Elem = L> + FromTargetArray<'a>,
294    T::View: AsTargets<Elem = L>,
295{
296    /// Split dataset into two disjoint chunks
297    ///
298    /// This function splits the observations in a dataset into two disjoint chunks. The splitting
299    /// threshold is calculated with the `ratio`. For example a ratio of `0.9` allocates 90% to the
300    /// first chunks and 9% to the second. This is often used in training, validation splitting
301    /// procedures.
302    pub fn split_with_ratio(
303        &'a self,
304        ratio: f32,
305    ) -> (
306        DatasetBase<ArrayView2<'a, F>, T::View>,
307        DatasetBase<ArrayView2<'a, F>, T::View>,
308    ) {
309        let n = (self.nsamples() as f32 * ratio).ceil() as usize;
310        let (records_first, records_second) = self.records.view().split_at(Axis(0), n);
311        let (targets_first, targets_second) = self.targets.as_targets().split_at(Axis(0), n);
312
313        let targets_first = T::new_targets_view(targets_first);
314        let targets_second = T::new_targets_view(targets_second);
315
316        let (first_weights, second_weights) = if self.weights.len() == self.nsamples() {
317            let a = self.weights.slice(s![..n]).to_vec();
318            let b = self.weights.slice(s![n..]).to_vec();
319
320            (Array1::from(a), Array1::from(b))
321        } else {
322            (Array1::zeros(0), Array1::zeros(0))
323        };
324        let dataset1 = DatasetBase::new(records_first, targets_first)
325            .with_weights(first_weights)
326            .with_feature_names(self.feature_names.clone())
327            .with_target_names(self.target_names.clone());
328
329        let dataset2 = DatasetBase::new(records_second, targets_second)
330            .with_weights(second_weights)
331            .with_feature_names(self.feature_names.clone())
332            .with_target_names(self.target_names.clone());
333
334        (dataset1, dataset2)
335    }
336}
337
338impl<L: Label, T: Labels<Elem = L>, R: Records> Labels for DatasetBase<R, T> {
339    type Elem = L;
340
341    fn label_count(&self) -> Vec<HashMap<L, usize>> {
342        self.targets().label_count()
343    }
344}
345
346#[allow(clippy::type_complexity)]
347impl<F, L: Label, T, D> DatasetBase<ArrayBase<D, Ix2>, T>
348where
349    D: Data<Elem = F>,
350    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
351{
352    /// Produce N boolean targets from multi-class targets
353    ///
354    /// Some algorithms (like SVM) don't support multi-class targets. This function splits a
355    /// dataset into multiple binary single-target views of the same dataset.
356    pub fn one_vs_all(
357        &self,
358    ) -> Result<
359        Vec<(
360            L,
361            DatasetBase<ArrayView2<'_, F>, CountedTargets<bool, Array1<bool>>>,
362        )>,
363    > {
364        let targets = self.targets().as_single_targets();
365
366        Ok(self
367            .labels()
368            .into_iter()
369            .map(|label| {
370                let targets = targets.iter().map(|x| x == &label).collect::<Array1<_>>();
371
372                let targets = CountedTargets::new(targets);
373
374                (
375                    label,
376                    DatasetBase::new(self.records().view(), targets)
377                        .with_feature_names(self.feature_names.clone())
378                        .with_weights(self.weights.clone())
379                        .with_target_names(self.target_names.clone()),
380                )
381            })
382            .collect())
383    }
384}
385
386impl<L: Label, R: Records, S: AsTargets<Elem = L>> DatasetBase<R, S> {
387    /// Calculates label frequencies from a dataset while masking certain samples.
388    ///
389    /// ### Parameters
390    ///
391    /// * `mask`: a boolean array that specifies which samples to include in the count
392    ///
393    /// ### Returns
394    ///
395    /// A mapping of the Dataset's samples to their frequencies
396    pub fn label_frequencies_with_mask(&self, mask: &[bool]) -> HashMap<L, f32> {
397        let mut freqs = HashMap::new();
398
399        for (elms, val) in self
400            .targets
401            .as_targets()
402            .axis_iter(Axis(0))
403            .enumerate()
404            .filter(|(i, _)| *mask.get(*i).unwrap_or(&true))
405            .map(|(i, x)| (x, self.weight_for(i)))
406        {
407            for elm in elms {
408                if !freqs.contains_key(elm) {
409                    freqs.insert(elm.clone(), 0.0);
410                }
411
412                *freqs.get_mut(elm).unwrap() += val;
413            }
414        }
415
416        freqs
417    }
418
419    /// Calculates label frequencies from a dataset
420    pub fn label_frequencies(&self) -> HashMap<L, f32> {
421        self.label_frequencies_with_mask(&[])
422    }
423}
424
425impl<F, D: Data<Elem = F>, I: Dimension> From<ArrayBase<D, I>>
426    for DatasetBase<ArrayBase<D, I>, Array1<()>>
427{
428    fn from(records: ArrayBase<D, I>) -> Self {
429        let empty_targets = Array1::default(records.len_of(Axis(0)));
430        DatasetBase {
431            records,
432            targets: empty_targets,
433            weights: Array1::zeros(0),
434            feature_names: Vec::new(),
435            target_names: Vec::new(),
436        }
437    }
438}
439
440impl<F, E, D, S, I: TargetDim> From<(ArrayBase<D, Ix2>, ArrayBase<S, I>)>
441    for DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, I>>
442where
443    D: Data<Elem = F>,
444    S: Data<Elem = E>,
445{
446    fn from(rec_tar: (ArrayBase<D, Ix2>, ArrayBase<S, I>)) -> Self {
447        DatasetBase {
448            records: rec_tar.0,
449            targets: rec_tar.1,
450            weights: Array1::zeros(0),
451            feature_names: Vec::new(),
452            target_names: Vec::new(),
453        }
454    }
455}
456
457impl<'b, F: Clone, E: Copy + 'b, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
458where
459    D: Data<Elem = F>,
460    T: FromTargetArrayOwned<Elem = E>,
461    T::Owned: AsTargets,
462{
463    /// Apply bootstrapping for samples and features
464    ///
465    /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
466    /// stability of machine learning algorithms. It samples data uniformly with replacement and
467    /// generates datasets where elements may be shared. This selects a subset of observations as
468    /// well as features.
469    ///
470    /// # Parameters
471    ///
472    ///  * `sample_feature_size`: The number of samples and features per bootstrap
473    ///  * `rng`: The random number generator used in the sampling procedure
474    ///
475    ///  # Returns
476    ///
477    ///  An infinite Iterator yielding at each step a new bootstrapped dataset
478    ///
479    pub fn bootstrap<R: Rng>(
480        &'b self,
481        sample_feature_size: (usize, usize),
482        rng: &'b mut R,
483    ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
484        std::iter::repeat(()).map(move |_| {
485            // sample with replacement
486            let indices = (0..sample_feature_size.0)
487                .map(|_| rng.gen_range(0..self.nsamples()))
488                .collect::<Vec<_>>();
489
490            let records = self.records().select(Axis(0), &indices);
491            let targets = T::new_targets(self.as_targets().select(Axis(0), &indices));
492
493            let indices = (0..sample_feature_size.1)
494                .map(|_| rng.gen_range(0..self.nfeatures()))
495                .collect::<Vec<_>>();
496
497            let records = records.select(Axis(1), &indices);
498
499            DatasetBase::new(records, targets)
500        })
501    }
502
503    /// Apply bootstrapping for samples and features
504    ///
505    /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
506    /// stability of machine learning algorithms. It samples data uniformly with replacement and
507    /// generates datasets where elements may be shared. This selects a subset of observations as
508    /// well as features.
509    ///
510    /// # Parameters
511    ///
512    ///  * `sample_feature_size`: The number of samples and features per bootstrap
513    ///  * `rng`: The random number generator used in the sampling procedure
514    ///
515    ///  # Returns
516    ///
517    ///  An infinite Iterator yielding at each step a tuple containing a bootstrapped dataset with
518    ///  a vector of the sampled data indices and sampled feature.
519    ///
520    #[allow(clippy::type_complexity)]
521    pub fn bootstrap_with_indices<R: Rng>(
522        &'b self,
523        sample_feature_size: (usize, usize),
524        rng: &'b mut R,
525    ) -> impl Iterator<Item = (DatasetBase<Array2<F>, T::Owned>, Vec<usize>, Vec<usize>)> + 'b {
526        std::iter::repeat(()).map(move |_| {
527            // sample with replacement
528            let data_indices = (0..sample_feature_size.0)
529                .map(|_| rng.gen_range(0..self.nsamples()))
530                .collect::<Vec<_>>();
531
532            let records = self.records().select(Axis(0), &data_indices);
533            let targets = T::new_targets(self.as_targets().select(Axis(0), &data_indices));
534
535            let feat_indices = (0..sample_feature_size.1)
536                .map(|_| rng.gen_range(0..self.nfeatures()))
537                .collect::<Vec<_>>();
538
539            let records = records.select(Axis(1), &feat_indices);
540
541            (
542                DatasetBase::new(records, targets),
543                data_indices,
544                feat_indices,
545            )
546        })
547    }
548
549    /// Apply sample bootstrapping
550    ///
551    /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
552    /// stability of machine learning algorithms. It samples data uniformly with replacement and
553    /// generates datasets where elements may be shared. Only a sample subset is selected which
554    /// retains all features and targets.
555    ///
556    /// # Parameters
557    ///
558    ///  * `num_samples`: The number of samples per bootstrap
559    ///  * `rng`: The random number generator used in the sampling procedure
560    ///
561    ///  # Returns
562    ///
563    ///  An infinite Iterator yielding at each step a new bootstrapped dataset
564    ///
565    pub fn bootstrap_samples<R: Rng>(
566        &'b self,
567        num_samples: usize,
568        rng: &'b mut R,
569    ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
570        std::iter::repeat(()).map(move |_| {
571            // sample with replacement
572            let indices = (0..num_samples)
573                .map(|_| rng.gen_range(0..self.nsamples()))
574                .collect::<Vec<_>>();
575
576            let records = self.records().select(Axis(0), &indices);
577            let targets = T::new_targets(self.as_targets().select(Axis(0), &indices));
578
579            DatasetBase::new(records, targets)
580        })
581    }
582
583    /// Apply sample bootstrapping
584    ///
585    /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
586    /// stability of machine learning algorithms. It samples data uniformly with replacement and
587    /// generates datasets where elements may be shared. Only a sample subset is selected which
588    /// retains all features and targets.
589    ///
590    /// # Parameters
591    ///
592    ///  * `num_samples`: The number of samples per bootstrap
593    ///  * `rng`: The random number generator used in the sampling procedure
594    ///
595    ///  # Returns
596    ///
597    ///  An infinite Iterator yielding at each step a new bootstrapped dataset and the sampled
598    ///  indices.
599    ///
600    pub fn bootstrap_samples_with_indices<R: Rng>(
601        &'b self,
602        num_samples: usize,
603        rng: &'b mut R,
604    ) -> impl Iterator<Item = (DatasetBase<Array2<F>, T::Owned>, Vec<usize>)> + 'b {
605        std::iter::repeat(()).map(move |_| {
606            // sample with replacement
607            let indices = (0..num_samples)
608                .map(|_| rng.gen_range(0..self.nsamples()))
609                .collect::<Vec<_>>();
610
611            let records = self.records().select(Axis(0), &indices);
612            let targets = T::new_targets(self.as_targets().select(Axis(0), &indices));
613
614            (DatasetBase::new(records, targets), indices)
615        })
616    }
617
618    /// Apply feature bootstrapping
619    ///
620    /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
621    /// stability of machine learning algorithms. It samples data uniformly with replacement and
622    /// generates datasets where elements may be shared. Only a feature subset is selected while
623    /// retaining all samples and targets.
624    ///
625    /// # Parameters
626    ///
627    ///  * `num_features`: The number of features per bootstrap
628    ///  * `rng`: The random number generator used in the sampling procedure
629    ///
630    ///  # Returns
631    ///
632    ///  An infinite Iterator yielding at each step a new bootstrapped dataset
633    ///
634    pub fn bootstrap_features<R: Rng>(
635        &'b self,
636        num_features: usize,
637        rng: &'b mut R,
638    ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
639        std::iter::repeat(()).map(move |_| {
640            let targets = T::new_targets(self.as_targets().to_owned());
641
642            let indices = (0..num_features)
643                .map(|_| rng.gen_range(0..self.nfeatures()))
644                .collect::<Vec<_>>();
645
646            let records = self.records.select(Axis(1), &indices);
647
648            DatasetBase::new(records, targets)
649        })
650    }
651
652    /// Apply feature bootstrapping
653    ///
654    /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
655    /// stability of machine learning algorithms. It samples data uniformly with replacement and
656    /// generates datasets where elements may be shared. Only a feature subset is selected while
657    /// retaining all samples and targets.
658    ///
659    /// # Parameters
660    ///
661    ///  * `num_features`: The number of features per bootstrap
662    ///  * `rng`: The random number generator used in the sampling procedure
663    ///
664    ///  # Returns
665    ///
666    ///  An infinite Iterator yielding at each step a new bootstrapped dataset with the indices of
667    ///  the features sampled
668    ///
669    pub fn bootstrap_features_with_indices<R: Rng>(
670        &'b self,
671        num_features: usize,
672        rng: &'b mut R,
673    ) -> impl Iterator<Item = (DatasetBase<Array2<F>, T::Owned>, Vec<usize>)> + 'b {
674        std::iter::repeat(()).map(move |_| {
675            let targets = T::new_targets(self.as_targets().to_owned());
676
677            let indices = (0..num_features)
678                .map(|_| rng.gen_range(0..self.nfeatures()))
679                .collect::<Vec<_>>();
680
681            let records = self.records.select(Axis(1), &indices);
682
683            (DatasetBase::new(records, targets), indices)
684        })
685    }
686
687    /// Produces a shuffled version of the current Dataset.
688    ///
689    /// ### Parameters
690    ///
691    /// * `rng`: the random number generator that will be used to shuffle the samples
692    ///
693    /// ### Returns
694    ///
695    /// A new shuffled version of the current Dataset
696    ///
697    pub fn shuffle<R: Rng>(&self, rng: &mut R) -> DatasetBase<Array2<F>, T::Owned> {
698        let mut indices = (0..self.nsamples()).collect::<Vec<_>>();
699        indices.shuffle(rng);
700
701        let records = self.records().select(Axis(0), &indices);
702        let targets = self.as_targets().select(Axis(0), &indices);
703        let targets = T::new_targets(targets);
704
705        DatasetBase::new(records, targets)
706            .with_feature_names(self.feature_names().to_vec())
707            .with_target_names(self.target_names().to_vec())
708    }
709
710    #[allow(clippy::type_complexity)]
711    /// Performs K-folding on the dataset.
712    ///
713    /// The dataset is divided into `k` "folds", each containing `(dataset size)/k` samples, used
714    /// to generate `k` training-validation dataset pairs. Each pair contains a validation
715    /// `Dataset` with `k` samples, the ones contained in the i-th fold, and a training `Dataset`
716    /// composed by the union of all the samples in the remaining folds.
717    ///
718    /// ### Parameters
719    ///
720    /// * `k`: the number of folds to apply
721    ///
722    /// ### Returns
723    ///
724    /// A vector of `k` training-validation Dataset pairs.
725    ///
726    /// ### Example
727    ///
728    /// ```rust
729    /// use linfa::dataset::DatasetView;
730    /// use ndarray::{Ix1, array};
731    ///
732    /// let records = array![[1.,1.], [2.,1.], [3.,2.], [4.,1.],[5., 3.], [6.,2.]];
733    /// let targets = array![1, 1, 0, 1, 0, 0];
734    ///
735    /// let dataset : DatasetView<f64, usize, Ix1> = (records.view(), targets.view()).into();
736    /// let accuracies = dataset.fold(3).into_iter().map(|(train, valid)| {
737    ///     // Here you can train your model and perform validation
738    ///     
739    ///     // let model = params.fit(&dataset);
740    ///     // let predi = model.predict(&valid);
741    ///     // predi.confusion_matrix(&valid).accuracy()  
742    /// });
743    /// ```
744    ///  
745    pub fn fold(
746        &self,
747        k: usize,
748    ) -> Vec<(
749        DatasetBase<Array2<F>, T::Owned>,
750        DatasetBase<Array2<F>, T::Owned>,
751    )> {
752        let targets = self.as_targets();
753        let fold_size = targets.len() / k;
754
755        // Generates all k folds of records and targets
756        let mut records_chunks: Vec<_> =
757            self.records.axis_chunks_iter(Axis(0), fold_size).collect();
758        let mut targets_chunks: Vec<_> = targets.axis_chunks_iter(Axis(0), fold_size).collect();
759
760        let mut res = Vec::with_capacity(k);
761        // For each iteration, take the first chunk for both records and targets as the validation set and
762        // concatenate all the other chunks to create the training set. In the end swap the first chunk with the
763        // one in the next index so that it is ready for the next iteration
764        for i in 0..k {
765            let remaining_records = concatenate(Axis(0), &records_chunks.as_slice()[1..]).unwrap();
766            let remaining_targets = concatenate(Axis(0), &targets_chunks.as_slice()[1..]).unwrap();
767
768            res.push((
769                // training
770                DatasetBase::new(remaining_records, T::new_targets(remaining_targets)),
771                // validation
772                DatasetBase::new(
773                    records_chunks[0].into_owned(),
774                    T::new_targets(targets_chunks[0].clone().into_owned()),
775                ),
776            ));
777
778            // swap
779            if i < k - 1 {
780                records_chunks.swap(0, i + 1);
781                targets_chunks.swap(0, i + 1);
782            }
783        }
784        res
785    }
786
787    pub fn sample_chunks<'a: 'b>(&'b self, chunk_size: usize) -> ChunksIter<'b, 'a, F, T> {
788        ChunksIter::new(self.records().view(), &self.targets, chunk_size, Axis(0))
789    }
790
791    pub fn to_owned(&self) -> DatasetBase<Array2<F>, T::Owned> {
792        DatasetBase::new(
793            self.records().to_owned(),
794            T::new_targets(self.as_targets().to_owned()),
795        )
796    }
797}
798
799macro_rules! assist_swap_array2 {
800    ($slice: expr, $index: expr, $fold_size: expr, $features: expr) => {
801        if $index != 0 {
802            let adj_fold_size = $fold_size * $features;
803            let start = adj_fold_size * $index;
804            let (first_s, second_s) = $slice.split_at_mut(start);
805            let (mut fold, _) = second_s.split_at_mut(adj_fold_size);
806            first_s[..$fold_size * $features].swap_with_slice(&mut fold);
807        }
808    };
809}
810
811impl<'a, F: 'a + Clone, E: Copy + 'a, D, S, I: TargetDim>
812    DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, I>>
813where
814    D: DataMut<Elem = F>,
815    S: DataMut<Elem = E>,
816{
817    /// Performs k-folding cross validation on fittable algorithms.
818    ///
819    /// Given a dataset as input, a value of k and the desired params for the fittable
820    /// algorithm, returns an iterator over the k trained models and the
821    /// associated validation set.
822    ///
823    /// The models are trained according to a closure specified
824    /// as an input.
825    ///
826    /// ## Parameters
827    ///
828    /// - `k`: the number of folds to apply to the dataset
829    /// - `params`: the desired parameters for the fittable algorithm at hand
830    /// - `fit_closure`: a closure of the type `(params, training_data) -> fitted_model`
831    ///   that will be used to produce the trained model for each fold. The training data given in input
832    ///   won't outlive the closure.
833    ///
834    /// ## Returns
835    ///
836    /// An iterator over couples `(trained_model, validation_set)`.
837    ///
838    /// ## Panics
839    ///
840    /// This method will panic for any of the following three reasons:
841    ///
842    /// - The value of `k` provided is not positive;
843    /// - The value of `k` provided is greater than the total number of samples in the dataset;
844    /// - The dataset's data is not stored contiguously and in standard order;
845    ///
846    /// ## Example
847    /// ```rust
848    /// use linfa::traits::Fit;
849    /// use linfa::dataset::{Dataset, DatasetView, Records};
850    /// use ndarray::{array, ArrayView1, ArrayView2, Ix1};
851    /// use linfa::Error;
852    ///
853    /// struct MockFittable {}
854    ///
855    /// struct MockFittableResult {
856    ///    mock_var: usize,
857    /// }
858    ///
859    /// impl<'a> Fit<ArrayView2<'a,f64>, ArrayView1<'a, f64>, linfa::error::Error> for MockFittable {
860    ///     type Object = MockFittableResult;
861    ///
862    ///     fn fit(&self, training_data: &DatasetView<f64, f64, Ix1>) -> Result<Self::Object, linfa::error::Error> {
863    ///         Ok(MockFittableResult {
864    ///             mock_var: training_data.nsamples(),
865    ///         })
866    ///     }
867    /// }
868    ///
869    /// let records = array![[1.,1.], [2.,2.], [3.,3.], [4.,4.], [5.,5.]];
870    /// let targets = array![1.,2.,3.,4.,5.];
871    /// let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
872    /// let params = MockFittable {};
873    ///
874    /// for (model,validation_set) in dataset.iter_fold(5, |v| params.fit(v).unwrap()){
875    ///     // Here you can use `model` and `validation_set` to
876    ///     // assert the performance of the chosen algorithm
877    /// }
878    /// ```
879    pub fn iter_fold<O, C: Fn(&DatasetView<F, E, I>) -> O>(
880        &'a mut self,
881        k: usize,
882        fit_closure: C,
883    ) -> impl Iterator<Item = (O, DatasetBase<ArrayView2<'a, F>, ArrayView<'a, E, I>>)> {
884        assert!(k > 0);
885        assert!(k <= self.nsamples());
886        let samples_count = self.nsamples();
887        let fold_size = samples_count / k;
888
889        let features = self.nfeatures();
890        let targets = self.ntargets();
891        let tshape = self.targets.raw_dim();
892
893        let mut objs: Vec<O> = Vec::with_capacity(k);
894
895        {
896            let records_sl = self.records.as_slice_mut().unwrap();
897            let mut targets_sl2 = self.targets.as_targets_mut();
898            let targets_sl = targets_sl2.as_slice_mut().unwrap();
899
900            for i in 0..k {
901                assist_swap_array2!(records_sl, i, fold_size, features);
902                assist_swap_array2!(targets_sl, i, fold_size, targets);
903
904                {
905                    let train = DatasetBase::new(
906                        ArrayView2::from_shape(
907                            (samples_count - fold_size, features),
908                            records_sl.split_at(fold_size * features).1,
909                        )
910                        .unwrap(),
911                        ArrayView::from_shape(
912                            tshape.clone().nsamples(samples_count - fold_size),
913                            targets_sl.split_at(fold_size * targets).1,
914                        )
915                        .unwrap(),
916                    );
917
918                    let obj = fit_closure(&train);
919                    objs.push(obj);
920                }
921
922                assist_swap_array2!(records_sl, i, fold_size, features);
923                assist_swap_array2!(targets_sl, i, fold_size, targets);
924            }
925        }
926
927        objs.into_iter().zip(self.sample_chunks(fold_size))
928    }
929
930    /// Cross validation for single and multi-target algorithms
931    ///
932    /// Given a list of fittable models, cross validation is used to compare their performance
933    /// according to some performance metric. To do so, k-folding is applied to the dataset and,
934    /// for each fold, each model is trained on the training set and its performance is evaluated
935    /// on the validation set. The performances collected for each model are then averaged over the
936    /// number of folds.
937    ///
938    /// For single-target datasets, [`Dataset::cross_validate_single`] is recommended.
939    ///
940    /// ### Parameters:
941    ///
942    /// - `k`: the number of folds to apply
943    /// - `parameters`: a list of models to compare
944    /// - `eval`: closure used to evaluate the performance of each trained model. This closure is
945    ///   called on the model output and validation targets of each fold and outputs the performance
946    ///   score for each target. For single-target dataset the signature is `(Array1, Array1) ->
947    ///   Array0`. For multi-target dataset the signature is `(Array2, Array2) -> Array1`.
948    ///
949    /// ### Returns
950    ///
951    /// An array of model performances, for each model and each target, if no errors occur.
952    /// For multi-target dataset, the array has dimensions `(n_models, n_targets)`. For
953    /// single-target dataset, the array has dimensions `(n_models)`.
954    /// Otherwise, it might return an Error in one of the following cases:
955    ///
956    /// - An error occurred during the fitting of one model
957    /// - An error occurred inside the evaluation closure
958    ///
959    /// ### Example
960    ///
961    /// ```rust, ignore
962    ///
963    /// use linfa::prelude::*;
964    /// use ndarray::arr0;
965    /// # use ndarray::{array, ArrayView1, ArrayView2, Ix1};
966    ///
967    /// # struct MockFittable {}
968    ///
969    /// # struct MockFittableResult {
970    /// #    mock_var: usize,
971    /// # }
972    ///
973    /// # impl<'a> Fit<ArrayView2<'a,f64>, ArrayView1<'a, f64>, linfa::error::Error> for MockFittable {
974    /// #     type Object = MockFittableResult;
975    ///
976    /// #     fn fit(&self, training_data: &DatasetView<f64, f64, Ix1>) -> Result<Self::Object, linfa::error::Error> {
977    /// #         Ok(MockFittableResult {
978    /// #             mock_var: training_data.nsamples(),
979    /// #         })
980    /// #     }
981    /// # }
982    ///
983    /// # let model1 = MockFittable {};
984    /// # let model2 = MockFittable {};
985    ///
986    /// // mutability needed for fast cross validation
987    /// let mut dataset = linfa_datasets::diabetes();
988    ///
989    /// let models = vec![model1, model2];
990    ///
991    /// let r2_scores = dataset.cross_validate(5, &models, |prediction, truth| prediction.r2(truth).map(arr0))?;
992    ///
993    /// ```
994    pub fn cross_validate<O, ER, M, FACC, C>(
995        &'a mut self,
996        k: usize,
997        parameters: &[M],
998        eval: C,
999    ) -> std::result::Result<Array<FACC, I>, ER>
1000    where
1001        ER: std::error::Error + std::convert::From<crate::error::Error>,
1002        M: for<'c> Fit<ArrayView2<'c, F>, ArrayView<'c, E, I>, ER, Object = O>,
1003        O: for<'d> PredictInplace<ArrayView2<'a, F>, Array<E, I>>,
1004        FACC: Float,
1005        C: Fn(
1006            &Array<E, I>,
1007            &ArrayView<E, I>,
1008        ) -> std::result::Result<Array<FACC, I::Smaller>, crate::error::Error>,
1009    {
1010        let mut evaluations = Array::from_elem(
1011            self.targets.raw_dim().nsamples(parameters.len()),
1012            FACC::zero(),
1013        );
1014        let folds_evaluations: std::result::Result<Vec<_>, ER> = self
1015            .iter_fold(k, |train| {
1016                let fit_result: std::result::Result<Vec<_>, ER> =
1017                    parameters.iter().map(|p| p.fit(train)).collect();
1018                fit_result
1019            })
1020            .map(|(models, valid)| {
1021                let targets = valid.targets();
1022                let models = models?;
1023                // XXX diverges from master branch
1024                let mut eval_predictions =
1025                    Array::from_elem(targets.raw_dim().nsamples(models.len()), FACC::zero());
1026                for (i, model) in models.iter().enumerate() {
1027                    let predicted = model.predict(valid.records());
1028                    let eval_pred = match eval(&predicted, targets) {
1029                        Err(e) => Err(ER::from(e)),
1030                        Ok(res) => Ok(res),
1031                    }?;
1032                    eval_predictions
1033                        .index_axis_mut(Axis(0), i)
1034                        .add_assign(&eval_pred);
1035                }
1036                Ok(eval_predictions)
1037            })
1038            .collect();
1039
1040        for fold_evaluation in folds_evaluations? {
1041            evaluations.add_assign(&fold_evaluation)
1042        }
1043        Ok(evaluations / FACC::from(k).unwrap())
1044    }
1045}
1046
1047impl<'a, F: 'a + Clone, E: Copy + 'a, D, S> DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, Ix1>>
1048where
1049    D: DataMut<Elem = F>,
1050    S: DataMut<Elem = E>,
1051{
1052    /// Specialized version of `cross_validate` for single-target datasets. Allows the evaluation
1053    /// closure to return a float without wrapping it in `arr0`. See [`Dataset::cross_validate`] for
1054    /// more details.
1055    pub fn cross_validate_single<O, ER, M, FACC, C>(
1056        &'a mut self,
1057        k: usize,
1058        parameters: &[M],
1059        eval: C,
1060    ) -> std::result::Result<Array1<FACC>, ER>
1061    where
1062        ER: std::error::Error + std::convert::From<crate::error::Error>,
1063        M: for<'c> Fit<ArrayView2<'c, F>, ArrayView1<'c, E>, ER, Object = O>,
1064        O: for<'d> PredictInplace<ArrayView2<'a, F>, Array1<E>>,
1065        FACC: Float,
1066        C: Fn(&Array1<E>, &ArrayView1<E>) -> std::result::Result<FACC, crate::error::Error>,
1067    {
1068        self.cross_validate(k, parameters, |a, b| eval(a, b).map(arr0))
1069    }
1070}
1071
1072impl<F, E, I: TargetDim> Dataset<F, E, I> {
1073    /// Split dataset into two disjoint chunks
1074    ///
1075    /// This function splits the observations in a dataset into two disjoint chunks. The splitting
1076    /// threshold is calculated with the `ratio`. If the input Dataset contains `n` samples then the
1077    /// two new Datasets will have respectively `n * ratio` and `n - (n*ratio)` samples.
1078    /// For example a ratio of `0.9` allocates 90% to the
1079    /// first chunks and 10% to the second. This is often used in training, validation splitting
1080    /// procedures.
1081    ///
1082    /// ### Parameters
1083    ///
1084    /// * `ratio`: the ratio of samples in the input Dataset to include in the first output one
1085    ///
1086    /// ### Returns
1087    ///  
1088    /// The input Dataset split into two according to the input ratio.
1089    ///
1090    /// ### Panics
1091    ///
1092    /// Panic occurs when the input record or targets are not in row-major layout.
1093    pub fn split_with_ratio(mut self, ratio: f32) -> (Self, Self) {
1094        assert!(
1095            self.records.is_standard_layout(),
1096            "records not in row-major layout"
1097        );
1098        assert!(
1099            self.targets.is_standard_layout(),
1100            "targets not in row-major layout"
1101        );
1102
1103        let nfeatures = self.nfeatures();
1104
1105        let n1 = (self.nsamples() as f32 * ratio).ceil() as usize;
1106        let n2 = self.nsamples() - n1;
1107
1108        let feature_names = self.feature_names().to_vec();
1109        let target_names = self.target_names().to_vec();
1110
1111        // split records into two disjoint arrays
1112        let (mut array_buf, _) = self.records.into_raw_vec_and_offset();
1113        let second_array_buf = array_buf.split_off(n1 * nfeatures);
1114
1115        let first = Array2::from_shape_vec((n1, nfeatures), array_buf).unwrap();
1116        let second = Array2::from_shape_vec((n2, nfeatures), second_array_buf).unwrap();
1117
1118        // split targets into two disjoint Vec
1119        let dim1 = self.targets.raw_dim().nsamples(n1);
1120        let dim2 = self.targets.raw_dim().nsamples(n2);
1121        let (mut array_buf, _) = self.targets.into_raw_vec_and_offset();
1122        let second_array_buf = array_buf.split_off(dim1.size());
1123
1124        let first_targets = Array::from_shape_vec(dim1, array_buf).unwrap();
1125        let second_targets = Array::from_shape_vec(dim2, second_array_buf).unwrap();
1126
1127        // split weights into two disjoint Vec
1128        let second_weights = if self.weights.len() == n1 + n2 {
1129            let (mut weights, _) = self.weights.into_raw_vec_and_offset();
1130
1131            let weights2 = weights.split_off(n1);
1132            self.weights = Array1::from(weights);
1133
1134            Array1::from(weights2)
1135        } else {
1136            Array1::zeros(0)
1137        };
1138
1139        // create new datasets with attached weights
1140        let dataset1 = Dataset::new(first, first_targets)
1141            .with_weights(self.weights)
1142            .with_feature_names(feature_names.clone())
1143            .with_target_names(target_names.clone());
1144        let dataset2 = Dataset::new(second, second_targets)
1145            .with_weights(second_weights)
1146            .with_feature_names(feature_names.clone())
1147            .with_target_names(target_names.clone());
1148
1149        (dataset1, dataset2)
1150    }
1151}
1152
1153impl<F, D, E, T, O> Predict<ArrayBase<D, Ix2>, DatasetBase<ArrayBase<D, Ix2>, T>> for O
1154where
1155    D: Data<Elem = F>,
1156    T: AsTargets<Elem = E>,
1157    O: PredictInplace<ArrayBase<D, Ix2>, T>,
1158{
1159    fn predict(&self, records: ArrayBase<D, Ix2>) -> DatasetBase<ArrayBase<D, Ix2>, T> {
1160        let mut targets = self.default_target(&records);
1161        self.predict_inplace(&records, &mut targets);
1162        DatasetBase::new(records, targets)
1163    }
1164}
1165
1166impl<F, R, T, E, S, O> Predict<DatasetBase<R, T>, DatasetBase<R, S>> for O
1167where
1168    R: Records<Elem = F>,
1169    S: AsTargets<Elem = E>,
1170    O: PredictInplace<R, S>,
1171{
1172    fn predict(&self, ds: DatasetBase<R, T>) -> DatasetBase<R, S> {
1173        let mut targets = self.default_target(&ds.records);
1174        self.predict_inplace(&ds.records, &mut targets);
1175        DatasetBase::new(ds.records, targets)
1176    }
1177}
1178
1179impl<'a, F, R, T, S, O> Predict<&'a DatasetBase<R, T>, S> for O
1180where
1181    R: Records<Elem = F>,
1182    O: PredictInplace<R, S>,
1183{
1184    fn predict(&self, ds: &'a DatasetBase<R, T>) -> S {
1185        let mut targets = self.default_target(&ds.records);
1186        self.predict_inplace(&ds.records, &mut targets);
1187        targets
1188    }
1189}
1190
1191impl<'a, F, D, DM, T, O> Predict<&'a ArrayBase<D, DM>, T> for O
1192where
1193    D: Data<Elem = F>,
1194    DM: Dimension,
1195    O: PredictInplace<ArrayBase<D, DM>, T>,
1196{
1197    fn predict(&self, records: &'a ArrayBase<D, DM>) -> T {
1198        let mut targets = self.default_target(records);
1199        self.predict_inplace(records, &mut targets);
1200        targets
1201    }
1202}
1203
1204impl<L: Label, S: Labels<Elem = L>> CountedTargets<L, S> {
1205    pub fn new(targets: S) -> Self {
1206        let labels = targets.label_count();
1207
1208        CountedTargets { targets, labels }
1209    }
1210}