linfa/dataset/
mod.rs

1//! Datasets
2//!
3//! This module implements the dataset struct and various helper traits to extend its
4//! functionality.
5use ndarray::{
6    Array, Array1, ArrayBase, ArrayView, ArrayView1, ArrayView2, ArrayViewMut, ArrayViewMut1,
7    ArrayViewMut2, CowArray, Ix1, Ix2, Ix3, NdFloat, OwnedRepr, RemoveAxis, ScalarOperand,
8};
9
10#[cfg(feature = "ndarray-linalg")]
11use ndarray_linalg::{Lapack, Scalar};
12
13use num_traits::{AsPrimitive, FromPrimitive, NumCast, Signed};
14use rand::distributions::uniform::SampleUniform;
15
16use std::cmp::{Ordering, PartialOrd};
17use std::collections::{HashMap, HashSet};
18use std::convert::{TryFrom, TryInto};
19use std::fmt;
20use std::hash::Hash;
21use std::iter::Sum;
22use std::ops::{AddAssign, Deref, DivAssign, MulAssign, SubAssign};
23
24use crate::error::Result;
25
26mod impl_dataset;
27mod impl_records;
28mod impl_targets;
29
30mod iter;
31
32mod lapack_bounds;
33pub use lapack_bounds::*;
34
35/// Floating point numbers
36///
37/// This trait bound multiplexes to the most common assumption of floating point number and
38/// implement them for 32bit and 64bit floating points. They are used in records of a dataset and, for
39/// regression task, in the targets as well.
40pub trait Float:
41    NdFloat
42    + FromPrimitive
43    + Default
44    + Signed
45    + Sum
46    + AsPrimitive<usize>
47    + for<'a> AddAssign<&'a Self>
48    + for<'a> MulAssign<&'a Self>
49    + for<'a> SubAssign<&'a Self>
50    + for<'a> DivAssign<&'a Self>
51    + num_traits::MulAdd<Output = Self>
52    + SampleUniform
53    + ScalarOperand
54    + approx::AbsDiffEq
55    + std::marker::Unpin
56    + sprs::MulAcc
57{
58    #[cfg(feature = "ndarray-linalg")]
59    type Lapack: Float + Scalar + Lapack;
60    #[cfg(not(feature = "ndarray-linalg"))]
61    type Lapack: Float;
62
63    fn cast<T: NumCast>(x: T) -> Self {
64        NumCast::from(x).unwrap()
65    }
66}
67
68impl Float for f32 {
69    type Lapack = f32;
70}
71
72impl Float for f64 {
73    type Lapack = f64;
74}
75
76/// Discrete labels
77///
78/// Labels are countable, comparable and hashable. Currently null-type (no targets),
79/// boolean (binary task) and usize, strings (multi-label tasks) are supported.
80pub trait Label: PartialEq + Eq + Hash + Clone + Ord + fmt::Debug + Default {}
81
82impl Label for bool {}
83impl Label for usize {}
84impl Label for String {}
85impl Label for () {}
86impl Label for &str {}
87impl<L: Label> Label for Option<L> {}
88
89/// Probability types
90///
91/// This helper struct exists to distinguish probabilities from floating points. For example SVM
92/// selects regression or classification training, based on the target type, and could not
93/// distinguish them without a new-type definition.
94#[repr(transparent)]
95#[derive(Debug, Copy, Clone, Default)]
96pub struct Pr(f32);
97
98/// Tries to convert float to probability type.
99///
100/// # Returns
101/// Either probability type Pr(f32) or error as Err(f32)
102impl TryFrom<f32> for Pr {
103    type Error = f32;
104
105    fn try_from(prob: f32) -> std::result::Result<Self, Self::Error> {
106        if (0. ..=1.).contains(&prob) {
107            Ok(Pr(prob))
108        } else {
109            Err(prob)
110        }
111    }
112}
113
114impl Pr {
115    /// Creates probability from the given float.
116    ///
117    /// # Panics
118    /// Panics if probability is negative or bigger than one.
119    pub fn new(prob: f32) -> Self {
120        prob.try_into().unwrap()
121    }
122
123    /// Creates probability from the given float.
124    /// Doesn't check whether it is negative or bigger than one.
125    pub fn new_unchecked(prob: f32) -> Self {
126        Pr(prob)
127    }
128    pub fn even() -> Pr {
129        Pr(0.5)
130    }
131}
132
133impl PartialEq for Pr {
134    fn eq(&self, other: &Self) -> bool {
135        self.0 == other.0
136    }
137}
138
139impl PartialOrd for Pr {
140    fn partial_cmp(&self, other: &Pr) -> Option<Ordering> {
141        self.0.partial_cmp(&other.0)
142    }
143}
144
145impl Deref for Pr {
146    type Target = f32;
147
148    fn deref(&self) -> &f32 {
149        &self.0
150    }
151}
152
153/// DatasetBase
154///
155/// This is the fundamental structure of a dataset. It contains a number of records about the data
156/// and may contain targets, weights and feature names. In order to keep the type complexity low
157/// the dataset base is only generic over the records and targets and introduces a trait bound on
158/// the records. `weights` and `feature_names`, on the other hand, are always assumed to be owned
159/// and copied when views are created.
160///
161/// # Fields
162///
163/// * `records`: a two-dimensional matrix with dimensionality (nsamples, nfeatures), in case of
164///   kernel methods a quadratic matrix with dimensionality (nsamples, nsamples), which may be sparse
165/// * `targets`: a two-/one-dimension matrix with dimensionality (nsamples, ntargets)
166/// * `weights`: optional weights for each sample with dimensionality (nsamples)
167/// * `feature_names`: optional descriptive feature names with dimensionality (nfeatures)
168/// * `target_names`: optional descriptive target names with dimensionality (ntargets)
169///
170/// # Trait bounds
171///
172/// * `R: Records`: generic over feature matrices or kernel matrices
173/// * `T`: generic over any `ndarray` matrix which can be used as targets. The `AsTargets` trait
174///   bound is omitted here to avoid some repetition in implementation `src/dataset/impl_dataset.rs`
175#[derive(Debug, Clone, PartialEq)]
176pub struct DatasetBase<R, T>
177where
178    R: Records,
179{
180    pub records: R,
181    pub targets: T,
182
183    pub weights: Array1<f32>,
184    feature_names: Vec<String>,
185    target_names: Vec<String>,
186}
187
188/// Targets with precomputed, counted labels
189///
190/// This extends plain targets with pre-counted labels. The label map is useful when, for example,
191/// a prior probability is estimated (e.g. in Naive Bayesian implementation) or the samples are
192/// weighted inverse to their occurence.
193///
194/// # Fields
195///
196/// * `targets`: wrapped target field
197/// * `labels`: counted labels with label-count association
198#[derive(Debug, Clone, PartialEq, Eq)]
199pub struct CountedTargets<L: Label, P> {
200    targets: P,
201    labels: Vec<HashMap<L, usize>>,
202}
203
204/// Dataset
205///
206/// The most commonly used typed of dataset. It contains a number of records
207/// stored as an `Array2` and each record may correspond to multiple targets. The
208/// targets are stored as an `Array1` or `Array2`.
209pub type Dataset<D, T, I = Ix2> =
210    DatasetBase<ArrayBase<OwnedRepr<D>, Ix2>, ArrayBase<OwnedRepr<T>, I>>;
211
212/// DatasetView
213///
214/// A read only view of a Dataset
215pub type DatasetView<'a, D, T, I = Ix2> = DatasetBase<ArrayView<'a, D, Ix2>, ArrayView<'a, T, I>>;
216
217/// DatasetPr
218///
219/// Dataset with probabilities as targets. Useful for multiclass probabilities.
220/// It stores records as an `Array2` of elements of type `D`, and targets as an `Array3`
221/// of elements of type `Pr`
222pub type DatasetPr<D, L> =
223    DatasetBase<ArrayBase<OwnedRepr<D>, Ix2>, CountedTargets<L, ArrayBase<OwnedRepr<Pr>, Ix3>>>;
224
225/// Record trait
226pub trait Records: Sized {
227    type Elem;
228
229    fn nsamples(&self) -> usize;
230    fn nfeatures(&self) -> usize;
231}
232
233pub trait TargetDim: RemoveAxis {
234    fn nsamples(mut self, nsamples: usize) -> Self {
235        self.as_array_view_mut()[0] = nsamples;
236        self
237    }
238}
239
240/// Return a reference to single or multiple target variables.
241///
242/// This is generic over the dimension of the target array to support both single-target and
243/// multi-target variables.
244pub trait AsTargets {
245    type Elem;
246    type Ix: TargetDim;
247
248    fn as_targets(&self) -> ArrayView<'_, Self::Elem, Self::Ix>;
249}
250
251/// Return a reference to single-target variables.
252pub trait AsSingleTargets: AsTargets<Ix = Ix1> {
253    fn as_single_targets(&self) -> ArrayView1<'_, Self::Elem> {
254        self.as_targets()
255    }
256}
257
258/// Return a reference to multi-target variables.
259pub trait AsMultiTargets: AsTargets<Ix = Ix2> {
260    fn as_multi_targets(&self) -> ArrayView2<'_, Self::Elem> {
261        self.as_targets()
262    }
263}
264
265pub trait FromTargetArrayOwned: AsTargets {
266    type Owned;
267
268    /// Create self object from new target array
269    fn new_targets(targets: Array<Self::Elem, Self::Ix>) -> Self::Owned;
270}
271
272/// Helper trait to construct counted labels
273///
274/// This is implemented for objects which can act as targets and created from a target matrix. For
275/// targets represented as `ndarray` matrix this is identity, for counted labels, i.e.
276/// `TargetsWithLabels`, it creates the corresponding wrapper struct.
277pub trait FromTargetArray<'a>: AsTargets {
278    type View;
279
280    /// Create self object from new target array
281    fn new_targets_view(targets: ArrayView<'a, Self::Elem, Self::Ix>) -> Self::View;
282}
283
284/// Return a mutable reference to single or multiple target variables.
285///
286/// This is generic over the dimension of the target array to support both single-target and
287/// multi-target variables.
288pub trait AsTargetsMut {
289    type Elem;
290    type Ix: TargetDim;
291
292    fn as_targets_mut(&mut self) -> ArrayViewMut<'_, Self::Elem, Self::Ix>;
293}
294
295/// Returns a mutable reference to single-target variables.
296pub trait AsSingleTargetsMut: AsTargetsMut<Ix = Ix1> {
297    fn as_single_targets_mut(&mut self) -> ArrayViewMut1<'_, Self::Elem> {
298        self.as_targets_mut()
299    }
300}
301
302/// Returns a mutable reference to multi-target variables.
303pub trait AsMultiTargetsMut: AsTargetsMut<Ix = Ix2> {
304    fn as_multi_targets_mut(&mut self) -> ArrayViewMut2<'_, Self::Elem> {
305        self.as_targets_mut()
306    }
307}
308
309/// Convert to probability matrix
310///
311/// Some algorithms are working with probabilities. Targets which allow an implicit conversion into
312/// probabilities can implement this trait.
313pub trait AsProbabilities {
314    fn as_multi_target_probabilities(&self) -> CowArray<'_, Pr, Ix3>;
315}
316
317/// Get the labels in all targets
318///
319pub trait Labels {
320    type Elem: Label;
321
322    fn label_count(&self) -> Vec<HashMap<Self::Elem, usize>>;
323
324    fn label_set(&self) -> Vec<HashSet<Self::Elem>> {
325        self.label_count()
326            .iter()
327            .map(|x| x.keys().cloned().collect::<HashSet<_>>())
328            .collect()
329    }
330
331    fn labels(&self) -> Vec<Self::Elem> {
332        self.label_set()
333            .into_iter()
334            .flatten()
335            .collect::<HashSet<_>>()
336            .into_iter()
337            .collect()
338    }
339
340    fn combined_labels<T>(&self, other: &T) -> Vec<Self::Elem>
341    where
342        T: Labels<Elem = <Self as Labels>::Elem>,
343    {
344        let mut combined = self.label_set();
345        combined.extend(other.label_set());
346
347        combined
348            .iter()
349            .flatten()
350            .collect::<HashSet<_>>()
351            .into_iter()
352            .cloned()
353            .collect()
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use crate::error::Error;
361    use approx::{assert_abs_diff_eq, assert_abs_diff_ne};
362    use linfa_datasets::generate::make_dataset;
363    use ndarray::{array, Array1, Array2, Axis};
364    use rand::{rngs::SmallRng, SeedableRng};
365    use statrs::distribution::{DiscreteUniform, Laplace};
366
367    #[test]
368    fn into_single_target() {
369        let feat_distr = Laplace::new(0.5, 5.).unwrap();
370        let target_distr = DiscreteUniform::new(0, 5).unwrap();
371        let dataset = make_dataset(10, 5, 1, feat_distr, target_distr);
372        assert!(dataset.into_single_target().targets.shape() == [10]);
373    }
374
375    #[test]
376    fn set_target_name() {
377        let dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![0., 1.])
378            .with_target_names(vec!["test"]);
379        assert_eq!(dataset.target_names, vec!["test"]);
380    }
381
382    #[test]
383    fn empty_target_name() {
384        let dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![[0., 1.], [2., 3.]]);
385        assert_eq!(dataset.target_names, Vec::<String>::new());
386    }
387
388    #[test]
389    #[should_panic]
390    fn test_wrong_feature_names_lenght() {
391        let _dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![0., 1.])
392            .with_feature_names(vec!["test"]);
393    }
394
395    #[test]
396    #[should_panic]
397    fn test_wrong_target_names_lenght() {
398        let _dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![0., 1.])
399            .with_target_names(vec!["test", "bad"]);
400    }
401
402    #[test]
403    fn dataset_implements_required_methods() {
404        let mut rng = SmallRng::seed_from_u64(42);
405
406        // ------ Targets ------
407
408        // New
409        let mut dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![0., 1.]);
410
411        // Shuffle
412        dataset = dataset.shuffle(&mut rng);
413
414        // Bootstrap samples
415        {
416            let mut iter = dataset.bootstrap_samples(3, &mut rng);
417            for _ in 1..5 {
418                let b_dataset = iter.next().unwrap();
419                assert_eq!(b_dataset.records().dim().0, 3);
420            }
421        }
422
423        // Bootstrap features
424        {
425            let mut iter = dataset.bootstrap_features(3, &mut rng);
426            for _ in 1..5 {
427                let dataset = iter.next().unwrap();
428                assert_eq!(dataset.records().dim(), (2, 3));
429            }
430        }
431
432        // Bootstrap both
433        {
434            let mut iter = dataset.bootstrap((10, 10), &mut rng);
435            for _ in 1..5 {
436                let dataset = iter.next().unwrap();
437                assert_eq!(dataset.records().dim(), (10, 10));
438            }
439        }
440
441        let linspace: Array1<f64> = Array1::linspace(0.0, 0.8, 100);
442        let records = Array2::from_shape_vec((50, 2), linspace.to_vec()).unwrap();
443        let targets: Array1<f64> = Array1::linspace(0.0, 0.8, 50);
444        let dataset = Dataset::from((records, targets));
445
446        //Split with ratio view
447        let dataset_view = dataset.view();
448        let (train, val) = dataset_view.split_with_ratio(0.5);
449        assert_eq!(train.nsamples(), 25);
450        assert_eq!(val.nsamples(), 25);
451
452        // Split with ratio
453        let (train, val) = dataset.split_with_ratio(0.25);
454        assert_eq!(train.targets().dim(), 13);
455        assert_eq!(val.targets().dim(), 37);
456        assert_eq!(train.records().dim().0, 13);
457        assert_eq!(val.records().dim().0, 37);
458
459        // ------ Labels ------
460        let dataset_multiclass =
461            Dataset::from((array![[1., 2.], [2., 1.], [0., 0.]], array![0usize, 1, 2]));
462
463        // One Vs All
464        let datasets_one_vs_all = dataset_multiclass.one_vs_all().unwrap();
465
466        assert_eq!(datasets_one_vs_all.len(), 3);
467
468        for (_, dataset) in datasets_one_vs_all.iter() {
469            assert_eq!(dataset.labels().iter().filter(|x| **x).count(), 1);
470        }
471
472        let dataset_multiclass = Dataset::from((
473            array![[1., 2.], [2., 1.], [0., 0.], [2., 2.]],
474            array![0, 1, 2, 2],
475        ));
476
477        // Frequencies with mask
478        let freqs = dataset_multiclass.label_frequencies_with_mask(&[true, true, true, true]);
479        assert_eq!(*freqs.get(&0).unwrap() as usize, 1);
480        assert_eq!(*freqs.get(&1).unwrap() as usize, 1);
481        assert_eq!(*freqs.get(&2).unwrap() as usize, 2);
482
483        let freqs = dataset_multiclass.label_frequencies_with_mask(&[true, true, true, false]);
484        assert_eq!(*freqs.get(&0).unwrap() as usize, 1);
485        assert_eq!(*freqs.get(&1).unwrap() as usize, 1);
486        assert_eq!(*freqs.get(&2).unwrap() as usize, 1);
487    }
488
489    #[test]
490    fn dataset_view_implements_required_methods() -> Result<()> {
491        let mut rng = SmallRng::seed_from_u64(42);
492        let records = array![[1., 2.], [1., 2.]];
493        let targets = array![0., 1.];
494
495        // ------ Targets ------
496
497        // New
498        let dataset_view = DatasetView::from((records.view(), targets.view()));
499
500        // Shuffle
501        let _shuffled_owned = dataset_view.shuffle(&mut rng);
502
503        // Bootstrap
504        let mut iter = dataset_view.bootstrap_samples(3, &mut rng);
505        for _ in 1..5 {
506            let b_dataset = iter.next().unwrap();
507            assert_eq!(b_dataset.records().dim().0, 3);
508        }
509
510        let linspace: Array1<f64> = Array1::linspace(0.0, 0.8, 100);
511        let records = Array2::from_shape_vec((50, 2), linspace.to_vec()).unwrap();
512        let targets: Array1<f64> = Array1::linspace(0.0, 0.8, 50);
513        let dataset = Dataset::from((records, targets));
514
515        // view ,Split with ratio view
516        let view: DatasetView<f64, f64, Ix1> = dataset.view();
517
518        let (train, val) = view.split_with_ratio(0.5);
519        assert_eq!(train.targets().len(), 25);
520        assert_eq!(val.targets().len(), 25);
521        assert_eq!(train.nsamples(), 25);
522        assert_eq!(val.nsamples(), 25);
523
524        // ------ Labels ------
525        let dataset_multiclass =
526            Dataset::from((array![[1., 2.], [2., 1.], [0., 0.]], array![0, 1, 2]));
527        let view: DatasetView<f64, usize, Ix1> = dataset_multiclass.view();
528
529        // One Vs All
530        let datasets_one_vs_all = view.one_vs_all()?;
531        assert_eq!(datasets_one_vs_all.len(), 3);
532
533        for (_, dataset) in datasets_one_vs_all.iter() {
534            assert_eq!(dataset.labels().iter().filter(|x| **x).count(), 1);
535        }
536
537        let dataset_multiclass = Dataset::from((
538            array![[1., 2.], [2., 1.], [0., 0.], [2., 2.]],
539            array![0, 1, 2, 2],
540        ));
541
542        let view: DatasetView<f64, usize, Ix1> = dataset_multiclass.view();
543
544        // Frequencies with mask
545        let freqs = view.label_frequencies_with_mask(&[true, true, true, true]);
546        assert_eq!(*freqs.get(&0).unwrap() as usize, 1);
547        assert_eq!(*freqs.get(&1).unwrap() as usize, 1);
548        assert_eq!(*freqs.get(&2).unwrap() as usize, 2);
549
550        let freqs = view.label_frequencies_with_mask(&[true, true, true, false]);
551        assert_eq!(*freqs.get(&0).unwrap() as usize, 1);
552        assert_eq!(*freqs.get(&1).unwrap() as usize, 1);
553        assert_eq!(*freqs.get(&2).unwrap() as usize, 1);
554
555        Ok(())
556    }
557
558    #[test]
559    fn datasets_have_k_fold() {
560        let linspace: Array1<f64> = Array1::linspace(0.0, 0.8, 100);
561        let records = Array2::from_shape_vec((50, 2), linspace.to_vec()).unwrap();
562        let targets: Array1<f64> = Array1::linspace(0.0, 0.8, 50);
563        for (train, val) in DatasetView::from((records.view(), targets.view()))
564            .fold(2)
565            .into_iter()
566        {
567            assert_eq!(train.records().dim(), (25, 2));
568            assert_eq!(val.records().dim(), (25, 2));
569            assert_eq!(train.targets().dim(), 25);
570            assert_eq!(val.targets().dim(), 25);
571        }
572        assert_eq!(Dataset::from((records, targets)).fold(10).len(), 10);
573
574        let records =
575            Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
576        let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
577        for (i, (train, val)) in Dataset::from((records, targets))
578            .fold(5)
579            .into_iter()
580            .enumerate()
581        {
582            assert_eq!(val.records.row(0)[0] as usize, (i + 1));
583            assert_eq!(val.records.row(0)[1] as usize, (i + 1));
584            assert_eq!(val.targets[0] as usize, (i + 1));
585
586            for j in 0..4 {
587                assert!(train.records.row(j)[0] as usize != (i + 1));
588                assert!(train.records.row(j)[1] as usize != (i + 1));
589                assert!(train.targets[j] as usize != (i + 1));
590            }
591        }
592    }
593
594    #[test]
595    fn check_iteration() {
596        let dataset = Dataset::new(
597            array![[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.]],
598            array![[1, 2], [3, 4], [5, 6]],
599        )
600        .with_target_names(vec!["a", "b"]);
601
602        let res = dataset
603            .target_iter()
604            .map(|x| x.as_targets().remove_axis(Axis(1)).to_owned())
605            .collect::<Vec<_>>();
606
607        assert_eq!(res, &[array![1, 3, 5], array![2, 4, 6]]);
608
609        let mut iter = dataset.target_iter();
610        let first = iter.next();
611        let second = iter.next();
612
613        assert_eq!(vec!["a"], first.unwrap().target_names());
614        assert_eq!(vec!["b"], second.unwrap().target_names());
615
616        let res = dataset
617            .feature_iter()
618            .map(|x| x.records)
619            .collect::<Vec<_>>();
620
621        assert_eq!(
622            res,
623            &[
624                array![[1.], [5.], [9.]],
625                array![[2.], [6.], [10.]],
626                array![[3.], [7.], [11.]],
627                array![[4.], [8.], [12.]],
628            ]
629        );
630
631        let res = dataset
632            .sample_iter()
633            .map(|(a, b)| (a.to_owned(), b.to_owned()))
634            .collect::<Vec<_>>();
635
636        assert_eq!(
637            res,
638            &[
639                (array![1., 2., 3., 4.], array![1, 2]),
640                (array![5., 6., 7., 8.], array![3, 4]),
641                (array![9., 10., 11., 12.], array![5, 6]),
642            ]
643        );
644    }
645
646    use crate::traits::{Fit, PredictInplace};
647    use ndarray::ArrayView2;
648    use thiserror::Error;
649
650    struct MockFittable {
651        mock_var: usize,
652    }
653
654    struct MockFittableResult {
655        mock_var: usize,
656    }
657
658    #[derive(Error, Debug)]
659    enum MockError {
660        #[error(transparent)]
661        LinfaError(#[from] crate::error::Error),
662    }
663
664    type MockResult<T> = std::result::Result<T, MockError>;
665
666    impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, f64>, MockError> for MockFittable {
667        type Object = MockFittableResult;
668
669        fn fit(
670            &self,
671            training_data: &DatasetView<f64, f64, Ix1>,
672        ) -> std::result::Result<Self::Object, MockError> {
673            if self.mock_var == 0 {
674                Err(MockError::LinfaError(Error::Parameters("0".to_string())))
675            } else {
676                Ok(MockFittableResult {
677                    mock_var: training_data.nsamples(),
678                })
679            }
680        }
681    }
682
683    impl<'a> Fit<ArrayView2<'a, f64>, ArrayView2<'a, f64>, MockError> for MockFittable {
684        type Object = MockFittableResult;
685
686        fn fit(
687            &self,
688            training_data: &DatasetView<f64, f64, Ix2>,
689        ) -> std::result::Result<Self::Object, MockError> {
690            if self.mock_var == 0 {
691                Err(MockError::LinfaError(Error::Parameters("0".to_string())))
692            } else {
693                Ok(MockFittableResult {
694                    mock_var: training_data.nsamples(),
695                })
696            }
697        }
698    }
699
700    impl<'b> PredictInplace<ArrayView2<'b, f64>, Array1<f64>> for MockFittableResult {
701        fn predict_inplace<'a>(&'a self, x: &'a ArrayView2<'b, f64>, y: &mut Array1<f64>) {
702            assert_eq!(
703                x.nrows(),
704                y.len(),
705                "The number of data points must match the number of output targets."
706            );
707            *y = array![0.];
708        }
709
710        fn default_target(&self, x: &ArrayView2<f64>) -> Array1<f64> {
711            Array1::zeros(x.nrows())
712        }
713    }
714
715    impl<'b> PredictInplace<ArrayView2<'b, f64>, Array2<f64>> for MockFittableResult {
716        fn predict_inplace<'a>(&'a self, x: &'a ArrayView2<'b, f64>, y: &mut Array2<f64>) {
717            assert_eq!(
718                y.shape(),
719                &[x.nrows(), 2],
720                "The number of data points must match the number of output targets."
721            );
722            *y = array![[0., 0.]];
723        }
724
725        fn default_target(&self, x: &ArrayView2<f64>) -> Array2<f64> {
726            Array2::zeros((x.nrows(), 2))
727        }
728    }
729
730    #[test]
731    fn test_iter_fold() {
732        let records =
733            Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
734        let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
735        let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
736        let params = MockFittable { mock_var: 1 };
737
738        for (i, (model, validation_set)) in
739            dataset.iter_fold(5, |v| params.fit(v).unwrap()).enumerate()
740        {
741            assert_eq!(model.mock_var, 4);
742            assert_eq!(validation_set.records().row(0)[0] as usize, i + 1);
743            assert_eq!(validation_set.records().row(0)[1] as usize, i + 1);
744            assert_eq!(validation_set.targets()[0] as usize, i + 1);
745            assert_eq!(validation_set.records().dim(), (1, 2));
746            assert_eq!(validation_set.targets().dim(), 1);
747        }
748    }
749
750    #[test]
751    fn test_iter_fold_uneven_folds() {
752        let records =
753            Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
754        let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
755        let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
756        let params = MockFittable { mock_var: 1 };
757
758        // If we request three folds from a dataset with 5 samples it will cut the
759        // last two samples from the folds and always add them as a tail of the training
760        // data
761        for (i, (model, validation_set)) in
762            dataset.iter_fold(3, |v| params.fit(v).unwrap()).enumerate()
763        {
764            assert_eq!(model.mock_var, 4);
765            assert_eq!(validation_set.records().row(0)[0] as usize, i + 1);
766            assert_eq!(validation_set.records().row(0)[1] as usize, i + 1);
767            assert_eq!(validation_set.targets()[0] as usize, i + 1);
768            assert_eq!(validation_set.records().dim(), (1, 2));
769            assert_eq!(validation_set.targets().dim(), 1);
770            assert!(i < 3);
771        }
772
773        // the same goes for the last sample if we choose 4 folds
774        for (i, (model, validation_set)) in
775            dataset.iter_fold(4, |v| params.fit(v).unwrap()).enumerate()
776        {
777            assert_eq!(model.mock_var, 4);
778            assert_eq!(validation_set.records().row(0)[0] as usize, i + 1);
779            assert_eq!(validation_set.records().row(0)[1] as usize, i + 1);
780            assert_eq!(validation_set.targets()[0] as usize, i + 1);
781            assert_eq!(validation_set.records().dim(), (1, 2));
782            assert_eq!(validation_set.targets().dim(), 1);
783            assert!(i < 4);
784        }
785
786        // if we choose 2 folds then again the last sample will be only
787        // used for trainig
788        for (i, (model, validation_set)) in
789            dataset.iter_fold(2, |v| params.fit(v).unwrap()).enumerate()
790        {
791            assert_eq!(model.mock_var, 3);
792            assert_eq!(validation_set.targets().dim(), 2);
793            assert!(i < 2);
794        }
795    }
796
797    #[test]
798    #[should_panic]
799    fn iter_fold_panics_k_0() {
800        let records =
801            Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
802        let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
803        let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
804        let params = MockFittable { mock_var: 1 };
805        let _ = dataset.iter_fold(0, |v| params.fit(v)).enumerate();
806    }
807
808    #[test]
809    #[should_panic]
810    fn iter_fold_panics_k_more_than_samples() {
811        let records =
812            Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
813        let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
814        let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
815        let params = MockFittable { mock_var: 1 };
816        let _ = dataset.iter_fold(6, |v| params.fit(v)).enumerate();
817    }
818
819    #[test]
820    fn test_st_cv_all_correct() {
821        let records =
822            Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
823        let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
824        let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
825        let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 2 }];
826        let acc = dataset
827            .cross_validate_single(5, &params, |_pred, _truth| Ok(3.))
828            .unwrap();
829        assert_eq!(acc, array![3., 3.]);
830
831        let mut dataset: Dataset<f64, f64> =
832            (array![[1., 1.], [2., 2.]], array![[1., 2.], [3., 4.]]).into();
833
834        let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 2 }];
835        let acc = dataset
836            .cross_validate(2, &params, |_pred, _truth| Ok(array![3., 3.]))
837            .unwrap();
838        assert_eq!(acc, array![[3., 3.], [3., 3.]]);
839    }
840    #[test]
841    #[should_panic(
842        expected = "called `Result::unwrap()` on an `Err` value: LinfaError(Parameters(\"0\"))"
843    )]
844    fn test_st_cv_one_incorrect() {
845        let records =
846            Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
847        let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
848        let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
849        // second one should throw an error
850        let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 0 }];
851        let acc: MockResult<Array1<_>> =
852            dataset.cross_validate_single(5, &params, |_pred, _truth| Ok(0.));
853
854        acc.unwrap();
855    }
856
857    #[test]
858    #[should_panic(
859        expected = "called `Result::unwrap()` on an `Err` value: LinfaError(Parameters(\"eval\"))"
860    )]
861    fn test_st_cv_incorrect_eval() {
862        let records =
863            Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
864        let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
865        let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
866        // second one should throw an error
867        let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 1 }];
868        let err: MockResult<Array1<_>> =
869            dataset.cross_validate_single(5, &params, |_pred, _truth| {
870                if false {
871                    Ok(0f32)
872                } else {
873                    Err(Error::Parameters("eval".to_string()))
874                }
875            });
876
877        err.unwrap();
878    }
879
880    #[test]
881    fn test_st_cv_mt_all_correct() {
882        let records =
883            Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
884        let targets = array![[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]];
885        let mut dataset: Dataset<f64, f64> = (records, targets).into();
886        let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 2 }];
887        let acc = dataset
888            .cross_validate(5, &params, |_pred, _truth| Ok(array![5., 6.]))
889            .unwrap();
890        assert_eq!(acc.dim(), (params.len(), dataset.ntargets()));
891        assert_eq!(acc, array![[5., 6.], [5., 6.]])
892    }
893    #[test]
894    fn test_st_cv_mt_one_incorrect() {
895        let records =
896            Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
897        let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
898        let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
899        // second one should throw an error
900        let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 0 }];
901        let err = dataset
902            .cross_validate_single(5, &params, |_pred, _truth| Ok(5.))
903            .unwrap_err();
904        assert_eq!(err.to_string(), "invalid parameter 0".to_string());
905    }
906
907    #[test]
908    fn test_st_cv_mt_incorrect_eval() {
909        let records =
910            Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
911        let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
912        let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
913        // second one should throw an error
914        let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 1 }];
915        let err = dataset
916            .cross_validate_single(5, &params, |_pred, _truth| {
917                if false {
918                    Ok(0f32)
919                } else {
920                    Err(Error::Parameters("eval".to_string()))
921                }
922            })
923            .unwrap_err();
924        assert_eq!(err.to_string(), "invalid parameter eval".to_string());
925    }
926
927    #[test]
928    fn test_with_labels_st() {
929        let records = array![
930            [0., 1.],
931            [1., 2.],
932            [2., 3.],
933            [0., 4.],
934            [1., 5.],
935            [2., 6.],
936            [0., 7.],
937            [1., 8.],
938            [2., 9.],
939            [0., 10.]
940        ];
941        let targets = array![0, 1, 2, 0, 1, 2, 0, 1, 2, 0];
942        let dataset = DatasetBase::from((records, targets));
943        assert_eq!(dataset.nsamples(), 10);
944        assert_eq!(dataset.ntargets(), 1);
945        let dataset_no_0 = dataset.with_labels(&[1, 2]);
946        assert_eq!(dataset_no_0.nsamples(), 6);
947        assert_eq!(dataset_no_0.ntargets(), 1);
948        assert_abs_diff_eq!(
949            dataset_no_0.records,
950            array![[1., 2.], [2., 3.], [1., 5.], [2., 6.], [1., 8.], [2., 9.]]
951        );
952        assert_abs_diff_eq!(dataset_no_0.as_single_targets(), array![1, 2, 1, 2, 1, 2]);
953        let dataset_no_1 = dataset.with_labels(&[0, 2]);
954        assert_eq!(dataset_no_1.nsamples(), 7);
955        assert_eq!(dataset_no_1.ntargets(), 1);
956        assert_abs_diff_eq!(
957            dataset_no_1.records,
958            array![
959                [0., 1.],
960                [2., 3.],
961                [0., 4.],
962                [2., 6.],
963                [0., 7.],
964                [2., 9.],
965                [0., 10.]
966            ]
967        );
968        assert_abs_diff_eq!(
969            dataset_no_1.as_single_targets(),
970            array![0, 2, 0, 2, 0, 2, 0]
971        );
972        let dataset_no_2 = dataset.with_labels(&[0, 1]);
973        assert_eq!(dataset_no_2.nsamples(), 7);
974        assert_eq!(dataset_no_2.ntargets(), 1);
975        assert_abs_diff_eq!(
976            dataset_no_2.records,
977            array![
978                [0., 1.],
979                [1., 2.],
980                [0., 4.],
981                [1., 5.],
982                [0., 7.],
983                [1., 8.],
984                [0., 10.]
985            ]
986        );
987        assert_abs_diff_eq!(
988            dataset_no_2.as_single_targets(),
989            array![0, 1, 0, 1, 0, 1, 0]
990        );
991    }
992
993    #[test]
994    fn test_with_labels_mt() {
995        let records = array![
996            [0., 1.],
997            [1., 2.],
998            [2., 3.],
999            [0., 4.],
1000            [1., 5.],
1001            [2., 6.],
1002            [0., 7.],
1003            [1., 8.],
1004            [2., 9.],
1005            [0., 10.]
1006        ];
1007        let targets = array![
1008            [0, 7],
1009            [1, 8],
1010            [2, 9],
1011            [0, 7],
1012            [1, 8],
1013            [2, 9],
1014            [0, 7],
1015            [1, 8],
1016            [2, 9],
1017            [0, 7]
1018        ];
1019        let dataset = DatasetBase::from((records, targets));
1020        assert_eq!(dataset.nsamples(), 10);
1021        assert_eq!(dataset.ntargets(), 2);
1022        // remove 0 from target 1 and 7 from target 2
1023        let dataset_no_07 = dataset.with_labels(&[1, 2, 8, 9]);
1024        assert_eq!(dataset_no_07.nsamples(), 6);
1025        assert_eq!(dataset_no_07.ntargets(), 2);
1026        assert_abs_diff_eq!(
1027            dataset_no_07.records,
1028            array![[1., 2.], [2., 3.], [1., 5.], [2., 6.], [1., 8.], [2., 9.]]
1029        );
1030        assert_abs_diff_eq!(
1031            dataset_no_07.as_multi_targets(),
1032            array![[1, 8], [2, 9], [1, 8], [2, 9], [1, 8], [2, 9]]
1033        );
1034        // remove label 1 from target 1 and label 7 from target 2: with labels is an "any" so all targets should be kept
1035        let dataset_no_17 = dataset.with_labels(&[0, 2, 8, 9]);
1036        assert_eq!(dataset_no_17.nsamples(), 10);
1037        assert_eq!(dataset_no_17.ntargets(), 2);
1038    }
1039
1040    #[test]
1041    fn correct_probability_creation() {
1042        let prob = 0.5;
1043        assert_abs_diff_eq!(Pr::new(prob).0, prob);
1044    }
1045
1046    #[test]
1047    #[should_panic]
1048    fn negative_probability_panics() {
1049        let prob = -0.5;
1050        Pr::new(prob);
1051    }
1052
1053    #[test]
1054    fn negative_probability_unchecked() {
1055        let prob = -0.5;
1056        assert_abs_diff_eq!(Pr::new_unchecked(prob).0, prob);
1057    }
1058
1059    #[test]
1060    fn test_dataset_shuffle() {
1061        let mut rng = SmallRng::seed_from_u64(42);
1062        let f_names = vec!["f1", "f2", "f3"];
1063        let t_names = vec!["t1"];
1064        let dataset = Dataset::new(
1065            array![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
1066            array![0., 1., 3.],
1067        )
1068        .with_feature_names(f_names.clone())
1069        .with_target_names(t_names.clone());
1070
1071        let shuffled = dataset.shuffle(&mut rng);
1072
1073        assert_abs_diff_ne!(dataset.records(), shuffled.records());
1074        assert_abs_diff_ne!(dataset.targets(), shuffled.targets());
1075        assert_eq!(f_names, shuffled.feature_names());
1076        assert_eq!(t_names, shuffled.target_names());
1077    }
1078}