axonml_data/
sampler.rs

1//! Samplers - Data Access Patterns
2//!
3//! Provides different strategies for sampling data indices.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use rand::seq::SliceRandom;
9use rand::Rng;
10
11// =============================================================================
12// Sampler Trait
13// =============================================================================
14
15/// Trait for all samplers.
16///
17/// A sampler generates indices that define the order of data access.
18pub trait Sampler: Send + Sync {
19    /// Returns the number of samples.
20    fn len(&self) -> usize;
21
22    /// Returns true if empty.
23    fn is_empty(&self) -> bool {
24        self.len() == 0
25    }
26
27    /// Creates an iterator over indices.
28    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
29}
30
31// =============================================================================
32// SequentialSampler
33// =============================================================================
34
35/// Samples elements sequentially.
36pub struct SequentialSampler {
37    len: usize,
38}
39
40impl SequentialSampler {
41    /// Creates a new `SequentialSampler`.
42    #[must_use] pub fn new(len: usize) -> Self {
43        Self { len }
44    }
45}
46
47impl Sampler for SequentialSampler {
48    fn len(&self) -> usize {
49        self.len
50    }
51
52    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
53        Box::new(0..self.len)
54    }
55}
56
57// =============================================================================
58// RandomSampler
59// =============================================================================
60
61/// Samples elements randomly.
62pub struct RandomSampler {
63    len: usize,
64    replacement: bool,
65    num_samples: Option<usize>,
66}
67
68impl RandomSampler {
69    /// Creates a new `RandomSampler` without replacement.
70    #[must_use] pub fn new(len: usize) -> Self {
71        Self {
72            len,
73            replacement: false,
74            num_samples: None,
75        }
76    }
77
78    /// Creates a `RandomSampler` with replacement.
79    #[must_use] pub fn with_replacement(len: usize, num_samples: usize) -> Self {
80        Self {
81            len,
82            replacement: true,
83            num_samples: Some(num_samples),
84        }
85    }
86}
87
88impl Sampler for RandomSampler {
89    fn len(&self) -> usize {
90        self.num_samples.unwrap_or(self.len)
91    }
92
93    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
94        if self.replacement {
95            // With replacement: random sampling
96            let len = self.len;
97            let num = self.num_samples.unwrap_or(len);
98            Box::new(RandomWithReplacementIter::new(len, num))
99        } else {
100            // Without replacement: shuffled indices
101            let mut indices: Vec<usize> = (0..self.len).collect();
102            indices.shuffle(&mut rand::thread_rng());
103            Box::new(indices.into_iter())
104        }
105    }
106}
107
108/// Iterator for random sampling with replacement.
109struct RandomWithReplacementIter {
110    len: usize,
111    remaining: usize,
112}
113
114impl RandomWithReplacementIter {
115    fn new(len: usize, num_samples: usize) -> Self {
116        Self {
117            len,
118            remaining: num_samples,
119        }
120    }
121}
122
123impl Iterator for RandomWithReplacementIter {
124    type Item = usize;
125
126    fn next(&mut self) -> Option<Self::Item> {
127        if self.remaining == 0 {
128            return None;
129        }
130        self.remaining -= 1;
131        Some(rand::thread_rng().gen_range(0..self.len))
132    }
133}
134
135// =============================================================================
136// SubsetRandomSampler
137// =============================================================================
138
139/// Samples randomly from a subset of indices.
140pub struct SubsetRandomSampler {
141    indices: Vec<usize>,
142}
143
144impl SubsetRandomSampler {
145    /// Creates a new `SubsetRandomSampler`.
146    #[must_use] pub fn new(indices: Vec<usize>) -> Self {
147        Self { indices }
148    }
149}
150
151impl Sampler for SubsetRandomSampler {
152    fn len(&self) -> usize {
153        self.indices.len()
154    }
155
156    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
157        let mut shuffled = self.indices.clone();
158        shuffled.shuffle(&mut rand::thread_rng());
159        Box::new(shuffled.into_iter())
160    }
161}
162
163// =============================================================================
164// WeightedRandomSampler
165// =============================================================================
166
167/// Samples elements with specified weights.
168pub struct WeightedRandomSampler {
169    weights: Vec<f64>,
170    num_samples: usize,
171    replacement: bool,
172}
173
174impl WeightedRandomSampler {
175    /// Creates a new `WeightedRandomSampler`.
176    #[must_use] pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
177        Self {
178            weights,
179            num_samples,
180            replacement,
181        }
182    }
183
184    /// Samples an index based on weights.
185    fn sample_index(&self) -> usize {
186        let total: f64 = self.weights.iter().sum();
187        let mut cumulative = 0.0;
188        let threshold: f64 = rand::thread_rng().gen::<f64>() * total;
189
190        for (i, &weight) in self.weights.iter().enumerate() {
191            cumulative += weight;
192            if cumulative > threshold {
193                return i;
194            }
195        }
196        self.weights.len() - 1
197    }
198}
199
200impl Sampler for WeightedRandomSampler {
201    fn len(&self) -> usize {
202        self.num_samples
203    }
204
205    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
206        if self.replacement {
207            Box::new(WeightedIter::new(self))
208        } else {
209            // Without replacement: sample all unique indices
210            let mut indices = Vec::with_capacity(self.num_samples);
211            let mut available: Vec<usize> = (0..self.weights.len()).collect();
212            let mut weights = self.weights.clone();
213
214            while indices.len() < self.num_samples && !available.is_empty() {
215                let total: f64 = weights.iter().sum();
216                if total <= 0.0 {
217                    break;
218                }
219
220                let threshold: f64 = rand::thread_rng().gen::<f64>() * total;
221                let mut cumulative = 0.0;
222                let mut selected = 0;
223
224                for (i, &weight) in weights.iter().enumerate() {
225                    cumulative += weight;
226                    if cumulative > threshold {
227                        selected = i;
228                        break;
229                    }
230                }
231
232                indices.push(available[selected]);
233                available.remove(selected);
234                weights.remove(selected);
235            }
236
237            Box::new(indices.into_iter())
238        }
239    }
240}
241
242/// Iterator for weighted random sampling with replacement.
243struct WeightedIter<'a> {
244    sampler: &'a WeightedRandomSampler,
245    remaining: usize,
246}
247
248impl<'a> WeightedIter<'a> {
249    fn new(sampler: &'a WeightedRandomSampler) -> Self {
250        Self {
251            sampler,
252            remaining: sampler.num_samples,
253        }
254    }
255}
256
257impl Iterator for WeightedIter<'_> {
258    type Item = usize;
259
260    fn next(&mut self) -> Option<Self::Item> {
261        if self.remaining == 0 {
262            return None;
263        }
264        self.remaining -= 1;
265        Some(self.sampler.sample_index())
266    }
267}
268
269// =============================================================================
270// BatchSampler
271// =============================================================================
272
273/// Wraps a sampler to yield batches of indices.
274pub struct BatchSampler<S: Sampler> {
275    sampler: S,
276    batch_size: usize,
277    drop_last: bool,
278}
279
280impl<S: Sampler> BatchSampler<S> {
281    /// Creates a new `BatchSampler`.
282    pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
283        Self {
284            sampler,
285            batch_size,
286            drop_last,
287        }
288    }
289
290    /// Creates an iterator over batches of indices.
291    pub fn iter(&self) -> BatchIter {
292        let indices: Vec<usize> = self.sampler.iter().collect();
293        BatchIter {
294            indices,
295            batch_size: self.batch_size,
296            drop_last: self.drop_last,
297            position: 0,
298        }
299    }
300
301    /// Returns the number of batches.
302    pub fn len(&self) -> usize {
303        let total = self.sampler.len();
304        if self.drop_last {
305            total / self.batch_size
306        } else {
307            total.div_ceil(self.batch_size)
308        }
309    }
310
311    /// Returns true if empty.
312    pub fn is_empty(&self) -> bool {
313        self.len() == 0
314    }
315}
316
317/// Iterator over batches of indices.
318pub struct BatchIter {
319    indices: Vec<usize>,
320    batch_size: usize,
321    drop_last: bool,
322    position: usize,
323}
324
325impl Iterator for BatchIter {
326    type Item = Vec<usize>;
327
328    fn next(&mut self) -> Option<Self::Item> {
329        if self.position >= self.indices.len() {
330            return None;
331        }
332
333        let end = (self.position + self.batch_size).min(self.indices.len());
334        let batch: Vec<usize> = self.indices[self.position..end].to_vec();
335
336        if batch.len() < self.batch_size && self.drop_last {
337            return None;
338        }
339
340        self.position = end;
341        Some(batch)
342    }
343}
344
345// =============================================================================
346// Tests
347// =============================================================================
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_sequential_sampler() {
355        let sampler = SequentialSampler::new(5);
356        let indices: Vec<usize> = sampler.iter().collect();
357        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
358    }
359
360    #[test]
361    fn test_random_sampler() {
362        let sampler = RandomSampler::new(10);
363        let indices: Vec<usize> = sampler.iter().collect();
364
365        assert_eq!(indices.len(), 10);
366        // All indices should be unique (no replacement)
367        let mut sorted = indices.clone();
368        sorted.sort_unstable();
369        sorted.dedup();
370        assert_eq!(sorted.len(), 10);
371    }
372
373    #[test]
374    fn test_random_sampler_with_replacement() {
375        let sampler = RandomSampler::with_replacement(5, 20);
376        let indices: Vec<usize> = sampler.iter().collect();
377
378        assert_eq!(indices.len(), 20);
379        // All indices should be in valid range
380        assert!(indices.iter().all(|&i| i < 5));
381    }
382
383    #[test]
384    fn test_subset_random_sampler() {
385        let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
386        let indices: Vec<usize> = sampler.iter().collect();
387
388        assert_eq!(indices.len(), 4);
389        // All returned indices should be from the subset
390        let mut sorted = indices.clone();
391        sorted.sort_unstable();
392        assert_eq!(sorted, vec![0, 5, 10, 15]);
393    }
394
395    #[test]
396    fn test_weighted_random_sampler() {
397        // Heavy weight on index 0
398        let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
399        let indices: Vec<usize> = sampler.iter().collect();
400
401        assert_eq!(indices.len(), 100);
402        // Most samples should be index 0
403        let zeros = indices.iter().filter(|&&i| i == 0).count();
404        assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
405    }
406
407    #[test]
408    fn test_batch_sampler() {
409        let base = SequentialSampler::new(10);
410        let sampler = BatchSampler::new(base, 3, false);
411
412        let batches: Vec<Vec<usize>> = sampler.iter().collect();
413        assert_eq!(batches.len(), 4); // 10 / 3 = 3 full + 1 partial
414
415        assert_eq!(batches[0], vec![0, 1, 2]);
416        assert_eq!(batches[1], vec![3, 4, 5]);
417        assert_eq!(batches[2], vec![6, 7, 8]);
418        assert_eq!(batches[3], vec![9]); // Partial batch
419    }
420
421    #[test]
422    fn test_batch_sampler_drop_last() {
423        let base = SequentialSampler::new(10);
424        let sampler = BatchSampler::new(base, 3, true);
425
426        let batches: Vec<Vec<usize>> = sampler.iter().collect();
427        assert_eq!(batches.len(), 3); // 10 / 3 = 3, drop the partial
428
429        assert_eq!(batches[0], vec![0, 1, 2]);
430        assert_eq!(batches[1], vec![3, 4, 5]);
431        assert_eq!(batches[2], vec![6, 7, 8]);
432    }
433}