unsync-pipe 0.2.0

Ringbuffer-backed !Send !Sync binary safe repr(C) AsyncWrite/AsyncRead pair
Documentation
//! Single-threaded binary safe AsyncWrite/AsyncRead pair
//!
//! The main entry point is [pipe]. [Writer] and [Reader] can just be sent
//! across binary boundaries. A change to the ABI constitutes a major version
//! break.

use std::alloc::{Layout, alloc, dealloc};
use std::pin::Pin;
use std::process::abort;
use std::ptr::{null, null_mut, slice_from_raw_parts};
use std::task::{Context, Poll, Waker};
use std::{io, mem};

use futures::{AsyncRead, AsyncWrite};

fn pipe_layout(bs: usize) -> Layout { Layout::from_size_align(bs, 1).expect("1-align is trivial") }

/// Create a ringbuffer with the specified byte capacity. Once the buffer is
/// exhausted, the writer will block.
pub fn pipe(size: usize) -> (Writer, Reader) {
	assert!(0 < size, "cannot create async pipe without buffer");
	// SAFETY: the
	let start = unsafe { alloc(pipe_layout(size)) };
	extern "C" fn drop(val: *const ()) {
		let AsyncRingbuffer {
			start,
			size,
			mut read_waker,
			mut write_waker,
			reader_dropped,
			writer_dropped,
			// irrelevant if correctly dropped
			read_idx: _,
			write_idx: _,
			// data used to make this call
			drop: _,
			state: _,
		} = *unsafe { Box::from_raw(val as *mut AsyncRingbuffer) };
		if !writer_dropped || !reader_dropped {
			eprintln!("Pipe dropped in err before reader or writer");
			abort()
		}
		read_waker.drop();
		write_waker.drop();
		unsafe { dealloc(start, pipe_layout(size)) }
	}
	let state = Box::into_raw(Box::new(AsyncRingbuffer {
		start,
		size,
		state: null(),
		read_idx: 0,
		write_idx: 0,
		read_waker: Trigger::empty(),
		write_waker: Trigger::empty(),
		reader_dropped: false,
		writer_dropped: false,
		drop,
	}));
	let state_mut = unsafe { state.as_mut().unwrap() };
	state_mut.state = state as *const ();
	(Writer(state_mut as *mut _), Reader(state_mut as *mut _))
}

/// A single-fire empty event, to be distributed by value. Either one of the
/// functions can be called exactly once.
#[repr(C)]
struct Trigger {
	state: *const (),
	invoke: extern "C" fn(*const ()),
	drop: extern "C" fn(*const ()),
}
impl Trigger {
	fn new(waker: Waker) -> Self {
		let state = Box::into_raw(Box::new(waker)) as *const ();
		extern "C" fn drop(state: *const ()) {
			unsafe { mem::drop(Box::from_raw(state as *mut Waker)) };
		}
		extern "C" fn invoke(state: *const ()) { unsafe { Box::from_raw(state as *mut Waker) }.wake(); }
		Self { state, invoke, drop }
	}
	fn empty() -> Self {
		extern "C" fn empty_fn_ptr(_: *const ()) { abort() }
		Self { state: null(), drop: empty_fn_ptr, invoke: empty_fn_ptr }
	}
	fn is_empty(&self) -> bool { self.state.is_null() }
	fn invoke(&mut self) {
		if let Some(this) = self.take() {
			(this.invoke)(this.state)
		}
	}
	fn drop(&mut self) {
		if let Some(this) = self.take() {
			(this.drop)(this.state)
		}
	}
	fn take(&mut self) -> Option<Self> {
		(!self.is_empty()).then(|| std::mem::replace(self, Self::empty()))
	}
}

