use crate::util::poll_proceed;
use bytes::Buf;
use bytes::BytesMut;
use futures_core::ready;
use std::io::Error as IoError;
use std::io::ErrorKind as IoErrorKind;
use std::io::IoSlice;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
type IoResult<T> = Result<T, IoError>;
const CLOSED_ERROR_MSG: &str = "simplex has been closed";
#[derive(Debug)]
struct Inner {
backpressure_boundary: usize,
is_closed: bool,
receiver_waker: Option<Waker>,
sender_waker: Option<Waker>,
buf: BytesMut,
}
impl Inner {
fn with_capacity(capacity: usize) -> Self {
Self {
backpressure_boundary: capacity,
is_closed: false,
receiver_waker: None,
sender_waker: None,
buf: BytesMut::with_capacity(capacity),
}
}
fn register_receiver_waker(&mut self, waker: &Waker) -> Option<Waker> {
match self.receiver_waker.as_mut() {
Some(old) if old.will_wake(waker) => None,
_ => self.receiver_waker.replace(waker.clone()),
}
}
fn register_sender_waker(&mut self, waker: &Waker) -> Option<Waker> {
match self.sender_waker.as_mut() {
Some(old) if old.will_wake(waker) => None,
_ => self.sender_waker.replace(waker.clone()),
}
}
fn take_receiver_waker(&mut self) -> Option<Waker> {
self.receiver_waker.take()
}
fn take_sender_waker(&mut self) -> Option<Waker> {
self.sender_waker.take()
}
fn is_closed(&self) -> bool {
self.is_closed
}
fn close_receiver(&mut self) -> Option<Waker> {
self.is_closed = true;
self.take_sender_waker()
}
fn close_sender(&mut self) -> Option<Waker> {
self.is_closed = true;
self.take_receiver_waker()
}
}
#[derive(Debug)]
pub struct Receiver {
inner: Arc<Mutex<Inner>>,
}
impl Drop for Receiver {
fn drop(&mut self) {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_receiver()
};
if let Some(waker) = maybe_waker {
waker.wake();
}
}
}
impl AsyncRead for Receiver {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
let coop = ready!(poll_proceed(cx));
let mut inner = self.inner.lock().unwrap();
let to_read = buf.remaining().min(inner.buf.remaining());
if to_read == 0 {
if inner.is_closed() || buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
let old_waker = inner.register_receiver_waker(cx.waker());
let maybe_waker = inner.take_sender_waker();
drop(inner);
drop(old_waker);
if let Some(waker) = maybe_waker {
waker.wake();
}
return Poll::Pending;
}
coop.made_progress();
buf.put_slice(&inner.buf[..to_read]);
inner.buf.advance(to_read);
let waker = inner.take_sender_waker();
drop(inner); if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
pub struct Sender {
inner: Arc<Mutex<Inner>>,
}
impl Drop for Sender {
fn drop(&mut self) {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_sender()
};
if let Some(waker) = maybe_waker {
waker.wake();
}
}
}
impl AsyncWrite for Sender {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
let coop = ready!(poll_proceed(cx));
let mut inner = self.inner.lock().unwrap();
if inner.is_closed() {
return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
}
let free = inner
.backpressure_boundary
.checked_sub(inner.buf.len())
.expect("backpressure boundary overflow");
let to_write = buf.len().min(free);
if to_write == 0 {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
let old_waker = inner.register_sender_waker(cx.waker());
let waker = inner.take_receiver_waker();
drop(inner);
drop(old_waker);
if let Some(waker) = waker {
waker.wake();
}
return Poll::Pending;
}
coop.made_progress();
inner.buf.extend_from_slice(&buf[..to_write]);
let waker = inner.take_receiver_waker();
drop(inner); if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(Ok(to_write))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let inner = self.inner.lock().unwrap();
if inner.is_closed() {
Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)))
} else {
Poll::Ready(Ok(()))
}
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_sender()
};
if let Some(waker) = maybe_waker {
waker.wake();
}
Poll::Ready(Ok(()))
}
fn is_write_vectored(&self) -> bool {
true
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, IoError>> {
let coop = ready!(poll_proceed(cx));
let mut inner = self.inner.lock().unwrap();
if inner.is_closed() {
return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
}
let free = inner
.backpressure_boundary
.checked_sub(inner.buf.len())
.expect("backpressure boundary overflow");
if free == 0 {
let old_waker = inner.register_sender_waker(cx.waker());
let maybe_waker = inner.take_receiver_waker();
drop(inner);
drop(old_waker);
if let Some(waker) = maybe_waker {
waker.wake();
}
return Poll::Pending;
}
coop.made_progress();
let mut rem = free;
for buf in bufs {
if rem == 0 {
break;
}
let to_write = buf.len().min(rem);
if to_write == 0 {
assert_ne!(rem, 0);
assert_eq!(buf.len(), 0);
continue;
}
inner.buf.extend_from_slice(&buf[..to_write]);
rem -= to_write;
}
let waker = inner.take_receiver_waker();
drop(inner); if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(Ok(free - rem))
}
}
pub fn new(capacity: usize) -> (Sender, Receiver) {
assert_ne!(capacity, 0, "capacity must be greater than zero");
let inner = Arc::new(Mutex::new(Inner::with_capacity(capacity)));
let tx = Sender {
inner: Arc::clone(&inner),
};
let rx = Receiver { inner };
(tx, rx)
}