Skip to main content

axonml_data/
sampler.rs

1//! Samplers - Data Access Patterns
2//!
3//! # File
4//! `crates/axonml-data/src/sampler.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use rand::Rng;
18use rand::seq::SliceRandom;
19
20// =============================================================================
21// Sampler Trait
22// =============================================================================
23
24/// Trait for all samplers.
25///
26/// A sampler generates indices that define the order of data access.
27pub trait Sampler: Send + Sync {
28    /// Returns the number of samples.
29    fn len(&self) -> usize;
30
31    /// Returns true if empty.
32    fn is_empty(&self) -> bool {
33        self.len() == 0
34    }
35
36    /// Creates an iterator over indices.
37    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
38}
39
40// =============================================================================
41// SequentialSampler
42// =============================================================================
43
44/// Samples elements sequentially.
45pub struct SequentialSampler {
46    len: usize,
47}
48
49impl SequentialSampler {
50    /// Creates a new `SequentialSampler`.
51    #[must_use]
52    pub fn new(len: usize) -> Self {
53        Self { len }
54    }
55}
56
57impl Sampler for SequentialSampler {
58    fn len(&self) -> usize {
59        self.len
60    }
61
62    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
63        Box::new(0..self.len)
64    }
65}
66
67// =============================================================================
68// RandomSampler
69// =============================================================================
70
71/// Samples elements randomly.
72pub struct RandomSampler {
73    len: usize,
74    replacement: bool,
75    num_samples: Option<usize>,
76}
77
78impl RandomSampler {
79    /// Creates a new `RandomSampler` without replacement.
80    #[must_use]
81    pub fn new(len: usize) -> Self {
82        Self {
83            len,
84            replacement: false,
85            num_samples: None,
86        }
87    }
88
89    /// Creates a `RandomSampler` with replacement.
90    #[must_use]
91    pub fn with_replacement(len: usize, num_samples: usize) -> Self {
92        Self {
93            len,
94            replacement: true,
95            num_samples: Some(num_samples),
96        }
97    }
98}
99
100impl Sampler for RandomSampler {
101    fn len(&self) -> usize {
102        self.num_samples.unwrap_or(self.len)
103    }
104
105    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
106        if self.replacement {
107            // With replacement: random sampling
108            let len = self.len;
109            let num = self.num_samples.unwrap_or(len);
110            Box::new(RandomWithReplacementIter::new(len, num))
111        } else {
112            // Without replacement: shuffled indices
113            let mut indices: Vec<usize> = (0..self.len).collect();
114            indices.shuffle(&mut rand::thread_rng());
115            Box::new(indices.into_iter())
116        }
117    }
118}
119
120/// Iterator for random sampling with replacement.
121struct RandomWithReplacementIter {
122    len: usize,
123    remaining: usize,
124}
125
126impl RandomWithReplacementIter {
127    fn new(len: usize, num_samples: usize) -> Self {
128        Self {
129            len,
130            remaining: num_samples,
131        }
132    }
133}
134
135impl Iterator for RandomWithReplacementIter {
136    type Item = usize;
137
138    fn next(&mut self) -> Option<Self::Item> {
139        if self.remaining == 0 {
140            return None;
141        }
142        self.remaining -= 1;
143        Some(rand::thread_rng().gen_range(0..self.len))
144    }
145}
146
147// =============================================================================
148// SubsetRandomSampler
149// =============================================================================
150
151/// Samples randomly from a subset of indices.
152pub struct SubsetRandomSampler {
153    indices: Vec<usize>,
154}
155
156impl SubsetRandomSampler {
157    /// Creates a new `SubsetRandomSampler`.
158    #[must_use]
159    pub fn new(indices: Vec<usize>) -> Self {
160        Self { indices }
161    }
162}
163
164impl Sampler for SubsetRandomSampler {
165    fn len(&self) -> usize {
166        self.indices.len()
167    }
168
169    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
170        let mut shuffled = self.indices.clone();
171        shuffled.shuffle(&mut rand::thread_rng());
172        Box::new(shuffled.into_iter())
173    }
174}
175
176// =============================================================================
177// WeightedRandomSampler
178// =============================================================================
179
180/// Samples elements with specified weights.
181///
182/// Uses precomputed cumulative sums with binary search for O(log n) per sample
183/// instead of O(n) linear scan.
184pub struct WeightedRandomSampler {
185    weights: Vec<f64>,
186    /// Precomputed cumulative sum for O(log n) sampling.
187    cumulative: Vec<f64>,
188    num_samples: usize,
189    replacement: bool,
190}
191
192impl WeightedRandomSampler {
193    /// Creates a new `WeightedRandomSampler`.
194    #[must_use]
195    pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
196        let cumulative = Self::build_cumulative(&weights);
197        Self {
198            weights,
199            cumulative,
200            num_samples,
201            replacement,
202        }
203    }
204
205    /// Builds cumulative sum array from weights.
206    fn build_cumulative(weights: &[f64]) -> Vec<f64> {
207        let mut cum = Vec::with_capacity(weights.len());
208        let mut total = 0.0;
209        for &w in weights {
210            total += w;
211            cum.push(total);
212        }
213        cum
214    }
215
216    /// Samples an index based on weights using binary search on cumulative sums.
217    fn sample_index(&self) -> usize {
218        let total = *self.cumulative.last().unwrap_or(&1.0);
219        let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
220
221        // Binary search: find first index where cumulative > threshold
222        match self.cumulative.binary_search_by(|c| c.partial_cmp(&threshold).unwrap()) {
223            Ok(i) => i,
224            Err(i) => i.min(self.cumulative.len() - 1),
225        }
226    }
227}
228
229impl Sampler for WeightedRandomSampler {
230    fn len(&self) -> usize {
231        self.num_samples
232    }
233
234    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
235        if self.replacement {
236            Box::new(WeightedIter::new(self))
237        } else {
238            // Without replacement: use swap-remove for O(1) removal instead of O(n)
239            let mut indices = Vec::with_capacity(self.num_samples);
240            let mut available: Vec<usize> = (0..self.weights.len()).collect();
241            let mut weights = self.weights.clone();
242            let mut cumulative = self.cumulative.clone();
243
244            while indices.len() < self.num_samples && !available.is_empty() {
245                let total = *cumulative.last().unwrap_or(&0.0);
246                if total <= 0.0 {
247                    break;
248                }
249
250                let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
251                let selected = match cumulative.binary_search_by(|c| {
252                    c.partial_cmp(&threshold).unwrap()
253                }) {
254                    Ok(i) => i,
255                    Err(i) => i.min(cumulative.len() - 1),
256                };
257
258                indices.push(available[selected]);
259
260                // O(1) swap-remove instead of O(n) remove
261                available.swap_remove(selected);
262                weights.swap_remove(selected);
263
264                // Rebuild cumulative (needed after swap_remove reorders)
265                cumulative = Self::build_cumulative(&weights);
266            }
267
268            Box::new(indices.into_iter())
269        }
270    }
271}
272
273/// Iterator for weighted random sampling with replacement.
274struct WeightedIter<'a> {
275    sampler: &'a WeightedRandomSampler,
276    remaining: usize,
277}
278
279impl<'a> WeightedIter<'a> {
280    fn new(sampler: &'a WeightedRandomSampler) -> Self {
281        Self {
282            sampler,
283            remaining: sampler.num_samples,
284        }
285    }
286}
287
288impl Iterator for WeightedIter<'_> {
289    type Item = usize;
290
291    fn next(&mut self) -> Option<Self::Item> {
292        if self.remaining == 0 {
293            return None;
294        }
295        self.remaining -= 1;
296        Some(self.sampler.sample_index())
297    }
298}
299
300// =============================================================================
301// BatchSampler
302// =============================================================================
303
304/// Wraps a sampler to yield batches of indices.
305pub struct BatchSampler<S: Sampler> {
306    sampler: S,
307    batch_size: usize,
308    drop_last: bool,
309}
310
311impl<S: Sampler> BatchSampler<S> {
312    /// Creates a new `BatchSampler`.
313    pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
314        Self {
315            sampler,
316            batch_size,
317            drop_last,
318        }
319    }
320
321    /// Creates an iterator over batches of indices.
322    pub fn iter(&self) -> BatchIter {
323        let indices: Vec<usize> = self.sampler.iter().collect();
324        BatchIter {
325            indices,
326            batch_size: self.batch_size,
327            drop_last: self.drop_last,
328            position: 0,
329        }
330    }
331
332    /// Returns the number of batches.
333    pub fn len(&self) -> usize {
334        let total = self.sampler.len();
335        if self.drop_last {
336            total / self.batch_size
337        } else {
338            total.div_ceil(self.batch_size)
339        }
340    }
341
342    /// Returns true if empty.
343    pub fn is_empty(&self) -> bool {
344        self.len() == 0
345    }
346}
347
348/// Iterator over batches of indices.
349pub struct BatchIter {
350    indices: Vec<usize>,
351    batch_size: usize,
352    drop_last: bool,
353    position: usize,
354}
355
356impl Iterator for BatchIter {
357    type Item = Vec<usize>;
358
359    fn next(&mut self) -> Option<Self::Item> {
360        if self.position >= self.indices.len() {
361            return None;
362        }
363
364        let end = (self.position + self.batch_size).min(self.indices.len());
365        let batch: Vec<usize> = self.indices[self.position..end].to_vec();
366
367        if batch.len() < self.batch_size && self.drop_last {
368            return None;
369        }
370
371        self.position = end;
372        Some(batch)
373    }
374}
375
376// =============================================================================
377// Tests
378// =============================================================================
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_sequential_sampler() {
386        let sampler = SequentialSampler::new(5);
387        let indices: Vec<usize> = sampler.iter().collect();
388        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
389    }
390
391    #[test]
392    fn test_random_sampler() {
393        let sampler = RandomSampler::new(10);
394        let indices: Vec<usize> = sampler.iter().collect();
395
396        assert_eq!(indices.len(), 10);
397        // All indices should be unique (no replacement)
398        let mut sorted = indices.clone();
399        sorted.sort_unstable();
400        sorted.dedup();
401        assert_eq!(sorted.len(), 10);
402    }
403
404    #[test]
405    fn test_random_sampler_with_replacement() {
406        let sampler = RandomSampler::with_replacement(5, 20);
407        let indices: Vec<usize> = sampler.iter().collect();
408
409        assert_eq!(indices.len(), 20);
410        // All indices should be in valid range
411        assert!(indices.iter().all(|&i| i < 5));
412    }
413
414    #[test]
415    fn test_subset_random_sampler() {
416        let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
417        let indices: Vec<usize> = sampler.iter().collect();
418
419        assert_eq!(indices.len(), 4);
420        // All returned indices should be from the subset
421        let mut sorted = indices.clone();
422        sorted.sort_unstable();
423        assert_eq!(sorted, vec![0, 5, 10, 15]);
424    }
425
426    #[test]
427    fn test_weighted_random_sampler() {
428        // Heavy weight on index 0
429        let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
430        let indices: Vec<usize> = sampler.iter().collect();
431
432        assert_eq!(indices.len(), 100);
433        // Most samples should be index 0
434        let zeros = indices.iter().filter(|&&i| i == 0).count();
435        assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
436    }
437
438    #[test]
439    fn test_batch_sampler() {
440        let base = SequentialSampler::new(10);
441        let sampler = BatchSampler::new(base, 3, false);
442
443        let batches: Vec<Vec<usize>> = sampler.iter().collect();
444        assert_eq!(batches.len(), 4); // 10 / 3 = 3 full + 1 partial
445
446        assert_eq!(batches[0], vec![0, 1, 2]);
447        assert_eq!(batches[1], vec![3, 4, 5]);
448        assert_eq!(batches[2], vec![6, 7, 8]);
449        assert_eq!(batches[3], vec![9]); // Partial batch
450    }
451
452    #[test]
453    fn test_batch_sampler_drop_last() {
454        let base = SequentialSampler::new(10);
455        let sampler = BatchSampler::new(base, 3, true);
456
457        let batches: Vec<Vec<usize>> = sampler.iter().collect();
458        assert_eq!(batches.len(), 3); // 10 / 3 = 3, drop the partial
459
460        assert_eq!(batches[0], vec![0, 1, 2]);
461        assert_eq!(batches[1], vec![3, 4, 5]);
462        assert_eq!(batches[2], vec![6, 7, 8]);
463    }
464}