/// A ringbuffer for single-threaded synchronized communication.
#[repr(C)]
struct AsyncRingbuffer {
	state: *const (),
	start: *mut u8,
	size: usize,
	read_idx: usize,
	write_idx: usize,
	read_waker: Trigger,
	write_waker: Trigger,
	reader_dropped: bool,
	writer_dropped: bool,
	drop: extern "C" fn(*const ()),
}
impl AsyncRingbuffer {
	fn drop_writer(&mut self) {
		self.writer_dropped = true;
		if self.reader_dropped {
			(self.drop)(self.state)
		}
	}
	fn drop_reader(&mut self) {
		self.reader_dropped = true;
		if self.writer_dropped {
			(self.drop)(self.state)
		}
	}
	fn writer_wait<T>(&mut self, waker: &Waker) -> Poll<io::Result<T>> {
		if self.reader_dropped {
			return Poll::Ready(Err(broken_pipe_error()));
		}
		self.read_waker.invoke();
		self.write_waker.drop();
		self.write_waker = Trigger::new(waker.clone());
		Poll::Pending
	}
	fn reader_wait(&mut self, waker: &Waker) -> Poll<io::Result<usize>> {
		if self.writer_dropped {
			return Poll::Ready(Err(broken_pipe_error()));
		}
		self.write_waker.invoke();
		self.read_waker.drop();
		self.read_waker = Trigger::new(waker.clone());
		Poll::Pending
	}
	unsafe fn non_wrapping_write_unchecked(&mut self, buf: &[u8]) {
		let write_ptr = unsafe { self.start.add(self.write_idx) };
		let slc = slice_from_raw_parts(write_ptr, buf.len()).cast_mut();
		unsafe { &mut *slc }.copy_from_slice(buf);
		self.write_idx = (self.write_idx + buf.len()) % self.size;
	}
	unsafe fn non_wrapping_read_unchecked(&mut self, buf: &mut [u8]) {
		let read_ptr = unsafe { self.start.add(self.read_idx) };
		let slc = slice_from_raw_parts(read_ptr, buf.len()).cast_mut();
		buf.copy_from_slice(unsafe { &*slc });
		self.read_idx = (self.read_idx + buf.len()) % self.size;
	}
	fn is_full(&self) -> bool { (self.write_idx + 1) % self.size == self.read_idx }
	fn is_empty(&self) -> bool { self.write_idx == self.read_idx }
}

fn already_closed_error() -> io::Error {
	io::Error::new(io::ErrorKind::BrokenPipe, "Pipe already closed from this end")
}
fn broken_pipe_error() -> io::Error {
	io::Error::new(io::ErrorKind::BrokenPipe, "Pipe already closed from other end")
}

/// A binary safe [AsyncWrite] implementor writing to a ringbuffer created by
/// [pipe].
#[repr(C)]
pub struct Writer(*mut AsyncRingbuffer);
impl Writer {
	unsafe fn get_state(self: Pin<&mut Self>) -> io::Result<&mut AsyncRingbuffer> {
		match unsafe { self.0.as_mut() } {
			Some(data) => Ok(data),
			None => Err(already_closed_error()),
		}
	}
}
impl AsyncWrite for Writer {
	fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
		unsafe {
			match self.as_mut().get_state() {
				Err(e) => return Poll::Ready(Err(e)),
				Ok(data) => {
					data.drop_writer();
				},
			}
		}
		self.0 = null_mut();
		Poll::Ready(Ok(()))
	}
	fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
		unsafe {
			let data = self.as_mut().get_state()?;
			if data.is_empty() { Poll::Ready(Ok(())) } else { data.writer_wait(cx.waker()) }
		}
	}
	fn poll_write(
		mut self: Pin<&mut Self>,
		cx: &mut Context<'_>,
		buf: &[u8],
	) -> Poll<io::Result<usize>> {
		unsafe {
			let data = self.as_mut().get_state()?;
			let AsyncRingbuffer { write_idx, read_idx, size, .. } = *data;
			if !buf.is_empty() && data.is_empty() {
				data.read_waker.invoke();
			}
			if !buf.is_empty() && data.is_full() {
				// Writer is blocked
				data.writer_wait(cx.waker())
			} else if write_idx < read_idx {
				// Non-wrapping backside write w < r <= s
				let count = buf.len().min(read_idx - write_idx - 1);
				data.non_wrapping_write_unchecked(&buf[0..count]);
				Poll::Ready(Ok(count))
			} else if data.write_idx + buf.len() < size {
				// Non-wrapping frontside write r <= w + b < s
				data.non_wrapping_write_unchecked(&buf[0..buf.len()]);
				Poll::Ready(Ok(buf.len()))
			} else if read_idx == 0 {
				// Frontside write up to origin r=0 < s < w + b
				data.non_wrapping_write_unchecked(&buf[0..size - write_idx - 1]);
				Poll::Ready(Ok(size - write_idx - 1))
			} else {
				let (end, start) = buf.split_at(size - write_idx);
				// Wrapping write r < s < w + b
				data.non_wrapping_write_unchecked(end);
				let start_count = start.len().min(read_idx - 1);
				data.non_wrapping_write_unchecked(&start[0..start_count]);
				Poll::Ready(Ok(end.len() + start_count))
			}
		}
	}
}
impl Drop for Writer {
	fn drop(&mut self) {
		unsafe {
			if let Some(data) = self.0.as_mut() {
				data.drop_writer();
			}
		}
	}
}

