ai_dataloader/indexable/sampler/
random_sampler.rs

1use rand::seq::SliceRandom;
2use rand::thread_rng;
3
4use super::{Len, Sampler};
5
6/// Sampler that returns random index between zero and `data_source_len`.
7#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Hash, Eq, Ord)]
8pub struct RandomSampler {
9    /// The length of the data source.
10    data_source_len: usize,
11    /// Whether the sample is replaced or not.
12    /// If it's replaced, we can have 2 times the same sample.
13    replacement: bool,
14}
15
16impl Sampler for RandomSampler {
17    fn new(data_source_len: usize) -> Self {
18        Self {
19            data_source_len,
20            replacement: false,
21        }
22    }
23}
24impl Len for RandomSampler {
25    fn len(&self) -> usize {
26        self.data_source_len
27    }
28}
29impl IntoIterator for RandomSampler {
30    type Item = usize;
31    type IntoIter = RandomSamplerIter;
32    fn into_iter(self) -> Self::IntoIter {
33        RandomSamplerIter::new(self.data_source_len, self.replacement)
34    }
35}
36/// Iterator that returns random index between zero and `data_source_len`.
37#[derive(Debug)]
38pub struct RandomSamplerIter {
39    /// A permutation over the datasets indexes.
40    indexes: Vec<usize>,
41    /// The current index.
42    idx: usize,
43}
44
45impl RandomSamplerIter {
46    /// Create a new `RandomSamplerIter`.
47    ///
48    /// # Arguments
49    ///
50    /// * `data_source_len` - The length of the dataset.
51    /// * `replacement` - Whether we can have the same sample twice over one iteration or not.
52    // FIXME: change this parameters in the next breaking release
53    #[allow(clippy::fn_params_excessive_bools)]
54    fn new(data_source_len: usize, replacement: bool) -> Self {
55        if replacement {
56            todo!()
57        } else {
58            let mut vec: Vec<usize> = (0..data_source_len).collect();
59            vec.shuffle(&mut thread_rng());
60            Self {
61                indexes: vec,
62                idx: 0,
63            }
64        }
65    }
66}
67
68impl Iterator for RandomSamplerIter {
69    type Item = usize;
70    fn next(&mut self) -> Option<Self::Item> {
71        if self.idx < self.indexes.len() {
72            self.idx += 1;
73            Some(self.indexes[self.idx - 1])
74        } else {
75            None
76        }
77    }
78
79    fn size_hint(&self) -> (usize, Option<usize>) {
80        let len = self.indexes.len() - self.idx;
81        (len, Some(len))
82    }
83}
84
85impl ExactSizeIterator for RandomSamplerIter {}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn random_sampler() {
93        let random_sampler = RandomSampler {
94            data_source_len: 10,
95            replacement: false,
96        };
97        for idx in random_sampler {
98            println!("{idx}");
99        }
100    }
101
102    #[test]
103    fn len() {
104        let random_sampler = RandomSampler {
105            data_source_len: 10,
106            replacement: false,
107        };
108
109        assert_eq!(random_sampler.len(), 10);
110        let mut iter = random_sampler.into_iter();
111        assert_eq!(iter.len(), 10);
112        let _ = iter.next();
113        assert_eq!(iter.len(), 9);
114        let _ = iter.next();
115        let _ = iter.next();
116        assert_eq!(iter.len(), 7);
117    }
118}