use rand::{prelude::SliceRandom, thread_rng};
use super::Dataloader;
use crate::pipeline::Node;
struct CreateRange {
nums_to_make: Vec<usize>,
current_progress: usize,
}
impl CreateRange {
pub fn new(max: usize) -> Self {
CreateRange {
nums_to_make: (0..max).collect(),
current_progress: 0
}
}
}
impl Node for CreateRange {
type Input = ();
type Output = usize;
fn process(&mut self, input: Vec<Self::Input>) -> Vec<Self::Output> {
let data = self.nums_to_make[self.current_progress..self.current_progress + input.len()].to_vec();
self.current_progress += input.len();
data
}
fn reset(&mut self) {
self.nums_to_make.shuffle(&mut thread_rng());
self.current_progress = 0;
}
fn data_remaining(&self) -> usize {
self.nums_to_make.len() - self.current_progress
}
}
#[test]
fn test_dataloader() {
let pipeline = CreateRange::new(10_000)
.add_fn(|i| i.into_iter().map(|i| i * 10).collect());
let mut loader = Dataloader::new(pipeline, 10);
assert_eq!(loader.len(), 10_000);
let mut data = Vec::with_capacity(10_000);
for example in &mut loader {
data.extend(example.into_iter());
if data.len() == 5_000 {break;}
}
assert_eq!(data.len(), 5_000);
assert_eq!(loader.len(), 5_000);
for example in &mut loader {
data.extend(example.into_iter());
}
assert_eq!(loader.len(), 10_000);
data.sort_unstable();
assert_eq!(data, (0..10_000).into_iter().map(|i| i * 10).collect::<Vec<usize>>())
}