math/
sample.rs

1//! # Blanket implementations for online sampling algorithms
2
3use rand::distributions::{Distribution, Uniform};
4
5use crate::{
6    set::traits::Finite,
7    traits::{Collecting, ToIterator},
8};
9
10pub mod trait_impl;
11
12pub trait Sample<'a, I: Iterator<Item = E>, E, O: Collecting<E> + Default>:
13    Finite + ToIterator<'a, I, E> {
14    /// samples `size` elements without replacement
15    /// `size`: the number of samples to be drawn
16    /// returns Err if `size` is larger than the population size
17    fn sample_subset_without_replacement<'s: 'a>(
18        &'s self,
19        size: usize,
20    ) -> Result<O, String> {
21        let mut remaining = self.size();
22        if size > remaining {
23            return Err(format!(
24                "desired sample size {} > population size {}",
25                size, remaining
26            ));
27        }
28        let mut samples = O::default();
29        let mut needed = size;
30        let mut rng = rand::thread_rng();
31        let uniform = Uniform::new(0., 1.);
32
33        for element in self.to_iter() {
34            if uniform.sample(&mut rng) <= (needed as f64 / remaining as f64) {
35                samples.collect(element);
36                needed -= 1;
37            }
38            remaining -= 1;
39        }
40        Ok(samples)
41    }
42
43    fn sample_with_replacement<'s: 'a>(
44        &'s self,
45        size: usize,
46    ) -> Result<O, String> {
47        let population_size = self.size();
48        if population_size == 0 {
49            return Err(
50                "cannot sample from a population of 0 elements".to_string()
51            );
52        }
53        let mut samples = O::default();
54        let mut rng = rand::thread_rng();
55        let uniform = Uniform::new(0., population_size as f64);
56        for _ in 0..size {
57            samples.collect(
58                self.to_iter()
59                    .nth(uniform.sample(&mut rng) as usize)
60                    .unwrap(),
61            );
62        }
63        Ok(samples)
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use crate::set::{
70        contiguous_integer_set::ContiguousIntegerSet,
71        ordered_integer_set::OrderedIntegerSet, traits::Finite,
72    };
73
74    use super::Sample;
75
76    #[test]
77    fn test_sampling_without_replacement() {
78        let interval = ContiguousIntegerSet::new(0, 100);
79        let num_samples = 25;
80        let samples = interval
81            .sample_subset_without_replacement(num_samples)
82            .unwrap();
83        assert_eq!(samples.size(), num_samples);
84
85        let set =
86            OrderedIntegerSet::from_slice(&[[-89, -23], [-2, 100], [300, 345]]);
87        let num_samples = 18;
88        let samples =
89            set.sample_subset_without_replacement(num_samples).unwrap();
90        assert_eq!(samples.size(), num_samples);
91    }
92
93    #[test]
94    fn test_sampling_with_replacement() {
95        let num_samples = 25;
96        let v = vec![1];
97        let samples = v.sample_with_replacement(num_samples);
98        assert_eq!(samples, Ok(vec![1; num_samples]));
99        assert!(Vec::<f32>::new()
100            .sample_with_replacement(num_samples)
101            .is_err());
102    }
103}