linfa/dataset/
impl_dataset.rs

1use super::{
2    super::traits::{Predict, PredictInplace},
3    iter::{ChunksIter, DatasetIter, Iter},
4    AsSingleTargets, AsTargets, AsTargetsMut, CountedTargets, Dataset, DatasetBase, DatasetView,
5    Float, FromTargetArray, FromTargetArrayOwned, Label, Labels, Records, Result, TargetDim,
6};
7use crate::traits::Fit;
8use ndarray::{concatenate, prelude::*, Data, DataMut, Dimension};
9use rand::{seq::SliceRandom, Rng};
10use std::collections::HashMap;
11use std::ops::AddAssign;
12
13/// Implementation without constraints on records and targets
14///
15/// This implementation block provides methods for the creation and mutation of datasets. This
16/// includes swapping the targets, return the records etc.
17impl<R: Records, S> DatasetBase<R, S> {
18    /// Create a new dataset from records and targets
19    ///
20    /// # Example
21    ///
22    /// ```ignore
23    /// let dataset = Dataset::new(records, targets);
24    /// ```
25    pub fn new(records: R, targets: S) -> DatasetBase<R, S> {
26        let targets = targets;
27
28        DatasetBase {
29            records,
30            targets,
31            weights: Array1::zeros(0),
32            feature_names: Vec::new(),
33            target_names: Vec::new(),
34        }
35    }
36
37    /// Returns reference to targets
38    pub fn targets(&self) -> &S {
39        &self.targets
40    }
41
42    /// Returns optionally weights
43    pub fn weights(&self) -> Option<&[f32]> {
44        if !self.weights.is_empty() {
45            Some(self.weights.as_slice().unwrap())
46        } else {
47            None
48        }
49    }
50
51    /// Return a single weight
52    ///
53    /// The weight of the `idx`th observation is returned. If no weight is specified, then all
54    /// observations are unweighted with default value `1.0`.
55    pub fn weight_for(&self, idx: usize) -> f32 {
56        self.weights.get(idx).copied().unwrap_or(1.0)
57    }
58
59    /// Returns feature names
60    ///
61    /// A feature name gives a human-readable string describing the purpose of a single feature.
62    /// This allow the reader to understand its purpose while analysing results, for example
63    /// correlation analysis or feature importance.
64    pub fn feature_names(&self) -> &[String] {
65        &self.feature_names
66    }
67
68    /// Return records of a dataset
69    ///
70    /// The records are data points from which predictions are made. This functions returns a
71    /// reference to the record field.
72    pub fn records(&self) -> &R {
73        &self.records
74    }
75
76    /// Updates the records of a dataset
77    ///
78    /// This function overwrites the records in a dataset. It also invalidates the weights and
79    /// feature/target names.
80    pub fn with_records<T: Records>(self, records: T) -> DatasetBase<T, S> {
81        DatasetBase {
82            records,
83            targets: self.targets,
84            weights: Array1::zeros(0),
85            feature_names: Vec::new(),
86            target_names: Vec::new(),
87        }
88    }
89
90    /// Updates the targets of a dataset
91    ///
92    /// This function overwrites the targets in a dataset.
93    pub fn with_targets<T>(self, targets: T) -> DatasetBase<R, T> {
94        DatasetBase {
95            records: self.records,
96            targets,
97            weights: self.weights,
98            feature_names: self.feature_names,
99            target_names: self.target_names,
100        }
101    }
102
103    /// Updates the weights of a dataset
104    pub fn with_weights(mut self, weights: Array1<f32>) -> DatasetBase<R, S> {
105        self.weights = weights;
106
107        self
108    }
109
110    /// Updates the feature names of a dataset
111    ///
112    /// **Panics** when given names not empty and length does not equal to the number of features
113    pub fn with_feature_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, S> {
114        assert!(
115            names.is_empty() || names.len() == self.nfeatures(),
116            "Wrong number of feature names"
117        );
118        self.feature_names = names.into_iter().map(|x| x.into()).collect();
119        self
120    }
121}
122
123impl<X, Y> Dataset<X, Y> {
124    // Convert 2D targets to 1D. Only works for targets with shape of form [X, 1], panics otherwise.
125    pub fn into_single_target(self) -> Dataset<X, Y, Ix1> {
126        let nsamples = self.records.nsamples();
127        let targets = self.targets.into_shape_with_order(nsamples).unwrap();
128        let features = self.records;
129        Dataset::new(features, targets)
130    }
131}
132
133impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
134    /// Updates the target names of a dataset
135    ///
136    /// **Panics**  when given names not empty and length does not equal to the number of targets
137    pub fn with_target_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, T> {
138        assert!(
139            names.is_empty() || names.len() == self.ntargets(),
140            "Wrong number of target names"
141        );
142        self.target_names = names.into_iter().map(|x| x.into()).collect();
143        self
144    }
145
146    /// Map targets with a function `f`
147    ///
148    /// # Example
149    ///
150    /// ```
151    /// let dataset = linfa_datasets::winequality()
152    ///     .map_targets(|x| *x > 6);
153    ///
154    /// // dataset has now boolean targets
155    /// println!("{:?}", dataset.targets());
156    /// ```
157    ///
158    /// # Returns
159    ///
160    /// A modified dataset with new target type.
161    ///
162    pub fn map_targets<S, G: FnMut(&L) -> S>(self, fnc: G) -> DatasetBase<R, Array<S, T::Ix>> {
163        let DatasetBase {
164            records,
165            targets,
166            weights,
167            feature_names,
168            target_names,
169            ..
170        } = self;
171
172        let targets = targets.as_targets();
173
174        DatasetBase {
175            records,
176            targets: targets.map(fnc),
177            weights,
178            feature_names,
179            target_names,
180        }
181    }
182
183    /// Returns target names
184    ///
185    /// A target name gives a human-readable string describing the purpose of a single target.
186    pub fn target_names(&self) -> &[String] {
187        &self.target_names
188    }
189
190    /// Return the number of targets in the dataset
191    ///
192    /// # Example
193    ///
194    /// ```
195    /// let dataset = linfa_datasets::winequality();
196    ///
197    /// println!("#targets {}", dataset.ntargets());
198    /// ```
199    ///
200    pub fn ntargets(&self) -> usize {
201        if T::Ix::NDIM.unwrap() == 1 {
202            1
203        } else {
204            self.targets.as_targets().len_of(Axis(1))
205        }
206    }
207}
208
209impl<'a, F, L, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
210where
211    D: Data<Elem = F>,
212    T: AsTargets<Elem = L>,
213{
214    /// Iterate over observations
215    ///
216    /// This function creates an iterator which produces tuples of data points and target value. The
217    /// iterator runs once for each data point and, while doing so, holds an reference to the owned
218    /// dataset.
219    ///
220    /// For multi-target datasets, the yielded target value is `ArrayView1` consisting of the
221    /// different targets. For single-target datasets, the target value is `ArrayView0` containing
222    /// the single target.
223    ///
224    /// # Example
225    /// ```
226    /// let dataset = linfa_datasets::iris();
227    ///
228    /// for (x, y) in dataset.sample_iter() {
229    ///     println!("{} => {}", x, y);
230    /// }
231    /// ```
232    pub fn sample_iter(&'a self) -> Iter<'a, 'a, F, T::Elem, T::Ix> {
233        Iter::new(self.records.view(), self.targets.as_targets())
234    }
235}
236
237impl<'a, F: 'a, L: 'a, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
238where
239    D: Data<Elem = F>,
240    T: AsTargets<Elem = L> + FromTargetArray<'a>,
241    T::View: AsTargets<Elem = L>,
242{
243    /// Creates a view of a dataset
244    pub fn view(&'a self) -> DatasetBase<ArrayView2<'a, F>, T::View> {
245        let records = self.records().view();
246        let targets = T::new_targets_view(self.as_targets());
247
248        DatasetBase::new(records, targets)
249            .with_feature_names(self.feature_names.clone())
250            .with_weights(self.weights.clone())
251            .with_target_names(self.target_names.clone())
252    }
253
254    /// Iterate over features
255    ///
256    /// This iterator produces dataset views with only a single feature, while the set of targets remain
257    /// complete. It can be useful to compare each feature individual to all targets.
258    pub fn feature_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase<D, Ix2>, T> {
259        DatasetIter::new(self, true)
260    }
261
262    /// Iterate over targets
263    ///
264    /// This functions creates an iterator which produces dataset views complete records, but only
265    /// a single target each. Useful to train multiple single target models for a multi-target
266    /// dataset.
267    pub fn target_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase<D, Ix2>, T> {
268        DatasetIter::new(self, false)
269    }
270}
271
272impl<L, R: Records, T: AsTargets<Elem = L>> AsTargets for DatasetBase<R, T> {
273    type Elem = L;
274    type Ix = T::Ix;
275
276    fn as_targets(&self) -> ArrayView<'_, Self::Elem, Self::Ix> {
277        self.targets.as_targets()
278    }
279}
280
281impl<L, R: Records, T: AsTargetsMut<Elem = L>> AsTargetsMut for DatasetBase<R, T> {
282    type Elem = L;
283    type Ix = T::Ix;
284
285    fn as_targets_mut(&mut self) -> ArrayViewMut<'_, Self::Elem, Self::Ix> {
286        self.targets.as_targets_mut()
287    }
288}
289
290#[allow(clippy::type_complexity)]
291impl<'a, L: 'a, F, T> DatasetBase<ArrayView2<'a, F>, T>
292where
293    T: AsTargets<Elem = L> + FromTargetArray<'a>,
294    T::View: AsTargets<Elem = L>,
295{
296    /// Split dataset into two disjoint chunks
297    ///
298    /// This function splits the observations in a dataset into two disjoint chunks. The splitting
299    /// threshold is calculated with the `ratio`. For example a ratio of `0.9` allocates 90% to the
300    /// first chunks and 9% to the second. This is often used in training, validation splitting
301    /// procedures.
302    pub fn split_with_ratio(
303        &'a self,
304        ratio: f32,
305    ) -> (
306        DatasetBase<ArrayView2<'a, F>, T::View>,
307        DatasetBase<ArrayView2<'a, F>, T::View>,
308    ) {
309        let n = (self.nsamples() as f32 * ratio).ceil() as usize;
310        let (records_first, records_second) = self.records.view().split_at(Axis(0), n);
311        let (targets_first, targets_second) = self.targets.as_targets().split_at(Axis(0), n);
312
313        let targets_first = T::new_targets_view(targets_first);
314        let targets_second = T::new_targets_view(targets_second);
315
316        let (first_weights, second_weights) = if self.weights.len() == self.nsamples() {
317            let a = self.weights.slice(s![..n]).to_vec();
318            let b = self.weights.slice(s![n..]).to_vec();
319
320            (Array1::from(a), Array1::from(b))
321        } else {
322            (Array1::zeros(0), Array1::zeros(0))
323        };
324        let dataset1 = DatasetBase::new(records_first, targets_first)
325            .with_weights(first_weights)
326            .with_feature_names(self.feature_names.clone())
327            .with_target_names(self.target_names.clone());
328
329        let dataset2 = DatasetBase::new(records_second, targets_second)
330            .with_weights(second_weights)
331            .with_feature_names(self.feature_names.clone())
332            .with_target_names(self.target_names.clone());
333
334        (dataset1, dataset2)
335    }
336}
337
338impl<L: Label, T: Labels<Elem = L>, R: Records> Labels for DatasetBase<R, T> {
339    type Elem = L;
340
341    fn label_count(&self) -> Vec<HashMap<L, usize>> {
342        self.targets().label_count()
343    }
344}
345
346#[allow(clippy::type_complexity)]
347impl<F, L: Label, T, D> DatasetBase<ArrayBase<D, Ix2>, T>
348where
349    D: Data<Elem = F>,
350    T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
351{
352    /// Produce N boolean targets from multi-class targets
353    ///
354    /// Some algorithms (like SVM) don't support multi-class targets. This function splits a
355    /// dataset into multiple binary single-target views of the same dataset.
356    pub fn one_vs_all(
357        &self,
358    ) -> Result<
359        Vec<(
360            L,
361            DatasetBase<ArrayView2<'_, F>, CountedTargets<bool, Array1<bool>>>,
362        )>,
363    > {
364        let targets = self.targets().as_single_targets();
365
366        Ok(self
367            .labels()
368            .into_iter()
369            .map(|label| {
370                let targets = targets.iter().map(|x| x == &label).collect::<Array1<_>>();
371
372                let targets = CountedTargets::new(targets);
373
374                (
375                    label,
376                    DatasetBase::new(self.records().view(), targets)
377                        .with_feature_names(self.feature_names.clone())
378                        .with_weights(self.weights.clone())
379                        .with_target_names(self.target_names.clone()),
380                )
381            })
382            .collect())
383    }
384}
385
386impl<L: Label, R: Records, S: AsTargets<Elem = L>> DatasetBase<R, S> {
387    /// Calculates label frequencies from a dataset while masking certain samples.
388    ///
389    /// ### Parameters
390    ///
391    /// * `mask`: a boolean array that specifies which samples to include in the count
392    ///
393    /// ### Returns
394    ///
395    /// A mapping of the Dataset's samples to their frequencies
396    pub fn label_frequencies_with_mask(&self, mask: &[bool]) -> HashMap<L, f32> {
397        let mut freqs = HashMap::new();
398
399        for (elms, val) in self
400            .targets
401            .as_targets()
402            .axis_iter(Axis(0))
403            .enumerate()
404            .filter(|(i, _)| *mask.get(*i).unwrap_or(&true))
405            .map(|(i, x)| (x, self.weight_for(i)))
406        {
407            for elm in elms {
408                if !freqs.contains_key(elm) {
409                    freqs.insert(elm.clone(), 0.0);
410                }
411
412                *freqs.get_mut(elm).unwrap() += val;
413            }
414        }
415
416        freqs
417    }
418
419    /// Calculates label frequencies from a dataset
420    pub fn label_frequencies(&self) -> HashMap<L, f32> {
421        self.label_frequencies_with_mask(&[])
422    }
423}
424
425impl<F, D: Data<Elem = F>, I: Dimension> From<ArrayBase<D, I>>
426    for DatasetBase<ArrayBase<D, I>, Array1<()>>
427{
428    fn from(records: ArrayBase<D, I>) -> Self {
429        let empty_targets = Array1::default(records.len_of(Axis(0)));
430        DatasetBase {
431            records,
432            targets: empty_targets,
433            weights: Array1::zeros(0),
434            feature_names: Vec::new(),
435            target_names: Vec::new(),
436        }
437    }
438}
439
440impl<F, E, D, S, I: TargetDim> From<(ArrayBase<D, Ix2>, ArrayBase<S, I>)>
441    for DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, I>>
442where
443    D: Data<Elem = F>,
444    S: Data<Elem = E>,
445{
446    fn from(rec_tar: (ArrayBase<D, Ix2>, ArrayBase<S, I>)) -> Self {
447        DatasetBase {
448            records: rec_tar.0,
449            targets: rec_tar.1,
450            weights: Array1::zeros(0),
451            feature_names: Vec::new(),
452            target_names: Vec::new(),
453        }
454    }
455}
456
457impl<'b, F: Clone, E: Copy + 'b, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
458where
459    D: Data<Elem = F>,
460    T: FromTargetArrayOwned<Elem = E>,
461    T::Owned: AsTargets,
462{
463    /// Apply bootstrapping for samples and features
464    ///
465    /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
466    /// stability of machine learning algorithms. It samples data uniformly with replacement and
467    /// generates datasets where elements may be shared. This selects a subset of observations as
468    /// well as features.
469    ///
470    /// # Parameters
471    ///
472    ///  * `sample_feature_size`: The number of samples and features per bootstrap
473    ///  * `rng`: The random number generator used in the sampling procedure
474    ///
475    ///  # Returns
476    ///
477    ///  An infinite Iterator yielding at each step a new bootstrapped dataset
478    ///
479    pub fn bootstrap<R: Rng>(
480        &'b self,
481        sample_feature_size: (usize, usize),
482        rng: &'b mut R,
483    ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
484        std::iter::repeat(()).map(move |_| {
485            // sample with replacement
486            let indices = (0..sample_feature_size.0)
487                .map(|_| rng.gen_range(0..self.nsamples()))
488                .collect::<Vec<_>>();
489
490            let records = self.records().select(Axis(0), &indices);
491            let targets = T::new_targets(self.as_targets().select(Axis(0), &indices));
492
493            let indices = (0..sample_feature_size.1)
494                .map(|_| rng.gen_range(0..self.nfeatures()))
495                .collect::<Vec<_>>();
496
497            let records = records.select(Axis(1), &indices);
498
499            DatasetBase::new(records, targets)
500        })
501    }
502
503    /// Apply sample bootstrapping
504    ///
505    /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
506    /// stability of machine learning algorithms. It samples data uniformly with replacement and
507    /// generates datasets where elements may be shared. Only a sample subset is selected which
508    /// retains all features and targets.
509    ///
510    /// # Parameters
511    ///
512    ///  * `num_samples`: The number of samples per bootstrap
513    ///  * `rng`: The random number generator used in the sampling procedure
514    ///
515    ///  # Returns
516    ///
517    ///  An infinite Iterator yielding at each step a new bootstrapped dataset
518    ///
519    pub fn bootstrap_samples<R: Rng>(
520        &'b self,
521        num_samples: usize,
522        rng: &'b mut R,
523    ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
524        std::iter::repeat(()).map(move |_| {
525            // sample with replacement
526            let indices = (0..num_samples)
527                .map(|_| rng.gen_range(0..self.nsamples()))
528                .collect::<Vec<_>>();
529
530            let records = self.records().select(Axis(0), &indices);
531            let targets = T::new_targets(self.as_targets().select(Axis(0), &indices));
532
533            DatasetBase::new(records, targets)
534        })
535    }
536
537    /// Apply feature bootstrapping
538    ///
539    /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
540    /// stability of machine learning algorithms. It samples data uniformly with replacement and
541    /// generates datasets where elements may be shared. Only a feature subset is selected while
542    /// retaining all samples and targets.
543    ///
544    /// # Parameters
545    ///
546    ///  * `num_features`: The number of features per bootstrap
547    ///  * `rng`: The random number generator used in the sampling procedure
548    ///
549    ///  # Returns
550    ///
551    ///  An infinite Iterator yielding at each step a new bootstrapped dataset
552    ///
553    pub fn bootstrap_features<R: Rng>(
554        &'b self,
555        num_features: usize,
556        rng: &'b mut R,
557    ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
558        std::iter::repeat(()).map(move |_| {
559            let targets = T::new_targets(self.as_targets().to_owned());
560
561            let indices = (0..num_features)
562                .map(|_| rng.gen_range(0..self.nfeatures()))
563                .collect::<Vec<_>>();
564
565            let records = self.records.select(Axis(1), &indices);
566
567            DatasetBase::new(records, targets)
568        })
569    }
570
571    /// Produces a shuffled version of the current Dataset.
572    ///
573    /// ### Parameters
574    ///
575    /// * `rng`: the random number generator that will be used to shuffle the samples
576    ///
577    /// ### Returns
578    ///
579    /// A new shuffled version of the current Dataset
580    ///
581    pub fn shuffle<R: Rng>(&self, rng: &mut R) -> DatasetBase<Array2<F>, T::Owned> {
582        let mut indices = (0..self.nsamples()).collect::<Vec<_>>();
583        indices.shuffle(rng);
584
585        let records = self.records().select(Axis(0), &indices);
586        let targets = self.as_targets().select(Axis(0), &indices);
587        let targets = T::new_targets(targets);
588
589        DatasetBase::new(records, targets)
590            .with_feature_names(self.feature_names().to_vec())
591            .with_target_names(self.target_names().to_vec())
592    }
593
594    #[allow(clippy::type_complexity)]
595    /// Performs K-folding on the dataset.
596    ///
597    /// The dataset is divided into `k` "folds", each containing `(dataset size)/k` samples, used
598    /// to generate `k` training-validation dataset pairs. Each pair contains a validation
599    /// `Dataset` with `k` samples, the ones contained in the i-th fold, and a training `Dataset`
600    /// composed by the union of all the samples in the remaining folds.
601    ///
602    /// ### Parameters
603    ///
604    /// * `k`: the number of folds to apply
605    ///
606    /// ### Returns
607    ///
608    /// A vector of `k` training-validation Dataset pairs.
609    ///
610    /// ### Example
611    ///
612    /// ```rust
613    /// use linfa::dataset::DatasetView;
614    /// use ndarray::{Ix1, array};
615    ///
616    /// let records = array![[1.,1.], [2.,1.], [3.,2.], [4.,1.],[5., 3.], [6.,2.]];
617    /// let targets = array![1, 1, 0, 1, 0, 0];
618    ///
619    /// let dataset : DatasetView<f64, usize, Ix1> = (records.view(), targets.view()).into();
620    /// let accuracies = dataset.fold(3).into_iter().map(|(train, valid)| {
621    ///     // Here you can train your model and perform validation
622    ///     
623    ///     // let model = params.fit(&dataset);
624    ///     // let predi = model.predict(&valid);
625    ///     // predi.confusion_matrix(&valid).accuracy()  
626    /// });
627    /// ```
628    ///  
629    pub fn fold(
630        &self,
631        k: usize,
632    ) -> Vec<(
633        DatasetBase<Array2<F>, T::Owned>,
634        DatasetBase<Array2<F>, T::Owned>,
635    )> {
636        let targets = self.as_targets();
637        let fold_size = targets.len() / k;
638
639        // Generates all k folds of records and targets
640        let mut records_chunks: Vec<_> =
641            self.records.axis_chunks_iter(Axis(0), fold_size).collect();
642        let mut targets_chunks: Vec<_> = targets.axis_chunks_iter(Axis(0), fold_size).collect();
643
644        let mut res = Vec::with_capacity(k);
645        // For each iteration, take the first chunk for both records and targets as the validation set and
646        // concatenate all the other chunks to create the training set. In the end swap the first chunk with the
647        // one in the next index so that it is ready for the next iteration
648        for i in 0..k {
649            let remaining_records = concatenate(Axis(0), &records_chunks.as_slice()[1..]).unwrap();
650            let remaining_targets = concatenate(Axis(0), &targets_chunks.as_slice()[1..]).unwrap();
651
652            res.push((
653                // training
654                DatasetBase::new(remaining_records, T::new_targets(remaining_targets)),
655                // validation
656                DatasetBase::new(
657                    records_chunks[0].into_owned(),
658                    T::new_targets(targets_chunks[0].clone().into_owned()),
659                ),
660            ));
661
662            // swap
663            if i < k - 1 {
664                records_chunks.swap(0, i + 1);
665                targets_chunks.swap(0, i + 1);
666            }
667        }
668        res
669    }
670
671    pub fn sample_chunks<'a: 'b>(&'b self, chunk_size: usize) -> ChunksIter<'b, 'a, F, T> {
672        ChunksIter::new(self.records().view(), &self.targets, chunk_size, Axis(0))
673    }
674
675    pub fn to_owned(&self) -> DatasetBase<Array2<F>, T::Owned> {
676        DatasetBase::new(
677            self.records().to_owned(),
678            T::new_targets(self.as_targets().to_owned()),
679        )
680    }
681}
682
683macro_rules! assist_swap_array2 {
684    ($slice: expr, $index: expr, $fold_size: expr, $features: expr) => {
685        if $index != 0 {
686            let adj_fold_size = $fold_size * $features;
687            let start = adj_fold_size * $index;
688            let (first_s, second_s) = $slice.split_at_mut(start);
689            let (mut fold, _) = second_s.split_at_mut(adj_fold_size);
690            first_s[..$fold_size * $features].swap_with_slice(&mut fold);
691        }
692    };
693}
694
695impl<'a, F: 'a + Clone, E: Copy + 'a, D, S, I: TargetDim>
696    DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, I>>
697where
698    D: DataMut<Elem = F>,
699    S: DataMut<Elem = E>,
700{
701    /// Performs k-folding cross validation on fittable algorithms.
702    ///
703    /// Given a dataset as input, a value of k and the desired params for the fittable
704    /// algorithm, returns an iterator over the k trained models and the
705    /// associated validation set.
706    ///
707    /// The models are trained according to a closure specified
708    /// as an input.
709    ///
710    /// ## Parameters
711    ///
712    /// - `k`: the number of folds to apply to the dataset
713    /// - `params`: the desired parameters for the fittable algorithm at hand
714    /// - `fit_closure`: a closure of the type `(params, training_data) -> fitted_model`
715    ///   that will be used to produce the trained model for each fold. The training data given in input
716    ///   won't outlive the closure.
717    ///
718    /// ## Returns
719    ///
720    /// An iterator over couples `(trained_model, validation_set)`.
721    ///
722    /// ## Panics
723    ///
724    /// This method will panic for any of the following three reasons:
725    ///
726    /// - The value of `k` provided is not positive;
727    /// - The value of `k` provided is greater than the total number of samples in the dataset;
728    /// - The dataset's data is not stored contiguously and in standard order;
729    ///
730    /// ## Example
731    /// ```rust
732    /// use linfa::traits::Fit;
733    /// use linfa::dataset::{Dataset, DatasetView, Records};
734    /// use ndarray::{array, ArrayView1, ArrayView2, Ix1};
735    /// use linfa::Error;
736    ///
737    /// struct MockFittable {}
738    ///
739    /// struct MockFittableResult {
740    ///    mock_var: usize,
741    /// }
742    ///
743    /// impl<'a> Fit<ArrayView2<'a,f64>, ArrayView1<'a, f64>, linfa::error::Error> for MockFittable {
744    ///     type Object = MockFittableResult;
745    ///
746    ///     fn fit(&self, training_data: &DatasetView<f64, f64, Ix1>) -> Result<Self::Object, linfa::error::Error> {
747    ///         Ok(MockFittableResult {
748    ///             mock_var: training_data.nsamples(),
749    ///         })
750    ///     }
751    /// }
752    ///
753    /// let records = array![[1.,1.], [2.,2.], [3.,3.], [4.,4.], [5.,5.]];
754    /// let targets = array![1.,2.,3.,4.,5.];
755    /// let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
756    /// let params = MockFittable {};
757    ///
758    /// for (model,validation_set) in dataset.iter_fold(5, |v| params.fit(v).unwrap()){
759    ///     // Here you can use `model` and `validation_set` to
760    ///     // assert the performance of the chosen algorithm
761    /// }
762    /// ```
763    pub fn iter_fold<O, C: Fn(&DatasetView<F, E, I>) -> O>(
764        &'a mut self,
765        k: usize,
766        fit_closure: C,
767    ) -> impl Iterator<Item = (O, DatasetBase<ArrayView2<'a, F>, ArrayView<'a, E, I>>)> {
768        assert!(k > 0);
769        assert!(k <= self.nsamples());
770        let samples_count = self.nsamples();
771        let fold_size = samples_count / k;
772
773        let features = self.nfeatures();
774        let targets = self.ntargets();
775        let tshape = self.targets.raw_dim();
776
777        let mut objs: Vec<O> = Vec::with_capacity(k);
778
779        {
780            let records_sl = self.records.as_slice_mut().unwrap();
781            let mut targets_sl2 = self.targets.as_targets_mut();
782            let targets_sl = targets_sl2.as_slice_mut().unwrap();
783
784            for i in 0..k {
785                assist_swap_array2!(records_sl, i, fold_size, features);
786                assist_swap_array2!(targets_sl, i, fold_size, targets);
787
788                {
789                    let train = DatasetBase::new(
790                        ArrayView2::from_shape(
791                            (samples_count - fold_size, features),
792                            records_sl.split_at(fold_size * features).1,
793                        )
794                        .unwrap(),
795                        ArrayView::from_shape(
796                            tshape.clone().nsamples(samples_count - fold_size),
797                            targets_sl.split_at(fold_size * targets).1,
798                        )
799                        .unwrap(),
800                    );
801
802                    let obj = fit_closure(&train);
803                    objs.push(obj);
804                }
805
806                assist_swap_array2!(records_sl, i, fold_size, features);
807                assist_swap_array2!(targets_sl, i, fold_size, targets);
808            }
809        }
810
811        objs.into_iter().zip(self.sample_chunks(fold_size))
812    }
813
814    /// Cross validation for single and multi-target algorithms
815    ///
816    /// Given a list of fittable models, cross validation is used to compare their performance
817    /// according to some performance metric. To do so, k-folding is applied to the dataset and,
818    /// for each fold, each model is trained on the training set and its performance is evaluated
819    /// on the validation set. The performances collected for each model are then averaged over the
820    /// number of folds.
821    ///
822    /// For single-target datasets, [`Dataset::cross_validate_single`] is recommended.
823    ///
824    /// ### Parameters:
825    ///
826    /// - `k`: the number of folds to apply
827    /// - `parameters`: a list of models to compare
828    /// - `eval`: closure used to evaluate the performance of each trained model. This closure is
829    ///   called on the model output and validation targets of each fold and outputs the performance
830    ///   score for each target. For single-target dataset the signature is `(Array1, Array1) ->
831    ///   Array0`. For multi-target dataset the signature is `(Array2, Array2) -> Array1`.
832    ///
833    /// ### Returns
834    ///
835    /// An array of model performances, for each model and each target, if no errors occur.
836    /// For multi-target dataset, the array has dimensions `(n_models, n_targets)`. For
837    /// single-target dataset, the array has dimensions `(n_models)`.
838    /// Otherwise, it might return an Error in one of the following cases:
839    ///
840    /// - An error occurred during the fitting of one model
841    /// - An error occurred inside the evaluation closure
842    ///
843    /// ### Example
844    ///
845    /// ```rust, ignore
846    ///
847    /// use linfa::prelude::*;
848    /// use ndarray::arr0;
849    /// # use ndarray::{array, ArrayView1, ArrayView2, Ix1};
850    ///
851    /// # struct MockFittable {}
852    ///
853    /// # struct MockFittableResult {
854    /// #    mock_var: usize,
855    /// # }
856    ///
857    /// # impl<'a> Fit<ArrayView2<'a,f64>, ArrayView1<'a, f64>, linfa::error::Error> for MockFittable {
858    /// #     type Object = MockFittableResult;
859    ///
860    /// #     fn fit(&self, training_data: &DatasetView<f64, f64, Ix1>) -> Result<Self::Object, linfa::error::Error> {
861    /// #         Ok(MockFittableResult {
862    /// #             mock_var: training_data.nsamples(),
863    /// #         })
864    /// #     }
865    /// # }
866    ///
867    /// # let model1 = MockFittable {};
868    /// # let model2 = MockFittable {};
869    ///
870    /// // mutability needed for fast cross validation
871    /// let mut dataset = linfa_datasets::diabetes();
872    ///
873    /// let models = vec![model1, model2];
874    ///
875    /// let r2_scores = dataset.cross_validate(5, &models, |prediction, truth| prediction.r2(truth).map(arr0))?;
876    ///
877    /// ```
878    pub fn cross_validate<O, ER, M, FACC, C>(
879        &'a mut self,
880        k: usize,
881        parameters: &[M],
882        eval: C,
883    ) -> std::result::Result<Array<FACC, I>, ER>
884    where
885        ER: std::error::Error + std::convert::From<crate::error::Error>,
886        M: for<'c> Fit<ArrayView2<'c, F>, ArrayView<'c, E, I>, ER, Object = O>,
887        O: for<'d> PredictInplace<ArrayView2<'a, F>, Array<E, I>>,
888        FACC: Float,
889        C: Fn(
890            &Array<E, I>,
891            &ArrayView<E, I>,
892        ) -> std::result::Result<Array<FACC, I::Smaller>, crate::error::Error>,
893    {
894        let mut evaluations = Array::from_elem(
895            self.targets.raw_dim().nsamples(parameters.len()),
896            FACC::zero(),
897        );
898        let folds_evaluations: std::result::Result<Vec<_>, ER> = self
899            .iter_fold(k, |train| {
900                let fit_result: std::result::Result<Vec<_>, ER> =
901                    parameters.iter().map(|p| p.fit(train)).collect();
902                fit_result
903            })
904            .map(|(models, valid)| {
905                let targets = valid.targets();
906                let models = models?;
907                // XXX diverges from master branch
908                let mut eval_predictions =
909                    Array::from_elem(targets.raw_dim().nsamples(models.len()), FACC::zero());
910                for (i, model) in models.iter().enumerate() {
911                    let predicted = model.predict(valid.records());
912                    let eval_pred = match eval(&predicted, targets) {
913                        Err(e) => Err(ER::from(e)),
914                        Ok(res) => Ok(res),
915                    }?;
916                    eval_predictions
917                        .index_axis_mut(Axis(0), i)
918                        .add_assign(&eval_pred);
919                }
920                Ok(eval_predictions)
921            })
922            .collect();
923
924        for fold_evaluation in folds_evaluations? {
925            evaluations.add_assign(&fold_evaluation)
926        }
927        Ok(evaluations / FACC::from(k).unwrap())
928    }
929}
930
931impl<'a, F: 'a + Clone, E: Copy + 'a, D, S> DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, Ix1>>
932where
933    D: DataMut<Elem = F>,
934    S: DataMut<Elem = E>,
935{
936    /// Specialized version of `cross_validate` for single-target datasets. Allows the evaluation
937    /// closure to return a float without wrapping it in `arr0`. See [`Dataset.cross_validate`] for
938    /// more details.
939    pub fn cross_validate_single<O, ER, M, FACC, C>(
940        &'a mut self,
941        k: usize,
942        parameters: &[M],
943        eval: C,
944    ) -> std::result::Result<Array1<FACC>, ER>
945    where
946        ER: std::error::Error + std::convert::From<crate::error::Error>,
947        M: for<'c> Fit<ArrayView2<'c, F>, ArrayView1<'c, E>, ER, Object = O>,
948        O: for<'d> PredictInplace<ArrayView2<'a, F>, Array1<E>>,
949        FACC: Float,
950        C: Fn(&Array1<E>, &ArrayView1<E>) -> std::result::Result<FACC, crate::error::Error>,
951    {
952        self.cross_validate(k, parameters, |a, b| eval(a, b).map(arr0))
953    }
954}
955
956impl<F, E, I: TargetDim> Dataset<F, E, I> {
957    /// Split dataset into two disjoint chunks
958    ///
959    /// This function splits the observations in a dataset into two disjoint chunks. The splitting
960    /// threshold is calculated with the `ratio`. If the input Dataset contains `n` samples then the
961    /// two new Datasets will have respectively `n * ratio` and `n - (n*ratio)` samples.
962    /// For example a ratio of `0.9` allocates 90% to the
963    /// first chunks and 10% to the second. This is often used in training, validation splitting
964    /// procedures.
965    ///
966    /// ### Parameters
967    ///
968    /// * `ratio`: the ratio of samples in the input Dataset to include in the first output one
969    ///
970    /// ### Returns
971    ///  
972    /// The input Dataset split into two according to the input ratio.
973    ///
974    /// ### Panics
975    ///
976    /// Panic occurs when the input record or targets are not in row-major layout.
977    pub fn split_with_ratio(mut self, ratio: f32) -> (Self, Self) {
978        assert!(
979            self.records.is_standard_layout(),
980            "records not in row-major layout"
981        );
982        assert!(
983            self.targets.is_standard_layout(),
984            "targets not in row-major layout"
985        );
986
987        let nfeatures = self.nfeatures();
988
989        let n1 = (self.nsamples() as f32 * ratio).ceil() as usize;
990        let n2 = self.nsamples() - n1;
991
992        let feature_names = self.feature_names().to_vec();
993        let target_names = self.target_names().to_vec();
994
995        // split records into two disjoint arrays
996        let (mut array_buf, _) = self.records.into_raw_vec_and_offset();
997        let second_array_buf = array_buf.split_off(n1 * nfeatures);
998
999        let first = Array2::from_shape_vec((n1, nfeatures), array_buf).unwrap();
1000        let second = Array2::from_shape_vec((n2, nfeatures), second_array_buf).unwrap();
1001
1002        // split targets into two disjoint Vec
1003        let dim1 = self.targets.raw_dim().nsamples(n1);
1004        let dim2 = self.targets.raw_dim().nsamples(n2);
1005        let (mut array_buf, _) = self.targets.into_raw_vec_and_offset();
1006        let second_array_buf = array_buf.split_off(dim1.size());
1007
1008        let first_targets = Array::from_shape_vec(dim1, array_buf).unwrap();
1009        let second_targets = Array::from_shape_vec(dim2, second_array_buf).unwrap();
1010
1011        // split weights into two disjoint Vec
1012        let second_weights = if self.weights.len() == n1 + n2 {
1013            let (mut weights, _) = self.weights.into_raw_vec_and_offset();
1014
1015            let weights2 = weights.split_off(n1);
1016            self.weights = Array1::from(weights);
1017
1018            Array1::from(weights2)
1019        } else {
1020            Array1::zeros(0)
1021        };
1022
1023        // create new datasets with attached weights
1024        let dataset1 = Dataset::new(first, first_targets)
1025            .with_weights(self.weights)
1026            .with_feature_names(feature_names.clone())
1027            .with_target_names(target_names.clone());
1028        let dataset2 = Dataset::new(second, second_targets)
1029            .with_weights(second_weights)
1030            .with_feature_names(feature_names.clone())
1031            .with_target_names(target_names.clone());
1032
1033        (dataset1, dataset2)
1034    }
1035}
1036
1037impl<F, D, E, T, O> Predict<ArrayBase<D, Ix2>, DatasetBase<ArrayBase<D, Ix2>, T>> for O
1038where
1039    D: Data<Elem = F>,
1040    T: AsTargets<Elem = E>,
1041    O: PredictInplace<ArrayBase<D, Ix2>, T>,
1042{
1043    fn predict(&self, records: ArrayBase<D, Ix2>) -> DatasetBase<ArrayBase<D, Ix2>, T> {
1044        let mut targets = self.default_target(&records);
1045        self.predict_inplace(&records, &mut targets);
1046        DatasetBase::new(records, targets)
1047    }
1048}
1049
1050impl<F, R, T, E, S, O> Predict<DatasetBase<R, T>, DatasetBase<R, S>> for O
1051where
1052    R: Records<Elem = F>,
1053    S: AsTargets<Elem = E>,
1054    O: PredictInplace<R, S>,
1055{
1056    fn predict(&self, ds: DatasetBase<R, T>) -> DatasetBase<R, S> {
1057        let mut targets = self.default_target(&ds.records);
1058        self.predict_inplace(&ds.records, &mut targets);
1059        DatasetBase::new(ds.records, targets)
1060    }
1061}
1062
1063impl<'a, F, R, T, S, O> Predict<&'a DatasetBase<R, T>, S> for O
1064where
1065    R: Records<Elem = F>,
1066    O: PredictInplace<R, S>,
1067{
1068    fn predict(&self, ds: &'a DatasetBase<R, T>) -> S {
1069        let mut targets = self.default_target(&ds.records);
1070        self.predict_inplace(&ds.records, &mut targets);
1071        targets
1072    }
1073}
1074
1075impl<'a, F, D, DM, T, O> Predict<&'a ArrayBase<D, DM>, T> for O
1076where
1077    D: Data<Elem = F>,
1078    DM: Dimension,
1079    O: PredictInplace<ArrayBase<D, DM>, T>,
1080{
1081    fn predict(&self, records: &'a ArrayBase<D, DM>) -> T {
1082        let mut targets = self.default_target(records);
1083        self.predict_inplace(records, &mut targets);
1084        targets
1085    }
1086}
1087
1088impl<L: Label, S: Labels<Elem = L>> CountedTargets<L, S> {
1089    pub fn new(targets: S) -> Self {
1090        let labels = targets.label_count();
1091
1092        CountedTargets { targets, labels }
1093    }
1094}