ai-dataloader 0.6.2

Rust implementation to the PyTorch DataLoader
Documentation
//! Yield index from zero to `data_source_len` in ascending order.

use std::ops::Range;

use super::{Len, Sampler};

/// Yield index from zero to `data_source_len` in ascending order.
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Hash, Eq, Ord)]
pub struct SequentialSampler {
    /// The length of the dataset that will be sampled.
    pub data_source_len: usize,
}
impl Sampler for SequentialSampler {
    fn new(data_source_len: usize) -> Self {
        Self { data_source_len }
    }
}

impl Len for SequentialSampler {
    fn len(&self) -> usize {
        self.data_source_len
    }
}
impl IntoIterator for SequentialSampler {
    type Item = usize;
    type IntoIter = Range<usize>;
    fn into_iter(self) -> Self::IntoIter {
        0..self.data_source_len
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn sequential_sampler() {
        let dataset = [1, 2, 3];
        let sampler = SequentialSampler {
            data_source_len: dataset.len(),
        };
        let mut iter = sampler.into_iter();
        assert_eq!(iter.next(), Some(0));
        assert_eq!(iter.next(), Some(1));
        assert_eq!(iter.next(), Some(2));
        assert_eq!(iter.next(), None);
    }
}