use std::sync::mpsc;
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use crate::tensor::{Device, Result, Tensor};
use super::BatchDataSet;
pub(crate) struct PrefetchedBatch {
pub tensors: Vec<Tensor>,
#[cfg(feature = "cuda")]
pub ready_event: Option<crate::distributed::cuda_event::CudaEvent>,
}
pub(crate) enum WorkerCmd {
StartEpoch {
indices: Vec<usize>,
batch_size: usize,
drop_last: bool,
batch_tx: mpsc::SyncSender<Result<PrefetchedBatch>>,
},
StartDistributedEpoch {
batch_tx: mpsc::SyncSender<Result<PrefetchedBatch>>,
},
LoadBatch {
indices: Vec<usize>,
},
Stop,
}
pub(crate) struct PrefetchWorker {
cmd_tx: mpsc::Sender<WorkerCmd>,
handle: Option<JoinHandle<()>>,
prefetch_depth: usize,
}
impl PrefetchWorker {
pub fn new(
dataset: Arc<dyn BatchDataSet>,
device: Device,
prefetch_depth: usize,
) -> Self {
let (cmd_tx, cmd_rx) = mpsc::channel::<WorkerCmd>();
let handle = thread::spawn(move || {
worker_loop(dataset, device, cmd_rx);
});
PrefetchWorker {
cmd_tx,
handle: Some(handle),
prefetch_depth,
}
}
pub fn start_epoch(
&self,
indices: Vec<usize>,
batch_size: usize,
drop_last: bool,
) -> mpsc::Receiver<Result<PrefetchedBatch>> {
let (batch_tx, batch_rx) =
mpsc::sync_channel::<Result<PrefetchedBatch>>(self.prefetch_depth);
let _ = self.cmd_tx.send(WorkerCmd::StartEpoch {
indices,
batch_size,
drop_last,
batch_tx,
});
batch_rx
}
pub fn start_distributed_epoch(&self) -> mpsc::Receiver<Result<PrefetchedBatch>> {
let (batch_tx, batch_rx) =
mpsc::sync_channel::<Result<PrefetchedBatch>>(self.prefetch_depth);
let _ = self.cmd_tx.send(WorkerCmd::StartDistributedEpoch { batch_tx });
batch_rx
}
pub fn load_batch(&self, indices: Vec<usize>) {
let _ = self.cmd_tx.send(WorkerCmd::LoadBatch { indices });
}
pub fn prefetch_depth(&self) -> usize {
self.prefetch_depth
}
pub fn set_prefetch_depth(&mut self, depth: usize) {
self.prefetch_depth = depth;
}
}
impl Drop for PrefetchWorker {
fn drop(&mut self) {
let _ = self.cmd_tx.send(WorkerCmd::Stop);
if let Some(h) = self.handle.take() {
let _ = h.join();
}
}
}
fn worker_loop(
dataset: Arc<dyn BatchDataSet>,
device: Device,
cmd_rx: mpsc::Receiver<WorkerCmd>,
) {
#[cfg(feature = "cuda")]
let copy_stream = if device.is_cuda() {
crate::distributed::cuda_stream::CudaStream::new(device, false).ok()
} else {
None
};
let mut dist_tx: Option<mpsc::SyncSender<Result<PrefetchedBatch>>> = None;
for cmd in &cmd_rx {
match cmd {
WorkerCmd::StartEpoch {
indices,
batch_size,
drop_last,
batch_tx,
} => {
dist_tx = None;
let n = indices.len();
let mut start = 0;
while start < n {
let end = (start + batch_size).min(n);
if drop_last && (end - start) < batch_size {
break;
}
let batch_indices = &indices[start..end];
start = end;
let result = fetch_and_transfer(
&*dataset,
batch_indices,
device,
#[cfg(feature = "cuda")]
copy_stream.as_ref(),
);
if batch_tx.send(result).is_err() {
break;
}
}
}
WorkerCmd::StartDistributedEpoch { batch_tx } => {
dist_tx = Some(batch_tx);
}
WorkerCmd::LoadBatch { indices } => {
if let Some(ref tx) = dist_tx {
let result = fetch_and_transfer(
&*dataset,
&indices,
device,
#[cfg(feature = "cuda")]
copy_stream.as_ref(),
);
if tx.send(result).is_err() {
dist_tx = None; }
}
}
WorkerCmd::Stop => break,
}
}
}
fn fetch_and_transfer(
dataset: &dyn BatchDataSet,
indices: &[usize],
device: Device,
#[cfg(feature = "cuda")] copy_stream: Option<&crate::distributed::cuda_stream::CudaStream>,
) -> Result<PrefetchedBatch> {
let tensors = dataset.get_batch(indices)?;
if !device.is_cuda() {
return Ok(PrefetchedBatch {
tensors,
#[cfg(feature = "cuda")]
ready_event: None,
});
}
#[cfg(feature = "cuda")]
{
use crate::distributed::cuda_event::{CudaEvent, CudaEventFlags};
use crate::distributed::cuda_stream::StreamGuard;
let mut on_device = Vec::with_capacity(tensors.len());
if let Some(stream) = copy_stream {
let _guard = StreamGuard::new(stream);
for t in &tensors {
let pinned = t.pin_memory()?;
on_device.push(pinned.to_device_async(device)?);
}
let event = CudaEvent::new(CudaEventFlags::DisableTiming)?;
event.record_on(stream)?;
return Ok(PrefetchedBatch {
tensors: on_device,
ready_event: Some(event),
});
}
for t in &tensors {
let pinned = t.pin_memory()?;
on_device.push(pinned.to_device(device)?);
}
Ok(PrefetchedBatch {
tensors: on_device,
ready_event: None,
})
}
#[cfg(not(feature = "cuda"))]
{
Ok(PrefetchedBatch { tensors })
}
}