Skip to main content

axonml_data/
sampler.rs

1//! Samplers - Data Access Patterns
2//!
3//! Provides different strategies for sampling data indices.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use rand::seq::SliceRandom;
9use rand::Rng;
10
11// =============================================================================
12// Sampler Trait
13// =============================================================================
14
15/// Trait for all samplers.
16///
17/// A sampler generates indices that define the order of data access.
18pub trait Sampler: Send + Sync {
19    /// Returns the number of samples.
20    fn len(&self) -> usize;
21
22    /// Returns true if empty.
23    fn is_empty(&self) -> bool {
24        self.len() == 0
25    }
26
27    /// Creates an iterator over indices.
28    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
29}
30
31// =============================================================================
32// SequentialSampler
33// =============================================================================
34
35/// Samples elements sequentially.
36pub struct SequentialSampler {
37    len: usize,
38}
39
40impl SequentialSampler {
41    /// Creates a new `SequentialSampler`.
42    #[must_use]
43    pub fn new(len: usize) -> Self {
44        Self { len }
45    }
46}
47
48impl Sampler for SequentialSampler {
49    fn len(&self) -> usize {
50        self.len
51    }
52
53    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
54        Box::new(0..self.len)
55    }
56}
57
58// =============================================================================
59// RandomSampler
60// =============================================================================
61
62/// Samples elements randomly.
63pub struct RandomSampler {
64    len: usize,
65    replacement: bool,
66    num_samples: Option<usize>,
67}
68
69impl RandomSampler {
70    /// Creates a new `RandomSampler` without replacement.
71    #[must_use]
72    pub fn new(len: usize) -> Self {
73        Self {
74            len,
75            replacement: false,
76            num_samples: None,
77        }
78    }
79
80    /// Creates a `RandomSampler` with replacement.
81    #[must_use]
82    pub fn with_replacement(len: usize, num_samples: usize) -> Self {
83        Self {
84            len,
85            replacement: true,
86            num_samples: Some(num_samples),
87        }
88    }
89}
90
91impl Sampler for RandomSampler {
92    fn len(&self) -> usize {
93        self.num_samples.unwrap_or(self.len)
94    }
95
96    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
97        if self.replacement {
98            // With replacement: random sampling
99            let len = self.len;
100            let num = self.num_samples.unwrap_or(len);
101            Box::new(RandomWithReplacementIter::new(len, num))
102        } else {
103            // Without replacement: shuffled indices
104            let mut indices: Vec<usize> = (0..self.len).collect();
105            indices.shuffle(&mut rand::thread_rng());
106            Box::new(indices.into_iter())
107        }
108    }
109}
110
111/// Iterator for random sampling with replacement.
112struct RandomWithReplacementIter {
113    len: usize,
114    remaining: usize,
115}
116
117impl RandomWithReplacementIter {
118    fn new(len: usize, num_samples: usize) -> Self {
119        Self {
120            len,
121            remaining: num_samples,
122        }
123    }
124}
125
126impl Iterator for RandomWithReplacementIter {
127    type Item = usize;
128
129    fn next(&mut self) -> Option<Self::Item> {
130        if self.remaining == 0 {
131            return None;
132        }
133        self.remaining -= 1;
134        Some(rand::thread_rng().gen_range(0..self.len))
135    }
136}
137
138// =============================================================================
139// SubsetRandomSampler
140// =============================================================================
141
142/// Samples randomly from a subset of indices.
143pub struct SubsetRandomSampler {
144    indices: Vec<usize>,
145}
146
147impl SubsetRandomSampler {
148    /// Creates a new `SubsetRandomSampler`.
149    #[must_use]
150    pub fn new(indices: Vec<usize>) -> Self {
151        Self { indices }
152    }
153}
154
155impl Sampler for SubsetRandomSampler {
156    fn len(&self) -> usize {
157        self.indices.len()
158    }
159
160    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
161        let mut shuffled = self.indices.clone();
162        shuffled.shuffle(&mut rand::thread_rng());
163        Box::new(shuffled.into_iter())
164    }
165}
166
167// =============================================================================
168// WeightedRandomSampler
169// =============================================================================
170
171/// Samples elements with specified weights.
172pub struct WeightedRandomSampler {
173    weights: Vec<f64>,
174    num_samples: usize,
175    replacement: bool,
176}
177
178impl WeightedRandomSampler {
179    /// Creates a new `WeightedRandomSampler`.
180    #[must_use]
181    pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
182        Self {
183            weights,
184            num_samples,
185            replacement,
186        }
187    }
188
189    /// Samples an index based on weights.
190    fn sample_index(&self) -> usize {
191        let total: f64 = self.weights.iter().sum();
192        let mut cumulative = 0.0;
193        let threshold: f64 = rand::thread_rng().gen::<f64>() * total;
194
195        for (i, &weight) in self.weights.iter().enumerate() {
196            cumulative += weight;
197            if cumulative > threshold {
198                return i;
199            }
200        }
201        self.weights.len() - 1
202    }
203}
204
205impl Sampler for WeightedRandomSampler {
206    fn len(&self) -> usize {
207        self.num_samples
208    }
209
210    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
211        if self.replacement {
212            Box::new(WeightedIter::new(self))
213        } else {
214            // Without replacement: sample all unique indices
215            let mut indices = Vec::with_capacity(self.num_samples);
216            let mut available: Vec<usize> = (0..self.weights.len()).collect();
217            let mut weights = self.weights.clone();
218
219            while indices.len() < self.num_samples && !available.is_empty() {
220                let total: f64 = weights.iter().sum();
221                if total <= 0.0 {
222                    break;
223                }
224
225                let threshold: f64 = rand::thread_rng().gen::<f64>() * total;
226                let mut cumulative = 0.0;
227                let mut selected = 0;
228
229                for (i, &weight) in weights.iter().enumerate() {
230                    cumulative += weight;
231                    if cumulative > threshold {
232                        selected = i;
233                        break;
234                    }
235                }
236
237                indices.push(available[selected]);
238                available.remove(selected);
239                weights.remove(selected);
240            }
241
242            Box::new(indices.into_iter())
243        }
244    }
245}
246
247/// Iterator for weighted random sampling with replacement.
248struct WeightedIter<'a> {
249    sampler: &'a WeightedRandomSampler,
250    remaining: usize,
251}
252
253impl<'a> WeightedIter<'a> {
254    fn new(sampler: &'a WeightedRandomSampler) -> Self {
255        Self {
256            sampler,
257            remaining: sampler.num_samples,
258        }
259    }
260}
261
262impl Iterator for WeightedIter<'_> {
263    type Item = usize;
264
265    fn next(&mut self) -> Option<Self::Item> {
266        if self.remaining == 0 {
267            return None;
268        }
269        self.remaining -= 1;
270        Some(self.sampler.sample_index())
271    }
272}
273
274// =============================================================================
275// BatchSampler
276// =============================================================================
277
278/// Wraps a sampler to yield batches of indices.
279pub struct BatchSampler<S: Sampler> {
280    sampler: S,
281    batch_size: usize,
282    drop_last: bool,
283}
284
285impl<S: Sampler> BatchSampler<S> {
286    /// Creates a new `BatchSampler`.
287    pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
288        Self {
289            sampler,
290            batch_size,
291            drop_last,
292        }
293    }
294
295    /// Creates an iterator over batches of indices.
296    pub fn iter(&self) -> BatchIter {
297        let indices: Vec<usize> = self.sampler.iter().collect();
298        BatchIter {
299            indices,
300            batch_size: self.batch_size,
301            drop_last: self.drop_last,
302            position: 0,
303        }
304    }
305
306    /// Returns the number of batches.
307    pub fn len(&self) -> usize {
308        let total = self.sampler.len();
309        if self.drop_last {
310            total / self.batch_size
311        } else {
312            total.div_ceil(self.batch_size)
313        }
314    }
315
316    /// Returns true if empty.
317    pub fn is_empty(&self) -> bool {
318        self.len() == 0
319    }
320}
321
322/// Iterator over batches of indices.
323pub struct BatchIter {
324    indices: Vec<usize>,
325    batch_size: usize,
326    drop_last: bool,
327    position: usize,
328}
329
330impl Iterator for BatchIter {
331    type Item = Vec<usize>;
332
333    fn next(&mut self) -> Option<Self::Item> {
334        if self.position >= self.indices.len() {
335            return None;
336        }
337
338        let end = (self.position + self.batch_size).min(self.indices.len());
339        let batch: Vec<usize> = self.indices[self.position..end].to_vec();
340
341        if batch.len() < self.batch_size && self.drop_last {
342            return None;
343        }
344
345        self.position = end;
346        Some(batch)
347    }
348}
349
350// =============================================================================
351// Tests
352// =============================================================================
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_sequential_sampler() {
360        let sampler = SequentialSampler::new(5);
361        let indices: Vec<usize> = sampler.iter().collect();
362        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
363    }
364
365    #[test]
366    fn test_random_sampler() {
367        let sampler = RandomSampler::new(10);
368        let indices: Vec<usize> = sampler.iter().collect();
369
370        assert_eq!(indices.len(), 10);
371        // All indices should be unique (no replacement)
372        let mut sorted = indices.clone();
373        sorted.sort_unstable();
374        sorted.dedup();
375        assert_eq!(sorted.len(), 10);
376    }
377
378    #[test]
379    fn test_random_sampler_with_replacement() {
380        let sampler = RandomSampler::with_replacement(5, 20);
381        let indices: Vec<usize> = sampler.iter().collect();
382
383        assert_eq!(indices.len(), 20);
384        // All indices should be in valid range
385        assert!(indices.iter().all(|&i| i < 5));
386    }
387
388    #[test]
389    fn test_subset_random_sampler() {
390        let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
391        let indices: Vec<usize> = sampler.iter().collect();
392
393        assert_eq!(indices.len(), 4);
394        // All returned indices should be from the subset
395        let mut sorted = indices.clone();
396        sorted.sort_unstable();
397        assert_eq!(sorted, vec![0, 5, 10, 15]);
398    }
399
400    #[test]
401    fn test_weighted_random_sampler() {
402        // Heavy weight on index 0
403        let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
404        let indices: Vec<usize> = sampler.iter().collect();
405
406        assert_eq!(indices.len(), 100);
407        // Most samples should be index 0
408        let zeros = indices.iter().filter(|&&i| i == 0).count();
409        assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
410    }
411
412    #[test]
413    fn test_batch_sampler() {
414        let base = SequentialSampler::new(10);
415        let sampler = BatchSampler::new(base, 3, false);
416
417        let batches: Vec<Vec<usize>> = sampler.iter().collect();
418        assert_eq!(batches.len(), 4); // 10 / 3 = 3 full + 1 partial
419
420        assert_eq!(batches[0], vec![0, 1, 2]);
421        assert_eq!(batches[1], vec![3, 4, 5]);
422        assert_eq!(batches[2], vec![6, 7, 8]);
423        assert_eq!(batches[3], vec![9]); // Partial batch
424    }
425
426    #[test]
427    fn test_batch_sampler_drop_last() {
428        let base = SequentialSampler::new(10);
429        let sampler = BatchSampler::new(base, 3, true);
430
431        let batches: Vec<Vec<usize>> = sampler.iter().collect();
432        assert_eq!(batches.len(), 3); // 10 / 3 = 3, drop the partial
433
434        assert_eq!(batches[0], vec![0, 1, 2]);
435        assert_eq!(batches[1], vec![3, 4, 5]);
436        assert_eq!(batches[2], vec![6, 7, 8]);
437    }
438}