Skip to main content

axonml_data/
sampler.rs

1//! Samplers - Data Access Patterns
2//!
3//! # File
4//! `crates/axonml-data/src/sampler.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use rand::Rng;
18use rand::seq::SliceRandom;
19
20// =============================================================================
21// Sampler Trait
22// =============================================================================
23
24/// Trait for all samplers.
25///
26/// A sampler generates indices that define the order of data access.
27pub trait Sampler: Send + Sync {
28    /// Returns the number of samples.
29    fn len(&self) -> usize;
30
31    /// Returns true if empty.
32    fn is_empty(&self) -> bool {
33        self.len() == 0
34    }
35
36    /// Creates an iterator over indices.
37    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
38}
39
40// =============================================================================
41// SequentialSampler
42// =============================================================================
43
44/// Samples elements sequentially.
45pub struct SequentialSampler {
46    len: usize,
47}
48
49impl SequentialSampler {
50    /// Creates a new `SequentialSampler`.
51    #[must_use]
52    pub fn new(len: usize) -> Self {
53        Self { len }
54    }
55}
56
57impl Sampler for SequentialSampler {
58    fn len(&self) -> usize {
59        self.len
60    }
61
62    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
63        Box::new(0..self.len)
64    }
65}
66
67// =============================================================================
68// RandomSampler
69// =============================================================================
70
71/// Samples elements randomly.
72pub struct RandomSampler {
73    len: usize,
74    replacement: bool,
75    num_samples: Option<usize>,
76}
77
78impl RandomSampler {
79    /// Creates a new `RandomSampler` without replacement.
80    #[must_use]
81    pub fn new(len: usize) -> Self {
82        Self {
83            len,
84            replacement: false,
85            num_samples: None,
86        }
87    }
88
89    /// Creates a `RandomSampler` with replacement.
90    #[must_use]
91    pub fn with_replacement(len: usize, num_samples: usize) -> Self {
92        Self {
93            len,
94            replacement: true,
95            num_samples: Some(num_samples),
96        }
97    }
98}
99
100impl Sampler for RandomSampler {
101    fn len(&self) -> usize {
102        self.num_samples.unwrap_or(self.len)
103    }
104
105    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
106        if self.replacement {
107            // With replacement: random sampling
108            let len = self.len;
109            let num = self.num_samples.unwrap_or(len);
110            Box::new(RandomWithReplacementIter::new(len, num))
111        } else {
112            // Without replacement: shuffled indices
113            let mut indices: Vec<usize> = (0..self.len).collect();
114            indices.shuffle(&mut rand::thread_rng());
115            Box::new(indices.into_iter())
116        }
117    }
118}
119
120/// Iterator for random sampling with replacement.
121struct RandomWithReplacementIter {
122    len: usize,
123    remaining: usize,
124}
125
126impl RandomWithReplacementIter {
127    fn new(len: usize, num_samples: usize) -> Self {
128        Self {
129            len,
130            remaining: num_samples,
131        }
132    }
133}
134
135impl Iterator for RandomWithReplacementIter {
136    type Item = usize;
137
138    fn next(&mut self) -> Option<Self::Item> {
139        if self.remaining == 0 {
140            return None;
141        }
142        self.remaining -= 1;
143        Some(rand::thread_rng().gen_range(0..self.len))
144    }
145}
146
147// =============================================================================
148// SubsetRandomSampler
149// =============================================================================
150
151/// Samples randomly from a subset of indices.
152pub struct SubsetRandomSampler {
153    indices: Vec<usize>,
154}
155
156impl SubsetRandomSampler {
157    /// Creates a new `SubsetRandomSampler`.
158    #[must_use]
159    pub fn new(indices: Vec<usize>) -> Self {
160        Self { indices }
161    }
162}
163
164impl Sampler for SubsetRandomSampler {
165    fn len(&self) -> usize {
166        self.indices.len()
167    }
168
169    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
170        let mut shuffled = self.indices.clone();
171        shuffled.shuffle(&mut rand::thread_rng());
172        Box::new(shuffled.into_iter())
173    }
174}
175
176// =============================================================================
177// WeightedRandomSampler
178// =============================================================================
179
180/// Samples elements with specified weights.
181///
182/// Uses precomputed cumulative sums with binary search for O(log n) per sample
183/// instead of O(n) linear scan.
184pub struct WeightedRandomSampler {
185    weights: Vec<f64>,
186    /// Precomputed cumulative sum for O(log n) sampling.
187    cumulative: Vec<f64>,
188    num_samples: usize,
189    replacement: bool,
190}
191
192impl WeightedRandomSampler {
193    /// Creates a new `WeightedRandomSampler`.
194    #[must_use]
195    pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
196        let cumulative = Self::build_cumulative(&weights);
197        Self {
198            weights,
199            cumulative,
200            num_samples,
201            replacement,
202        }
203    }
204
205    /// Builds cumulative sum array from weights.
206    fn build_cumulative(weights: &[f64]) -> Vec<f64> {
207        let mut cum = Vec::with_capacity(weights.len());
208        let mut total = 0.0;
209        for &w in weights {
210            total += w;
211            cum.push(total);
212        }
213        cum
214    }
215
216    /// Samples an index based on weights using binary search on cumulative sums.
217    fn sample_index(&self) -> usize {
218        let total = *self.cumulative.last().unwrap_or(&1.0);
219        let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
220
221        // Binary search: find first index where cumulative > threshold
222        match self
223            .cumulative
224            .binary_search_by(|c| c.partial_cmp(&threshold).unwrap())
225        {
226            Ok(i) => i,
227            Err(i) => i.min(self.cumulative.len() - 1),
228        }
229    }
230}
231
232impl Sampler for WeightedRandomSampler {
233    fn len(&self) -> usize {
234        self.num_samples
235    }
236
237    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
238        if self.replacement {
239            Box::new(WeightedIter::new(self))
240        } else {
241            // Without replacement: use swap-remove for O(1) removal instead of O(n)
242            let mut indices = Vec::with_capacity(self.num_samples);
243            let mut available: Vec<usize> = (0..self.weights.len()).collect();
244            let mut weights = self.weights.clone();
245            let mut cumulative = self.cumulative.clone();
246
247            while indices.len() < self.num_samples && !available.is_empty() {
248                let total = *cumulative.last().unwrap_or(&0.0);
249                if total <= 0.0 {
250                    break;
251                }
252
253                let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
254                let selected =
255                    match cumulative.binary_search_by(|c| c.partial_cmp(&threshold).unwrap()) {
256                        Ok(i) => i,
257                        Err(i) => i.min(cumulative.len() - 1),
258                    };
259
260                indices.push(available[selected]);
261
262                // O(1) swap-remove instead of O(n) remove
263                available.swap_remove(selected);
264                weights.swap_remove(selected);
265
266                // Rebuild cumulative (needed after swap_remove reorders)
267                cumulative = Self::build_cumulative(&weights);
268            }
269
270            Box::new(indices.into_iter())
271        }
272    }
273}
274
275/// Iterator for weighted random sampling with replacement.
276struct WeightedIter<'a> {
277    sampler: &'a WeightedRandomSampler,
278    remaining: usize,
279}
280
281impl<'a> WeightedIter<'a> {
282    fn new(sampler: &'a WeightedRandomSampler) -> Self {
283        Self {
284            sampler,
285            remaining: sampler.num_samples,
286        }
287    }
288}
289
290impl Iterator for WeightedIter<'_> {
291    type Item = usize;
292
293    fn next(&mut self) -> Option<Self::Item> {
294        if self.remaining == 0 {
295            return None;
296        }
297        self.remaining -= 1;
298        Some(self.sampler.sample_index())
299    }
300}
301
302// =============================================================================
303// BatchSampler
304// =============================================================================
305
306/// Wraps a sampler to yield batches of indices.
307pub struct BatchSampler<S: Sampler> {
308    sampler: S,
309    batch_size: usize,
310    drop_last: bool,
311}
312
313impl<S: Sampler> BatchSampler<S> {
314    /// Creates a new `BatchSampler`.
315    pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
316        Self {
317            sampler,
318            batch_size,
319            drop_last,
320        }
321    }
322
323    /// Creates an iterator over batches of indices.
324    pub fn iter(&self) -> BatchIter {
325        let indices: Vec<usize> = self.sampler.iter().collect();
326        BatchIter {
327            indices,
328            batch_size: self.batch_size,
329            drop_last: self.drop_last,
330            position: 0,
331        }
332    }
333
334    /// Returns the number of batches.
335    pub fn len(&self) -> usize {
336        let total = self.sampler.len();
337        if self.drop_last {
338            total / self.batch_size
339        } else {
340            total.div_ceil(self.batch_size)
341        }
342    }
343
344    /// Returns true if empty.
345    pub fn is_empty(&self) -> bool {
346        self.len() == 0
347    }
348}
349
350/// Iterator over batches of indices.
351pub struct BatchIter {
352    indices: Vec<usize>,
353    batch_size: usize,
354    drop_last: bool,
355    position: usize,
356}
357
358impl Iterator for BatchIter {
359    type Item = Vec<usize>;
360
361    fn next(&mut self) -> Option<Self::Item> {
362        if self.position >= self.indices.len() {
363            return None;
364        }
365
366        let end = (self.position + self.batch_size).min(self.indices.len());
367        let batch: Vec<usize> = self.indices[self.position..end].to_vec();
368
369        if batch.len() < self.batch_size && self.drop_last {
370            return None;
371        }
372
373        self.position = end;
374        Some(batch)
375    }
376}
377
378// =============================================================================
379// Tests
380// =============================================================================
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_sequential_sampler() {
388        let sampler = SequentialSampler::new(5);
389        let indices: Vec<usize> = sampler.iter().collect();
390        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
391    }
392
393    #[test]
394    fn test_random_sampler() {
395        let sampler = RandomSampler::new(10);
396        let indices: Vec<usize> = sampler.iter().collect();
397
398        assert_eq!(indices.len(), 10);
399        // All indices should be unique (no replacement)
400        let mut sorted = indices.clone();
401        sorted.sort_unstable();
402        sorted.dedup();
403        assert_eq!(sorted.len(), 10);
404    }
405
406    #[test]
407    fn test_random_sampler_with_replacement() {
408        let sampler = RandomSampler::with_replacement(5, 20);
409        let indices: Vec<usize> = sampler.iter().collect();
410
411        assert_eq!(indices.len(), 20);
412        // All indices should be in valid range
413        assert!(indices.iter().all(|&i| i < 5));
414    }
415
416    #[test]
417    fn test_subset_random_sampler() {
418        let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
419        let indices: Vec<usize> = sampler.iter().collect();
420
421        assert_eq!(indices.len(), 4);
422        // All returned indices should be from the subset
423        let mut sorted = indices.clone();
424        sorted.sort_unstable();
425        assert_eq!(sorted, vec![0, 5, 10, 15]);
426    }
427
428    #[test]
429    fn test_weighted_random_sampler() {
430        // Heavy weight on index 0
431        let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
432        let indices: Vec<usize> = sampler.iter().collect();
433
434        assert_eq!(indices.len(), 100);
435        // Most samples should be index 0
436        let zeros = indices.iter().filter(|&&i| i == 0).count();
437        assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
438    }
439
440    #[test]
441    fn test_batch_sampler() {
442        let base = SequentialSampler::new(10);
443        let sampler = BatchSampler::new(base, 3, false);
444
445        let batches: Vec<Vec<usize>> = sampler.iter().collect();
446        assert_eq!(batches.len(), 4); // 10 / 3 = 3 full + 1 partial
447
448        assert_eq!(batches[0], vec![0, 1, 2]);
449        assert_eq!(batches[1], vec![3, 4, 5]);
450        assert_eq!(batches[2], vec![6, 7, 8]);
451        assert_eq!(batches[3], vec![9]); // Partial batch
452    }
453
454    #[test]
455    fn test_batch_sampler_drop_last() {
456        let base = SequentialSampler::new(10);
457        let sampler = BatchSampler::new(base, 3, true);
458
459        let batches: Vec<Vec<usize>> = sampler.iter().collect();
460        assert_eq!(batches.len(), 3); // 10 / 3 = 3, drop the partial
461
462        assert_eq!(batches[0], vec![0, 1, 2]);
463        assert_eq!(batches[1], vec![3, 4, 5]);
464        assert_eq!(batches[2], vec![6, 7, 8]);
465    }
466}