ai_dataloader/indexable/sampler/
sequential_sampler.rs1use std::ops::Range;
4
5use super::{Len, Sampler};
6
7#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Hash, Eq, Ord)]
9pub struct SequentialSampler {
10 pub data_source_len: usize,
12}
13impl Sampler for SequentialSampler {
14 fn new(data_source_len: usize) -> Self {
15 Self { data_source_len }
16 }
17}
18
19impl Len for SequentialSampler {
20 fn len(&self) -> usize {
21 self.data_source_len
22 }
23}
24impl IntoIterator for SequentialSampler {
25 type Item = usize;
26 type IntoIter = Range<usize>;
27 fn into_iter(self) -> Self::IntoIter {
28 0..self.data_source_len
29 }
30}
31
32#[cfg(test)]
33mod tests {
34 use super::*;
35
36 #[test]
37 fn sequential_sampler() {
38 let dataset = [1, 2, 3];
39 let sampler = SequentialSampler {
40 data_source_len: dataset.len(),
41 };
42 let mut iter = sampler.into_iter();
43 assert_eq!(iter.next(), Some(0));
44 assert_eq!(iter.next(), Some(1));
45 assert_eq!(iter.next(), Some(2));
46 assert_eq!(iter.next(), None);
47 }
48}