use std::io::{self, BufRead, Read, Write};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard};
use std::thread::{self, JoinHandle};
use ringbuf::traits::{Consumer, Observer, Producer, Split};
use ringbuf::{HeapCons, HeapProd, HeapRb};
fn lock_or_recover<T>(m: &Mutex<T>) -> MutexGuard<'_, T> {
m.lock().unwrap_or_else(|e| e.into_inner())
}
pub(crate) struct ThreadedReader {
consumer: HeapCons<u8>,
io_thread: thread::Thread,
eof: Arc<AtomicBool>,
stop: Arc<AtomicBool>,
error: Arc<Mutex<Option<io::Error>>>,
join: Option<JoinHandle<()>>,
}
impl ThreadedReader {
pub(crate) fn new<R: Read + Send + 'static>(src: R, ring_bytes: usize) -> Self {
let rb = HeapRb::<u8>::new(ring_bytes.max(64 * 1024));
let (producer, consumer) = rb.split();
let eof = Arc::new(AtomicBool::new(false));
let stop = Arc::new(AtomicBool::new(false));
let error = Arc::new(Mutex::new(None));
let eof_io = eof.clone();
let stop_io = stop.clone();
let error_io = error.clone();
let consumer_thread = thread::current();
let join = thread::Builder::new()
.name("methylsieve-io-read".into())
.spawn(move || io_read_loop(src, producer, eof_io, stop_io, error_io, consumer_thread))
.expect("spawning IO read thread");
let io_thread = join.thread().clone();
Self { consumer, io_thread, eof, stop, error, join: Some(join) }
}
fn take_error(&self) -> Option<io::Error> {
lock_or_recover(&self.error).take()
}
}
impl Read for ThreadedReader {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
let src = self.fill_buf()?;
let n = src.len().min(dst.len());
dst[..n].copy_from_slice(&src[..n]);
self.consume(n);
Ok(n)
}
}
impl BufRead for ThreadedReader {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
loop {
if let Some(e) = self.take_error() {
return Err(e);
}
if self.consumer.occupied_len() > 0 {
let (first, _second) = self.consumer.as_slices();
return Ok(first);
}
if self.eof.load(Ordering::Acquire) {
if self.consumer.occupied_len() > 0 {
continue;
}
return Ok(&[]);
}
thread::park();
}
}
fn consume(&mut self, amt: usize) {
self.consumer.skip(amt);
self.io_thread.unpark();
}
}
impl Drop for ThreadedReader {
fn drop(&mut self) {
self.stop.store(true, Ordering::Release);
self.io_thread.unpark();
if let Some(h) = self.join.take() {
let _ = h.join();
}
}
}
fn io_read_loop<R: Read>(
mut src: R,
mut producer: HeapProd<u8>,
eof: Arc<AtomicBool>,
stop: Arc<AtomicBool>,
error: Arc<Mutex<Option<io::Error>>>,
consumer_thread: thread::Thread,
) {
loop {
if stop.load(Ordering::Acquire) {
break;
}
let (first, _second) = producer.vacant_slices_mut();
if first.is_empty() {
consumer_thread.unpark();
thread::park();
continue;
}
let dst: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(first.as_mut_ptr() as *mut u8, first.len()) };
match src.read(dst) {
Ok(0) => {
eof.store(true, Ordering::Release);
consumer_thread.unpark();
break;
}
Ok(n) => {
unsafe {
producer.advance_write_index(n);
}
consumer_thread.unpark();
}
Err(e) => {
*lock_or_recover(&error) = Some(e);
eof.store(true, Ordering::Release);
consumer_thread.unpark();
break;
}
}
}
consumer_thread.unpark();
}
pub(crate) struct ThreadedWriter {
producer: HeapProd<u8>,
io_thread: thread::Thread,
finished: Arc<AtomicBool>,
error: Arc<Mutex<Option<io::Error>>>,
join: Option<JoinHandle<()>>,
}
impl ThreadedWriter {
pub(crate) fn new<W: Write + Send + 'static>(dst: W, ring_bytes: usize) -> Self {
let rb = HeapRb::<u8>::new(ring_bytes.max(64 * 1024));
let (producer, consumer) = rb.split();
let finished = Arc::new(AtomicBool::new(false));
let error = Arc::new(Mutex::new(None));
let finished_io = finished.clone();
let error_io = error.clone();
let producer_thread = thread::current();
let join = thread::Builder::new()
.name("methylsieve-io-write".into())
.spawn(move || io_write_loop(dst, consumer, finished_io, error_io, producer_thread))
.expect("spawning IO write thread");
let io_thread = join.thread().clone();
Self { producer, io_thread, finished, error, join: Some(join) }
}
fn take_error(&self) -> Option<io::Error> {
lock_or_recover(&self.error).take()
}
pub(crate) fn finish(mut self) -> io::Result<()> {
self.finished.store(true, Ordering::Release);
self.io_thread.unpark();
if let Some(h) = self.join.take() {
let _ = h.join();
}
if let Some(e) = self.take_error() {
return Err(e);
}
Ok(())
}
}
impl Write for ThreadedWriter {
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
if let Some(e) = self.take_error() {
return Err(e);
}
let initial_len = buf.len();
while !buf.is_empty() {
let pushed = self.producer.push_slice(buf);
if pushed > 0 {
buf = &buf[pushed..];
self.io_thread.unpark();
} else {
thread::park();
}
}
Ok(initial_len)
}
fn flush(&mut self) -> io::Result<()> {
if let Some(e) = self.take_error() {
return Err(e);
}
Ok(())
}
}
impl Drop for ThreadedWriter {
fn drop(&mut self) {
self.finished.store(true, Ordering::Release);
self.io_thread.unpark();
if let Some(h) = self.join.take() {
let _ = h.join();
}
}
}
fn io_write_loop<W: Write>(
mut dst: W,
mut consumer: HeapCons<u8>,
finished: Arc<AtomicBool>,
error: Arc<Mutex<Option<io::Error>>>,
producer_thread: thread::Thread,
) {
loop {
if consumer.occupied_len() > 0 {
let (first, _second) = consumer.as_slices();
let n = first.len();
if let Err(e) = dst.write_all(first) {
*lock_or_recover(&error) = Some(e);
producer_thread.unpark();
break;
}
consumer.skip(n);
producer_thread.unpark();
continue;
}
if finished.load(Ordering::Acquire) {
if consumer.occupied_len() > 0 {
continue;
}
if let Err(e) = dst.flush() {
*lock_or_recover(&error) = Some(e);
}
break;
}
thread::park();
}
producer_thread.unpark();
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn threaded_reader_round_trip_small() {
let payload: Vec<u8> = (0..1000u32).flat_map(|i| i.to_le_bytes()).collect();
let mut r = ThreadedReader::new(Cursor::new(payload.clone()), 64 * 1024);
let mut out = Vec::new();
std::io::copy(&mut r, &mut out).unwrap();
assert_eq!(out, payload);
}
#[test]
fn threaded_reader_round_trip_larger_than_ring() {
let ring = 4096;
let payload: Vec<u8> = (0..(ring * 8) as u32).map(|i| i as u8).collect();
let mut r = ThreadedReader::new(Cursor::new(payload.clone()), ring);
let mut out = Vec::new();
std::io::copy(&mut r, &mut out).unwrap();
assert_eq!(out, payload);
}
#[test]
fn threaded_writer_round_trip_with_finish() {
struct Sink(std::sync::Arc<std::sync::Mutex<Vec<u8>>>);
impl Write for Sink {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
let payload: Vec<u8> = (0..50_000u32).map(|i| i as u8).collect();
let captured = std::sync::Arc::new(std::sync::Mutex::new(Vec::<u8>::new()));
let mut w = ThreadedWriter::new(Sink(captured.clone()), 4096);
w.write_all(&payload).unwrap();
w.finish().unwrap();
assert_eq!(*captured.lock().unwrap(), payload);
}
}