use crate::util::{Backoff, CachePadded};
use core::cell::UnsafeCell;
use core::mem::MaybeUninit;
use core::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release, SeqCst};
use core::sync::atomic::{fence, AtomicPtr, AtomicUsize};
use std::sync::Arc;
use std::{fmt, io};
const WRITE: usize = 1;
const READ: usize = 1 << 1;
const DESTROY: usize = 1 << 2;
const ROUND: usize = 64;
const BLOCK_SIZE: usize = ROUND - 1;
const CLOSED_FLAG: usize = 1 << 63;
const CROSSED_FLAG: usize = CLOSED_FLAG;
struct Slot<T> {
state: AtomicUsize,
message: UnsafeCell<MaybeUninit<T>>,
}
impl<T> fmt::Debug for Slot<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Slot")
.field("state", &self.state.load(Acquire))
.finish()
}
}
struct Block<T> {
next: AtomicPtr<Block<T>>,
slots: [Slot<T>; BLOCK_SIZE],
}
impl<T> Block<T> {
fn new() -> *mut Self {
let block = unsafe { MaybeUninit::zeroed().assume_init() };
Box::into_raw(Box::new(block))
}
fn wait_next(&self) -> *mut Block<T> {
let backoff = Backoff::new();
loop {
let next = self.next.load(Acquire);
if !next.is_null() {
return next;
}
backoff.spin();
}
}
fn set_next(&self, next: *mut Block<T>) {
let prev = self.next.swap(next, Release);
debug_assert!(prev.is_null());
}
fn get_next(&self) -> Option<*mut Block<T>> {
let next = self.next.load(Acquire);
if next.is_null() {
return None;
}
Some(next)
}
fn destroy(this: *mut Block<T>, start: usize) {
for i in start..BLOCK_SIZE - 1 {
let slot = unsafe { (*this).slots.get_unchecked(i) };
if slot.state.load(Acquire) & READ == 0
&& slot.state.fetch_or(DESTROY, AcqRel) & READ == 0
{
return;
}
}
unsafe { drop(Box::from_raw(this)) };
}
}
#[derive(Debug)]
struct Cursor<T> {
index: AtomicUsize,
block: AtomicPtr<Block<T>>,
}
impl<T> Cursor<T> {
fn from(block: *mut Block<T>) -> Self {
Cursor {
index: AtomicUsize::new(0),
block: AtomicPtr::new(block),
}
}
#[inline]
fn slot<'a>(block: *mut Block<T>, index: usize) -> &'a Slot<T> {
debug_assert!(index < BLOCK_SIZE);
unsafe { (*block).slots.get_unchecked(index) }
}
}
#[derive(Debug)]
struct Channel<T> {
tail: CachePadded<Cursor<T>>,
head: CachePadded<Cursor<T>>,
}
impl<T> Drop for Channel<T> {
fn drop(&mut self) {
while let Ok(Some(_)) = self.try_recv() {}
let block = self.head.block.load(Acquire);
unsafe { drop(Box::from_raw(block)) };
}
}
impl<T> Channel<T> {
fn new() -> Channel<T> {
let block = Block::<T>::new();
Channel {
tail: CachePadded::new(Cursor::from(block)),
head: CachePadded::new(Cursor::from(block)),
}
}
#[inline]
fn send(&self, msg: T) -> io::Result<()> {
let backoff = Backoff::new();
let mut tail = self.tail.index.load(Acquire);
let mut block = self.tail.block.load(Acquire);
loop {
let index = tail & BLOCK_SIZE;
if tail & CLOSED_FLAG == CLOSED_FLAG {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"channel is closed",
));
}
if index == BLOCK_SIZE {
backoff.snooze();
tail = self.tail.index.load(Acquire);
block = self.tail.block.load(Acquire);
continue;
}
match self
.tail
.index
.compare_exchange_weak(tail, tail + 1, SeqCst, Relaxed)
{
Ok(_) => {
let slot = Cursor::slot(block, index);
if index + 1 == BLOCK_SIZE {
let next = Block::new();
self.tail.block.store(next, Release);
self.tail.index.fetch_add(1, Release);
unsafe { (*block).set_next(next) };
}
unsafe { slot.message.get().write(MaybeUninit::new(msg)) };
slot.state.fetch_or(WRITE, Release);
return Ok(());
}
Err(t) => {
tail = t;
block = self.tail.block.load(Acquire);
backoff.spin();
}
}
}
}
#[inline]
fn try_recv(&self) -> io::Result<Option<T>> {
let backoff = Backoff::new();
let mut head = self.head.index.load(Acquire);
let mut block = self.head.block.load(Acquire);
loop {
let index = head & BLOCK_SIZE;
if index == BLOCK_SIZE {
backoff.snooze();
head = self.head.index.load(Acquire);
block = self.head.block.load(Acquire);
continue;
}
let mut new_head = head + 1;
if head & CROSSED_FLAG == 0 {
fence(SeqCst);
let tail = self.tail.index.load(Relaxed);
if head == tail & !CLOSED_FLAG {
if tail & CLOSED_FLAG != 0 {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"channel is closed",
));
}
return Ok(None);
}
if head / ROUND != (tail & !CLOSED_FLAG) / ROUND {
new_head |= CROSSED_FLAG;
}
}
match self
.head
.index
.compare_exchange_weak(head, new_head, SeqCst, Relaxed)
{
Ok(_) => unsafe {
if index + 1 == BLOCK_SIZE {
let next_block = (*block).wait_next();
if (*next_block).get_next().is_some() {
new_head |= CROSSED_FLAG;
} else {
new_head &= !CROSSED_FLAG;
}
self.head.block.store(next_block, Release);
self.head.index.store(new_head + 1, Release);
}
let slot = Cursor::slot(block, index);
while slot.state.load(Acquire) & WRITE == 0 {
backoff.spin();
}
let msg = slot.message.get().read().assume_init();
if index + 1 == BLOCK_SIZE {
Block::destroy(block, 0);
}
else if slot.state.fetch_or(READ, AcqRel) & DESTROY != 0 {
Block::destroy(block, index + 1);
}
return Ok(Some(msg));
},
Err(h) => {
head = h;
block = self.head.block.load(Acquire);
backoff.spin();
}
}
}
}
#[inline]
fn is_empty(&self) -> bool {
let head = self.head.index.load(SeqCst);
let tail = self.tail.index.load(SeqCst);
head & !CROSSED_FLAG == tail & !CLOSED_FLAG
}
#[inline]
fn close(&self) {
self.tail.index.fetch_or(CLOSED_FLAG, AcqRel);
}
}
pub(crate) struct Counters {
pub(crate) receivers: AtomicUsize,
pub(crate) senders: AtomicUsize,
}
pub struct Sender<T> {
chan: Arc<Channel<T>>,
pub(crate) cnts: Arc<Counters>,
}
impl<T: fmt::Debug> fmt::Debug for Sender<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.chan)
}
}
impl<T> Sender<T> {
#[inline]
pub fn send(&self, item: T) -> io::Result<()> {
self.chan.send(item)
}
#[inline]
pub fn is_empty(&self) -> bool {
self.chan.is_empty()
}
pub fn close(&self) {
self.chan.close()
}
pub fn receiver(&self) -> Receiver<T> {
self.cnts.receivers.fetch_add(1, SeqCst);
Receiver {
chan: Arc::clone(&self.chan),
cnts: self.cnts.clone(),
}
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.cnts.senders.fetch_add(1, SeqCst);
Sender {
chan: Arc::clone(&self.chan),
cnts: self.cnts.clone(),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let senders = self.cnts.senders.fetch_sub(1, SeqCst);
if senders == 1 {
self.chan.close();
}
}
}
unsafe impl<T> Send for Sender<T> {}
pub struct Receiver<T> {
chan: Arc<Channel<T>>,
pub(crate) cnts: Arc<Counters>,
}
impl<T> Receiver<T> {
#[inline]
pub fn try_recv(&self) -> io::Result<Option<T>> {
self.chan.try_recv()
}
pub fn sender(&self) -> Sender<T> {
self.cnts.senders.fetch_add(1, SeqCst);
Sender {
chan: Arc::clone(&self.chan),
cnts: self.cnts.clone(),
}
}
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.cnts.receivers.fetch_add(1, SeqCst);
Receiver {
chan: Arc::clone(&self.chan),
cnts: self.cnts.clone(),
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let receivers = self.cnts.receivers.fetch_sub(1, SeqCst);
if receivers == 1 {
self.chan.close();
}
}
}
unsafe impl<T> Send for Receiver<T> {}
pub fn new<T>() -> (Sender<T>, Receiver<T>) {
let chan = Arc::new(Channel::new());
let cnts = Arc::new(Counters {
receivers: AtomicUsize::new(1),
senders: AtomicUsize::new(1),
});
let tx = Sender {
chan: chan.clone(),
cnts: cnts.clone(),
};
let rx = Receiver {
chan: chan.clone(),
cnts: cnts,
};
(tx, rx)
}