Skip to main content

axonml_data/
sampler.rs

1//! Samplers - Data Access Patterns
2//!
3//! # File
4//! `crates/axonml-data/src/sampler.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use rand::Rng;
18use rand::seq::SliceRandom;
19
20// =============================================================================
21// Sampler Trait
22// =============================================================================
23
24/// Trait for all samplers.
25///
26/// A sampler generates indices that define the order of data access.
27pub trait Sampler: Send + Sync {
28    /// Returns the number of samples.
29    fn len(&self) -> usize;
30
31    /// Returns true if empty.
32    fn is_empty(&self) -> bool {
33        self.len() == 0
34    }
35
36    /// Creates an iterator over indices.
37    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_>;
38}
39
40// =============================================================================
41// SequentialSampler
42// =============================================================================
43
44/// Samples elements sequentially.
45pub struct SequentialSampler {
46    len: usize,
47}
48
49impl SequentialSampler {
50    /// Creates a new `SequentialSampler`.
51    #[must_use]
52    pub fn new(len: usize) -> Self {
53        Self { len }
54    }
55}
56
57impl Sampler for SequentialSampler {
58    fn len(&self) -> usize {
59        self.len
60    }
61
62    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
63        Box::new(0..self.len)
64    }
65}
66
67// =============================================================================
68// RandomSampler
69// =============================================================================
70
71/// Samples elements randomly.
72pub struct RandomSampler {
73    len: usize,
74    replacement: bool,
75    num_samples: Option<usize>,
76}
77
78impl RandomSampler {
79    /// Creates a new `RandomSampler` without replacement.
80    #[must_use]
81    pub fn new(len: usize) -> Self {
82        Self {
83            len,
84            replacement: false,
85            num_samples: None,
86        }
87    }
88
89    /// Creates a `RandomSampler` with replacement.
90    #[must_use]
91    pub fn with_replacement(len: usize, num_samples: usize) -> Self {
92        Self {
93            len,
94            replacement: true,
95            num_samples: Some(num_samples),
96        }
97    }
98}
99
100impl Sampler for RandomSampler {
101    fn len(&self) -> usize {
102        self.num_samples.unwrap_or(self.len)
103    }
104
105    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
106        if self.replacement {
107            // With replacement: random sampling
108            let len = self.len;
109            let num = self.num_samples.unwrap_or(len);
110            Box::new(RandomWithReplacementIter::new(len, num))
111        } else {
112            // Without replacement: shuffled indices
113            let mut indices: Vec<usize> = (0..self.len).collect();
114            indices.shuffle(&mut rand::thread_rng());
115            Box::new(indices.into_iter())
116        }
117    }
118}
119
120/// Iterator for random sampling with replacement.
121struct RandomWithReplacementIter {
122    len: usize,
123    remaining: usize,
124}
125
126impl RandomWithReplacementIter {
127    fn new(len: usize, num_samples: usize) -> Self {
128        Self {
129            len,
130            remaining: num_samples,
131        }
132    }
133}
134
135impl Iterator for RandomWithReplacementIter {
136    type Item = usize;
137
138    fn next(&mut self) -> Option<Self::Item> {
139        if self.remaining == 0 {
140            return None;
141        }
142        self.remaining -= 1;
143        Some(rand::thread_rng().gen_range(0..self.len))
144    }
145}
146
147// =============================================================================
148// SubsetRandomSampler
149// =============================================================================
150
151/// Samples randomly from a subset of indices.
152pub struct SubsetRandomSampler {
153    indices: Vec<usize>,
154}
155
156impl SubsetRandomSampler {
157    /// Creates a new `SubsetRandomSampler`.
158    #[must_use]
159    pub fn new(indices: Vec<usize>) -> Self {
160        Self { indices }
161    }
162}
163
164impl Sampler for SubsetRandomSampler {
165    fn len(&self) -> usize {
166        self.indices.len()
167    }
168
169    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
170        let mut shuffled = self.indices.clone();
171        shuffled.shuffle(&mut rand::thread_rng());
172        Box::new(shuffled.into_iter())
173    }
174}
175
176// =============================================================================
177// WeightedRandomSampler
178// =============================================================================
179
180/// Samples elements with specified weights.
181pub struct WeightedRandomSampler {
182    weights: Vec<f64>,
183    num_samples: usize,
184    replacement: bool,
185}
186
187impl WeightedRandomSampler {
188    /// Creates a new `WeightedRandomSampler`.
189    #[must_use]
190    pub fn new(weights: Vec<f64>, num_samples: usize, replacement: bool) -> Self {
191        Self {
192            weights,
193            num_samples,
194            replacement,
195        }
196    }
197
198    /// Samples an index based on weights.
199    fn sample_index(&self) -> usize {
200        let total: f64 = self.weights.iter().sum();
201        let mut cumulative = 0.0;
202        let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
203
204        for (i, &weight) in self.weights.iter().enumerate() {
205            cumulative += weight;
206            if cumulative > threshold {
207                return i;
208            }
209        }
210        self.weights.len() - 1
211    }
212}
213
214impl Sampler for WeightedRandomSampler {
215    fn len(&self) -> usize {
216        self.num_samples
217    }
218
219    fn iter(&self) -> Box<dyn Iterator<Item = usize> + '_> {
220        if self.replacement {
221            Box::new(WeightedIter::new(self))
222        } else {
223            // Without replacement: sample all unique indices
224            let mut indices = Vec::with_capacity(self.num_samples);
225            let mut available: Vec<usize> = (0..self.weights.len()).collect();
226            let mut weights = self.weights.clone();
227
228            while indices.len() < self.num_samples && !available.is_empty() {
229                let total: f64 = weights.iter().sum();
230                if total <= 0.0 {
231                    break;
232                }
233
234                let threshold: f64 = rand::thread_rng().r#gen::<f64>() * total;
235                let mut cumulative = 0.0;
236                let mut selected = 0;
237
238                for (i, &weight) in weights.iter().enumerate() {
239                    cumulative += weight;
240                    if cumulative > threshold {
241                        selected = i;
242                        break;
243                    }
244                }
245
246                indices.push(available[selected]);
247                available.remove(selected);
248                weights.remove(selected);
249            }
250
251            Box::new(indices.into_iter())
252        }
253    }
254}
255
256/// Iterator for weighted random sampling with replacement.
257struct WeightedIter<'a> {
258    sampler: &'a WeightedRandomSampler,
259    remaining: usize,
260}
261
262impl<'a> WeightedIter<'a> {
263    fn new(sampler: &'a WeightedRandomSampler) -> Self {
264        Self {
265            sampler,
266            remaining: sampler.num_samples,
267        }
268    }
269}
270
271impl Iterator for WeightedIter<'_> {
272    type Item = usize;
273
274    fn next(&mut self) -> Option<Self::Item> {
275        if self.remaining == 0 {
276            return None;
277        }
278        self.remaining -= 1;
279        Some(self.sampler.sample_index())
280    }
281}
282
283// =============================================================================
284// BatchSampler
285// =============================================================================
286
287/// Wraps a sampler to yield batches of indices.
288pub struct BatchSampler<S: Sampler> {
289    sampler: S,
290    batch_size: usize,
291    drop_last: bool,
292}
293
294impl<S: Sampler> BatchSampler<S> {
295    /// Creates a new `BatchSampler`.
296    pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
297        Self {
298            sampler,
299            batch_size,
300            drop_last,
301        }
302    }
303
304    /// Creates an iterator over batches of indices.
305    pub fn iter(&self) -> BatchIter {
306        let indices: Vec<usize> = self.sampler.iter().collect();
307        BatchIter {
308            indices,
309            batch_size: self.batch_size,
310            drop_last: self.drop_last,
311            position: 0,
312        }
313    }
314
315    /// Returns the number of batches.
316    pub fn len(&self) -> usize {
317        let total = self.sampler.len();
318        if self.drop_last {
319            total / self.batch_size
320        } else {
321            total.div_ceil(self.batch_size)
322        }
323    }
324
325    /// Returns true if empty.
326    pub fn is_empty(&self) -> bool {
327        self.len() == 0
328    }
329}
330
331/// Iterator over batches of indices.
332pub struct BatchIter {
333    indices: Vec<usize>,
334    batch_size: usize,
335    drop_last: bool,
336    position: usize,
337}
338
339impl Iterator for BatchIter {
340    type Item = Vec<usize>;
341
342    fn next(&mut self) -> Option<Self::Item> {
343        if self.position >= self.indices.len() {
344            return None;
345        }
346
347        let end = (self.position + self.batch_size).min(self.indices.len());
348        let batch: Vec<usize> = self.indices[self.position..end].to_vec();
349
350        if batch.len() < self.batch_size && self.drop_last {
351            return None;
352        }
353
354        self.position = end;
355        Some(batch)
356    }
357}
358
359// =============================================================================
360// Tests
361// =============================================================================
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_sequential_sampler() {
369        let sampler = SequentialSampler::new(5);
370        let indices: Vec<usize> = sampler.iter().collect();
371        assert_eq!(indices, vec![0, 1, 2, 3, 4]);
372    }
373
374    #[test]
375    fn test_random_sampler() {
376        let sampler = RandomSampler::new(10);
377        let indices: Vec<usize> = sampler.iter().collect();
378
379        assert_eq!(indices.len(), 10);
380        // All indices should be unique (no replacement)
381        let mut sorted = indices.clone();
382        sorted.sort_unstable();
383        sorted.dedup();
384        assert_eq!(sorted.len(), 10);
385    }
386
387    #[test]
388    fn test_random_sampler_with_replacement() {
389        let sampler = RandomSampler::with_replacement(5, 20);
390        let indices: Vec<usize> = sampler.iter().collect();
391
392        assert_eq!(indices.len(), 20);
393        // All indices should be in valid range
394        assert!(indices.iter().all(|&i| i < 5));
395    }
396
397    #[test]
398    fn test_subset_random_sampler() {
399        let sampler = SubsetRandomSampler::new(vec![0, 5, 10, 15]);
400        let indices: Vec<usize> = sampler.iter().collect();
401
402        assert_eq!(indices.len(), 4);
403        // All returned indices should be from the subset
404        let mut sorted = indices.clone();
405        sorted.sort_unstable();
406        assert_eq!(sorted, vec![0, 5, 10, 15]);
407    }
408
409    #[test]
410    fn test_weighted_random_sampler() {
411        // Heavy weight on index 0
412        let sampler = WeightedRandomSampler::new(vec![100.0, 1.0, 1.0, 1.0], 100, true);
413        let indices: Vec<usize> = sampler.iter().collect();
414
415        assert_eq!(indices.len(), 100);
416        // Most samples should be index 0
417        let zeros = indices.iter().filter(|&&i| i == 0).count();
418        assert!(zeros > 50, "Expected mostly zeros, got {zeros}");
419    }
420
421    #[test]
422    fn test_batch_sampler() {
423        let base = SequentialSampler::new(10);
424        let sampler = BatchSampler::new(base, 3, false);
425
426        let batches: Vec<Vec<usize>> = sampler.iter().collect();
427        assert_eq!(batches.len(), 4); // 10 / 3 = 3 full + 1 partial
428
429        assert_eq!(batches[0], vec![0, 1, 2]);
430        assert_eq!(batches[1], vec![3, 4, 5]);
431        assert_eq!(batches[2], vec![6, 7, 8]);
432        assert_eq!(batches[3], vec![9]); // Partial batch
433    }
434
435    #[test]
436    fn test_batch_sampler_drop_last() {
437        let base = SequentialSampler::new(10);
438        let sampler = BatchSampler::new(base, 3, true);
439
440        let batches: Vec<Vec<usize>> = sampler.iter().collect();
441        assert_eq!(batches.len(), 3); // 10 / 3 = 3, drop the partial
442
443        assert_eq!(batches[0], vec![0, 1, 2]);
444        assert_eq!(batches[1], vec![3, 4, 5]);
445        assert_eq!(batches[2], vec![6, 7, 8]);
446    }
447}