#![deny(missing_docs)]
extern crate futures_io;
#[cfg(test)]
extern crate futures;
use core::pin::Pin;
use core::task::Context;
use std::cell::RefCell;
use std::cmp::min;
use std::ptr::copy_nonoverlapping;
use std::rc::Rc;
use futures_io::{AsyncRead, AsyncWrite, Result};
use std::task::{Poll, Poll::Pending, Poll::Ready, Waker};
mod duplex;
pub use duplex::Duplex;
pub fn ring_buffer(capacity: usize) -> (Writer, Reader) {
if capacity == 0 || capacity > (isize::max_value() as usize) {
panic!("Invalid ring buffer capacity.");
}
let mut data: Vec<u8> = Vec::with_capacity(capacity);
let ptr = data.as_mut_slice().as_mut_ptr();
let rb = Rc::new(RefCell::new(RingBuffer {
data,
read: ptr,
amount: 0,
waker: None,
did_shutdown: false,
}));
(Writer(Rc::clone(&rb)), Reader(rb))
}
struct RingBuffer {
data: Vec<u8>,
read: *mut u8,
amount: usize,
waker: Option<Waker>,
did_shutdown: bool,
}
fn offset_from<T>(x: *const T, other: *const T) -> isize
where
T: Sized,
{
let size = std::mem::size_of::<T>();
assert!(size != 0);
let diff = (x as isize).wrapping_sub(other as isize);
diff / size as isize
}
impl RingBuffer {
fn park(&mut self, waker: &Waker) {
self.waker = Some(waker.clone());
}
fn wake(&mut self) {
if let Some(w) = self.waker.take() {
w.wake()
}
}
fn write_ptr(&mut self) -> *mut u8 {
unsafe {
let start = self.data.as_mut_slice().as_mut_ptr();
let diff = offset_from(self.read.add(self.amount), start.add(self.data.capacity()));
if diff < 0 {
self.read.add(self.amount)
} else {
start.offset(diff)
}
}
}
}
pub struct Writer(Rc<RefCell<RingBuffer>>);
impl Writer {
pub fn is_closed(&self) -> bool {
self.0.borrow().did_shutdown
}
}
impl Drop for Writer {
fn drop(&mut self) {
self.0.borrow_mut().wake();
}
}
impl AsyncWrite for Writer {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize>> {
let mut rb = self.0.borrow_mut();
if buf.is_empty() || rb.did_shutdown {
return Ready(Ok(0));
}
let capacity = rb.data.capacity();
let start = rb.data.as_mut_slice().as_mut_ptr();
let end = unsafe { start.add(capacity) };
if rb.amount == capacity {
if Rc::strong_count(&self.0) == 1 {
return Ready(Ok(0));
} else {
rb.park(cx.waker());
return Pending;
}
}
let buf_ptr = buf.as_ptr();
let write_total = min(buf.len(), capacity - rb.amount);
if (unsafe { rb.write_ptr().add(write_total) } as *const u8) < end {
unsafe { copy_nonoverlapping(buf_ptr, rb.write_ptr(), write_total) };
rb.amount += write_total;
} else {
let distance_we = offset_from(end, rb.write_ptr()) as usize;
let remaining: usize = write_total - distance_we;
unsafe { copy_nonoverlapping(buf_ptr, rb.write_ptr(), distance_we) };
unsafe { copy_nonoverlapping(buf_ptr.add(distance_we), start, remaining) };
rb.amount += write_total;
}
debug_assert!(rb.read >= start);
debug_assert!(rb.read < end);
debug_assert!(rb.amount <= capacity);
rb.wake();
Ready(Ok(write_total))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<()>> {
Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<()>> {
let mut rb = self.0.borrow_mut();
if !rb.did_shutdown {
rb.wake(); }
rb.did_shutdown = true;
Ready(Ok(()))
}
}
pub struct Reader(Rc<RefCell<RingBuffer>>);
impl Reader {
pub fn is_closed(&self) -> bool {
self.0.borrow().did_shutdown
}
}
impl Drop for Reader {
fn drop(&mut self) {
self.0.borrow_mut().wake();
}
}
impl AsyncRead for Reader {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<usize>> {
let mut rb = self.0.borrow_mut();
if buf.is_empty() {
return Ready(Ok(0));
}
let capacity = rb.data.capacity();
let start = rb.data.as_mut_slice().as_mut_ptr();
let end = unsafe { start.add(capacity) };
if rb.amount == 0 {
if Rc::strong_count(&self.0) == 1 || rb.did_shutdown {
return Ready(Ok(0));
} else {
rb.park(cx.waker());
return Pending;
}
}
let buf_ptr = buf.as_mut_ptr();
let read_total = min(buf.len(), rb.amount);
if (unsafe { rb.read.add(read_total) } as *const u8) < end {
unsafe { copy_nonoverlapping(rb.read, buf_ptr, read_total) };
rb.read = unsafe { rb.read.add(read_total) };
rb.amount -= read_total;
} else {
let distance_re = offset_from(end, rb.read) as usize;
let remaining: usize = read_total - distance_re;
unsafe { copy_nonoverlapping(rb.read, buf_ptr, distance_re) };
unsafe { copy_nonoverlapping(start, buf_ptr.add(distance_re), remaining) };
rb.read = unsafe { start.add(remaining) };
rb.amount -= read_total;
}
debug_assert!(rb.read >= start);
debug_assert!(rb.read < end);
debug_assert!(rb.amount <= capacity);
rb.wake();
Ready(Ok(read_total))
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::executor::block_on;
use futures::future::join;
use futures::io::{AsyncReadExt, AsyncWriteExt};
#[test]
fn it_works() {
let (mut writer, mut reader) = ring_buffer(8);
let data: Vec<u8> = (0..255).collect();
let write_all = async {
writer.write_all(&data).await.unwrap();
writer.close().await.unwrap();
};
let mut out: Vec<u8> = Vec::with_capacity(256);
let read_all = reader.read_to_end(&mut out);
block_on(async { join(write_all, read_all).await.1.unwrap() });
for (i, byte) in out.iter().enumerate() {
assert_eq!(*byte, i as u8);
}
}
#[test]
#[should_panic]
fn panic_on_capacity_0() {
let _ = ring_buffer(0);
}
#[test]
#[should_panic]
fn panic_on_capacity_too_large() {
let _ = ring_buffer((isize::max_value() as usize) + 1);
}
#[test]
fn close() {
let (mut writer, mut reader) = ring_buffer(8);
block_on(async {
writer.write_all(&[1, 2, 3, 4, 5]).await.unwrap();
assert!(!writer.is_closed());
assert!(!reader.is_closed());
writer.close().await.unwrap();
assert!(writer.is_closed());
assert!(reader.is_closed());
let r = writer.write_all(&[6, 7, 8]).await;
assert!(r.is_err());
let mut buf = [0; 8];
let n = reader.read(&mut buf).await.unwrap();
assert_eq!(n, 5);
let n = reader.read(&mut buf).await.unwrap();
assert_eq!(n, 0);
});
}
}