Skip to main content

axonml_data/
sampler.rs

1//! Samplers - Data Access Patterns
2//!
3//! # File
4//! `crates/axonml-data/src/sampler.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use rand::Rng;
19use rand::seq::SliceRandom;
20
21// =============================================================================
22// Sampler Trait
23// =============================================================================
24
25/// Trait for all samplers.
26///
27/// A sampler generates indices that define the order of data access.
28pub trait Sampler: Send + Sync {
29    /// Returns the number of samples.
30    fn len(&self) -> usize;
31
32    /// Returns true if empty.
33    fn is_empty(&self) -> bool {
34        self.len() == 0
35    }
36
37    /// Creates an iterator over indices.
38    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
39}
40
41// =============================================================================
42// SequentialSampler
43// =============================================================================
44
45/// Samples elements sequentially.
46pub struct SequentialSampler {
47    len: usize,
48}
49
50impl SequentialSampler {
51    /// Creates a new `SequentialSampler`.
52    #[must_use]
53    pub fn new(len: usize) -> Self {
54        Self { len }
55    }
56}
57
58impl Sampler for SequentialSampler {
59    fn len(&self) -> usize {
60        self.len
61    }
62
63    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
64        Box::new(0..self.len)
65    }
66}
67
68// =============================================================================
69// RandomSampler
70// =============================================================================
71
72/// Samples elements randomly.
73pub struct RandomSampler {
74    len: usize,
75    replacement: bool,
76    num_samples: Option<usize>,
77}
78
79impl RandomSampler {
80    /// Creates a new `RandomSampler` without replacement.
81    #[must_use]
82    pub fn new(len: usize) -> Self {
83        Self {
84            len,
85            replacement: false,
86            num_samples: None,
87        }
88    }
89
90    /// Creates a `RandomSampler` with replacement.
91    #[must_use]
92    pub fn with_replacement(len: usize, num_samples: usize) -> Self {
93        Self {
94            len,
95            replacement: true,
96            num_samples: Some(num_samples),
97        }
98    }
99}
100
101impl Sampler for RandomSampler {
102    fn len(&self) -> usize {
103        self.num_samples.unwrap_or(self.len)
104    }
105
106    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
107        if self.replacement {
108            // With replacement: random sampling
109            let len = self.len;
110            let num = self.num_samples.unwrap_or(len);
111            Box::new(RandomWithReplacementIter::new(len, num))
112        } else {
113            // Without replacement: shuffled indices
114            let mut indices: Vec<usize> = (0..self.len).collect();
115            indices.shuffle(&mut rand::thread_rng());
116            Box::new(indices.into_iter())
117        }
118    }
119}
120
121/// Iterator for random sampling with replacement.
122struct RandomWithReplacementIter {
123    len: usize,
124    remaining: usize,
125}
126
127impl RandomWithReplacementIter {
128    fn new(len: usize, num_samples: usize) -> Self {
129        Self {
130            len,
131            remaining: num_samples,
132        }
133    }
134}
135
136impl Iterator for RandomWithReplacementIter {
137    type Item = usize;
138
139    fn next(&mut self) -> Option<Self::Item> {
140        if self.remaining == 0 {
141            return None;
142        }
143        self.remaining -= 1;
144        Some(rand::thread_rng().gen_range(0..self.len))
145    }
146}
147
148// =============================================================================
149// SubsetRandomSampler
150// =============================================================================
151
152/// Samples randomly from a subset of indices.
153pub struct SubsetRandomSampler {
154    indices: Vec<usize>,
155}
156
157impl SubsetRandomSampler {
158    /// Creates a new `SubsetRandomSampler`.
159    #[must_use]
160    pub fn new(indices: Vec<usize>) -> Self {
161        Self { indices }
162    }
163}
164
165impl Sampler for SubsetRandomSampler {
166    fn len(&self) -> usize {
167        self.indices.len()
168    }
169
170    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
171        let mut shuffled = self.indices.clone();
172        shuffled.shuffle(&mut rand::thread_rng());
173        Box::new(shuffled.into_iter())
174    }
175}
176
177// =============================================================================
178// WeightedRandomSampler
179// =============================================================================
180
181/// Samples elements with specified weights.
182///
183/// Uses precomputed cumulative sums with binary search for O(log n) per sample
184/// instead of O(n) linear scan.
185pub struct WeightedRandomSampler {
186    weights: Vec<f64>,
187    /// Precomputed cumulative sum for O(log n) sampling.
188    cumulative: Vec<f64>,
189    num_samples: usize,
190    replacement: bool,
191}
192
193impl WeightedRandomSampler {
194    /// Creates a new `WeightedRandomSampler`.
195    #[must_use]
196    pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
197        let cumulative = Self::build_cumulative(&weights);
198        Self {
199            weights,
200            cumulative,
201            num_samples,
202            replacement,
203        }
204    }
205
206    /// Builds cumulative sum array from weights.
207    fn build_cumulative(weights: &[f64]) -> Vec<f64> {
208        let mut cum = Vec::with_capacity(weights.len());
209        let mut total = 0.0;
210        for &w in weights {
211            total += w;
212            cum.push(total);
213        }
214        cum
215    }
216
217    /// Samples an index based on weights using binary search on cumulative sums.
218    fn sample_index(&self) -> usize {
219        let total = *self.cumulative.last().unwrap_or(&1.0);
220        let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
221
222        // Binary search: find first index where cumulative > threshold
223        match self
224            .cumulative
225            .binary_search_by(|c| c.partial_cmp(&threshold).unwrap())
226        {
227            Ok(i) => i,
228            Err(i) => i.min(self.cumulative.len() - 1),
229        }
230    }
231}
232
233impl Sampler for WeightedRandomSampler {
234    fn len(&self) -> usize {
235        self.num_samples
236    }
237
238    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
239        if self.replacement {
240            Box::new(WeightedIter::new(self))
241        } else {
242            // Without replacement: use swap-remove for O(1) removal instead of O(n)
243            let mut indices = Vec::with_capacity(self.num_samples);
244            let mut available: Vec<usize> = (0..self.weights.len()).collect();
245            let mut weights = self.weights.clone();
246            let mut cumulative = self.cumulative.clone();
247
248            while indices.len() < self.num_samples && !available.is_empty() {
249                let total = *cumulative.last().unwrap_or(&0.0);
250                if total <= 0.0 {
251                    break;
252                }
253
254                let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
255                let selected =
256                    match cumulative.binary_search_by(|c| c.partial_cmp(&threshold).unwrap()) {
257                        Ok(i) => i,
258                        Err(i) => i.min(cumulative.len() - 1),
259                    };
260
261                indices.push(available[selected]);
262
263                // O(1) swap-remove instead of O(n) remove
264                available.swap_remove(selected);
265                weights.swap_remove(selected);
266
267                // Rebuild cumulative (needed after swap_remove reorders)
268                cumulative = Self::build_cumulative(&weights);
269            }
270
271            Box::new(indices.into_iter())
272        }
273    }
274}
275
276/// Iterator for weighted random sampling with replacement.
277struct WeightedIter<'a> {
278    sampler: &'a WeightedRandomSampler,
279    remaining: usize,
280}
281
282impl<'a> WeightedIter<'a> {
283    fn new(sampler: &'a WeightedRandomSampler) -> Self {
284        Self {
285            sampler,
286            remaining: sampler.num_samples,
287        }
288    }
289}
290
291impl Iterator for WeightedIter<'_> {
292    type Item = usize;
293
294    fn next(&mut self) -> Option<Self::Item> {
295        if self.remaining == 0 {
296            return None;
297        }
298        self.remaining -= 1;
299        Some(self.sampler.sample_index())
300    }
301}
302
303// =============================================================================
304// BatchSampler
305// =============================================================================
306
307/// Wraps a sampler to yield batches of indices.
308pub struct BatchSampler<S: Sampler> {
309    sampler: S,
310    batch_size: usize,
311    drop_last: bool,
312}
313
314impl<S: Sampler> BatchSampler<S> {
315    /// Creates a new `BatchSampler`.
316    pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
317        Self {
318            sampler,
319            batch_size,
320            drop_last,
321        }
322    }
323
324    /// Creates an iterator over batches of indices.
325    pub fn iter(&self) -> BatchIter {
326        let indices: Vec<usize> = self.sampler.iter().collect();
327        BatchIter {
328            indices,
329            batch_size: self.batch_size,
330            drop_last: self.drop_last,
331            position: 0,
332        }
333    }
334
335    /// Returns the number of batches.
336    pub fn len(&self) -> usize {
337        let total = self.sampler.len();
338        if self.drop_last {
339            total / self.batch_size
340        } else {
341            total.div_ceil(self.batch_size)
342        }
343    }
344
345    /// Returns true if empty.
346    pub fn is_empty(&self) -> bool {
347        self.len() == 0
348    }
349}
350
351/// Iterator over batches of indices.
352pub struct BatchIter {
353    indices: Vec<usize>,
354    batch_size: usize,
355    drop_last: bool,
356    position: usize,
357}
358
359impl Iterator for BatchIter {
360    type Item = Vec<usize>;
361
362    fn next(&mut self) -> Option<Self::Item> {
363        if self.position >= self.indices.len() {
364            return None;
365        }
366
367        let end = (self.position + self.batch_size).min(self.indices.len());
368        let batch: Vec<usize> = self.indices[self.position..end].to_vec();
369
370        if batch.len() < self.batch_size && self.drop_last {
371            return None;
372        }
373
374        self.position = end;
375        Some(batch)
376    }
377}
378
379// =============================================================================
380// Tests
381// =============================================================================
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_sequential_sampler() {
389        let sampler = SequentialSampler::new(5);
390        let indices: Vec<usize> = sampler.iter().collect();
391        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
392    }
393
394    #[test]
395    fn test_random_sampler() {
396        let sampler = RandomSampler::new(10);
397        let indices: Vec<usize> = sampler.iter().collect();
398
399        assert_eq!(indices.len(), 10);
400        // All indices should be unique (no replacement)
401        let mut sorted = indices.clone();
402        sorted.sort_unstable();
403        sorted.dedup();
404        assert_eq!(sorted.len(), 10);
405    }
406
407    #[test]
408    fn test_random_sampler_with_replacement() {
409        let sampler = RandomSampler::with_replacement(5, 20);
410        let indices: Vec<usize> = sampler.iter().collect();
411
412        assert_eq!(indices.len(), 20);
413        // All indices should be in valid range
414        assert!(indices.iter().all(|&i| i < 5));
415    }
416
417    #[test]
418    fn test_subset_random_sampler() {
419        let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
420        let indices: Vec<usize> = sampler.iter().collect();
421
422        assert_eq!(indices.len(), 4);
423        // All returned indices should be from the subset
424        let mut sorted = indices.clone();
425        sorted.sort_unstable();
426        assert_eq!(sorted, vec![0, 5, 10, 15]);
427    }
428
429    #[test]
430    fn test_weighted_random_sampler() {
431        // Heavy weight on index 0
432        let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
433        let indices: Vec<usize> = sampler.iter().collect();
434
435        assert_eq!(indices.len(), 100);
436        // Most samples should be index 0
437        let zeros = indices.iter().filter(|&&i| i == 0).count();
438        assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
439    }
440
441    #[test]
442    fn test_batch_sampler() {
443        let base = SequentialSampler::new(10);
444        let sampler = BatchSampler::new(base, 3, false);
445
446        let batches: Vec<Vec<usize>> = sampler.iter().collect();
447        assert_eq!(batches.len(), 4); // 10 / 3 = 3 full + 1 partial
448
449        assert_eq!(batches[0], vec![0, 1, 2]);
450        assert_eq!(batches[1], vec![3, 4, 5]);
451        assert_eq!(batches[2], vec![6, 7, 8]);
452        assert_eq!(batches[3], vec![9]); // Partial batch
453    }
454
455    #[test]
456    fn test_batch_sampler_drop_last() {
457        let base = SequentialSampler::new(10);
458        let sampler = BatchSampler::new(base, 3, true);
459
460        let batches: Vec<Vec<usize>> = sampler.iter().collect();
461        assert_eq!(batches.len(), 3); // 10 / 3 = 3, drop the partial
462
463        assert_eq!(batches[0], vec![0, 1, 2]);
464        assert_eq!(batches[1], vec![3, 4, 5]);
465        assert_eq!(batches[2], vec![6, 7, 8]);
466    }
467}