use std::cell::UnsafeCell;
use std::fmt;
use std::future::Future;
use std::future::IntoFuture;
use std::hint;
use std::mem;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::ptr;
use std::ptr::NonNull;
use std::sync::atomic::AtomicU8;
use std::sync::atomic::Ordering;
use std::sync::atomic::fence;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
#[cfg(test)]
mod tests;
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let channel_ptr = NonNull::from(Box::leak(Box::new(Channel::new())));
(Sender { channel_ptr }, Receiver { channel_ptr })
}
pub struct Sender<T> {
channel_ptr: NonNull<Channel<T>>,
}
impl<T> fmt::Debug for Sender<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sender").finish_non_exhaustive()
}
}
unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Sync> Sync for Sender<T> {}
#[inline(always)]
fn sender_wake_up_receiver<T>(channel: &Channel<T>, state: u8) {
fence(Ordering::Acquire);
let waker = unsafe { channel.take_waker() };
channel.state.swap(state, Ordering::AcqRel);
waker.wake();
}
impl<T> Sender<T> {
pub fn send(self, message: T) -> Result<(), SendError<T>> {
let channel_ptr = self.channel_ptr;
mem::forget(self);
let channel = unsafe { channel_ptr.as_ref() };
unsafe { channel.write_message(message) };
match channel.state.fetch_add(1, Ordering::Release) {
EMPTY => Ok(()),
RECEIVING => {
sender_wake_up_receiver(channel, MESSAGE);
Ok(())
}
DISCONNECTED => Err(SendError { channel_ptr }),
state => unreachable!("unexpected channel state: {}", state),
}
}
pub fn is_closed(&self) -> bool {
let channel = unsafe { self.channel_ptr.as_ref() };
matches!(channel.state.load(Ordering::Relaxed), DISCONNECTED)
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let channel = unsafe { self.channel_ptr.as_ref() };
match channel.state.fetch_xor(0b001, Ordering::Relaxed) {
EMPTY => {}
RECEIVING => sender_wake_up_receiver(channel, DISCONNECTED),
DISCONNECTED => {
unsafe { dealloc(self.channel_ptr) };
}
state => unreachable!("unexpected channel state: {}", state),
}
}
}
pub struct Receiver<T> {
channel_ptr: NonNull<Channel<T>>,
}
impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Receiver").finish_non_exhaustive()
}
}
unsafe impl<T: Send> Send for Receiver<T> {}
impl<T> IntoFuture for Receiver<T> {
type Output = Result<T, RecvError>;
type IntoFuture = Recv<T>;
fn into_future(self) -> Self::IntoFuture {
let Receiver { channel_ptr } = self;
mem::forget(self);
Recv { channel_ptr }
}
}
impl<T> Receiver<T> {
pub fn is_closed(&self) -> bool {
let channel = unsafe { self.channel_ptr.as_ref() };
matches!(channel.state.load(Ordering::Relaxed), DISCONNECTED)
}
pub fn has_message(&self) -> bool {
let channel = unsafe { self.channel_ptr.as_ref() };
matches!(channel.state.load(Ordering::Acquire), MESSAGE)
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
let channel = unsafe { self.channel_ptr.as_ref() };
match channel.state.load(Ordering::Acquire) {
EMPTY => Err(TryRecvError::Empty),
DISCONNECTED => Err(TryRecvError::Disconnected),
MESSAGE => {
channel.state.store(DISCONNECTED, Ordering::Relaxed);
Ok(unsafe { channel.take_message() })
}
state => unreachable!("unexpected channel state: {}", state),
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let channel = unsafe { self.channel_ptr.as_ref() };
match channel.state.swap(DISCONNECTED, Ordering::Acquire) {
EMPTY => {}
MESSAGE => {
unsafe { channel.drop_message() };
unsafe { dealloc(self.channel_ptr) };
}
DISCONNECTED => {
unsafe { dealloc(self.channel_ptr) };
}
state => unreachable!("unexpected channel state: {}", state),
}
}
}
pub struct Recv<T> {
channel_ptr: NonNull<Channel<T>>,
}
impl<T> fmt::Debug for Recv<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Recv").finish_non_exhaustive()
}
}
unsafe impl<T: Send> Send for Recv<T> {}
fn recv_awaken<T>(channel: &Channel<T>) -> Poll<Result<T, RecvError>> {
loop {
hint::spin_loop();
match channel.state.load(Ordering::Relaxed) {
AWAKING => {}
DISCONNECTED => break Poll::Ready(Err(RecvError::Disconnected)),
MESSAGE => {
channel.state.store(DISCONNECTED, Ordering::Relaxed);
break Poll::Ready(Ok(unsafe { channel.take_message() }));
}
state => unreachable!("unexpected channel state: {}", state),
}
}
}
impl<T> Future for Recv<T> {
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let channel = unsafe { self.channel_ptr.as_ref() };
match channel.state.load(Ordering::Acquire) {
EMPTY => {
let waker = cx.waker().clone();
unsafe { channel.write_waker(waker) }
}
MESSAGE => {
channel.state.store(DISCONNECTED, Ordering::Relaxed);
Poll::Ready(Ok(unsafe { channel.take_message() }))
}
RECEIVING => {
match channel.state.compare_exchange(
RECEIVING,
EMPTY,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
let waker = cx.waker().clone();
unsafe { channel.drop_waker() };
unsafe { channel.write_waker(waker) }
}
Err(MESSAGE) => {
channel.state.swap(DISCONNECTED, Ordering::Acquire);
Poll::Ready(Ok(unsafe { channel.take_message() }))
}
Err(AWAKING) => recv_awaken(channel),
Err(DISCONNECTED) => Poll::Ready(Err(RecvError::Disconnected)),
Err(state) => unreachable!("unexpected channel state: {}", state),
}
}
AWAKING => recv_awaken(channel),
DISCONNECTED => Poll::Ready(Err(RecvError::Disconnected)),
state => unreachable!("unexpected channel state: {}", state),
}
}
}
impl<T> Drop for Recv<T> {
fn drop(&mut self) {
let channel = unsafe { self.channel_ptr.as_ref() };
match channel.state.swap(DISCONNECTED, Ordering::Acquire) {
EMPTY => {}
MESSAGE => {
unsafe { channel.drop_message() };
unsafe { dealloc(self.channel_ptr) };
}
RECEIVING => {
unsafe { channel.drop_waker() };
}
DISCONNECTED => {
unsafe { dealloc(self.channel_ptr) };
}
AWAKING => {
loop {
hint::spin_loop();
match channel.state.load(Ordering::Relaxed) {
AWAKING => {}
DISCONNECTED => break,
MESSAGE => {
unsafe { channel.drop_message() };
break;
}
state => unreachable!("unexpected channel state: {}", state),
}
}
unsafe { dealloc(self.channel_ptr) };
}
state => unreachable!("unexpected channel state: {}", state),
}
}
}
struct Channel<T> {
state: AtomicU8,
message: UnsafeCell<MaybeUninit<T>>,
waker: UnsafeCell<MaybeUninit<Waker>>,
}
impl<T> Channel<T> {
const fn new() -> Self {
Self {
state: AtomicU8::new(EMPTY),
message: UnsafeCell::new(MaybeUninit::uninit()),
waker: UnsafeCell::new(MaybeUninit::uninit()),
}
}
#[inline(always)]
unsafe fn message(&self) -> &MaybeUninit<T> {
unsafe { &*self.message.get() }
}
#[inline(always)]
unsafe fn write_message(&self, message: T) {
unsafe {
let slot = &mut *self.message.get();
slot.as_mut_ptr().write(message);
}
}
#[inline(always)]
unsafe fn drop_message(&self) {
unsafe {
let slot = &mut *self.message.get();
slot.assume_init_drop();
}
}
#[inline(always)]
unsafe fn take_message(&self) -> T {
unsafe { ptr::read(self.message.get()).assume_init() }
}
unsafe fn write_waker(&self, waker: Waker) -> Poll<Result<T, RecvError>> {
unsafe {
let slot = &mut *self.waker.get();
slot.as_mut_ptr().write(waker);
}
match self
.state
.compare_exchange(EMPTY, RECEIVING, Ordering::Release, Ordering::Relaxed)
{
Ok(_) => Poll::Pending,
Err(MESSAGE) => {
fence(Ordering::Acquire);
unsafe { self.drop_waker() };
self.state.store(DISCONNECTED, Ordering::Relaxed);
Poll::Ready(Ok(unsafe { self.take_message() }))
}
Err(DISCONNECTED) => {
unsafe { self.drop_waker() };
Poll::Ready(Err(RecvError::Disconnected))
}
Err(state) => unreachable!("unexpected channel state: {}", state),
}
}
#[inline(always)]
unsafe fn drop_waker(&self) {
unsafe {
let slot = &mut *self.waker.get();
slot.assume_init_drop();
}
}
#[inline(always)]
unsafe fn take_waker(&self) -> Waker {
unsafe { ptr::read(self.waker.get()).assume_init() }
}
}
unsafe fn dealloc<T>(channel: NonNull<Channel<T>>) {
unsafe { drop(Box::from_raw(channel.as_ptr())) }
}
pub struct SendError<T> {
channel_ptr: NonNull<Channel<T>>,
}
unsafe impl<T: Send> Send for SendError<T> {}
unsafe impl<T: Sync> Sync for SendError<T> {}
impl<T> SendError<T> {
pub fn as_inner(&self) -> &T {
unsafe { self.channel_ptr.as_ref().message().assume_init_ref() }
}
pub fn into_inner(self) -> T {
let channel_ptr = self.channel_ptr;
mem::forget(self);
let channel: &Channel<T> = unsafe { channel_ptr.as_ref() };
let message = unsafe { channel.take_message() };
unsafe { dealloc(channel_ptr) };
message
}
}
impl<T> Drop for SendError<T> {
fn drop(&mut self) {
unsafe {
self.channel_ptr.as_ref().drop_message();
dealloc(self.channel_ptr);
}
}
}
impl<T> fmt::Display for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
"sending on a closed channel".fmt(f)
}
}
impl<T> fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SendError<{}>(..)", stringify!(T))
}
}
impl<T> std::error::Error for SendError<T> {}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum TryRecvError {
Empty,
Disconnected,
}
impl fmt::Display for TryRecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryRecvError::Empty => write!(f, "receiving on an empty channel"),
TryRecvError::Disconnected => write!(f, "receiving on a closed channel"),
}
}
}
impl std::error::Error for TryRecvError {}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum RecvError {
Disconnected,
}
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "receiving on a closed channel")
}
}
impl std::error::Error for RecvError {}
const EMPTY: u8 = 0b011;
const MESSAGE: u8 = 0b100;
const RECEIVING: u8 = 0b000;
const AWAKING: u8 = 0b001;
const DISCONNECTED: u8 = 0b010;