#![warn(missing_docs, missing_debug_implementations, unreachable_pub)]
mod loom_exports;
use std::error::Error;
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::mem::{ManuallyDrop, MaybeUninit};
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::pin::Pin;
use std::ptr::{self, NonNull};
use std::sync::atomic::Ordering;
use std::task::{Context, Poll, Waker};
use crate::loom_exports::cell::UnsafeCell;
use crate::loom_exports::sync::atomic::AtomicUsize;
const EMPTY: usize = 0b001;
const OPEN: usize = 0b010;
const INDEX: usize = 0b100;
struct Inner<T> {
state: AtomicUsize,
value: UnsafeCell<MaybeUninit<T>>,
waker: [UnsafeCell<Option<Waker>>; 2],
}
impl<T> Inner<T> {
unsafe fn write_value(&self, t: T) {
self.value.with_mut(|value| (*value).write(t));
}
unsafe fn read_value(&self) -> T {
self.value.with(|value| (*value).as_ptr().read())
}
unsafe fn drop_value_in_place(&self) {
self.value
.with_mut(|value| ptr::drop_in_place((*value).as_mut_ptr()));
}
unsafe fn set_waker(&self, idx: usize, new: Option<Waker>) {
self.waker[idx].with_mut(|waker| (*waker) = new);
}
unsafe fn take_waker(&self, idx: usize) -> Option<Waker> {
self.waker[idx].with_mut(|waker| (*waker).take())
}
}
#[derive(Debug)]
pub struct Receiver<T> {
inner: NonNull<Inner<T>>,
_phantom: PhantomData<Inner<T>>,
}
impl<T> Receiver<T> {
pub fn new() -> Self {
Self {
inner: NonNull::new(Box::into_raw(Box::new(Inner {
state: AtomicUsize::new(EMPTY),
value: UnsafeCell::new(MaybeUninit::uninit()),
waker: [UnsafeCell::new(None), UnsafeCell::new(None)],
})))
.unwrap(),
_phantom: PhantomData,
}
}
pub fn sender(&mut self) -> Option<Sender<T>> {
let state = unsafe { self.inner.as_ref().state.load(Ordering::Acquire) };
if state & OPEN == 0 {
Some(unsafe { self.sender_with_waker(state, None) })
} else {
None
}
}
pub fn recv(&mut self) -> Recv<'_, T> {
Recv { receiver: self }
}
unsafe fn sender_with_waker(&mut self, state: usize, waker: Option<Waker>) -> Sender<T> {
debug_assert!(state & OPEN == 0);
if state & EMPTY == 0 {
self.inner.as_ref().drop_value_in_place();
}
self.inner.as_ref().set_waker(0, waker);
self.inner
.as_ref()
.state
.store(OPEN | EMPTY, Ordering::Relaxed);
Sender {
inner: self.inner,
_phantom: PhantomData,
}
}
}
unsafe impl<T: Send> Send for Receiver<T> {}
unsafe impl<T: Send> Sync for Receiver<T> {}
impl<T> UnwindSafe for Receiver<T> {}
impl<T> RefUnwindSafe for Receiver<T> {}
impl<T> Default for Receiver<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let state = unsafe { self.inner.as_ref().state.swap(0, Ordering::AcqRel) };
if state & OPEN == OPEN {
return;
}
unsafe {
if state & EMPTY == 0 {
self.inner.as_ref().drop_value_in_place();
}
drop(Box::from_raw(self.inner.as_ptr()));
}
}
}
#[derive(Debug)]
pub struct Recv<'a, T> {
receiver: &'a mut Receiver<T>,
}
impl<'a, T> Recv<'a, T> {
fn poll_complete(self: Pin<&mut Self>, state: usize) -> Poll<Result<T, RecvError>> {
debug_assert!(state & OPEN == 0);
let ret = if state & EMPTY == 0 {
let value = unsafe { self.receiver.inner.as_ref().read_value() };
Ok(value)
} else {
Err(RecvError {})
};
unsafe {
self.receiver
.inner
.as_ref()
.state
.store(EMPTY, Ordering::Relaxed);
}
Poll::Ready(ret)
}
}
impl<'a, T> Future for Recv<'a, T> {
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = unsafe { self.receiver.inner.as_ref().state.load(Ordering::Acquire) };
if state & OPEN == 0 {
return self.poll_complete(state);
}
if state & EMPTY == 0 {
unsafe {
state = self
.receiver
.inner
.as_ref()
.state
.fetch_or(EMPTY, Ordering::Acquire);
}
if state & OPEN == 0 {
return self.poll_complete(state);
}
}
let current_idx = state_to_index(state);
let new_idx = 1 - current_idx;
unsafe {
self.receiver
.inner
.as_ref()
.set_waker(new_idx, Some(cx.waker().clone()));
}
let state = unsafe {
self.receiver
.inner
.as_ref()
.state
.swap(index_to_state(current_idx) | OPEN, Ordering::AcqRel)
};
if state & OPEN == 0 {
return self.poll_complete(state);
}
Poll::Pending
}
}
#[derive(Debug)]
pub struct Sender<T> {
inner: NonNull<Inner<T>>,
_phantom: PhantomData<Inner<T>>,
}
impl<T> Sender<T> {
pub fn send(self, t: T) {
let this = ManuallyDrop::new(self);
unsafe { this.inner.as_ref().write_value(t) };
let mut idx = state_to_index(unsafe { this.inner.as_ref().state.load(Ordering::Relaxed) });
loop {
let waker = unsafe { this.inner.as_ref().take_waker(idx) };
let state = unsafe {
this.inner
.as_ref()
.state
.fetch_sub(OPEN | EMPTY, Ordering::AcqRel)
};
unsafe {
if state & OPEN == 0 {
this.inner.as_ref().drop_value_in_place();
drop(Box::from_raw(this.inner.as_ptr()));
return;
}
}
if state & EMPTY == EMPTY {
if let Some(waker) = waker {
waker.wake()
}
return;
}
idx = 1 - idx;
}
}
}
unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Send> Sync for Sender<T> {}
impl<T> UnwindSafe for Sender<T> {}
impl<T> RefUnwindSafe for Sender<T> {}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let mut state = unsafe { self.inner.as_ref().state.load(Ordering::Relaxed) };
let mut idx = state_to_index(state);
loop {
let waker = unsafe { self.inner.as_ref().take_waker(idx) };
loop {
let new_state = if state & EMPTY == EMPTY {
EMPTY
} else {
state ^ (EMPTY | INDEX)
};
unsafe {
match self.inner.as_ref().state.compare_exchange_weak(
state,
new_state,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(s) => {
state = s;
break;
}
Err(s) => state = s,
}
}
}
unsafe {
if state & OPEN == 0 {
drop(Box::from_raw(self.inner.as_ptr()));
return;
}
}
if state & EMPTY == EMPTY {
if let Some(waker) = waker {
waker.wake()
}
return;
}
idx = 1 - idx;
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct RecvError {}
impl fmt::Display for RecvError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "channel closed")
}
}
impl Error for RecvError {}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let mut receiver = Receiver::new();
let sender = receiver.sender().unwrap();
(sender, receiver)
}
fn state_to_index(state: usize) -> usize {
(state & INDEX) >> 2
}
fn index_to_state(index: usize) -> usize {
index << 2
}
#[cfg(all(test, not(multishot_loom)))]
mod tests {
use super::*;
use std::sync::Arc;
use std::task::Wake;
use std::thread;
struct TestWaker {
count: AtomicUsize,
}
impl TestWaker {
fn new() -> Self {
Self {
count: AtomicUsize::new(0),
}
}
fn take_count(&self) -> usize {
self.count.swap(0, Ordering::Acquire)
}
}
impl Wake for TestWaker {
fn wake(self: Arc<Self>) {
self.count.fetch_add(1, Ordering::Release);
}
}
fn multishot_notify_single_threaded<F>(f: F, expect: Result<i32, RecvError>)
where
F: FnOnce(Sender<Box<i32>>) + Send + Copy + 'static,
{
let test_waker = Arc::new(TestWaker::new());
let waker = test_waker.clone().into();
let mut cx = Context::from_waker(&waker);
let mut receiver: Receiver<Box<i32>> = Receiver::new();
{
let sender = receiver.sender().expect("could not create sender");
let mut fut = receiver.recv();
let mut fut = Pin::new(&mut fut);
f(sender);
let res = fut.as_mut().poll(&mut cx);
assert_eq!(test_waker.take_count(), 0);
assert_eq!(res.map_ok(|v| *v), Poll::Ready(expect));
}
{
let sender = receiver.sender().expect("could not create sender");
let mut fut = receiver.recv();
let mut fut = Pin::new(&mut fut);
let res = fut.as_mut().poll(&mut cx);
assert_eq!(res, Poll::Pending);
f(sender);
assert_eq!(test_waker.take_count(), 1);
let res = fut.as_mut().poll(&mut cx);
assert_eq!(res.map_ok(|v| *v), Poll::Ready(expect));
}
}
#[test]
fn multishot_send_notify_single_threaded() {
multishot_notify_single_threaded(|sender| sender.send(Box::new(42)), Ok(42));
}
#[test]
fn multishot_drop_notify_single_threaded() {
multishot_notify_single_threaded(|sender| drop(sender), Err(RecvError {}));
}
fn multishot_change_waker_single_threaded<F>(f: F, expect: Result<i32, RecvError>)
where
F: FnOnce(Sender<Box<i32>>) + Send + Copy + 'static,
{
let test_waker1 = Arc::new(TestWaker::new());
let waker1 = test_waker1.clone().into();
let mut cx1 = Context::from_waker(&waker1);
let test_waker2 = Arc::new(TestWaker::new());
let waker2 = test_waker2.clone().into();
let mut cx2 = Context::from_waker(&waker2);
let test_waker3 = Arc::new(TestWaker::new());
let waker3 = test_waker3.clone().into();
let mut cx3 = Context::from_waker(&waker3);
{
let (sender, mut receiver) = channel::<Box<i32>>();
let mut fut = receiver.recv();
let mut fut = Pin::new(&mut fut);
let res = fut.as_mut().poll(&mut cx1);
assert_eq!(res, Poll::Pending);
let res = fut.as_mut().poll(&mut cx2);
assert_eq!(res, Poll::Pending);
f(sender);
assert_eq!(test_waker2.take_count(), 1);
let res = fut.as_mut().poll(&mut cx1);
assert_eq!(test_waker1.take_count(), 0);
assert_eq!(test_waker2.take_count(), 0);
assert_eq!(res.map_ok(|v| *v), Poll::Ready(expect));
}
{
let (sender, mut receiver) = channel::<Box<i32>>();
let mut fut = receiver.recv();
let mut fut = Pin::new(&mut fut);
let res = fut.as_mut().poll(&mut cx1);
assert_eq!(res, Poll::Pending);
let res = fut.as_mut().poll(&mut cx2);
assert_eq!(res, Poll::Pending);
let res = fut.as_mut().poll(&mut cx3);
assert_eq!(res, Poll::Pending);
f(sender);
assert_eq!(test_waker3.take_count(), 1);
let res = fut.as_mut().poll(&mut cx2);
assert_eq!(test_waker1.take_count(), 0);
assert_eq!(test_waker2.take_count(), 0);
assert_eq!(res.map_ok(|v| *v), Poll::Ready(expect));
}
}
#[test]
fn multishot_send_change_waker_single_threaded() {
multishot_change_waker_single_threaded(|sender| sender.send(Box::new(42)), Ok(42));
}
#[test]
fn multishot_drop_change_waker_single_threaded() {
multishot_change_waker_single_threaded(|sender| drop(sender), Err(RecvError {}));
}
fn multishot_notify_multi_threaded<F>(f: F, expect: Result<i32, RecvError>)
where
F: FnOnce(Sender<Box<i32>>) + Send + Copy + 'static,
{
let test_waker = Arc::new(TestWaker::new());
let waker = test_waker.clone().into();
let mut cx = Context::from_waker(&waker);
let mut receiver: Receiver<Box<i32>> = Receiver::new();
let sender = receiver.sender().expect("could not create sender");
let mut fut = receiver.recv();
let mut fut = Pin::new(&mut fut);
let th = thread::spawn(move || f(sender));
let res = fut.as_mut().poll(&mut cx);
th.join().unwrap();
match res {
Poll::Pending => {
assert_eq!(test_waker.take_count(), 1);
assert_eq!(
fut.as_mut().poll(&mut cx).map_ok(|v| *v),
Poll::Ready(expect)
);
}
Poll::Ready(res) => assert_eq!(res.map(|v| *v), expect),
}
}
#[test]
fn multishot_send_notify_multi_threaded() {
multishot_notify_multi_threaded(|sender| sender.send(Box::new(42)), Ok(42));
}
#[test]
fn multishot_drop_notify_multi_threaded() {
multishot_notify_multi_threaded(|sender| drop(sender), Err(RecvError {}));
}
#[test]
fn multishot_drop_both_multi_threaded() {
let mut receiver: Receiver<Box<i32>> = Receiver::new();
let sender = receiver.sender().expect("could not create sender");
let th = thread::spawn(move || drop(sender));
drop(receiver);
th.join().unwrap();
}
#[test]
fn multishot_send_and_drop_multi_threaded() {
let mut receiver: Receiver<Box<i32>> = Receiver::new();
let sender = receiver.sender().expect("could not create sender");
let th = thread::spawn(move || sender.send(Box::new(123)));
drop(receiver);
th.join().unwrap();
}
}
#[cfg(all(test, multishot_loom))]
mod tests {
use super::*;
use std::future::Future;
use std::sync::Arc;
use std::task::{Context, Poll, Wake};
use loom::sync::atomic::{AtomicBool, AtomicUsize};
use loom::thread;
struct TestWaker {
count: AtomicUsize,
}
impl TestWaker {
fn new() -> Self {
Self {
count: AtomicUsize::new(0),
}
}
fn take_count(&self) -> usize {
self.count.swap(0, Ordering::Acquire)
}
}
impl Wake for TestWaker {
fn wake(self: Arc<Self>) {
self.count.fetch_add(1, Ordering::Release);
}
}
fn multishot_loom_notify<F>(f: F, expect: Result<i32, RecvError>)
where
F: FnOnce(Sender<i32>) + Send + Sync + Copy + 'static,
{
loom::model(move || {
let test_waker = Arc::new(TestWaker::new());
let waker = test_waker.clone().into();
let mut cx = Context::from_waker(&waker);
let (sender, mut receiver) = channel();
let has_message = Arc::new(AtomicBool::new(false));
thread::spawn({
let has_message = has_message.clone();
move || {
f(sender);
has_message.store(true, Ordering::Release);
}
});
let mut fut = receiver.recv();
let mut fut = Pin::new(&mut fut);
let res = fut.as_mut().poll(&mut cx);
match res {
Poll::Pending => {
let msg = has_message.load(Ordering::Acquire);
let event_count = test_waker.take_count();
if event_count == 0 {
assert_eq!(msg, false);
} else {
assert_eq!(event_count, 1);
let res = fut.as_mut().poll(&mut cx);
assert_eq!(test_waker.take_count(), 0);
assert_eq!(res, Poll::Ready(expect));
}
}
Poll::Ready(val) => {
assert_eq!(val, expect);
}
}
});
}
fn multishot_loom_change_waker<F>(f: F, expect: Result<i32, RecvError>)
where
F: FnOnce(Sender<i32>) + Send + Sync + Copy + 'static,
{
loom::model(move || {
let test_waker1 = Arc::new(TestWaker::new());
let waker1 = test_waker1.clone().into();
let mut cx1 = Context::from_waker(&waker1);
let test_waker2 = Arc::new(TestWaker::new());
let waker2 = test_waker2.clone().into();
let mut cx2 = Context::from_waker(&waker2);
let (sender, mut receiver) = channel();
thread::spawn({
move || {
f(sender);
}
});
let mut fut = receiver.recv();
let mut fut = Pin::new(&mut fut);
fn try_complete(
fut: &mut Pin<&mut Recv<i32>>,
cx: &mut Context,
other_cx: &mut Context,
test_waker: &TestWaker,
other_test_waker: &TestWaker,
expect: Result<i32, RecvError>,
) -> bool {
let res = fut.as_mut().poll(cx);
if let Poll::Ready(val) = res {
assert_eq!(val, expect);
return true;
}
assert_eq!(other_test_waker.take_count(), 0);
let event_count = test_waker.take_count();
if event_count != 0 {
assert_eq!(event_count, 1);
let res = fut.as_mut().poll(other_cx);
assert_eq!(test_waker.take_count(), 0);
assert_eq!(other_test_waker.take_count(), 0);
assert_eq!(res, Poll::Ready(expect));
return true;
}
false
}
if try_complete(
&mut fut,
&mut cx1,
&mut cx2,
&test_waker1,
&test_waker2,
expect,
) {
return;
}
if try_complete(
&mut fut,
&mut cx2,
&mut cx1,
&test_waker2,
&test_waker1,
expect,
) {
return;
}
if try_complete(
&mut fut,
&mut cx1,
&mut cx2,
&test_waker1,
&test_waker2,
expect,
) {
return;
}
});
}
fn multishot_loom_recycle<F>(f: F)
where
F: FnOnce(Sender<i32>) + Send + Sync + Copy + 'static,
{
loom::model(move || {
let test_waker = Arc::new(TestWaker::new());
let waker = test_waker.clone().into();
let mut cx = Context::from_waker(&waker);
let (sender, mut receiver) = channel();
{
thread::spawn({
move || {
f(sender);
}
});
let mut fut = receiver.recv();
let mut fut = Pin::new(&mut fut);
let res = fut.as_mut().poll(&mut cx);
if res == Poll::Pending {
let res = fut.as_mut().poll(&mut cx);
if res == Poll::Pending {
return;
}
}
}
let sender = receiver
.sender()
.expect("Could not recycle the sender after it was consumed");
{
thread::spawn({
move || {
sender.send(13);
}
});
let mut fut = receiver.recv();
let mut fut = Pin::new(&mut fut);
let res = fut.as_mut().poll(&mut cx);
if let Poll::Ready(val) = res {
assert_eq!(val, Ok(13));
}
}
});
}
#[test]
fn multishot_loom_send_notify() {
multishot_loom_notify(|sender| sender.send(42), Ok(42));
}
#[test]
fn multishot_loom_drop_notify() {
multishot_loom_notify(|sender| drop(sender), Err(RecvError {}));
}
#[test]
fn multishot_loom_send_change_waker() {
multishot_loom_change_waker(|sender| sender.send(42), Ok(42));
}
#[test]
fn multishot_loom_drop_change_waker() {
multishot_loom_change_waker(|sender| drop(sender), Err(RecvError {}));
}
#[test]
fn multishot_loom_send_recycle() {
multishot_loom_recycle(|sender| sender.send(42));
}
#[test]
fn multishot_loom_drop_recycle() {
multishot_loom_recycle(|sender| drop(sender));
}
}