use std::sync::Mutex;
use burn::data::dataloader::batcher::Batcher;
use burn::data::dataloader::{DataLoaderIterator, Progress};
use burn::prelude::Backend;
use rand::SeedableRng;
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset};
#[derive(Clone)]
pub(crate) struct BatchTensorDataset<B: Backend> {
dataset: Vec<FSRSBatch<B>>,
}
impl<B: Backend> BatchTensorDataset<B> {
pub fn new(dataset: FSRSDataset, batch_size: usize, device: B::Device) -> Self {
let batcher = FSRSBatcher::<B>::new(device.clone());
let dataset = dataset
.items
.chunks(batch_size)
.map(|items| batcher.batch(items.to_vec(), &device))
.collect();
Self { dataset }
}
}
impl<B: Backend> BatchTensorDataset<B> {
fn get(&self, index: usize) -> Option<FSRSBatch<B>> {
self.dataset.get(index).cloned()
}
fn len(&self) -> usize {
self.dataset.len()
}
}
pub struct ShuffleDataLoader<B: Backend> {
dataset: BatchTensorDataset<B>,
rng: Mutex<StdRng>,
}
impl<B: Backend> ShuffleDataLoader<B> {
pub fn new(dataset: BatchTensorDataset<B>, seed: u64) -> Self {
Self {
dataset,
rng: Mutex::new(StdRng::seed_from_u64(seed)),
}
}
}
pub(crate) struct ShuffleDataLoaderIterator<B: Backend> {
current_index: usize,
indices: Vec<usize>,
dataset: BatchTensorDataset<B>,
}
impl<B: Backend> ShuffleDataLoaderIterator<B> {
pub(crate) fn new(dataset: BatchTensorDataset<B>, indices: Vec<usize>) -> Self {
Self {
current_index: 0,
indices,
dataset,
}
}
pub(crate) fn progress(&self) -> Progress {
Progress {
items_processed: self.current_index,
items_total: self.dataset.len(),
}
}
}
impl<B: Backend> Iterator for ShuffleDataLoaderIterator<B> {
type Item = FSRSBatch<B>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(index) = self.indices.get(self.current_index) {
self.current_index += 1;
return self.dataset.get(*index);
}
None
}
}
impl<B: Backend> DataLoaderIterator<FSRSBatch<B>> for ShuffleDataLoaderIterator<B> {
fn progress(&self) -> Progress {
Progress::new(self.current_index, self.dataset.len())
}
}
impl<B: Backend> ShuffleDataLoader<B> {
pub(crate) fn iter(&self) -> ShuffleDataLoaderIterator<B> {
let mut indices: Vec<_> = (0..self.dataset.len()).collect();
indices.shuffle(&mut *self.rng.lock().unwrap());
ShuffleDataLoaderIterator::new(self.dataset.clone(), indices)
}
}
#[cfg(test)]
mod tests {
use burn::{
backend::{NdArray, ndarray::NdArrayDevice},
tensor::Shape,
};
use itertools::Itertools;
use super::*;
use crate::{
convertor_tests::anki21_sample_file_converted_to_fsrs,
dataset::{constant_weighted_fsrs_items, prepare_training_data},
};
#[test]
fn test_simple_dataloader() {
let train_set = anki21_sample_file_converted_to_fsrs()
.into_iter()
.sorted_by_cached_key(|item| item.reviews.len())
.collect();
let (_pre_train_set, train_set) = prepare_training_data(train_set);
let dataset = FSRSDataset::from(constant_weighted_fsrs_items(train_set));
let batch_size = 512;
let seed = 114514;
let device = NdArrayDevice::Cpu;
type Backend = NdArray<f32>;
let dataset = BatchTensorDataset::<Backend>::new(dataset, batch_size, device);
let dataloader = ShuffleDataLoader::new(dataset, seed);
let mut iterator = dataloader.iter();
let batch = iterator.next().unwrap();
assert_eq!(
batch.t_historys.shape(),
Shape {
dims: vec![5, batch_size]
}
);
let batch = iterator.next().unwrap();
assert_eq!(
batch.t_historys.shape(),
Shape {
dims: vec![3, batch_size]
}
);
let lengths = iterator
.map(|batch| batch.t_historys.shape().dims[0])
.collect::<Vec<_>>();
assert_eq!(
lengths,
[
3, 5, 19, 7, 2, 4, 4, 3, 6, 13, 4, 4, 7, 4, 6, 48, 11, 8, 9, 1, 2, 5, 3, 5, 6, 3
]
);
let mut iterator = dataloader.iter();
let batch = iterator.next().unwrap();
assert_eq!(
batch.t_historys.shape(),
Shape {
dims: vec![4, batch_size]
}
);
let batch = iterator.next().unwrap();
assert_eq!(
batch.t_historys.shape(),
Shape {
dims: vec![2, batch_size]
}
);
let lengths = iterator
.map(|batch| batch.t_historys.shape().dims[0])
.collect::<Vec<_>>();
assert_eq!(
lengths,
[
11, 4, 5, 3, 1, 3, 13, 5, 4, 6, 2, 6, 19, 6, 3, 7, 4, 3, 48, 9, 5, 8, 5, 4, 3, 7
]
);
}
}