ai_dataloader/indexable/sampler/
sequential_sampler.rs

1//! Yield index from zero to `data_source_len` in ascending order.
2
3use std::ops::Range;
4
5use super::{Len, Sampler};
6
7/// Yield index from zero to `data_source_len` in ascending order.
8#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Hash, Eq, Ord)]
9pub struct SequentialSampler {
10    /// The length of the dataset that will be sampled.
11    pub data_source_len: usize,
12}
13impl Sampler for SequentialSampler {
14    fn new(data_source_len: usize) -> Self {
15        Self { data_source_len }
16    }
17}
18
19impl Len for SequentialSampler {
20    fn len(&self) -> usize {
21        self.data_source_len
22    }
23}
24impl IntoIterator for SequentialSampler {
25    type Item = usize;
26    type IntoIter = Range<usize>;
27    fn into_iter(self) -> Self::IntoIter {
28        0..self.data_source_len
29    }
30}
31
32#[cfg(test)]
33mod tests {
34    use super::*;
35
36    #[test]
37    fn sequential_sampler() {
38        let dataset = [1, 2, 3];
39        let sampler = SequentialSampler {
40            data_source_len: dataset.len(),
41        };
42        let mut iter = sampler.into_iter();
43        assert_eq!(iter.next(), Some(0));
44        assert_eq!(iter.next(), Some(1));
45        assert_eq!(iter.next(), Some(2));
46        assert_eq!(iter.next(), None);
47    }
48}