ghostflow_data/
sampler.rs1use rand::seq::SliceRandom;
4use rand::thread_rng;
5
6pub trait Sampler: Iterator<Item = usize> {
8 fn len(&self) -> usize;
9}
10
11pub struct SequentialSampler {
13 current: usize,
14 len: usize,
15}
16
17impl SequentialSampler {
18 pub fn new(len: usize) -> Self {
19 SequentialSampler { current: 0, len }
20 }
21}
22
23impl Iterator for SequentialSampler {
24 type Item = usize;
25
26 fn next(&mut self) -> Option<Self::Item> {
27 if self.current < self.len {
28 let idx = self.current;
29 self.current += 1;
30 Some(idx)
31 } else {
32 None
33 }
34 }
35}
36
37impl Sampler for SequentialSampler {
38 fn len(&self) -> usize {
39 self.len
40 }
41}
42
43pub struct RandomSampler {
45 indices: Vec<usize>,
46 current: usize,
47}
48
49impl RandomSampler {
50 pub fn new(len: usize) -> Self {
51 let mut indices: Vec<usize> = (0..len).collect();
52 indices.shuffle(&mut thread_rng());
53 RandomSampler { indices, current: 0 }
54 }
55}
56
57impl Iterator for RandomSampler {
58 type Item = usize;
59
60 fn next(&mut self) -> Option<Self::Item> {
61 if self.current < self.indices.len() {
62 let idx = self.indices[self.current];
63 self.current += 1;
64 Some(idx)
65 } else {
66 None
67 }
68 }
69}
70
71impl Sampler for RandomSampler {
72 fn len(&self) -> usize {
73 self.indices.len()
74 }
75}
76
77pub struct WeightedRandomSampler {
79 indices: Vec<usize>,
80 current: usize,
81}
82
83impl WeightedRandomSampler {
84 pub fn new(weights: &[f32], num_samples: usize, replacement: bool) -> Self {
85 let total_weight: f32 = weights.iter().sum();
86 let normalized: Vec<f32> = weights.iter().map(|w| w / total_weight).collect();
87
88 let mut indices = Vec::with_capacity(num_samples);
89 let mut available: Vec<usize> = (0..weights.len()).collect();
90
91 for _ in 0..num_samples {
92 let r: f32 = rand::random();
94 let mut cumsum = 0.0f32;
95
96 for (i, &w) in normalized.iter().enumerate() {
97 cumsum += w;
98 if r < cumsum {
99 if replacement {
100 indices.push(i);
101 } else if available.contains(&i) {
102 indices.push(i);
103 available.retain(|&x| x != i);
104 }
105 break;
106 }
107 }
108 }
109
110 WeightedRandomSampler { indices, current: 0 }
111 }
112}
113
114impl Iterator for WeightedRandomSampler {
115 type Item = usize;
116
117 fn next(&mut self) -> Option<Self::Item> {
118 if self.current < self.indices.len() {
119 let idx = self.indices[self.current];
120 self.current += 1;
121 Some(idx)
122 } else {
123 None
124 }
125 }
126}
127
128impl Sampler for WeightedRandomSampler {
129 fn len(&self) -> usize {
130 self.indices.len()
131 }
132}
133
134pub struct BatchSampler<S: Sampler> {
136 sampler: S,
137 batch_size: usize,
138 drop_last: bool,
139}
140
141impl<S: Sampler> BatchSampler<S> {
142 pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
143 BatchSampler {
144 sampler,
145 batch_size,
146 drop_last,
147 }
148 }
149}
150
151impl<S: Sampler> Iterator for BatchSampler<S> {
152 type Item = Vec<usize>;
153
154 fn next(&mut self) -> Option<Self::Item> {
155 let mut batch = Vec::with_capacity(self.batch_size);
156
157 for _ in 0..self.batch_size {
158 if let Some(idx) = self.sampler.next() {
159 batch.push(idx);
160 } else {
161 break;
162 }
163 }
164
165 if batch.is_empty() {
166 None
167 } else if self.drop_last && batch.len() < self.batch_size {
168 None
169 } else {
170 Some(batch)
171 }
172 }
173}