1use rand::distributions::{Distribution, Uniform};
4
5use crate::{
6 set::traits::Finite,
7 traits::{Collecting, ToIterator},
8};
9
10pub mod trait_impl;
11
12pub trait Sample<'a, I: Iterator<Item = E>, E, O: Collecting<E> + Default>:
13 Finite + ToIterator<'a, I, E> {
14 fn sample_subset_without_replacement<'s: 'a>(
18 &'s self,
19 size: usize,
20 ) -> Result<O, String> {
21 let mut remaining = self.size();
22 if size > remaining {
23 return Err(format!(
24 "desired sample size {} > population size {}",
25 size, remaining
26 ));
27 }
28 let mut samples = O::default();
29 let mut needed = size;
30 let mut rng = rand::thread_rng();
31 let uniform = Uniform::new(0., 1.);
32
33 for element in self.to_iter() {
34 if uniform.sample(&mut rng) <= (needed as f64 / remaining as f64) {
35 samples.collect(element);
36 needed -= 1;
37 }
38 remaining -= 1;
39 }
40 Ok(samples)
41 }
42
43 fn sample_with_replacement<'s: 'a>(
44 &'s self,
45 size: usize,
46 ) -> Result<O, String> {
47 let population_size = self.size();
48 if population_size == 0 {
49 return Err(
50 "cannot sample from a population of 0 elements".to_string()
51 );
52 }
53 let mut samples = O::default();
54 let mut rng = rand::thread_rng();
55 let uniform = Uniform::new(0., population_size as f64);
56 for _ in 0..size {
57 samples.collect(
58 self.to_iter()
59 .nth(uniform.sample(&mut rng) as usize)
60 .unwrap(),
61 );
62 }
63 Ok(samples)
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use crate::set::{
70 contiguous_integer_set::ContiguousIntegerSet,
71 ordered_integer_set::OrderedIntegerSet, traits::Finite,
72 };
73
74 use super::Sample;
75
76 #[test]
77 fn test_sampling_without_replacement() {
78 let interval = ContiguousIntegerSet::new(0, 100);
79 let num_samples = 25;
80 let samples = interval
81 .sample_subset_without_replacement(num_samples)
82 .unwrap();
83 assert_eq!(samples.size(), num_samples);
84
85 let set =
86 OrderedIntegerSet::from_slice(&[[-89, -23], [-2, 100], [300, 345]]);
87 let num_samples = 18;
88 let samples =
89 set.sample_subset_without_replacement(num_samples).unwrap();
90 assert_eq!(samples.size(), num_samples);
91 }
92
93 #[test]
94 fn test_sampling_with_replacement() {
95 let num_samples = 25;
96 let v = vec![1];
97 let samples = v.sample_with_replacement(num_samples);
98 assert_eq!(samples, Ok(vec![1; num_samples]));
99 assert!(Vec::<f32>::new()
100 .sample_with_replacement(num_samples)
101 .is_err());
102 }
103}