linfa/dataset/impl_dataset.rs
1use super::{
2 super::traits::{Predict, PredictInplace},
3 iter::{ChunksIter, DatasetIter, Iter},
4 AsSingleTargets, AsTargets, AsTargetsMut, CountedTargets, Dataset, DatasetBase, DatasetView,
5 Float, FromTargetArray, FromTargetArrayOwned, Label, Labels, Records, Result, TargetDim,
6};
7use crate::traits::Fit;
8use ndarray::{concatenate, prelude::*, Data, DataMut, Dimension};
9use rand::{seq::SliceRandom, Rng};
10use std::collections::HashMap;
11use std::ops::AddAssign;
12
13/// Implementation without constraints on records and targets
14///
15/// This implementation block provides methods for the creation and mutation of datasets. This
16/// includes swapping the targets, return the records etc.
17impl<R: Records, S> DatasetBase<R, S> {
18 /// Create a new dataset from records and targets
19 ///
20 /// # Example
21 ///
22 /// ```ignore
23 /// let dataset = Dataset::new(records, targets);
24 /// ```
25 pub fn new(records: R, targets: S) -> DatasetBase<R, S> {
26 let targets = targets;
27
28 DatasetBase {
29 records,
30 targets,
31 weights: Array1::zeros(0),
32 feature_names: Vec::new(),
33 target_names: Vec::new(),
34 }
35 }
36
37 /// Returns reference to targets
38 pub fn targets(&self) -> &S {
39 &self.targets
40 }
41
42 /// Returns optionally weights
43 pub fn weights(&self) -> Option<&[f32]> {
44 if !self.weights.is_empty() {
45 Some(self.weights.as_slice().unwrap())
46 } else {
47 None
48 }
49 }
50
51 /// Return a single weight
52 ///
53 /// The weight of the `idx`th observation is returned. If no weight is specified, then all
54 /// observations are unweighted with default value `1.0`.
55 pub fn weight_for(&self, idx: usize) -> f32 {
56 self.weights.get(idx).copied().unwrap_or(1.0)
57 }
58
59 /// Returns feature names
60 ///
61 /// A feature name gives a human-readable string describing the purpose of a single feature.
62 /// This allow the reader to understand its purpose while analysing results, for example
63 /// correlation analysis or feature importance.
64 pub fn feature_names(&self) -> &[String] {
65 &self.feature_names
66 }
67
68 /// Return records of a dataset
69 ///
70 /// The records are data points from which predictions are made. This functions returns a
71 /// reference to the record field.
72 pub fn records(&self) -> &R {
73 &self.records
74 }
75
76 /// Updates the records of a dataset
77 ///
78 /// This function overwrites the records in a dataset. It also invalidates the weights and
79 /// feature/target names.
80 pub fn with_records<T: Records>(self, records: T) -> DatasetBase<T, S> {
81 DatasetBase {
82 records,
83 targets: self.targets,
84 weights: Array1::zeros(0),
85 feature_names: Vec::new(),
86 target_names: Vec::new(),
87 }
88 }
89
90 /// Updates the targets of a dataset
91 ///
92 /// This function overwrites the targets in a dataset.
93 pub fn with_targets<T>(self, targets: T) -> DatasetBase<R, T> {
94 DatasetBase {
95 records: self.records,
96 targets,
97 weights: self.weights,
98 feature_names: self.feature_names,
99 target_names: self.target_names,
100 }
101 }
102
103 /// Updates the weights of a dataset
104 pub fn with_weights(mut self, weights: Array1<f32>) -> DatasetBase<R, S> {
105 self.weights = weights;
106
107 self
108 }
109
110 /// Updates the feature names of a dataset
111 ///
112 /// **Panics** when given names not empty and length does not equal to the number of features
113 pub fn with_feature_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, S> {
114 assert!(
115 names.is_empty() || names.len() == self.nfeatures(),
116 "Wrong number of feature names"
117 );
118 self.feature_names = names.into_iter().map(|x| x.into()).collect();
119 self
120 }
121}
122
123impl<X, Y> Dataset<X, Y> {
124 // Convert 2D targets to 1D. Only works for targets with shape of form [X, 1], panics otherwise.
125 pub fn into_single_target(self) -> Dataset<X, Y, Ix1> {
126 let nsamples = self.records.nsamples();
127 let targets = self.targets.into_shape_with_order(nsamples).unwrap();
128 let features = self.records;
129 Dataset::new(features, targets)
130 }
131}
132
133impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
134 /// Updates the target names of a dataset
135 ///
136 /// **Panics** when given names not empty and length does not equal to the number of targets
137 pub fn with_target_names<I: Into<String>>(mut self, names: Vec<I>) -> DatasetBase<R, T> {
138 assert!(
139 names.is_empty() || names.len() == self.ntargets(),
140 "Wrong number of target names"
141 );
142 self.target_names = names.into_iter().map(|x| x.into()).collect();
143 self
144 }
145
146 /// Map targets with a function `f`
147 ///
148 /// # Example
149 ///
150 /// ```
151 /// let dataset = linfa_datasets::winequality()
152 /// .map_targets(|x| *x > 6);
153 ///
154 /// // dataset has now boolean targets
155 /// println!("{:?}", dataset.targets());
156 /// ```
157 ///
158 /// # Returns
159 ///
160 /// A modified dataset with new target type.
161 ///
162 pub fn map_targets<S, G: FnMut(&L) -> S>(self, fnc: G) -> DatasetBase<R, Array<S, T::Ix>> {
163 let DatasetBase {
164 records,
165 targets,
166 weights,
167 feature_names,
168 target_names,
169 ..
170 } = self;
171
172 let targets = targets.as_targets();
173
174 DatasetBase {
175 records,
176 targets: targets.map(fnc),
177 weights,
178 feature_names,
179 target_names,
180 }
181 }
182
183 /// Returns target names
184 ///
185 /// A target name gives a human-readable string describing the purpose of a single target.
186 pub fn target_names(&self) -> &[String] {
187 &self.target_names
188 }
189
190 /// Return the number of targets in the dataset
191 ///
192 /// # Example
193 ///
194 /// ```
195 /// let dataset = linfa_datasets::winequality();
196 ///
197 /// println!("#targets {}", dataset.ntargets());
198 /// ```
199 ///
200 pub fn ntargets(&self) -> usize {
201 if T::Ix::NDIM.unwrap() == 1 {
202 1
203 } else {
204 self.targets.as_targets().len_of(Axis(1))
205 }
206 }
207}
208
209impl<'a, F, L, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
210where
211 D: Data<Elem = F>,
212 T: AsTargets<Elem = L>,
213{
214 /// Iterate over observations
215 ///
216 /// This function creates an iterator which produces tuples of data points and target value. The
217 /// iterator runs once for each data point and, while doing so, holds an reference to the owned
218 /// dataset.
219 ///
220 /// For multi-target datasets, the yielded target value is `ArrayView1` consisting of the
221 /// different targets. For single-target datasets, the target value is `ArrayView0` containing
222 /// the single target.
223 ///
224 /// # Example
225 /// ```
226 /// let dataset = linfa_datasets::iris();
227 ///
228 /// for (x, y) in dataset.sample_iter() {
229 /// println!("{} => {}", x, y);
230 /// }
231 /// ```
232 pub fn sample_iter(&'a self) -> Iter<'a, 'a, F, T::Elem, T::Ix> {
233 Iter::new(self.records.view(), self.targets.as_targets())
234 }
235}
236
237impl<'a, F: 'a, L: 'a, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
238where
239 D: Data<Elem = F>,
240 T: AsTargets<Elem = L> + FromTargetArray<'a>,
241 T::View: AsTargets<Elem = L>,
242{
243 /// Creates a view of a dataset
244 pub fn view(&'a self) -> DatasetBase<ArrayView2<'a, F>, T::View> {
245 let records = self.records().view();
246 let targets = T::new_targets_view(self.as_targets());
247
248 DatasetBase::new(records, targets)
249 .with_feature_names(self.feature_names.clone())
250 .with_weights(self.weights.clone())
251 .with_target_names(self.target_names.clone())
252 }
253
254 /// Iterate over features
255 ///
256 /// This iterator produces dataset views with only a single feature, while the set of targets remain
257 /// complete. It can be useful to compare each feature individual to all targets.
258 pub fn feature_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase<D, Ix2>, T> {
259 DatasetIter::new(self, true)
260 }
261
262 /// Iterate over targets
263 ///
264 /// This functions creates an iterator which produces dataset views complete records, but only
265 /// a single target each. Useful to train multiple single target models for a multi-target
266 /// dataset.
267 pub fn target_iter(&'a self) -> DatasetIter<'a, 'a, ArrayBase<D, Ix2>, T> {
268 DatasetIter::new(self, false)
269 }
270}
271
272impl<L, R: Records, T: AsTargets<Elem = L>> AsTargets for DatasetBase<R, T> {
273 type Elem = L;
274 type Ix = T::Ix;
275
276 fn as_targets(&self) -> ArrayView<'_, Self::Elem, Self::Ix> {
277 self.targets.as_targets()
278 }
279}
280
281impl<L, R: Records, T: AsTargetsMut<Elem = L>> AsTargetsMut for DatasetBase<R, T> {
282 type Elem = L;
283 type Ix = T::Ix;
284
285 fn as_targets_mut(&mut self) -> ArrayViewMut<'_, Self::Elem, Self::Ix> {
286 self.targets.as_targets_mut()
287 }
288}
289
290#[allow(clippy::type_complexity)]
291impl<'a, L: 'a, F, T> DatasetBase<ArrayView2<'a, F>, T>
292where
293 T: AsTargets<Elem = L> + FromTargetArray<'a>,
294 T::View: AsTargets<Elem = L>,
295{
296 /// Split dataset into two disjoint chunks
297 ///
298 /// This function splits the observations in a dataset into two disjoint chunks. The splitting
299 /// threshold is calculated with the `ratio`. For example a ratio of `0.9` allocates 90% to the
300 /// first chunks and 9% to the second. This is often used in training, validation splitting
301 /// procedures.
302 pub fn split_with_ratio(
303 &'a self,
304 ratio: f32,
305 ) -> (
306 DatasetBase<ArrayView2<'a, F>, T::View>,
307 DatasetBase<ArrayView2<'a, F>, T::View>,
308 ) {
309 let n = (self.nsamples() as f32 * ratio).ceil() as usize;
310 let (records_first, records_second) = self.records.view().split_at(Axis(0), n);
311 let (targets_first, targets_second) = self.targets.as_targets().split_at(Axis(0), n);
312
313 let targets_first = T::new_targets_view(targets_first);
314 let targets_second = T::new_targets_view(targets_second);
315
316 let (first_weights, second_weights) = if self.weights.len() == self.nsamples() {
317 let a = self.weights.slice(s![..n]).to_vec();
318 let b = self.weights.slice(s![n..]).to_vec();
319
320 (Array1::from(a), Array1::from(b))
321 } else {
322 (Array1::zeros(0), Array1::zeros(0))
323 };
324 let dataset1 = DatasetBase::new(records_first, targets_first)
325 .with_weights(first_weights)
326 .with_feature_names(self.feature_names.clone())
327 .with_target_names(self.target_names.clone());
328
329 let dataset2 = DatasetBase::new(records_second, targets_second)
330 .with_weights(second_weights)
331 .with_feature_names(self.feature_names.clone())
332 .with_target_names(self.target_names.clone());
333
334 (dataset1, dataset2)
335 }
336}
337
338impl<L: Label, T: Labels<Elem = L>, R: Records> Labels for DatasetBase<R, T> {
339 type Elem = L;
340
341 fn label_count(&self) -> Vec<HashMap<L, usize>> {
342 self.targets().label_count()
343 }
344}
345
346#[allow(clippy::type_complexity)]
347impl<F, L: Label, T, D> DatasetBase<ArrayBase<D, Ix2>, T>
348where
349 D: Data<Elem = F>,
350 T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
351{
352 /// Produce N boolean targets from multi-class targets
353 ///
354 /// Some algorithms (like SVM) don't support multi-class targets. This function splits a
355 /// dataset into multiple binary single-target views of the same dataset.
356 pub fn one_vs_all(
357 &self,
358 ) -> Result<
359 Vec<(
360 L,
361 DatasetBase<ArrayView2<'_, F>, CountedTargets<bool, Array1<bool>>>,
362 )>,
363 > {
364 let targets = self.targets().as_single_targets();
365
366 Ok(self
367 .labels()
368 .into_iter()
369 .map(|label| {
370 let targets = targets.iter().map(|x| x == &label).collect::<Array1<_>>();
371
372 let targets = CountedTargets::new(targets);
373
374 (
375 label,
376 DatasetBase::new(self.records().view(), targets)
377 .with_feature_names(self.feature_names.clone())
378 .with_weights(self.weights.clone())
379 .with_target_names(self.target_names.clone()),
380 )
381 })
382 .collect())
383 }
384}
385
386impl<L: Label, R: Records, S: AsTargets<Elem = L>> DatasetBase<R, S> {
387 /// Calculates label frequencies from a dataset while masking certain samples.
388 ///
389 /// ### Parameters
390 ///
391 /// * `mask`: a boolean array that specifies which samples to include in the count
392 ///
393 /// ### Returns
394 ///
395 /// A mapping of the Dataset's samples to their frequencies
396 pub fn label_frequencies_with_mask(&self, mask: &[bool]) -> HashMap<L, f32> {
397 let mut freqs = HashMap::new();
398
399 for (elms, val) in self
400 .targets
401 .as_targets()
402 .axis_iter(Axis(0))
403 .enumerate()
404 .filter(|(i, _)| *mask.get(*i).unwrap_or(&true))
405 .map(|(i, x)| (x, self.weight_for(i)))
406 {
407 for elm in elms {
408 if !freqs.contains_key(elm) {
409 freqs.insert(elm.clone(), 0.0);
410 }
411
412 *freqs.get_mut(elm).unwrap() += val;
413 }
414 }
415
416 freqs
417 }
418
419 /// Calculates label frequencies from a dataset
420 pub fn label_frequencies(&self) -> HashMap<L, f32> {
421 self.label_frequencies_with_mask(&[])
422 }
423}
424
425impl<F, D: Data<Elem = F>, I: Dimension> From<ArrayBase<D, I>>
426 for DatasetBase<ArrayBase<D, I>, Array1<()>>
427{
428 fn from(records: ArrayBase<D, I>) -> Self {
429 let empty_targets = Array1::default(records.len_of(Axis(0)));
430 DatasetBase {
431 records,
432 targets: empty_targets,
433 weights: Array1::zeros(0),
434 feature_names: Vec::new(),
435 target_names: Vec::new(),
436 }
437 }
438}
439
440impl<F, E, D, S, I: TargetDim> From<(ArrayBase<D, Ix2>, ArrayBase<S, I>)>
441 for DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, I>>
442where
443 D: Data<Elem = F>,
444 S: Data<Elem = E>,
445{
446 fn from(rec_tar: (ArrayBase<D, Ix2>, ArrayBase<S, I>)) -> Self {
447 DatasetBase {
448 records: rec_tar.0,
449 targets: rec_tar.1,
450 weights: Array1::zeros(0),
451 feature_names: Vec::new(),
452 target_names: Vec::new(),
453 }
454 }
455}
456
457impl<'b, F: Clone, E: Copy + 'b, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
458where
459 D: Data<Elem = F>,
460 T: FromTargetArrayOwned<Elem = E>,
461 T::Owned: AsTargets,
462{
463 /// Apply bootstrapping for samples and features
464 ///
465 /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
466 /// stability of machine learning algorithms. It samples data uniformly with replacement and
467 /// generates datasets where elements may be shared. This selects a subset of observations as
468 /// well as features.
469 ///
470 /// # Parameters
471 ///
472 /// * `sample_feature_size`: The number of samples and features per bootstrap
473 /// * `rng`: The random number generator used in the sampling procedure
474 ///
475 /// # Returns
476 ///
477 /// An infinite Iterator yielding at each step a new bootstrapped dataset
478 ///
479 pub fn bootstrap<R: Rng>(
480 &'b self,
481 sample_feature_size: (usize, usize),
482 rng: &'b mut R,
483 ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
484 std::iter::repeat(()).map(move |_| {
485 // sample with replacement
486 let indices = (0..sample_feature_size.0)
487 .map(|_| rng.gen_range(0..self.nsamples()))
488 .collect::<Vec<_>>();
489
490 let records = self.records().select(Axis(0), &indices);
491 let targets = T::new_targets(self.as_targets().select(Axis(0), &indices));
492
493 let indices = (0..sample_feature_size.1)
494 .map(|_| rng.gen_range(0..self.nfeatures()))
495 .collect::<Vec<_>>();
496
497 let records = records.select(Axis(1), &indices);
498
499 DatasetBase::new(records, targets)
500 })
501 }
502
503 /// Apply bootstrapping for samples and features
504 ///
505 /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
506 /// stability of machine learning algorithms. It samples data uniformly with replacement and
507 /// generates datasets where elements may be shared. This selects a subset of observations as
508 /// well as features.
509 ///
510 /// # Parameters
511 ///
512 /// * `sample_feature_size`: The number of samples and features per bootstrap
513 /// * `rng`: The random number generator used in the sampling procedure
514 ///
515 /// # Returns
516 ///
517 /// An infinite Iterator yielding at each step a tuple containing a bootstrapped dataset with
518 /// a vector of the sampled data indices and sampled feature.
519 ///
520 #[allow(clippy::type_complexity)]
521 pub fn bootstrap_with_indices<R: Rng>(
522 &'b self,
523 sample_feature_size: (usize, usize),
524 rng: &'b mut R,
525 ) -> impl Iterator<Item = (DatasetBase<Array2<F>, T::Owned>, Vec<usize>, Vec<usize>)> + 'b {
526 std::iter::repeat(()).map(move |_| {
527 // sample with replacement
528 let data_indices = (0..sample_feature_size.0)
529 .map(|_| rng.gen_range(0..self.nsamples()))
530 .collect::<Vec<_>>();
531
532 let records = self.records().select(Axis(0), &data_indices);
533 let targets = T::new_targets(self.as_targets().select(Axis(0), &data_indices));
534
535 let feat_indices = (0..sample_feature_size.1)
536 .map(|_| rng.gen_range(0..self.nfeatures()))
537 .collect::<Vec<_>>();
538
539 let records = records.select(Axis(1), &feat_indices);
540
541 (
542 DatasetBase::new(records, targets),
543 data_indices,
544 feat_indices,
545 )
546 })
547 }
548
549 /// Apply sample bootstrapping
550 ///
551 /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
552 /// stability of machine learning algorithms. It samples data uniformly with replacement and
553 /// generates datasets where elements may be shared. Only a sample subset is selected which
554 /// retains all features and targets.
555 ///
556 /// # Parameters
557 ///
558 /// * `num_samples`: The number of samples per bootstrap
559 /// * `rng`: The random number generator used in the sampling procedure
560 ///
561 /// # Returns
562 ///
563 /// An infinite Iterator yielding at each step a new bootstrapped dataset
564 ///
565 pub fn bootstrap_samples<R: Rng>(
566 &'b self,
567 num_samples: usize,
568 rng: &'b mut R,
569 ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
570 std::iter::repeat(()).map(move |_| {
571 // sample with replacement
572 let indices = (0..num_samples)
573 .map(|_| rng.gen_range(0..self.nsamples()))
574 .collect::<Vec<_>>();
575
576 let records = self.records().select(Axis(0), &indices);
577 let targets = T::new_targets(self.as_targets().select(Axis(0), &indices));
578
579 DatasetBase::new(records, targets)
580 })
581 }
582
583 /// Apply sample bootstrapping
584 ///
585 /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
586 /// stability of machine learning algorithms. It samples data uniformly with replacement and
587 /// generates datasets where elements may be shared. Only a sample subset is selected which
588 /// retains all features and targets.
589 ///
590 /// # Parameters
591 ///
592 /// * `num_samples`: The number of samples per bootstrap
593 /// * `rng`: The random number generator used in the sampling procedure
594 ///
595 /// # Returns
596 ///
597 /// An infinite Iterator yielding at each step a new bootstrapped dataset and the sampled
598 /// indices.
599 ///
600 pub fn bootstrap_samples_with_indices<R: Rng>(
601 &'b self,
602 num_samples: usize,
603 rng: &'b mut R,
604 ) -> impl Iterator<Item = (DatasetBase<Array2<F>, T::Owned>, Vec<usize>)> + 'b {
605 std::iter::repeat(()).map(move |_| {
606 // sample with replacement
607 let indices = (0..num_samples)
608 .map(|_| rng.gen_range(0..self.nsamples()))
609 .collect::<Vec<_>>();
610
611 let records = self.records().select(Axis(0), &indices);
612 let targets = T::new_targets(self.as_targets().select(Axis(0), &indices));
613
614 (DatasetBase::new(records, targets), indices)
615 })
616 }
617
618 /// Apply feature bootstrapping
619 ///
620 /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
621 /// stability of machine learning algorithms. It samples data uniformly with replacement and
622 /// generates datasets where elements may be shared. Only a feature subset is selected while
623 /// retaining all samples and targets.
624 ///
625 /// # Parameters
626 ///
627 /// * `num_features`: The number of features per bootstrap
628 /// * `rng`: The random number generator used in the sampling procedure
629 ///
630 /// # Returns
631 ///
632 /// An infinite Iterator yielding at each step a new bootstrapped dataset
633 ///
634 pub fn bootstrap_features<R: Rng>(
635 &'b self,
636 num_features: usize,
637 rng: &'b mut R,
638 ) -> impl Iterator<Item = DatasetBase<Array2<F>, T::Owned>> + 'b {
639 std::iter::repeat(()).map(move |_| {
640 let targets = T::new_targets(self.as_targets().to_owned());
641
642 let indices = (0..num_features)
643 .map(|_| rng.gen_range(0..self.nfeatures()))
644 .collect::<Vec<_>>();
645
646 let records = self.records.select(Axis(1), &indices);
647
648 DatasetBase::new(records, targets)
649 })
650 }
651
652 /// Apply feature bootstrapping
653 ///
654 /// Bootstrap aggregating is used for sub-sample generation and improves the accuracy and
655 /// stability of machine learning algorithms. It samples data uniformly with replacement and
656 /// generates datasets where elements may be shared. Only a feature subset is selected while
657 /// retaining all samples and targets.
658 ///
659 /// # Parameters
660 ///
661 /// * `num_features`: The number of features per bootstrap
662 /// * `rng`: The random number generator used in the sampling procedure
663 ///
664 /// # Returns
665 ///
666 /// An infinite Iterator yielding at each step a new bootstrapped dataset with the indices of
667 /// the features sampled
668 ///
669 pub fn bootstrap_features_with_indices<R: Rng>(
670 &'b self,
671 num_features: usize,
672 rng: &'b mut R,
673 ) -> impl Iterator<Item = (DatasetBase<Array2<F>, T::Owned>, Vec<usize>)> + 'b {
674 std::iter::repeat(()).map(move |_| {
675 let targets = T::new_targets(self.as_targets().to_owned());
676
677 let indices = (0..num_features)
678 .map(|_| rng.gen_range(0..self.nfeatures()))
679 .collect::<Vec<_>>();
680
681 let records = self.records.select(Axis(1), &indices);
682
683 (DatasetBase::new(records, targets), indices)
684 })
685 }
686
687 /// Produces a shuffled version of the current Dataset.
688 ///
689 /// ### Parameters
690 ///
691 /// * `rng`: the random number generator that will be used to shuffle the samples
692 ///
693 /// ### Returns
694 ///
695 /// A new shuffled version of the current Dataset
696 ///
697 pub fn shuffle<R: Rng>(&self, rng: &mut R) -> DatasetBase<Array2<F>, T::Owned> {
698 let mut indices = (0..self.nsamples()).collect::<Vec<_>>();
699 indices.shuffle(rng);
700
701 let records = self.records().select(Axis(0), &indices);
702 let targets = self.as_targets().select(Axis(0), &indices);
703 let targets = T::new_targets(targets);
704
705 DatasetBase::new(records, targets)
706 .with_feature_names(self.feature_names().to_vec())
707 .with_target_names(self.target_names().to_vec())
708 }
709
710 #[allow(clippy::type_complexity)]
711 /// Performs K-folding on the dataset.
712 ///
713 /// The dataset is divided into `k` "folds", each containing `(dataset size)/k` samples, used
714 /// to generate `k` training-validation dataset pairs. Each pair contains a validation
715 /// `Dataset` with `k` samples, the ones contained in the i-th fold, and a training `Dataset`
716 /// composed by the union of all the samples in the remaining folds.
717 ///
718 /// ### Parameters
719 ///
720 /// * `k`: the number of folds to apply
721 ///
722 /// ### Returns
723 ///
724 /// A vector of `k` training-validation Dataset pairs.
725 ///
726 /// ### Example
727 ///
728 /// ```rust
729 /// use linfa::dataset::DatasetView;
730 /// use ndarray::{Ix1, array};
731 ///
732 /// let records = array![[1.,1.], [2.,1.], [3.,2.], [4.,1.],[5., 3.], [6.,2.]];
733 /// let targets = array![1, 1, 0, 1, 0, 0];
734 ///
735 /// let dataset : DatasetView<f64, usize, Ix1> = (records.view(), targets.view()).into();
736 /// let accuracies = dataset.fold(3).into_iter().map(|(train, valid)| {
737 /// // Here you can train your model and perform validation
738 ///
739 /// // let model = params.fit(&dataset);
740 /// // let predi = model.predict(&valid);
741 /// // predi.confusion_matrix(&valid).accuracy()
742 /// });
743 /// ```
744 ///
745 pub fn fold(
746 &self,
747 k: usize,
748 ) -> Vec<(
749 DatasetBase<Array2<F>, T::Owned>,
750 DatasetBase<Array2<F>, T::Owned>,
751 )> {
752 let targets = self.as_targets();
753 let fold_size = targets.len() / k;
754
755 // Generates all k folds of records and targets
756 let mut records_chunks: Vec<_> =
757 self.records.axis_chunks_iter(Axis(0), fold_size).collect();
758 let mut targets_chunks: Vec<_> = targets.axis_chunks_iter(Axis(0), fold_size).collect();
759
760 let mut res = Vec::with_capacity(k);
761 // For each iteration, take the first chunk for both records and targets as the validation set and
762 // concatenate all the other chunks to create the training set. In the end swap the first chunk with the
763 // one in the next index so that it is ready for the next iteration
764 for i in 0..k {
765 let remaining_records = concatenate(Axis(0), &records_chunks.as_slice()[1..]).unwrap();
766 let remaining_targets = concatenate(Axis(0), &targets_chunks.as_slice()[1..]).unwrap();
767
768 res.push((
769 // training
770 DatasetBase::new(remaining_records, T::new_targets(remaining_targets)),
771 // validation
772 DatasetBase::new(
773 records_chunks[0].into_owned(),
774 T::new_targets(targets_chunks[0].clone().into_owned()),
775 ),
776 ));
777
778 // swap
779 if i < k - 1 {
780 records_chunks.swap(0, i + 1);
781 targets_chunks.swap(0, i + 1);
782 }
783 }
784 res
785 }
786
787 pub fn sample_chunks<'a: 'b>(&'b self, chunk_size: usize) -> ChunksIter<'b, 'a, F, T> {
788 ChunksIter::new(self.records().view(), &self.targets, chunk_size, Axis(0))
789 }
790
791 pub fn to_owned(&self) -> DatasetBase<Array2<F>, T::Owned> {
792 DatasetBase::new(
793 self.records().to_owned(),
794 T::new_targets(self.as_targets().to_owned()),
795 )
796 }
797}
798
799macro_rules! assist_swap_array2 {
800 ($slice: expr, $index: expr, $fold_size: expr, $features: expr) => {
801 if $index != 0 {
802 let adj_fold_size = $fold_size * $features;
803 let start = adj_fold_size * $index;
804 let (first_s, second_s) = $slice.split_at_mut(start);
805 let (mut fold, _) = second_s.split_at_mut(adj_fold_size);
806 first_s[..$fold_size * $features].swap_with_slice(&mut fold);
807 }
808 };
809}
810
811impl<'a, F: 'a + Clone, E: Copy + 'a, D, S, I: TargetDim>
812 DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, I>>
813where
814 D: DataMut<Elem = F>,
815 S: DataMut<Elem = E>,
816{
817 /// Performs k-folding cross validation on fittable algorithms.
818 ///
819 /// Given a dataset as input, a value of k and the desired params for the fittable
820 /// algorithm, returns an iterator over the k trained models and the
821 /// associated validation set.
822 ///
823 /// The models are trained according to a closure specified
824 /// as an input.
825 ///
826 /// ## Parameters
827 ///
828 /// - `k`: the number of folds to apply to the dataset
829 /// - `params`: the desired parameters for the fittable algorithm at hand
830 /// - `fit_closure`: a closure of the type `(params, training_data) -> fitted_model`
831 /// that will be used to produce the trained model for each fold. The training data given in input
832 /// won't outlive the closure.
833 ///
834 /// ## Returns
835 ///
836 /// An iterator over couples `(trained_model, validation_set)`.
837 ///
838 /// ## Panics
839 ///
840 /// This method will panic for any of the following three reasons:
841 ///
842 /// - The value of `k` provided is not positive;
843 /// - The value of `k` provided is greater than the total number of samples in the dataset;
844 /// - The dataset's data is not stored contiguously and in standard order;
845 ///
846 /// ## Example
847 /// ```rust
848 /// use linfa::traits::Fit;
849 /// use linfa::dataset::{Dataset, DatasetView, Records};
850 /// use ndarray::{array, ArrayView1, ArrayView2, Ix1};
851 /// use linfa::Error;
852 ///
853 /// struct MockFittable {}
854 ///
855 /// struct MockFittableResult {
856 /// mock_var: usize,
857 /// }
858 ///
859 /// impl<'a> Fit<ArrayView2<'a,f64>, ArrayView1<'a, f64>, linfa::error::Error> for MockFittable {
860 /// type Object = MockFittableResult;
861 ///
862 /// fn fit(&self, training_data: &DatasetView<f64, f64, Ix1>) -> Result<Self::Object, linfa::error::Error> {
863 /// Ok(MockFittableResult {
864 /// mock_var: training_data.nsamples(),
865 /// })
866 /// }
867 /// }
868 ///
869 /// let records = array![[1.,1.], [2.,2.], [3.,3.], [4.,4.], [5.,5.]];
870 /// let targets = array![1.,2.,3.,4.,5.];
871 /// let mut dataset: Dataset<f64, f64, Ix1> = (records, targets).into();
872 /// let params = MockFittable {};
873 ///
874 /// for (model,validation_set) in dataset.iter_fold(5, |v| params.fit(v).unwrap()){
875 /// // Here you can use `model` and `validation_set` to
876 /// // assert the performance of the chosen algorithm
877 /// }
878 /// ```
879 pub fn iter_fold<O, C: Fn(&DatasetView<F, E, I>) -> O>(
880 &'a mut self,
881 k: usize,
882 fit_closure: C,
883 ) -> impl Iterator<Item = (O, DatasetBase<ArrayView2<'a, F>, ArrayView<'a, E, I>>)> {
884 assert!(k > 0);
885 assert!(k <= self.nsamples());
886 let samples_count = self.nsamples();
887 let fold_size = samples_count / k;
888
889 let features = self.nfeatures();
890 let targets = self.ntargets();
891 let tshape = self.targets.raw_dim();
892
893 let mut objs: Vec<O> = Vec::with_capacity(k);
894
895 {
896 let records_sl = self.records.as_slice_mut().unwrap();
897 let mut targets_sl2 = self.targets.as_targets_mut();
898 let targets_sl = targets_sl2.as_slice_mut().unwrap();
899
900 for i in 0..k {
901 assist_swap_array2!(records_sl, i, fold_size, features);
902 assist_swap_array2!(targets_sl, i, fold_size, targets);
903
904 {
905 let train = DatasetBase::new(
906 ArrayView2::from_shape(
907 (samples_count - fold_size, features),
908 records_sl.split_at(fold_size * features).1,
909 )
910 .unwrap(),
911 ArrayView::from_shape(
912 tshape.clone().nsamples(samples_count - fold_size),
913 targets_sl.split_at(fold_size * targets).1,
914 )
915 .unwrap(),
916 );
917
918 let obj = fit_closure(&train);
919 objs.push(obj);
920 }
921
922 assist_swap_array2!(records_sl, i, fold_size, features);
923 assist_swap_array2!(targets_sl, i, fold_size, targets);
924 }
925 }
926
927 objs.into_iter().zip(self.sample_chunks(fold_size))
928 }
929
930 /// Cross validation for single and multi-target algorithms
931 ///
932 /// Given a list of fittable models, cross validation is used to compare their performance
933 /// according to some performance metric. To do so, k-folding is applied to the dataset and,
934 /// for each fold, each model is trained on the training set and its performance is evaluated
935 /// on the validation set. The performances collected for each model are then averaged over the
936 /// number of folds.
937 ///
938 /// For single-target datasets, [`Dataset::cross_validate_single`] is recommended.
939 ///
940 /// ### Parameters:
941 ///
942 /// - `k`: the number of folds to apply
943 /// - `parameters`: a list of models to compare
944 /// - `eval`: closure used to evaluate the performance of each trained model. This closure is
945 /// called on the model output and validation targets of each fold and outputs the performance
946 /// score for each target. For single-target dataset the signature is `(Array1, Array1) ->
947 /// Array0`. For multi-target dataset the signature is `(Array2, Array2) -> Array1`.
948 ///
949 /// ### Returns
950 ///
951 /// An array of model performances, for each model and each target, if no errors occur.
952 /// For multi-target dataset, the array has dimensions `(n_models, n_targets)`. For
953 /// single-target dataset, the array has dimensions `(n_models)`.
954 /// Otherwise, it might return an Error in one of the following cases:
955 ///
956 /// - An error occurred during the fitting of one model
957 /// - An error occurred inside the evaluation closure
958 ///
959 /// ### Example
960 ///
961 /// ```rust, ignore
962 ///
963 /// use linfa::prelude::*;
964 /// use ndarray::arr0;
965 /// # use ndarray::{array, ArrayView1, ArrayView2, Ix1};
966 ///
967 /// # struct MockFittable {}
968 ///
969 /// # struct MockFittableResult {
970 /// # mock_var: usize,
971 /// # }
972 ///
973 /// # impl<'a> Fit<ArrayView2<'a,f64>, ArrayView1<'a, f64>, linfa::error::Error> for MockFittable {
974 /// # type Object = MockFittableResult;
975 ///
976 /// # fn fit(&self, training_data: &DatasetView<f64, f64, Ix1>) -> Result<Self::Object, linfa::error::Error> {
977 /// # Ok(MockFittableResult {
978 /// # mock_var: training_data.nsamples(),
979 /// # })
980 /// # }
981 /// # }
982 ///
983 /// # let model1 = MockFittable {};
984 /// # let model2 = MockFittable {};
985 ///
986 /// // mutability needed for fast cross validation
987 /// let mut dataset = linfa_datasets::diabetes();
988 ///
989 /// let models = vec![model1, model2];
990 ///
991 /// let r2_scores = dataset.cross_validate(5, &models, |prediction, truth| prediction.r2(truth).map(arr0))?;
992 ///
993 /// ```
994 pub fn cross_validate<O, ER, M, FACC, C>(
995 &'a mut self,
996 k: usize,
997 parameters: &[M],
998 eval: C,
999 ) -> std::result::Result<Array<FACC, I>, ER>
1000 where
1001 ER: std::error::Error + std::convert::From<crate::error::Error>,
1002 M: for<'c> Fit<ArrayView2<'c, F>, ArrayView<'c, E, I>, ER, Object = O>,
1003 O: for<'d> PredictInplace<ArrayView2<'a, F>, Array<E, I>>,
1004 FACC: Float,
1005 C: Fn(
1006 &Array<E, I>,
1007 &ArrayView<E, I>,
1008 ) -> std::result::Result<Array<FACC, I::Smaller>, crate::error::Error>,
1009 {
1010 let mut evaluations = Array::from_elem(
1011 self.targets.raw_dim().nsamples(parameters.len()),
1012 FACC::zero(),
1013 );
1014 let folds_evaluations: std::result::Result<Vec<_>, ER> = self
1015 .iter_fold(k, |train| {
1016 let fit_result: std::result::Result<Vec<_>, ER> =
1017 parameters.iter().map(|p| p.fit(train)).collect();
1018 fit_result
1019 })
1020 .map(|(models, valid)| {
1021 let targets = valid.targets();
1022 let models = models?;
1023 // XXX diverges from master branch
1024 let mut eval_predictions =
1025 Array::from_elem(targets.raw_dim().nsamples(models.len()), FACC::zero());
1026 for (i, model) in models.iter().enumerate() {
1027 let predicted = model.predict(valid.records());
1028 let eval_pred = match eval(&predicted, targets) {
1029 Err(e) => Err(ER::from(e)),
1030 Ok(res) => Ok(res),
1031 }?;
1032 eval_predictions
1033 .index_axis_mut(Axis(0), i)
1034 .add_assign(&eval_pred);
1035 }
1036 Ok(eval_predictions)
1037 })
1038 .collect();
1039
1040 for fold_evaluation in folds_evaluations? {
1041 evaluations.add_assign(&fold_evaluation)
1042 }
1043 Ok(evaluations / FACC::from(k).unwrap())
1044 }
1045}
1046
1047impl<'a, F: 'a + Clone, E: Copy + 'a, D, S> DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, Ix1>>
1048where
1049 D: DataMut<Elem = F>,
1050 S: DataMut<Elem = E>,
1051{
1052 /// Specialized version of `cross_validate` for single-target datasets. Allows the evaluation
1053 /// closure to return a float without wrapping it in `arr0`. See [`Dataset::cross_validate`] for
1054 /// more details.
1055 pub fn cross_validate_single<O, ER, M, FACC, C>(
1056 &'a mut self,
1057 k: usize,
1058 parameters: &[M],
1059 eval: C,
1060 ) -> std::result::Result<Array1<FACC>, ER>
1061 where
1062 ER: std::error::Error + std::convert::From<crate::error::Error>,
1063 M: for<'c> Fit<ArrayView2<'c, F>, ArrayView1<'c, E>, ER, Object = O>,
1064 O: for<'d> PredictInplace<ArrayView2<'a, F>, Array1<E>>,
1065 FACC: Float,
1066 C: Fn(&Array1<E>, &ArrayView1<E>) -> std::result::Result<FACC, crate::error::Error>,
1067 {
1068 self.cross_validate(k, parameters, |a, b| eval(a, b).map(arr0))
1069 }
1070}
1071
1072impl<F, E, I: TargetDim> Dataset<F, E, I> {
1073 /// Split dataset into two disjoint chunks
1074 ///
1075 /// This function splits the observations in a dataset into two disjoint chunks. The splitting
1076 /// threshold is calculated with the `ratio`. If the input Dataset contains `n` samples then the
1077 /// two new Datasets will have respectively `n * ratio` and `n - (n*ratio)` samples.
1078 /// For example a ratio of `0.9` allocates 90% to the
1079 /// first chunks and 10% to the second. This is often used in training, validation splitting
1080 /// procedures.
1081 ///
1082 /// ### Parameters
1083 ///
1084 /// * `ratio`: the ratio of samples in the input Dataset to include in the first output one
1085 ///
1086 /// ### Returns
1087 ///
1088 /// The input Dataset split into two according to the input ratio.
1089 ///
1090 /// ### Panics
1091 ///
1092 /// Panic occurs when the input record or targets are not in row-major layout.
1093 pub fn split_with_ratio(mut self, ratio: f32) -> (Self, Self) {
1094 assert!(
1095 self.records.is_standard_layout(),
1096 "records not in row-major layout"
1097 );
1098 assert!(
1099 self.targets.is_standard_layout(),
1100 "targets not in row-major layout"
1101 );
1102
1103 let nfeatures = self.nfeatures();
1104
1105 let n1 = (self.nsamples() as f32 * ratio).ceil() as usize;
1106 let n2 = self.nsamples() - n1;
1107
1108 let feature_names = self.feature_names().to_vec();
1109 let target_names = self.target_names().to_vec();
1110
1111 // split records into two disjoint arrays
1112 let (mut array_buf, _) = self.records.into_raw_vec_and_offset();
1113 let second_array_buf = array_buf.split_off(n1 * nfeatures);
1114
1115 let first = Array2::from_shape_vec((n1, nfeatures), array_buf).unwrap();
1116 let second = Array2::from_shape_vec((n2, nfeatures), second_array_buf).unwrap();
1117
1118 // split targets into two disjoint Vec
1119 let dim1 = self.targets.raw_dim().nsamples(n1);
1120 let dim2 = self.targets.raw_dim().nsamples(n2);
1121 let (mut array_buf, _) = self.targets.into_raw_vec_and_offset();
1122 let second_array_buf = array_buf.split_off(dim1.size());
1123
1124 let first_targets = Array::from_shape_vec(dim1, array_buf).unwrap();
1125 let second_targets = Array::from_shape_vec(dim2, second_array_buf).unwrap();
1126
1127 // split weights into two disjoint Vec
1128 let second_weights = if self.weights.len() == n1 + n2 {
1129 let (mut weights, _) = self.weights.into_raw_vec_and_offset();
1130
1131 let weights2 = weights.split_off(n1);
1132 self.weights = Array1::from(weights);
1133
1134 Array1::from(weights2)
1135 } else {
1136 Array1::zeros(0)
1137 };
1138
1139 // create new datasets with attached weights
1140 let dataset1 = Dataset::new(first, first_targets)
1141 .with_weights(self.weights)
1142 .with_feature_names(feature_names.clone())
1143 .with_target_names(target_names.clone());
1144 let dataset2 = Dataset::new(second, second_targets)
1145 .with_weights(second_weights)
1146 .with_feature_names(feature_names.clone())
1147 .with_target_names(target_names.clone());
1148
1149 (dataset1, dataset2)
1150 }
1151}
1152
1153impl<F, D, E, T, O> Predict<ArrayBase<D, Ix2>, DatasetBase<ArrayBase<D, Ix2>, T>> for O
1154where
1155 D: Data<Elem = F>,
1156 T: AsTargets<Elem = E>,
1157 O: PredictInplace<ArrayBase<D, Ix2>, T>,
1158{
1159 fn predict(&self, records: ArrayBase<D, Ix2>) -> DatasetBase<ArrayBase<D, Ix2>, T> {
1160 let mut targets = self.default_target(&records);
1161 self.predict_inplace(&records, &mut targets);
1162 DatasetBase::new(records, targets)
1163 }
1164}
1165
1166impl<F, R, T, E, S, O> Predict<DatasetBase<R, T>, DatasetBase<R, S>> for O
1167where
1168 R: Records<Elem = F>,
1169 S: AsTargets<Elem = E>,
1170 O: PredictInplace<R, S>,
1171{
1172 fn predict(&self, ds: DatasetBase<R, T>) -> DatasetBase<R, S> {
1173 let mut targets = self.default_target(&ds.records);
1174 self.predict_inplace(&ds.records, &mut targets);
1175 DatasetBase::new(ds.records, targets)
1176 }
1177}
1178
1179impl<'a, F, R, T, S, O> Predict<&'a DatasetBase<R, T>, S> for O
1180where
1181 R: Records<Elem = F>,
1182 O: PredictInplace<R, S>,
1183{
1184 fn predict(&self, ds: &'a DatasetBase<R, T>) -> S {
1185 let mut targets = self.default_target(&ds.records);
1186 self.predict_inplace(&ds.records, &mut targets);
1187 targets
1188 }
1189}
1190
1191impl<'a, F, D, DM, T, O> Predict<&'a ArrayBase<D, DM>, T> for O
1192where
1193 D: Data<Elem = F>,
1194 DM: Dimension,
1195 O: PredictInplace<ArrayBase<D, DM>, T>,
1196{
1197 fn predict(&self, records: &'a ArrayBase<D, DM>) -> T {
1198 let mut targets = self.default_target(records);
1199 self.predict_inplace(records, &mut targets);
1200 targets
1201 }
1202}
1203
1204impl<L: Label, S: Labels<Elem = L>> CountedTargets<L, S> {
1205 pub fn new(targets: S) -> Self {
1206 let labels = targets.label_count();
1207
1208 CountedTargets { targets, labels }
1209 }
1210}