#![allow(unsafe_op_in_unsafe_fn)]
use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use std::task::{Context, Poll, Waker};
use atomic_waker::AtomicWaker;
use pin_project_lite::pin_project;
pub type Sender<T> = SenderExt<T, ()>;
pub type Receiver<T> = ReceiverExt<T, ()>;
pub fn connector<T>() -> (Sender<T>, Receiver<T>) {
let connector = Arc::new(Connector::new(()));
(
Sender {
connector: connector.clone(),
},
Receiver { connector },
)
}
pub fn connector_with<T, S>(shared: S) -> (SenderExt<T, S>, ReceiverExt<T, S>) {
let connector = Arc::new(Connector::new(shared));
(
SenderExt {
connector: connector.clone(),
},
ReceiverExt { connector },
)
}
const FULL_BIT: u8 = 0b1;
const CLOSED_BIT: u8 = 0b10;
const WAITING_BIT: u8 = 0b100;
#[repr(align(128))]
struct Connector<T, S> {
send_waker: AtomicWaker,
recv_waker: AtomicWaker,
value: UnsafeCell<MaybeUninit<T>>,
state: AtomicU8,
shared: S,
}
impl<T, S> Connector<T, S> {
fn new(shared: S) -> Self {
Self {
send_waker: AtomicWaker::new(),
recv_waker: AtomicWaker::new(),
value: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(0),
shared,
}
}
}
pub enum SendError<T> {
Full(T),
Closed(T),
}
pub enum RecvError {
Empty,
Closed,
}
impl<T, S> Connector<T, S> {
unsafe fn poll_send(&self, value: &mut Option<T>, waker: &Waker) -> Poll<Result<(), T>> {
if let Some(v) = value.take() {
let mut state = self.state.load(Ordering::Acquire);
if state & FULL_BIT == FULL_BIT {
self.send_waker.register(waker);
let (Ok(s) | Err(s)) = self.state.compare_exchange(
state,
state | WAITING_BIT,
Ordering::Relaxed,
Ordering::Acquire, );
state = s;
}
match self.try_send_impl(v, state) {
Ok(()) => {},
Err(SendError::Closed(v)) => return Poll::Ready(Err(v)),
Err(SendError::Full(v)) => {
*value = Some(v);
return Poll::Pending;
},
}
}
Poll::Ready(Ok(()))
}
unsafe fn try_send_impl(&self, value: T, state: u8) -> Result<(), SendError<T>> {
if state & CLOSED_BIT == CLOSED_BIT {
return Err(SendError::Closed(value));
}
if state & FULL_BIT == FULL_BIT {
return Err(SendError::Full(value));
}
unsafe {
self.value.get().write(MaybeUninit::new(value));
let state = self.state.swap(FULL_BIT, Ordering::Release);
if state & WAITING_BIT == WAITING_BIT {
self.recv_waker.wake();
}
if state & CLOSED_BIT == CLOSED_BIT {
self.state.store(CLOSED_BIT, Ordering::Relaxed);
return Err(SendError::Closed(self.value.get().read().assume_init()));
}
}
Ok(())
}
unsafe fn poll_recv(&self, waker: &Waker) -> Poll<Result<T, ()>> {
let mut state = self.state.load(Ordering::Acquire);
if state & FULL_BIT == 0 {
self.recv_waker.register(waker);
let (Ok(s) | Err(s)) = self.state.compare_exchange(
state,
state | WAITING_BIT,
Ordering::Relaxed,
Ordering::Acquire, );
state = s;
}
match self.try_recv_impl(state) {
Ok(v) => Poll::Ready(Ok(v)),
Err(RecvError::Empty) => Poll::Pending,
Err(RecvError::Closed) => Poll::Ready(Err(())),
}
}
unsafe fn try_recv_impl(&self, state: u8) -> Result<T, RecvError> {
if state & FULL_BIT == FULL_BIT {
unsafe {
let ret = self.value.get().read().assume_init();
let state = self.state.swap(0, Ordering::Release);
if state & WAITING_BIT == WAITING_BIT {
self.send_waker.wake();
}
if state & CLOSED_BIT == CLOSED_BIT {
self.state.store(CLOSED_BIT, Ordering::Relaxed);
}
return Ok(ret);
}
}
if state & CLOSED_BIT == CLOSED_BIT {
return Err(RecvError::Closed);
}
Err(RecvError::Empty)
}
unsafe fn try_send(&self, value: T) -> Result<(), SendError<T>> {
self.try_send_impl(value, self.state.load(Ordering::Acquire))
}
unsafe fn try_recv(&self) -> Result<T, RecvError> {
self.try_recv_impl(self.state.load(Ordering::Acquire))
}
unsafe fn close_send(&self) {
self.state.fetch_or(CLOSED_BIT, Ordering::Relaxed);
self.recv_waker.wake();
}
unsafe fn close_recv(&self) {
let state = self.state.fetch_or(CLOSED_BIT, Ordering::Acquire);
drop(self.try_recv_impl(state));
self.send_waker.wake();
}
}
pub struct SenderExt<T, S> {
connector: Arc<Connector<T, S>>,
}
unsafe impl<T: Send, S: Sync> Send for SenderExt<T, S> {}
impl<T, S> Drop for SenderExt<T, S> {
fn drop(&mut self) {
unsafe { self.connector.close_send() }
}
}
pub struct ReceiverExt<T, S> {
connector: Arc<Connector<T, S>>,
}
unsafe impl<T: Send, S: Sync> Send for ReceiverExt<T, S> {}
impl<T, S> Drop for ReceiverExt<T, S> {
fn drop(&mut self) {
unsafe { self.connector.close_recv() }
}
}
pin_project! {
pub struct SendFuture<'a, T, S> {
connector: &'a Connector<T, S>,
value: Option<T>,
}
}
unsafe impl<T: Send, S: Sync> Send for SendFuture<'_, T, S> {}
impl<T: Send, S: Sync> SenderExt<T, S> {
#[must_use]
pub fn send(&mut self, value: T) -> SendFuture<'_, T, S> {
SendFuture {
connector: &self.connector,
value: Some(value),
}
}
#[allow(unused)]
pub fn try_send(&mut self, value: T) -> Result<(), SendError<T>> {
unsafe { self.connector.try_send(value) }
}
pub fn shared(&self) -> &S {
&self.connector.shared
}
}
impl<T, S> std::future::Future for SendFuture<'_, T, S> {
type Output = Result<(), T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert!(
self.value.is_some(),
"re-poll after Poll::Ready in connector SendFuture"
);
unsafe { self.connector.poll_send(self.project().value, cx.waker()) }
}
}
pin_project! {
pub struct RecvFuture<'a, T, S> {
connector: &'a Connector<T, S>,
done: bool,
}
}
unsafe impl<T: Send, S: Sync> Send for RecvFuture<'_, T, S> {}
impl<T: Send, S: Sync> ReceiverExt<T, S> {
#[must_use]
pub fn recv(&mut self) -> RecvFuture<'_, T, S> {
RecvFuture {
connector: &self.connector,
done: false,
}
}
#[allow(unused)]
pub fn try_recv(&mut self) -> Result<T, RecvError> {
unsafe { self.connector.try_recv() }
}
pub fn shared(&self) -> &S {
&self.connector.shared
}
}
impl<T, S> std::future::Future for RecvFuture<'_, T, S> {
type Output = Result<T, ()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert!(
!self.done,
"re-poll after Poll::Ready in connector SendFuture"
);
unsafe { self.connector.poll_recv(cx.waker()) }
}
}