use std::{
io::{BufRead, Cursor, Error, Read, Take},
sync::mpsc,
thread::JoinHandle,
};
pub fn pair<R: Read + Send + 'static>(mut reader: R) -> (InterruptReader<R>, Interruptor) {
let (event_tx, event_rx) = mpsc::channel();
let (buffer_tx, buffer_rx) = mpsc::channel();
let join_handle = std::thread::spawn({
let event_tx = event_tx.clone();
move || {
let mut buf = vec![0; 8 * 1024];
loop {
match reader.read(&mut buf) {
Ok(num_bytes) => {
let event = Event::Buf(std::mem::take(&mut buf), num_bytes);
if event_tx.send(event).is_err() {
break reader;
}
buf = match buffer_rx.recv() {
Ok(buf) => buf,
Err(_) => break reader,
}
}
Err(err) => {
if event_tx.send(Event::Err(err)).is_err() {
break reader;
}
}
}
}
}
});
let interrupt_reader = InterruptReader {
cursor: None,
buffer_tx,
event_rx,
join_handle,
};
let interruptor = Interruptor(event_tx);
(interrupt_reader, interruptor)
}
#[derive(Debug)]
pub struct InterruptReader<R> {
cursor: Option<Take<Cursor<Vec<u8>>>>,
buffer_tx: mpsc::Sender<Vec<u8>>,
event_rx: mpsc::Receiver<Event>,
join_handle: JoinHandle<R>,
}
impl<R: Read> InterruptReader<R> {
pub fn into_inner(self) -> std::thread::Result<R> {
let Self { buffer_tx, event_rx, join_handle, .. } = self;
drop((event_rx, buffer_tx));
join_handle.join()
}
}
impl<R: Read> Read for InterruptReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if let Some(cursor) = self.cursor.as_mut() {
deal_with_interrupt(&self.event_rx)?;
match cursor.read(buf) {
Ok(0) => {
let buffer = self.cursor.take().unwrap().into_inner().into_inner();
match self.buffer_tx.send(buffer) {
Ok(()) => self.read(buf),
Err(_) => Ok(0),
}
}
Ok(num_bytes) => Ok(num_bytes),
Err(_) => unreachable!("Afaik, this shouldn't happen if T is Vec<u8>"),
}
} else {
match self.event_rx.recv() {
Ok(Event::Buf(buffer, len)) => {
self.cursor = Some(Cursor::new(buffer).take(len as u64));
if len == 0 { Ok(0) } else { self.read(buf) }
}
Ok(Event::Err(err)) => Err(err),
Ok(Event::Interrupt) => Err(interrupt_error()),
Err(_) => Ok(0),
}
}
}
}
impl<R: Read> BufRead for InterruptReader<R> {
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
if let Some(cursor) = self.cursor.as_mut() {
deal_with_interrupt(&self.event_rx)?;
let (addr, len) = {
let buf = cursor.fill_buf()?;
((buf as *const [u8]).addr(), buf.len())
};
if len == 0 {
let buffer = self.cursor.take().unwrap().into_inner().into_inner();
match self.buffer_tx.send(buffer) {
Ok(()) => self.fill_buf(),
Err(_) => Ok(&[]),
}
} else {
let buffer = self.cursor.as_ref().unwrap().get_ref().get_ref();
let buf_addr = (buffer.as_slice() as *const [u8]).addr();
Ok(&buffer[addr - buf_addr..(addr - buf_addr) + len])
}
} else {
match self.event_rx.recv() {
Ok(Event::Buf(buffer, len)) => {
self.cursor = Some(Cursor::new(buffer).take(len as u64));
if len == 0 { Ok(&[]) } else { self.fill_buf() }
}
Ok(Event::Err(err)) => Err(err),
Ok(Event::Interrupt) => Err(interrupt_error()),
Err(_) => Ok(&[]),
}
}
}
fn consume(&mut self, amount: usize) {
if let Some(cursor) = self.cursor.as_mut() {
cursor.consume(amount);
}
}
}
#[derive(Debug, Clone)]
pub struct Interruptor(mpsc::Sender<Event>);
impl Interruptor {
pub fn interrupt(&self) -> Result<(), InterruptSendError> {
self.0
.send(Event::Interrupt)
.map_err(|_| InterruptSendError)
}
}
#[derive(Debug, Clone, Copy)]
pub struct InterruptSendError;
impl std::fmt::Display for InterruptSendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("InterruptReader has been dropped")
}
}
impl std::error::Error for InterruptSendError {}
#[derive(Debug, Clone, Copy)]
pub struct InterruptReceived;
impl std::fmt::Display for InterruptReceived {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Interruptor has interrupted")
}
}
impl std::error::Error for InterruptReceived {}
#[derive(Debug)]
enum Event {
Buf(Vec<u8>, usize),
Err(std::io::Error),
Interrupt,
}
pub fn is_interrupt(err: &Error) -> bool {
err.get_ref()
.is_some_and(|err| err.is::<InterruptReceived>())
}
fn interrupt_error() -> Error {
Error::other(InterruptReceived)
}
fn deal_with_interrupt(event_rx: &mpsc::Receiver<Event>) -> std::io::Result<()> {
match event_rx.try_recv() {
Ok(Event::Interrupt) => Err(interrupt_error()),
Ok(_) => unreachable!("This should not be possible"),
Err(_) => Ok(()),
}
}