#![deny(missing_docs, rust_2018_idioms)]
use std::io::{self, Read};
use std::sync::{Arc, Mutex};
use std::sync::mpsc;
use std::thread;
#[derive(Debug)]
pub enum Error<E> {
Read(io::Error),
Process {
chunk_offset: u64,
error: E,
},
}
impl<E: std::fmt::Display> std::fmt::Display for Error<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Read(e) => write!(f, "error while reading: {}", e),
Error::Process { chunk_offset, error } => write!(f,
"error while processing data at chunk offset {}: {}", chunk_offset, error),
}
}
}
impl<E: std::error::Error + 'static> std::error::Error for Error<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(match self {
Error::Read(ref e) => e,
Error::Process { chunk_offset: _, ref error } => error,
})
}
}
fn start_worker_threads<E: Send + 'static>(
num_threads: usize,
work_rx: mpsc::Receiver<(u64, Vec<u8>)>,
job_tx: mpsc::Sender<(u64, E)>,
f: Arc<impl Fn(u64, &[u8]) -> Result<(), E> + Send + Sync + 'static>,
) -> Vec<thread::JoinHandle<()>> {
let mut threads = vec![];
let work_rx = Arc::new(Mutex::new(work_rx));
for _ in 0 .. num_threads {
let thread_work_rx = work_rx.clone();
let thread_job_tx = job_tx.clone();
let f = f.clone();
threads.push(thread::spawn(move || {
loop {
let (offset, data) = {
let rx = thread_work_rx.lock().unwrap();
match rx.recv() {
Ok(result) => result,
Err(_) => {
return;
}
}
};
if let Err(e) = f(offset, &data) {
thread_job_tx.send((offset, e)).unwrap();
}
}
}));
}
threads
}
pub fn read_stream_and_process_chunks_in_parallel<E: Send + 'static>(
mut reader: impl Read,
chunk_size: usize,
num_threads: usize,
f: Arc<impl Fn(u64, &[u8]) -> Result<(), E> + Send + Sync + 'static>,
) -> Result<(), Error<E>> {
assert!(num_threads > 0, "non-zero number of threads required");
let (work_tx, work_rx) = mpsc::sync_channel::<(u64, Vec<u8>)>(num_threads);
let (job_tx, job_rx) = mpsc::channel::<(u64, E)>();
let threads = start_worker_threads(num_threads, work_rx, job_tx, f);
let mut offset = 0u64;
let loop_result = loop {
match job_rx.try_recv() {
Ok((chunk_offset, error)) => break Err(Error::Process { chunk_offset, error }),
Err(mpsc::TryRecvError::Empty) => (),
Err(mpsc::TryRecvError::Disconnected) => unreachable!("we hold the sender open"),
}
let mut buf = vec![0u8; chunk_size];
match large_read(&mut reader, &mut buf) {
Ok(0) => {
break Ok(());
}
Ok(n) => {
buf.truncate(n);
work_tx.send((offset, buf)).expect("failed to send work to threads");
offset += n as u64;
}
Err(e) => {
break Err(Error::Read(e));
}
}
};
drop(work_tx);
for thread in threads {
thread.join().expect("failed to join on worker thread");
}
if let Err(e) = loop_result {
return Err(e);
}
match job_rx.recv() {
Ok((chunk_offset, error)) => {
Err(Error::Process { chunk_offset, error })
}
Err(mpsc::RecvError) => {
Ok(())
}
}
}
fn large_read(mut source: impl Read, buf: &mut [u8]) -> io::Result<usize> {
let mut total = 0;
loop {
match source.read(&mut buf[total ..]) {
Ok(0) => break,
Ok(n) => {
total += n;
if total == buf.len() {
break;
}
}
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue, Err(e) => return Err(e),
}
}
Ok(total)
}