pub use super::{
NoRecv,
RecvErr::{self, *},
};
use incin::Pause;
use owned_alloc::OwnedAlloc;
use ptr::{bypass_null, check_null_align};
use removable::Removable;
use std::{
fmt,
ptr::{null_mut, NonNull},
sync::{
atomic::{AtomicPtr, Ordering::*},
Arc,
},
};
pub fn create<T>() -> (Sender<T>, Receiver<T>) {
with_incin(SharedIncin::new())
}
pub fn with_incin<T>(incin: SharedIncin<T>) -> (Sender<T>, Receiver<T>) {
check_null_align::<Node<T>>();
let alloc = OwnedAlloc::new(Node {
message: Removable::empty(),
next: AtomicPtr::new(null_mut()),
});
let single_node = alloc.into_raw();
let shared = SharedBack { ptr: AtomicPtr::new(single_node.as_ptr()) };
let alloc = OwnedAlloc::new(shared);
let back = alloc.into_raw();
let sender = Sender { inner: Arc::new(SenderInner { back }) };
let receiver = Receiver {
inner: Arc::new(ReceiverInner {
front: AtomicPtr::new(single_node.as_ptr()),
back,
incin,
}),
};
(sender, receiver)
}
pub struct Sender<T> {
inner: Arc<SenderInner<T>>,
}
impl<T> Sender<T> {
pub fn send(&self, message: T) -> Result<(), NoRecv<T>> {
let alloc = OwnedAlloc::new(Node {
message: Removable::new(message),
next: AtomicPtr::new(null_mut()),
});
let node = alloc.into_raw();
let mut loaded = unsafe { self.inner.back.as_ref().ptr.load(Relaxed) };
loop {
if loaded as usize & 1 == 1 {
let mut alloc = unsafe { OwnedAlloc::from_raw(node) };
let message = alloc.message.replace(None).unwrap();
break Err(NoRecv { message });
}
let res = unsafe {
self.inner.back.as_ref().ptr.compare_exchange(
loaded,
node.as_ptr(),
AcqRel,
Relaxed,
)
};
match res {
Ok(_) => {
let prev = unsafe { bypass_null(loaded) };
let res = unsafe {
prev.as_ref().next.compare_exchange(
null_mut(),
node.as_ptr(),
Release,
Relaxed,
)
};
if res.is_err() {
unsafe {
OwnedAlloc::from_raw(prev);
delete_before_last(node, None);
}
}
break Ok(());
},
Err(new) => loaded = new,
}
}
}
pub fn is_connected(&self) -> bool {
let back = unsafe { self.inner.back.as_ref() };
back.ptr.load(Relaxed) as usize & 1 == 0
}
}
unsafe impl<T> Send for Sender<T> where T: Send {}
unsafe impl<T> Sync for Sender<T> where T: Send {}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}
impl<T> fmt::Debug for Sender<T> {
fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
write!(fmtr, "spmc::Sender {} ptr: {:p} {}", '{', self.inner, '}')
}
}
pub struct Receiver<T> {
inner: Arc<ReceiverInner<T>>,
}
impl<T> Receiver<T> {
#[allow(unused_must_use)]
pub fn recv(&self) -> Result<T, RecvErr> {
let pause = self.inner.incin.inner.pause();
let mut front_nnptr = unsafe {
bypass_null(self.inner.front.load(Relaxed))
};
loop {
match unsafe { front_nnptr.as_ref().message.take(AcqRel) } {
Some(val) => {
unsafe { self.try_clear_first(front_nnptr, &pause) };
break Ok(val);
},
None => unsafe {
front_nnptr = self.try_clear_first(front_nnptr, &pause)?;
},
}
}
}
pub fn is_connected(&self) -> bool {
let _pause = self.inner.incin.inner.pause();
let front = unsafe { &*self.inner.front.load(Relaxed) };
let back = unsafe { self.inner.back.as_ref() };
back.ptr.load(Relaxed) as usize & 1 == 0
|| front.message.is_present(Relaxed)
|| !front.next.load(Relaxed).is_null()
}
pub fn incin(&self) -> SharedIncin<T> {
self.inner.incin.clone()
}
unsafe fn try_clear_first(
&self,
expected: NonNull<Node<T>>,
pause: &Pause<OwnedAlloc<Node<T>>>,
) -> Result<NonNull<Node<T>>, RecvErr> {
let next = expected.as_ref().next.load(Acquire);
if let Some(next_nnptr) = NonNull::new(next) {
let res = self.inner.front.compare_exchange(
expected.as_ptr(),
next,
Relaxed,
Relaxed,
);
match res {
Ok(_) => {
pause.add_to_incin(OwnedAlloc::from_raw(expected));
Ok(next_nnptr)
},
Err(found) => Ok(bypass_null(found)),
}
} else if self.inner.back.as_ref().ptr.load(Relaxed) as usize & 1 == 1 {
Err(RecvErr::NoSender)
} else {
Err(RecvErr::NoMessage)
}
}
}
unsafe impl<T> Send for Receiver<T> where T: Send {}
unsafe impl<T> Sync for Receiver<T> where T: Send {}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}
impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, fmtr: &mut fmt::Formatter) -> fmt::Result {
write!(fmtr, "spmc::Receiver {} ptr: {:p} {}", '{', self.inner, '}')
}
}
struct SenderInner<T> {
back: NonNull<SharedBack<T>>,
}
impl<T> Drop for SenderInner<T> {
fn drop(&mut self) {
let ptr = unsafe { self.back.as_ref().ptr.load(Relaxed) };
if ptr as usize & 1 == 0 {
let res = unsafe {
self.back
.as_ref()
.ptr
.swap((ptr as usize | 1) as *mut _, Release)
};
if res == ptr {
return;
}
}
let ptr = (ptr as usize & !1) as *mut Node<T>;
unsafe {
OwnedAlloc::from_raw(bypass_null(ptr));
OwnedAlloc::from_raw(self.back);
}
}
}
struct ReceiverInner<T> {
front: AtomicPtr<Node<T>>,
back: NonNull<SharedBack<T>>,
incin: SharedIncin<T>,
}
impl<T> ReceiverInner<T> {
unsafe fn delete_all(&mut self) {
let mut node_ptr = NonNull::new(*self.front.get_mut());
while let Some(mut node) = node_ptr {
node_ptr = NonNull::new(node.as_mut().next.load(Acquire));
OwnedAlloc::from_raw(node);
}
}
}
impl<T> Drop for ReceiverInner<T> {
fn drop(&mut self) {
let mut ptr = unsafe { self.back.as_ref().ptr.load(Relaxed) };
loop {
if ptr as usize & 1 == 1 {
unsafe {
self.delete_all();
OwnedAlloc::from_raw(self.back);
}
break;
}
let res = unsafe {
self.back.as_ref().ptr.compare_exchange(
ptr,
(ptr as usize | 1) as *mut _,
Relaxed,
Relaxed,
)
};
match res {
Ok(_) => {
debug_assert!(!ptr.is_null());
unsafe {
delete_before_last(
NonNull::new_unchecked(self.front.load(Relaxed)),
NonNull::new(ptr),
)
}
break;
},
Err(new) => ptr = new,
}
}
}
}
struct SharedBack<T> {
ptr: AtomicPtr<Node<T>>,
}
#[repr(align(/* at least */ 2))]
struct Node<T> {
message: Removable<T>,
next: AtomicPtr<Node<T>>,
}
make_shared_incin! {
{ "`mpmc::Receiver`" }
pub SharedIncin<T> of OwnedAlloc<Node<T>>
}
unsafe fn delete_before_last<T>(
mut curr: NonNull<Node<T>>,
last: Option<NonNull<Node<T>>>,
) {
while last != Some(curr) {
let next = curr
.as_ref()
.next
.swap((null_mut::<Node<T>>() as usize | 1) as *mut _, Acquire);
match NonNull::new(next) {
Some(next) => {
OwnedAlloc::from_raw(curr);
curr = next;
},
None => break,
}
}
}
#[cfg(test)]
mod test {
use channel::mpmc;
use std::{
sync::{
atomic::{AtomicBool, Ordering::*},
Arc,
},
thread,
};
#[test]
fn correct_numbers() {
const THREADS: usize = 8;
const MSGS_PER_THREAD: usize = 64;
const MSGS: usize = THREADS * MSGS_PER_THREAD;
let mut done = Vec::with_capacity(MSGS);
for _ in 0 .. MSGS {
done.push(AtomicBool::new(false));
}
let done = Arc::<[AtomicBool]>::from(done);
let (sender, receiver) = mpmc::create::<usize>();
let mut threads = Vec::with_capacity(THREADS);
for i in 0 .. THREADS {
let sender = sender.clone();
threads.push(thread::spawn(move || {
let start = i * MSGS_PER_THREAD;
for j in start .. start + MSGS_PER_THREAD {
sender.send(j).unwrap();
}
}));
let receiver = receiver.clone();
let done = done.clone();
threads.push(thread::spawn(move || loop {
match receiver.recv() {
Ok(i) => assert!(!done[i].swap(true, AcqRel)),
Err(mpmc::NoMessage) => (),
Err(mpmc::NoSender) => break,
}
}));
}
drop(sender);
drop(receiver);
for thread in threads {
thread.join().unwrap();
}
for status in done.iter() {
assert!(status.load(Relaxed));
}
}
}