ghostflow_data/
sampler.rs1use rand::seq::SliceRandom;
4use rand::thread_rng;
5
6pub trait Sampler: Iterator<Item = usize> {
8 fn len(&self) -> usize;
9 fn is_empty(&self) -> bool {
10 self.len() == 0
11 }
12}
13
14pub struct SequentialSampler {
16 current: usize,
17 len: usize,
18}
19
20impl SequentialSampler {
21 pub fn new(len: usize) -> Self {
22 SequentialSampler { current: 0, len }
23 }
24}
25
26impl Iterator for SequentialSampler {
27 type Item = usize;
28
29 fn next(&mut self) -> Option<Self::Item> {
30 if self.current < self.len {
31 let idx = self.current;
32 self.current += 1;
33 Some(idx)
34 } else {
35 None
36 }
37 }
38}
39
40impl Sampler for SequentialSampler {
41 fn len(&self) -> usize {
42 self.len
43 }
44}
45
46pub struct RandomSampler {
48 indices: Vec<usize>,
49 current: usize,
50}
51
52impl RandomSampler {
53 pub fn new(len: usize) -> Self {
54 let mut indices: Vec<usize> = (0..len).collect();
55 indices.shuffle(&mut thread_rng());
56 RandomSampler { indices, current: 0 }
57 }
58}
59
60impl Iterator for RandomSampler {
61 type Item = usize;
62
63 fn next(&mut self) -> Option<Self::Item> {
64 if self.current < self.indices.len() {
65 let idx = self.indices[self.current];
66 self.current += 1;
67 Some(idx)
68 } else {
69 None
70 }
71 }
72}
73
74impl Sampler for RandomSampler {
75 fn len(&self) -> usize {
76 self.indices.len()
77 }
78}
79
80pub struct WeightedRandomSampler {
82 indices: Vec<usize>,
83 current: usize,
84}
85
86impl WeightedRandomSampler {
87 pub fn new(weights: &[f32], num_samples: usize, replacement: bool) -> Self {
88 let total_weight: f32 = weights.iter().sum();
89 let normalized: Vec<f32> = weights.iter().map(|w| w / total_weight).collect();
90
91 let mut indices = Vec::with_capacity(num_samples);
92 let mut available: Vec<usize> = (0..weights.len()).collect();
93
94 for _ in 0..num_samples {
95 let r: f32 = rand::random();
97 let mut cumsum = 0.0f32;
98
99 for (i, &w) in normalized.iter().enumerate() {
100 cumsum += w;
101 if r < cumsum {
102 if replacement {
103 indices.push(i);
104 } else if available.contains(&i) {
105 indices.push(i);
106 available.retain(|&x| x != i);
107 }
108 break;
109 }
110 }
111 }
112
113 WeightedRandomSampler { indices, current: 0 }
114 }
115}
116
117impl Iterator for WeightedRandomSampler {
118 type Item = usize;
119
120 fn next(&mut self) -> Option<Self::Item> {
121 if self.current < self.indices.len() {
122 let idx = self.indices[self.current];
123 self.current += 1;
124 Some(idx)
125 } else {
126 None
127 }
128 }
129}
130
131impl Sampler for WeightedRandomSampler {
132 fn len(&self) -> usize {
133 self.indices.len()
134 }
135}
136
137pub struct BatchSampler<S: Sampler> {
139 sampler: S,
140 batch_size: usize,
141 drop_last: bool,
142}
143
144impl<S: Sampler> BatchSampler<S> {
145 pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
146 BatchSampler {
147 sampler,
148 batch_size,
149 drop_last,
150 }
151 }
152}
153
154impl<S: Sampler> Iterator for BatchSampler<S> {
155 type Item = Vec<usize>;
156
157 fn next(&mut self) -> Option<Self::Item> {
158 let mut batch = Vec::with_capacity(self.batch_size);
159
160 for _ in 0..self.batch_size {
161 if let Some(idx) = self.sampler.next() {
162 batch.push(idx);
163 } else {
164 break;
165 }
166 }
167
168 if batch.is_empty() || (self.drop_last && batch.len() < self.batch_size) {
169 None
170 } else {
171 Some(batch)
172 }
173 }
174}