/// A binary safe [AsyncRead] implementor reading from a ringbuffer created by
/// [pipe]
#[repr(C)]
pub struct Reader(*mut AsyncRingbuffer);
impl AsyncRead for Reader {
	fn poll_read(
		self: Pin<&mut Self>,
		cx: &mut Context<'_>,
		buf: &mut [u8],
	) -> Poll<io::Result<usize>> {
		unsafe {
			let data = self.0.as_mut().expect("Cannot be null");
			let AsyncRingbuffer { read_idx, write_idx, size, .. } = *data;
			if !buf.is_empty() && data.is_full() {
				data.write_waker.invoke();
			}
			if !buf.is_empty() && data.is_empty() {
				// Nothing to read, waiting...
				data.reader_wait(cx.waker())
			} else if read_idx < write_idx {
				// Frontside non-wrapping read
				let count = buf.len().min(write_idx - read_idx);
				data.non_wrapping_read_unchecked(&mut buf[0..count]);
				Poll::Ready(Ok(count))
			} else if read_idx + buf.len() < size {
				// Backside non-wrapping read
				data.non_wrapping_read_unchecked(buf);
				Poll::Ready(Ok(buf.len()))
			} else {
				// Wrapping read
				let (end, start) = buf.split_at_mut(size - read_idx);
				data.non_wrapping_read_unchecked(end);
				let start_count = start.len().min(write_idx);
				data.non_wrapping_read_unchecked(&mut start[0..start_count]);
				Poll::Ready(Ok(end.len() + start_count))
			}
		}
	}
}
impl Drop for Reader {
	fn drop(&mut self) {
		unsafe {
			if let Some(data) = self.0.as_mut() {
				data.drop_reader();
			}
		}
	}
}

#[cfg(test)]
mod tests {
	use std::pin::pin;

	use futures::future::join;
	use futures::{AsyncReadExt, AsyncWriteExt};
	use itertools::Itertools;
	use rand::{Rng, SeedableRng};
	use rand_chacha::ChaCha8Rng;
	use test_executors::spin_on;

	use super::*;

	#[test]
	fn basic_io() {
		let mut w_rng = ChaCha8Rng::seed_from_u64(2);
		let mut r_rng = ChaCha8Rng::seed_from_u64(1);
		spin_on(async {
			let (w, r) = pipe(1024);
			let test_length = 10_000_000;
			let data = (0u32..test_length).flat_map(|num| num.to_be_bytes());
			let write_fut = async {
				let mut w = pin!(w);
				let mut source = data.clone();
				let mut tally = 0;
				while tally < test_length * 4 {
					let values = source.by_ref().take(w_rng.random_range(0..200)).collect::<Vec<_>>();
					tally += values.len() as u32;
					w.write_all(&values).await.unwrap();
				}
				w.flush().await.unwrap();
			};
			let read_fut = async {
				let mut r = pin!(r);
				let mut expected = data.clone();
				let mut tally = 0;
				while tally < test_length * 4 {
					let expected_values =
						expected.by_ref().take(r_rng.random_range(0..200)).collect::<Vec<_>>();
					tally += expected_values.len() as u32;
					let mut values = vec![0; expected_values.len()];
					r.read_exact(&mut values[..]).await.unwrap_or_else(|e| panic!("At {tally} bytes: {e}"));
					if values != expected_values {
						fn print_bytes(bytes: &[u8]) -> String {
							(bytes.iter().map(|s| format!("{s:>2x}")).chunks(32).into_iter())
								.map(|c| c.into_iter().join(" "))
								.join("\n")
						}
						panic!(
							"Difference in generated numbers\n{}\n{}",
							print_bytes(&values),
							print_bytes(&expected_values),
						)
					}
				}
			};
			join(write_fut, read_fut).await;
		})
	}
}