ai_dataloader/indexable/sampler/
random_sampler.rs1use rand::seq::SliceRandom;
2use rand::thread_rng;
3
4use super::{Len, Sampler};
5
6#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Hash, Eq, Ord)]
8pub struct RandomSampler {
9 data_source_len: usize,
11 replacement: bool,
14}
15
16impl Sampler for RandomSampler {
17 fn new(data_source_len: usize) -> Self {
18 Self {
19 data_source_len,
20 replacement: false,
21 }
22 }
23}
24impl Len for RandomSampler {
25 fn len(&self) -> usize {
26 self.data_source_len
27 }
28}
29impl IntoIterator for RandomSampler {
30 type Item = usize;
31 type IntoIter = RandomSamplerIter;
32 fn into_iter(self) -> Self::IntoIter {
33 RandomSamplerIter::new(self.data_source_len, self.replacement)
34 }
35}
36#[derive(Debug)]
38pub struct RandomSamplerIter {
39 indexes: Vec<usize>,
41 idx: usize,
43}
44
45impl RandomSamplerIter {
46 #[allow(clippy::fn_params_excessive_bools)]
54 fn new(data_source_len: usize, replacement: bool) -> Self {
55 if replacement {
56 todo!()
57 } else {
58 let mut vec: Vec<usize> = (0..data_source_len).collect();
59 vec.shuffle(&mut thread_rng());
60 Self {
61 indexes: vec,
62 idx: 0,
63 }
64 }
65 }
66}
67
68impl Iterator for RandomSamplerIter {
69 type Item = usize;
70 fn next(&mut self) -> Option<Self::Item> {
71 if self.idx < self.indexes.len() {
72 self.idx += 1;
73 Some(self.indexes[self.idx - 1])
74 } else {
75 None
76 }
77 }
78
79 fn size_hint(&self) -> (usize, Option<usize>) {
80 let len = self.indexes.len() - self.idx;
81 (len, Some(len))
82 }
83}
84
85impl ExactSizeIterator for RandomSamplerIter {}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn random_sampler() {
93 let random_sampler = RandomSampler {
94 data_source_len: 10,
95 replacement: false,
96 };
97 for idx in random_sampler {
98 println!("{idx}");
99 }
100 }
101
102 #[test]
103 fn len() {
104 let random_sampler = RandomSampler {
105 data_source_len: 10,
106 replacement: false,
107 };
108
109 assert_eq!(random_sampler.len(), 10);
110 let mut iter = random_sampler.into_iter();
111 assert_eq!(iter.len(), 10);
112 let _ = iter.next();
113 assert_eq!(iter.len(), 9);
114 let _ = iter.next();
115 let _ = iter.next();
116 assert_eq!(iter.len(), 7);
117 }
118}