use std::sync::mpsc;
use std::sync::Arc;
use std::thread;
use crate::data::dataset::{DataItem, Dataset};
use crate::tensor;
pub struct DataLoader {
dataset: Arc<dyn Dataset>,
batch_size: usize,
shuffle: bool,
drop_last: bool,
num_workers: usize,
prefetch_factor: usize,
}
impl DataLoader {
pub fn new(
dataset: Arc<dyn Dataset>,
batch_size: usize,
shuffle: bool,
drop_last: bool,
num_workers: usize,
prefetch_factor: usize,
) -> Self {
assert!(batch_size > 0, "DataLoader: batch_size must be > 0");
assert!(prefetch_factor > 0, "DataLoader: prefetch_factor must be > 0");
Self {
dataset,
batch_size,
shuffle,
drop_last,
num_workers,
prefetch_factor,
}
}
pub fn iter(&self) -> DataLoaderIter {
let n = self.dataset.len();
let mut indices: Vec<usize> = (0..n).collect();
if self.shuffle {
fisher_yates_shuffle(&mut indices);
}
let mut batch_groups: Vec<Vec<usize>> = indices
.chunks(self.batch_size)
.map(|c| c.to_vec())
.collect();
if self.drop_last {
if let Some(last) = batch_groups.last() {
if last.len() < self.batch_size {
batch_groups.pop();
}
}
}
if self.num_workers == 0 {
DataLoaderIter::SingleThread {
dataset: Arc::clone(&self.dataset),
batch_groups,
cursor: 0,
}
} else {
let (index_tx, index_rx) = mpsc::sync_channel::<Vec<usize>>(self.num_workers);
let channel_cap = self.prefetch_factor * self.num_workers;
let (batch_tx, batch_rx) = mpsc::sync_channel::<DataItem>(channel_cap);
let index_rx = Arc::new(parking_lot::Mutex::new(index_rx));
let mut handles = Vec::with_capacity(self.num_workers);
for _ in 0..self.num_workers {
let ds = Arc::clone(&self.dataset);
let irx = Arc::clone(&index_rx);
let btx = batch_tx.clone();
let handle = thread::spawn(move || {
loop {
let batch_indices = {
let rx = irx.lock();
match rx.recv() {
Ok(indices) => indices,
Err(_) => break, }
};
let items: Vec<DataItem> = batch_indices
.iter()
.map(|&i| ds.get(i))
.collect();
let batch = collate(items);
if btx.send(batch).is_err() {
break;
}
}
});
handles.push(handle);
}
drop(batch_tx);
let feeder = thread::spawn(move || {
for group in batch_groups {
if index_tx.send(group).is_err() {
break; }
}
});
DataLoaderIter::MultiThread {
batch_rx: Some(batch_rx),
feeder: Some(feeder),
workers: handles,
}
}
}
}
pub enum DataLoaderIter {
SingleThread {
dataset: Arc<dyn Dataset>,
batch_groups: Vec<Vec<usize>>,
cursor: usize,
},
MultiThread {
batch_rx: Option<mpsc::Receiver<DataItem>>,
feeder: Option<thread::JoinHandle<()>>,
workers: Vec<thread::JoinHandle<()>>,
},
}
impl Iterator for DataLoaderIter {
type Item = DataItem;
fn next(&mut self) -> Option<DataItem> {
match self {
DataLoaderIter::SingleThread {
dataset,
batch_groups,
cursor,
} => {
if *cursor >= batch_groups.len() {
return None;
}
let indices = &batch_groups[*cursor];
*cursor += 1;
let items: Vec<DataItem> = indices.iter().map(|&i| dataset.get(i)).collect();
Some(collate(items))
}
DataLoaderIter::MultiThread { batch_rx, .. } => {
batch_rx.as_ref().and_then(|rx| rx.recv().ok())
}
}
}
}
impl Drop for DataLoaderIter {
fn drop(&mut self) {
if let DataLoaderIter::MultiThread {
batch_rx,
feeder,
workers,
} = self
{
drop(batch_rx.take());
for handle in workers.drain(..) {
let _ = handle.join();
}
if let Some(f) = feeder.take() {
let _ = f.join();
}
}
}
}
fn collate(items: Vec<DataItem>) -> DataItem {
let inputs: Vec<_> = items.iter().map(|item| item.input.clone()).collect();
let targets: Vec<_> = items.iter().map(|item| item.target.clone()).collect();
DataItem {
input: tensor::stack(&inputs),
target: tensor::stack(&targets),
}
}
use std::cell::Cell;
thread_local! {
static SHUFFLE_RNG: Cell<u64> = Cell::new(0xDEADBEEFCAFE1234);
}
fn lcg_next() -> u64 {
SHUFFLE_RNG.with(|state| {
let s = state
.get()
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
state.set(s);
s
})
}
fn fisher_yates_shuffle(indices: &mut [usize]) {
let n = indices.len();
for i in (1..n).rev() {
let j = (lcg_next() as usize) % (i + 1);
indices.swap(i, j);
}
}