1use ndarray::{
6 Array, Array1, ArrayBase, ArrayView, ArrayView1, ArrayView2, ArrayViewMut, ArrayViewMut1,
7 ArrayViewMut2, CowArray, Ix1, Ix2, Ix3, NdFloat, OwnedRepr, RemoveAxis, ScalarOperand,
8};
9
10#[cfg(feature = "ndarray-linalg")]
11use ndarray_linalg::{Lapack, Scalar};
12
13use num_traits::{AsPrimitive, FromPrimitive, NumCast, Signed};
14use rand::distributions::uniform::SampleUniform;
15
16use std::cmp::{Ordering, PartialOrd};
17use std::collections::{HashMap, HashSet};
18use std::convert::{TryFrom, TryInto};
19use std::fmt;
20use std::hash::Hash;
21use std::iter::Sum;
22use std::ops::{AddAssign, Deref, DivAssign, MulAssign, SubAssign};
23
24use crate::error::Result;
25
26mod impl_dataset;
27mod impl_records;
28mod impl_targets;
29
30mod iter;
31
32mod lapack_bounds;
33pub use lapack_bounds::*;
34
35pub trait Float:
41 NdFloat
42 + FromPrimitive
43 + Default
44 + Signed
45 + Sum
46 + AsPrimitive<usize>
47 + for<'a> AddAssign<&'a Self>
48 + for<'a> MulAssign<&'a Self>
49 + for<'a> SubAssign<&'a Self>
50 + for<'a> DivAssign<&'a Self>
51 + num_traits::MulAdd<Output = Self>
52 + SampleUniform
53 + ScalarOperand
54 + approx::AbsDiffEq
55 + std::marker::Unpin
56 + sprs::MulAcc
57{
58 #[cfg(feature = "ndarray-linalg")]
59 type Lapack: Float + Scalar + Lapack;
60 #[cfg(not(feature = "ndarray-linalg"))]
61 type Lapack: Float;
62
63 fn cast<T: NumCast>(x: T) -> Self {
64 NumCast::from(x).unwrap()
65 }
66}
67
68impl Float for f32 {
69 type Lapack = f32;
70}
71
72impl Float for f64 {
73 type Lapack = f64;
74}
75
76pub trait Label: PartialEq + Eq + Hash + Clone + Ord + fmt::Debug + Default {}
81
82impl Label for bool {}
83impl Label for usize {}
84impl Label for String {}
85impl Label for () {}
86impl Label for &str {}
87impl<L: Label> Label for Option<L> {}
88
89#[repr(transparent)]
95#[derive(Debug, Copy, Clone, Default)]
96pub struct Pr(f32);
97
98impl TryFrom<f32> for Pr {
103 type Error = f32;
104
105 fn try_from(prob: f32) -> std::result::Result<Self, Self::Error> {
106 if (0. ..=1.).contains(&prob) {
107 Ok(Pr(prob))
108 } else {
109 Err(prob)
110 }
111 }
112}
113
114impl Pr {
115 pub fn new(prob: f32) -> Self {
120 prob.try_into().unwrap()
121 }
122
123 pub fn new_unchecked(prob: f32) -> Self {
126 Pr(prob)
127 }
128 pub fn even() -> Pr {
129 Pr(0.5)
130 }
131}
132
133impl PartialEq for Pr {
134 fn eq(&self, other: &Self) -> bool {
135 self.0 == other.0
136 }
137}
138
139impl PartialOrd for Pr {
140 fn partial_cmp(&self, other: &Pr) -> Option<Ordering> {
141 self.0.partial_cmp(&other.0)
142 }
143}
144
145impl Deref for Pr {
146 type Target = f32;
147
148 fn deref(&self) -> &f32 {
149 &self.0
150 }
151}
152
153#[derive(Debug, Clone, PartialEq)]
176pub struct DatasetBase<R, T>
177where
178 R: Records,
179{
180 pub records: R,
181 pub targets: T,
182
183 pub weights: Array1<f32>,
184 feature_names: Vec<String>,
185 target_names: Vec<String>,
186}
187
188#[derive(Debug, Clone, PartialEq, Eq)]
199pub struct CountedTargets<L: Label, P> {
200 targets: P,
201 labels: Vec<HashMap<L, usize>>,
202}
203
204pub type Dataset<D, T, I = Ix2> =
210 DatasetBase<ArrayBase<OwnedRepr<D>, Ix2>, ArrayBase<OwnedRepr<T>, I>>;
211
212pub type DatasetView<'a, D, T, I = Ix2> = DatasetBase<ArrayView<'a, D, Ix2>, ArrayView<'a, T, I>>;
216
217pub type DatasetPr<D, L> =
223 DatasetBase<ArrayBase<OwnedRepr<D>, Ix2>, CountedTargets<L, ArrayBase<OwnedRepr<Pr>, Ix3>>>;
224
225pub trait Records: Sized {
227 type Elem;
228
229 fn nsamples(&self) -> usize;
230 fn nfeatures(&self) -> usize;
231}
232
233pub trait TargetDim: RemoveAxis {
234 fn nsamples(mut self, nsamples: usize) -> Self {
235 self.as_array_view_mut()[0] = nsamples;
236 self
237 }
238}
239
240pub trait AsTargets {
245 type Elem;
246 type Ix: TargetDim;
247
248 fn as_targets(&self) -> ArrayView<'_, Self::Elem, Self::Ix>;
249}
250
251pub trait AsSingleTargets: AsTargets<Ix = Ix1> {
253 fn as_single_targets(&self) -> ArrayView1<'_, Self::Elem> {
254 self.as_targets()
255 }
256}
257
258pub trait AsMultiTargets: AsTargets<Ix = Ix2> {
260 fn as_multi_targets(&self) -> ArrayView2<'_, Self::Elem> {
261 self.as_targets()
262 }
263}
264
265pub trait FromTargetArrayOwned: AsTargets {
266 type Owned;
267
268 fn new_targets(targets: Array<Self::Elem, Self::Ix>) -> Self::Owned;
270}
271
272pub trait FromTargetArray<'a>: AsTargets {
278 type View;
279
280 fn new_targets_view(targets: ArrayView<'a, Self::Elem, Self::Ix>) -> Self::View;
282}
283
284pub trait AsTargetsMut {
289 type Elem;
290 type Ix: TargetDim;
291
292 fn as_targets_mut(&mut self) -> ArrayViewMut<'_, Self::Elem, Self::Ix>;
293}
294
295pub trait AsSingleTargetsMut: AsTargetsMut<Ix = Ix1> {
297 fn as_single_targets_mut(&mut self) -> ArrayViewMut1<'_, Self::Elem> {
298 self.as_targets_mut()
299 }
300}
301
302pub trait AsMultiTargetsMut: AsTargetsMut<Ix = Ix2> {
304 fn as_multi_targets_mut(&mut self) -> ArrayViewMut2<'_, Self::Elem> {
305 self.as_targets_mut()
306 }
307}
308
309pub trait AsProbabilities {
314 fn as_multi_target_probabilities(&self) -> CowArray<'_, Pr, Ix3>;
315}
316
317pub trait Labels {
320 type Elem: Label;
321
322 fn label_count(&self) -> Vec<HashMap<Self::Elem, usize>>;
323
324 fn label_set(&self) -> Vec<HashSet<Self::Elem>> {
325 self.label_count()
326 .iter()
327 .map(|x| x.keys().cloned().collect::<HashSet<_>>())
328 .collect()
329 }
330
331 fn labels(&self) -> Vec<Self::Elem> {
332 self.label_set()
333 .into_iter()
334 .flatten()
335 .collect::<HashSet<_>>()
336 .into_iter()
337 .collect()
338 }
339
340 fn combined_labels<T>(&self, other: &T) -> Vec<Self::Elem>
341 where
342 T: Labels<Elem = <Self as Labels>::Elem>,
343 {
344 let mut combined = self.label_set();
345 combined.extend(other.label_set());
346
347 combined
348 .iter()
349 .flatten()
350 .collect::<HashSet<_>>()
351 .into_iter()
352 .cloned()
353 .collect()
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::error::Error;
361 use approx::{assert_abs_diff_eq, assert_abs_diff_ne};
362 use linfa_datasets::generate::make_dataset;
363 use ndarray::{array, Array1, Array2, Axis};
364 use rand::{rngs::SmallRng, SeedableRng};
365 use statrs::distribution::{DiscreteUniform, Laplace};
366
367 #[test]
368 fn into_single_target() {
369 let feat_distr = Laplace::new(0.5, 5.).unwrap();
370 let target_distr = DiscreteUniform::new(0, 5).unwrap();
371 let dataset = make_dataset(10, 5, 1, feat_distr, target_distr);
372 assert!(dataset.into_single_target().targets.shape() == [10]);
373 }
374
375 #[test]
376 fn set_target_name() {
377 let dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![0., 1.])
378 .with_target_names(vec!["test"]);
379 assert_eq!(dataset.target_names, vec!["test"]);
380 }
381
382 #[test]
383 fn empty_target_name() {
384 let dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![[0., 1.], [2., 3.]]);
385 assert_eq!(dataset.target_names, Vec::<String>::new());
386 }
387
388 #[test]
389 #[should_panic]
390 fn test_wrong_feature_names_lenght() {
391 let _dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![0., 1.])
392 .with_feature_names(vec!["test"]);
393 }
394
395 #[test]
396 #[should_panic]
397 fn test_wrong_target_names_lenght() {
398 let _dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![0., 1.])
399 .with_target_names(vec!["test", "bad"]);
400 }
401
402 #[test]
403 fn dataset_implements_required_methods() {
404 let mut rng = SmallRng::seed_from_u64(42);
405
406 let mut dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![0., 1.]);
410
411 dataset = dataset.shuffle(&mut rng);
413
414 {
416 let mut iter = dataset.bootstrap_samples(3, &mut rng);
417 for _ in 1..5 {
418 let b_dataset = iter.next().unwrap();
419 assert_eq!(b_dataset.records().dim().0, 3);
420 }
421 }
422
423 {
425 let mut iter = dataset.bootstrap_features(3, &mut rng);
426 for _ in 1..5 {
427 let dataset = iter.next().unwrap();
428 assert_eq!(dataset.records().dim(), (2, 3));
429 }
430 }
431
432 {
434 let mut iter = dataset.bootstrap((10, 10), &mut rng);
435 for _ in 1..5 {
436 let dataset = iter.next().unwrap();
437 assert_eq!(dataset.records().dim(), (10, 10));
438 }
439 }
440
441 let linspace: Array1<f64> = Array1::linspace(0.0, 0.8, 100);
442 let records = Array2::from_shape_vec((50, 2), linspace.to_vec()).unwrap();
443 let targets: Array1<f64> = Array1::linspace(0.0, 0.8, 50);
444 let dataset = Dataset::from((records, targets));
445
446 let dataset_view = dataset.view();
448 let (train, val) = dataset_view.split_with_ratio(0.5);
449 assert_eq!(train.nsamples(), 25);
450 assert_eq!(val.nsamples(), 25);
451
452 let (train, val) = dataset.split_with_ratio(0.25);
454 assert_eq!(train.targets().dim(), 13);
455 assert_eq!(val.targets().dim(), 37);
456 assert_eq!(train.records().dim().0, 13);
457 assert_eq!(val.records().dim().0, 37);
458
459 let dataset_multiclass =
461 Dataset::from((array![[1., 2.], [2., 1.], [0., 0.]], array![0usize, 1, 2]));
462
463 let datasets_one_vs_all = dataset_multiclass.one_vs_all().unwrap();
465
466 assert_eq!(datasets_one_vs_all.len(), 3);
467
468 for (_, dataset) in datasets_one_vs_all.iter() {
469 assert_eq!(dataset.labels().iter().filter(|x| **x).count(), 1);
470 }
471
472 let dataset_multiclass = Dataset::from((
473 array![[1., 2.], [2., 1.], [0., 0.], [2., 2.]],
474 array![0, 1, 2, 2],
475 ));
476
477 let freqs = dataset_multiclass.label_frequencies_with_mask(&[true, true, true, true]);
479 assert_eq!(*freqs.get(&0).unwrap() as usize, 1);
480 assert_eq!(*freqs.get(&1).unwrap() as usize, 1);
481 assert_eq!(*freqs.get(&2).unwrap() as usize, 2);
482
483 let freqs = dataset_multiclass.label_frequencies_with_mask(&[true, true, true, false]);
484 assert_eq!(*freqs.get(&0).unwrap() as usize, 1);
485 assert_eq!(*freqs.get(&1).unwrap() as usize, 1);
486 assert_eq!(*freqs.get(&2).unwrap() as usize, 1);
487 }
488
489 #[test]
490 fn dataset_view_implements_required_methods() -> Result<()> {
491 let mut rng = SmallRng::seed_from_u64(42);
492 let records = array![[1., 2.], [1., 2.]];
493 let targets = array![0., 1.];
494
495 let dataset_view = DatasetView::from((records.view(), targets.view()));
499
500 let _shuffled_owned = dataset_view.shuffle(&mut rng);
502
503 let mut iter = dataset_view.bootstrap_samples(3, &mut rng);
505 for _ in 1..5 {
506 let b_dataset = iter.next().unwrap();
507 assert_eq!(b_dataset.records().dim().0, 3);
508 }
509
510 let linspace: Array1<f64> = Array1::linspace(0.0, 0.8, 100);
511 let records = Array2::from_shape_vec((50, 2), linspace.to_vec()).unwrap();
512 let targets: Array1<f64> = Array1::linspace(0.0, 0.8, 50);
513 let dataset = Dataset::from((records, targets));
514
515 let view: DatasetView<f64, f64, Ix1> = dataset.view();
517
518 let (train, val) = view.split_with_ratio(0.5);
519 assert_eq!(train.targets().len(), 25);
520 assert_eq!(val.targets().len(), 25);
521 assert_eq!(train.nsamples(), 25);
522 assert_eq!(val.nsamples(), 25);
523
524 let dataset_multiclass =
526 Dataset::from((array![[1., 2.], [2., 1.], [0., 0.]], array![0, 1, 2]));
527 let view: DatasetView<f64, usize, Ix1> = dataset_multiclass.view();
528
529 let datasets_one_vs_all = view.one_vs_all()?;
531 assert_eq!(datasets_one_vs_all.len(), 3);
532
533 for (_, dataset) in datasets_one_vs_all.iter() {
534 assert_eq!(dataset.labels().iter().filter(|x| **x).count(), 1);
535 }
536
537 let dataset_multiclass = Dataset::from((
538 array![[1., 2.], [2., 1.], [0., 0.], [2., 2.]],
539 array![0, 1, 2, 2],
540 ));
541
542 let view: DatasetView<f64, usize, Ix1> = dataset_multiclass.view();
543
544 let freqs = view.label_frequencies_with_mask(&[true, true, true, true]);
546 assert_eq!(*freqs.get(&0).unwrap() as usize, 1);
547 assert_eq!(*freqs.get(&1).unwrap() as usize, 1);
548 assert_eq!(*freqs.get(&2).unwrap() as usize, 2);
549
550 let freqs = view.label_frequencies_with_mask(&[true, true, true, false]);
551 assert_eq!(*freqs.get(&0).unwrap() as usize, 1);
552 assert_eq!(*freqs.get(&1).unwrap() as usize, 1);
553 assert_eq!(*freqs.get(&2).unwrap() as usize, 1);
554
555 Ok(())
556 }
557
558 #[test]
559 fn datasets_have_k_fold() {
560 let linspace: Array1<f64> = Array1::linspace(0.0, 0.8, 100);
561 let records = Array2::from_shape_vec((50, 2), linspace.to_vec()).unwrap();
562 let targets: Array1<f64> = Array1::linspace(0.0, 0.8, 50);
563 for (train, val) in DatasetView::from((records.view(), targets.view()))
564 .fold(2)
565 .into_iter()
566 {
567 assert_eq!(train.records().dim(), (25, 2));
568 assert_eq!(val.records().dim(), (25, 2));
569 assert_eq!(train.targets().dim(), 25);
570 assert_eq!(val.targets().dim(), 25);
571 }
572 assert_eq!(Dataset::from((records, targets)).fold(10).len(), 10);
573
574 let records =
575 Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
576 let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
577 for (i, (train, val)) in Dataset::from((records, targets))
578 .fold(5)
579 .into_iter()
580 .enumerate()
581 {
582 assert_eq!(val.records.row(0)[0] as usize, (i + 1));
583 assert_eq!(val.records.row(0)[1] as usize, (i + 1));
584 assert_eq!(val.targets[0] as usize, (i + 1));
585
586 for j in 0..4 {
587 assert!(train.records.row(j)[0] as usize != (i + 1));
588 assert!(train.records.row(j)[1] as usize != (i + 1));
589 assert!(train.targets[j] as usize != (i + 1));
590 }
591 }
592 }
593
594 #[test]
595 fn check_iteration() {
596 let dataset = Dataset::new(
597 array![[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.]],
598 array![[1, 2], [3, 4], [5, 6]],
599 )
600 .with_target_names(vec!["a", "b"]);
601
602 let res = dataset
603 .target_iter()
604 .map(|x| x.as_targets().remove_axis(Axis(1)).to_owned())
605 .collect::<Vec<_>>();
606
607 assert_eq!(res, &[array![1, 3, 5], array![2, 4, 6]]);
608
609 let mut iter = dataset.target_iter();
610 let first = iter.next();
611 let second = iter.next();
612
613 assert_eq!(vec!["a"], first.unwrap().target_names());
614 assert_eq!(vec!["b"], second.unwrap().target_names());
615
616 let res = dataset
617 .feature_iter()
618 .map(|x| x.records)
619 .collect::<Vec<_>>();
620
621 assert_eq!(
622 res,
623 &[
624 array![[1.], [5.], [9.]],
625 array![[2.], [6.], [10.]],
626 array![[3.], [7.], [11.]],
627 array![[4.], [8.], [12.]],
628 ]
629 );
630
631 let res = dataset
632 .sample_iter()
633 .map(|(a, b)| (a.to_owned(), b.to_owned()))
634 .collect::<Vec<_>>();
635
636 assert_eq!(
637 res,
638 &[
639 (array![1., 2., 3., 4.], array![1, 2]),
640 (array![5., 6., 7., 8.], array![3, 4]),
641 (array![9., 10., 11., 12.], array![5, 6]),
642 ]
643 );
644 }
645
646 use crate::traits::{Fit, PredictInplace};
647 use ndarray::ArrayView2;
648 use thiserror::Error;
649
650 struct MockFittable {
651 mock_var: usize,
652 }
653
654 struct MockFittableResult {
655 mock_var: usize,
656 }
657
658 #[derive(Error, Debug)]
659 enum MockError {
660 #[error(transparent)]
661 LinfaError(#[from] crate::error::Error),
662 }
663
664 type MockResult<T> = std::result::Result<T, MockError>;
665
666 impl<'a> Fit<ArrayView2<'a, f64>, ArrayView1<'a, f64>, MockError> for MockFittable {
667 type Object = MockFittableResult;
668
669 fn fit(
670 &self,
671 training_data: &DatasetView<f64, f64, Ix1>,
672 ) -> std::result::Result<Self::Object, MockError> {
673 if self.mock_var == 0 {
674 Err(MockError::LinfaError(Error::Parameters("0".to_string())))
675 } else {
676 Ok(MockFittableResult {
677 mock_var: training_data.nsamples(),
678 })
679 }
680 }
681 }
682
683 impl<'a> Fit<ArrayView2<'a, f64>, ArrayView2<'a, f64>, MockError> for MockFittable {
684 type Object = MockFittableResult;
685
686 fn fit(
687 &self,
688 training_data: &DatasetView<f64, f64, Ix2>,
689 ) -> std::result::Result<Self::Object, MockError> {
690 if self.mock_var == 0 {
691 Err(MockError::LinfaError(Error::Parameters("0".to_string())))
692 } else {
693 Ok(MockFittableResult {
694 mock_var: training_data.nsamples(),
695 })
696 }
697 }
698 }
699
700 impl<'b> PredictInplace<ArrayView2<'b, f64>, Array1<f64>> for MockFittableResult {
701 fn predict_inplace<'a>(&'a self, x: &'a ArrayView2<'b, f64>, y: &mut Array1<f64>) {
702 assert_eq!(
703 x.nrows(),
704 y.len(),
705 "The number of data points must match the number of output targets."
706 );
707 *y = array![0.];
708 }
709
710 fn default_target(&self, x: &ArrayView2<f64>) -> Array1<f64> {
711 Array1::zeros(x.nrows())
712 }
713 }
714
715 impl<'b> PredictInplace<ArrayView2<'b, f64>, Array2<f64>> for MockFittableResult {
716 fn predict_inplace<'a>(&'a self, x: &'a ArrayView2<'b, f64>, y: &mut Array2<f64>) {
717 assert_eq!(
718 y.shape(),
719 &[x.nrows(), 2],
720 "The number of data points must match the number of output targets."
721 );
722 *y = array![[0., 0.]];
723 }
724
725 fn default_target(&self, x: &ArrayView2<f64>) -> Array2<f64> {
726 Array2::zeros((x.nrows(), 2))
727 }
728 }
729
730 #[test]
731 fn test_iter_fold() {
732 let records =
733 Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
734 let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
735 let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
736 let params = MockFittable { mock_var: 1 };
737
738 for (i, (model, validation_set)) in
739 dataset.iter_fold(5, |v| params.fit(v).unwrap()).enumerate()
740 {
741 assert_eq!(model.mock_var, 4);
742 assert_eq!(validation_set.records().row(0)[0] as usize, i + 1);
743 assert_eq!(validation_set.records().row(0)[1] as usize, i + 1);
744 assert_eq!(validation_set.targets()[0] as usize, i + 1);
745 assert_eq!(validation_set.records().dim(), (1, 2));
746 assert_eq!(validation_set.targets().dim(), 1);
747 }
748 }
749
750 #[test]
751 fn test_iter_fold_uneven_folds() {
752 let records =
753 Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
754 let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
755 let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
756 let params = MockFittable { mock_var: 1 };
757
758 for (i, (model, validation_set)) in
762 dataset.iter_fold(3, |v| params.fit(v).unwrap()).enumerate()
763 {
764 assert_eq!(model.mock_var, 4);
765 assert_eq!(validation_set.records().row(0)[0] as usize, i + 1);
766 assert_eq!(validation_set.records().row(0)[1] as usize, i + 1);
767 assert_eq!(validation_set.targets()[0] as usize, i + 1);
768 assert_eq!(validation_set.records().dim(), (1, 2));
769 assert_eq!(validation_set.targets().dim(), 1);
770 assert!(i < 3);
771 }
772
773 for (i, (model, validation_set)) in
775 dataset.iter_fold(4, |v| params.fit(v).unwrap()).enumerate()
776 {
777 assert_eq!(model.mock_var, 4);
778 assert_eq!(validation_set.records().row(0)[0] as usize, i + 1);
779 assert_eq!(validation_set.records().row(0)[1] as usize, i + 1);
780 assert_eq!(validation_set.targets()[0] as usize, i + 1);
781 assert_eq!(validation_set.records().dim(), (1, 2));
782 assert_eq!(validation_set.targets().dim(), 1);
783 assert!(i < 4);
784 }
785
786 for (i, (model, validation_set)) in
789 dataset.iter_fold(2, |v| params.fit(v).unwrap()).enumerate()
790 {
791 assert_eq!(model.mock_var, 3);
792 assert_eq!(validation_set.targets().dim(), 2);
793 assert!(i < 2);
794 }
795 }
796
797 #[test]
798 #[should_panic]
799 fn iter_fold_panics_k_0() {
800 let records =
801 Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
802 let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
803 let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
804 let params = MockFittable { mock_var: 1 };
805 let _ = dataset.iter_fold(0, |v| params.fit(v)).enumerate();
806 }
807
808 #[test]
809 #[should_panic]
810 fn iter_fold_panics_k_more_than_samples() {
811 let records =
812 Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
813 let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
814 let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
815 let params = MockFittable { mock_var: 1 };
816 let _ = dataset.iter_fold(6, |v| params.fit(v)).enumerate();
817 }
818
819 #[test]
820 fn test_st_cv_all_correct() {
821 let records =
822 Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
823 let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
824 let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
825 let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 2 }];
826 let acc = dataset
827 .cross_validate_single(5, ¶ms, |_pred, _truth| Ok(3.))
828 .unwrap();
829 assert_eq!(acc, array![3., 3.]);
830
831 let mut dataset: Dataset<f64, f64> =
832 (array![[1., 1.], [2., 2.]], array![[1., 2.], [3., 4.]]).into();
833
834 let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 2 }];
835 let acc = dataset
836 .cross_validate(2, ¶ms, |_pred, _truth| Ok(array![3., 3.]))
837 .unwrap();
838 assert_eq!(acc, array![[3., 3.], [3., 3.]]);
839 }
840 #[test]
841 #[should_panic(
842 expected = "called `Result::unwrap()` on an `Err` value: LinfaError(Parameters(\"0\"))"
843 )]
844 fn test_st_cv_one_incorrect() {
845 let records =
846 Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
847 let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
848 let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
849 let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 0 }];
851 let acc: MockResult<Array1<_>> =
852 dataset.cross_validate_single(5, ¶ms, |_pred, _truth| Ok(0.));
853
854 acc.unwrap();
855 }
856
857 #[test]
858 #[should_panic(
859 expected = "called `Result::unwrap()` on an `Err` value: LinfaError(Parameters(\"eval\"))"
860 )]
861 fn test_st_cv_incorrect_eval() {
862 let records =
863 Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
864 let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
865 let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
866 let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 1 }];
868 let err: MockResult<Array1<_>> =
869 dataset.cross_validate_single(5, ¶ms, |_pred, _truth| {
870 if false {
871 Ok(0f32)
872 } else {
873 Err(Error::Parameters("eval".to_string()))
874 }
875 });
876
877 err.unwrap();
878 }
879
880 #[test]
881 fn test_st_cv_mt_all_correct() {
882 let records =
883 Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
884 let targets = array![[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]];
885 let mut dataset: Dataset<f64, f64> = (records, targets).into();
886 let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 2 }];
887 let acc = dataset
888 .cross_validate(5, ¶ms, |_pred, _truth| Ok(array![5., 6.]))
889 .unwrap();
890 assert_eq!(acc.dim(), (params.len(), dataset.ntargets()));
891 assert_eq!(acc, array![[5., 6.], [5., 6.]])
892 }
893 #[test]
894 fn test_st_cv_mt_one_incorrect() {
895 let records =
896 Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
897 let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
898 let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
899 let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 0 }];
901 let err = dataset
902 .cross_validate_single(5, ¶ms, |_pred, _truth| Ok(5.))
903 .unwrap_err();
904 assert_eq!(err.to_string(), "invalid parameter 0".to_string());
905 }
906
907 #[test]
908 fn test_st_cv_mt_incorrect_eval() {
909 let records =
910 Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
911 let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
912 let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
913 let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 1 }];
915 let err = dataset
916 .cross_validate_single(5, ¶ms, |_pred, _truth| {
917 if false {
918 Ok(0f32)
919 } else {
920 Err(Error::Parameters("eval".to_string()))
921 }
922 })
923 .unwrap_err();
924 assert_eq!(err.to_string(), "invalid parameter eval".to_string());
925 }
926
927 #[test]
928 fn test_with_labels_st() {
929 let records = array![
930 [0., 1.],
931 [1., 2.],
932 [2., 3.],
933 [0., 4.],
934 [1., 5.],
935 [2., 6.],
936 [0., 7.],
937 [1., 8.],
938 [2., 9.],
939 [0., 10.]
940 ];
941 let targets = array![0, 1, 2, 0, 1, 2, 0, 1, 2, 0];
942 let dataset = DatasetBase::from((records, targets));
943 assert_eq!(dataset.nsamples(), 10);
944 assert_eq!(dataset.ntargets(), 1);
945 let dataset_no_0 = dataset.with_labels(&[1, 2]);
946 assert_eq!(dataset_no_0.nsamples(), 6);
947 assert_eq!(dataset_no_0.ntargets(), 1);
948 assert_abs_diff_eq!(
949 dataset_no_0.records,
950 array![[1., 2.], [2., 3.], [1., 5.], [2., 6.], [1., 8.], [2., 9.]]
951 );
952 assert_abs_diff_eq!(dataset_no_0.as_single_targets(), array![1, 2, 1, 2, 1, 2]);
953 let dataset_no_1 = dataset.with_labels(&[0, 2]);
954 assert_eq!(dataset_no_1.nsamples(), 7);
955 assert_eq!(dataset_no_1.ntargets(), 1);
956 assert_abs_diff_eq!(
957 dataset_no_1.records,
958 array![
959 [0., 1.],
960 [2., 3.],
961 [0., 4.],
962 [2., 6.],
963 [0., 7.],
964 [2., 9.],
965 [0., 10.]
966 ]
967 );
968 assert_abs_diff_eq!(
969 dataset_no_1.as_single_targets(),
970 array![0, 2, 0, 2, 0, 2, 0]
971 );
972 let dataset_no_2 = dataset.with_labels(&[0, 1]);
973 assert_eq!(dataset_no_2.nsamples(), 7);
974 assert_eq!(dataset_no_2.ntargets(), 1);
975 assert_abs_diff_eq!(
976 dataset_no_2.records,
977 array![
978 [0., 1.],
979 [1., 2.],
980 [0., 4.],
981 [1., 5.],
982 [0., 7.],
983 [1., 8.],
984 [0., 10.]
985 ]
986 );
987 assert_abs_diff_eq!(
988 dataset_no_2.as_single_targets(),
989 array![0, 1, 0, 1, 0, 1, 0]
990 );
991 }
992
993 #[test]
994 fn test_with_labels_mt() {
995 let records = array![
996 [0., 1.],
997 [1., 2.],
998 [2., 3.],
999 [0., 4.],
1000 [1., 5.],
1001 [2., 6.],
1002 [0., 7.],
1003 [1., 8.],
1004 [2., 9.],
1005 [0., 10.]
1006 ];
1007 let targets = array![
1008 [0, 7],
1009 [1, 8],
1010 [2, 9],
1011 [0, 7],
1012 [1, 8],
1013 [2, 9],
1014 [0, 7],
1015 [1, 8],
1016 [2, 9],
1017 [0, 7]
1018 ];
1019 let dataset = DatasetBase::from((records, targets));
1020 assert_eq!(dataset.nsamples(), 10);
1021 assert_eq!(dataset.ntargets(), 2);
1022 let dataset_no_07 = dataset.with_labels(&[1, 2, 8, 9]);
1024 assert_eq!(dataset_no_07.nsamples(), 6);
1025 assert_eq!(dataset_no_07.ntargets(), 2);
1026 assert_abs_diff_eq!(
1027 dataset_no_07.records,
1028 array![[1., 2.], [2., 3.], [1., 5.], [2., 6.], [1., 8.], [2., 9.]]
1029 );
1030 assert_abs_diff_eq!(
1031 dataset_no_07.as_multi_targets(),
1032 array![[1, 8], [2, 9], [1, 8], [2, 9], [1, 8], [2, 9]]
1033 );
1034 let dataset_no_17 = dataset.with_labels(&[0, 2, 8, 9]);
1036 assert_eq!(dataset_no_17.nsamples(), 10);
1037 assert_eq!(dataset_no_17.ntargets(), 2);
1038 }
1039
1040 #[test]
1041 fn correct_probability_creation() {
1042 let prob = 0.5;
1043 assert_abs_diff_eq!(Pr::new(prob).0, prob);
1044 }
1045
1046 #[test]
1047 #[should_panic]
1048 fn negative_probability_panics() {
1049 let prob = -0.5;
1050 Pr::new(prob);
1051 }
1052
1053 #[test]
1054 fn negative_probability_unchecked() {
1055 let prob = -0.5;
1056 assert_abs_diff_eq!(Pr::new_unchecked(prob).0, prob);
1057 }
1058
1059 #[test]
1060 fn test_dataset_shuffle() {
1061 let mut rng = SmallRng::seed_from_u64(42);
1062 let f_names = vec!["f1", "f2", "f3"];
1063 let t_names = vec!["t1"];
1064 let dataset = Dataset::new(
1065 array![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
1066 array![0., 1., 3.],
1067 )
1068 .with_feature_names(f_names.clone())
1069 .with_target_names(t_names.clone());
1070
1071 let shuffled = dataset.shuffle(&mut rng);
1072
1073 assert_abs_diff_ne!(dataset.records(), shuffled.records());
1074 assert_abs_diff_ne!(dataset.targets(), shuffled.targets());
1075 assert_eq!(f_names, shuffled.feature_names());
1076 assert_eq!(t_names, shuffled.target_names());
1077 }
1078}