use std::cell::UnsafeCell;
use std::fmt;
use std::future::Future;
use std::isize;
use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::process;
use std::ptr;
use std::sync::atomic::{self, AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use crossbeam_utils::{Backoff, CachePadded};
use futures_core::stream::Stream;
use slab::Slab;
#[cfg(feature = "unstable")]
#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
pub fn channel<T>(cap: usize) -> (Sender<T>, Receiver<T>) {
let channel = Arc::new(Channel::with_capacity(cap));
let s = Sender {
channel: channel.clone(),
};
let r = Receiver {
channel,
opt_key: None,
};
(s, r)
}
#[cfg(feature = "unstable")]
#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
pub struct Sender<T> {
channel: Arc<Channel<T>>,
}
impl<T> Sender<T> {
pub async fn send(&self, msg: T) {
struct SendFuture<'a, T> {
sender: &'a Sender<T>,
msg: Option<T>,
opt_key: Option<usize>,
}
impl<T> Unpin for SendFuture<'_, T> {}
impl<T> Future for SendFuture<'_, T> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let msg = self.msg.take().unwrap();
let poll = match self.sender.channel.push(msg) {
Ok(()) => Poll::Ready(()),
Err(PushError::Disconnected(msg)) => {
self.msg = Some(msg);
Poll::Pending
}
Err(PushError::Full(msg)) => {
match self.opt_key {
None => self.opt_key = Some(self.sender.channel.sends.register(cx)),
Some(key) => self.sender.channel.sends.reregister(key, cx),
}
match self.sender.channel.push(msg) {
Ok(()) => Poll::Ready(()),
Err(PushError::Disconnected(msg)) | Err(PushError::Full(msg)) => {
self.msg = Some(msg);
Poll::Pending
}
}
}
};
if poll.is_ready() {
if let Some(key) = self.opt_key.take() {
self.sender.channel.sends.unregister(key, true);
}
}
poll
}
}
impl<T> Drop for SendFuture<'_, T> {
fn drop(&mut self) {
if let Some(key) = self.opt_key {
self.sender.channel.sends.unregister(key, false);
}
}
}
SendFuture {
sender: self,
msg: Some(msg),
opt_key: None,
}
.await
}
pub fn capacity(&self) -> usize {
self.channel.cap
}
pub fn is_empty(&self) -> bool {
self.channel.is_empty()
}
pub fn is_full(&self) -> bool {
self.channel.is_full()
}
pub fn len(&self) -> usize {
self.channel.len()
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if self.channel.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.channel.disconnect();
}
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Sender<T> {
let count = self.channel.sender_count.fetch_add(1, Ordering::Relaxed);
if count > isize::MAX as usize {
process::abort();
}
Sender {
channel: self.channel.clone(),
}
}
}
impl<T> fmt::Debug for Sender<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("Sender { .. }")
}
}
#[cfg(feature = "unstable")]
#[cfg_attr(feature = "docs", doc(cfg(unstable)))]
pub struct Receiver<T> {
channel: Arc<Channel<T>>,
opt_key: Option<usize>,
}
impl<T> Receiver<T> {
pub async fn recv(&self) -> Option<T> {
struct RecvFuture<'a, T> {
channel: &'a Channel<T>,
opt_key: Option<usize>,
}
impl<T> Future for RecvFuture<'_, T> {
type Output = Option<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
poll_recv(&self.channel, &self.channel.recvs, &mut self.opt_key, cx)
}
}
impl<T> Drop for RecvFuture<'_, T> {
fn drop(&mut self) {
if let Some(key) = self.opt_key {
self.channel.recvs.unregister(key, false);
}
}
}
RecvFuture {
channel: &self.channel,
opt_key: None,
}
.await
}
pub fn capacity(&self) -> usize {
self.channel.cap
}
pub fn is_empty(&self) -> bool {
self.channel.is_empty()
}
pub fn is_full(&self) -> bool {
self.channel.is_full()
}
pub fn len(&self) -> usize {
self.channel.len()
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
if let Some(key) = self.opt_key {
self.channel.streams.unregister(key, false);
}
if self.channel.receiver_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.channel.disconnect();
}
}
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Receiver<T> {
let count = self.channel.receiver_count.fetch_add(1, Ordering::Relaxed);
if count > isize::MAX as usize {
process::abort();
}
Receiver {
channel: self.channel.clone(),
opt_key: None,
}
}
}
impl<T> Stream for Receiver<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
poll_recv(&this.channel, &this.channel.streams, &mut this.opt_key, cx)
}
}
impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("Receiver { .. }")
}
}
fn poll_recv<T>(
channel: &Channel<T>,
registry: &Registry,
opt_key: &mut Option<usize>,
cx: &mut Context<'_>,
) -> Poll<Option<T>> {
let poll = match channel.pop() {
Ok(msg) => Poll::Ready(Some(msg)),
Err(PopError::Disconnected) => Poll::Ready(None),
Err(PopError::Empty) => {
match *opt_key {
None => *opt_key = Some(registry.register(cx)),
Some(key) => registry.reregister(key, cx),
}
match channel.pop() {
Ok(msg) => Poll::Ready(Some(msg)),
Err(PopError::Disconnected) => Poll::Ready(None),
Err(PopError::Empty) => Poll::Pending,
}
}
};
if poll.is_ready() {
if let Some(key) = opt_key.take() {
registry.unregister(key, true);
}
}
poll
}
struct Slot<T> {
stamp: AtomicUsize,
msg: UnsafeCell<T>,
}
struct Channel<T> {
head: CachePadded<AtomicUsize>,
tail: CachePadded<AtomicUsize>,
buffer: *mut Slot<T>,
cap: usize,
one_lap: usize,
mark_bit: usize,
sends: Registry,
recvs: Registry,
streams: Registry,
sender_count: AtomicUsize,
receiver_count: AtomicUsize,
_marker: PhantomData<T>,
}
unsafe impl<T: Send> Send for Channel<T> {}
unsafe impl<T: Send> Sync for Channel<T> {}
impl<T> Unpin for Channel<T> {}
impl<T> Channel<T> {
fn with_capacity(cap: usize) -> Self {
assert!(cap > 0, "capacity must be positive");
let mark_bit = (cap + 1).next_power_of_two();
let one_lap = mark_bit * 2;
let head = 0;
let tail = 0;
let buffer = {
let mut v = Vec::<Slot<T>>::with_capacity(cap);
let ptr = v.as_mut_ptr();
mem::forget(v);
ptr
};
for i in 0..cap {
unsafe {
let slot = buffer.add(i);
ptr::write(&mut (*slot).stamp, AtomicUsize::new(i));
}
}
Channel {
buffer,
cap,
one_lap,
mark_bit,
head: CachePadded::new(AtomicUsize::new(head)),
tail: CachePadded::new(AtomicUsize::new(tail)),
sends: Registry::new(),
recvs: Registry::new(),
streams: Registry::new(),
sender_count: AtomicUsize::new(1),
receiver_count: AtomicUsize::new(1),
_marker: PhantomData,
}
}
fn push(&self, msg: T) -> Result<(), PushError<T>> {
let backoff = Backoff::new();
let mut tail = self.tail.load(Ordering::Relaxed);
loop {
let index = tail & (self.mark_bit - 1);
let lap = tail & !(self.one_lap - 1);
let slot = unsafe { &*self.buffer.add(index) };
let stamp = slot.stamp.load(Ordering::Acquire);
if tail == stamp {
let new_tail = if index + 1 < self.cap {
tail + 1
} else {
lap.wrapping_add(self.one_lap)
};
match self.tail.compare_exchange_weak(
tail,
new_tail,
Ordering::SeqCst,
Ordering::Relaxed,
) {
Ok(_) => {
unsafe { slot.msg.get().write(msg) };
let stamp = tail + 1;
slot.stamp.store(stamp, Ordering::Release);
self.recvs.notify_one();
self.streams.notify_all();
return Ok(());
}
Err(t) => {
tail = t;
backoff.spin();
}
}
} else if stamp.wrapping_add(self.one_lap) == tail + 1 {
atomic::fence(Ordering::SeqCst);
let head = self.head.load(Ordering::Relaxed);
if head.wrapping_add(self.one_lap) == tail {
if tail & self.mark_bit != 0 {
return Err(PushError::Disconnected(msg));
} else {
return Err(PushError::Full(msg));
}
}
backoff.spin();
tail = self.tail.load(Ordering::Relaxed);
} else {
backoff.snooze();
tail = self.tail.load(Ordering::Relaxed);
}
}
}
fn pop(&self) -> Result<T, PopError> {
let backoff = Backoff::new();
let mut head = self.head.load(Ordering::Relaxed);
loop {
let index = head & (self.mark_bit - 1);
let lap = head & !(self.one_lap - 1);
let slot = unsafe { &*self.buffer.add(index) };
let stamp = slot.stamp.load(Ordering::Acquire);
if head + 1 == stamp {
let new = if index + 1 < self.cap {
head + 1
} else {
lap.wrapping_add(self.one_lap)
};
match self.head.compare_exchange_weak(
head,
new,
Ordering::SeqCst,
Ordering::Relaxed,
) {
Ok(_) => {
let msg = unsafe { slot.msg.get().read() };
let stamp = head.wrapping_add(self.one_lap);
slot.stamp.store(stamp, Ordering::Release);
self.sends.notify_one();
return Ok(msg);
}
Err(h) => {
head = h;
backoff.spin();
}
}
} else if stamp == head {
atomic::fence(Ordering::SeqCst);
let tail = self.tail.load(Ordering::Relaxed);
if (tail & !self.mark_bit) == head {
if tail & self.mark_bit != 0 {
return Err(PopError::Disconnected);
} else {
return Err(PopError::Empty);
}
}
backoff.spin();
head = self.head.load(Ordering::Relaxed);
} else {
backoff.snooze();
head = self.head.load(Ordering::Relaxed);
}
}
}
fn len(&self) -> usize {
loop {
let tail = self.tail.load(Ordering::SeqCst);
let head = self.head.load(Ordering::SeqCst);
if self.tail.load(Ordering::SeqCst) == tail {
let hix = head & (self.mark_bit - 1);
let tix = tail & (self.mark_bit - 1);
return if hix < tix {
tix - hix
} else if hix > tix {
self.cap - hix + tix
} else if (tail & !self.mark_bit) == head {
0
} else {
self.cap
};
}
}
}
fn is_empty(&self) -> bool {
let head = self.head.load(Ordering::SeqCst);
let tail = self.tail.load(Ordering::SeqCst);
(tail & !self.mark_bit) == head
}
fn is_full(&self) -> bool {
let tail = self.tail.load(Ordering::SeqCst);
let head = self.head.load(Ordering::SeqCst);
head.wrapping_add(self.one_lap) == tail & !self.mark_bit
}
fn disconnect(&self) {
let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
if tail & self.mark_bit == 0 {
self.sends.notify_all();
self.recvs.notify_all();
self.streams.notify_all();
}
}
}
impl<T> Drop for Channel<T> {
fn drop(&mut self) {
let hix = self.head.load(Ordering::Relaxed) & (self.mark_bit - 1);
for i in 0..self.len() {
let index = if hix + i < self.cap {
hix + i
} else {
hix + i - self.cap
};
unsafe {
self.buffer.add(index).drop_in_place();
}
}
unsafe {
Vec::from_raw_parts(self.buffer, 0, self.cap);
}
}
}
enum PushError<T> {
Full(T),
Disconnected(T),
}
enum PopError {
Empty,
Disconnected,
}
struct Blocked {
entries: Slab<Option<Waker>>,
waker_count: usize,
}
struct Registry {
blocked: Spinlock<Blocked>,
is_empty: AtomicBool,
}
impl Registry {
fn new() -> Registry {
Registry {
blocked: Spinlock::new(Blocked {
entries: Slab::new(),
waker_count: 0,
}),
is_empty: AtomicBool::new(true),
}
}
fn register(&self, cx: &Context<'_>) -> usize {
let mut blocked = self.blocked.lock();
let w = cx.waker().clone();
let key = blocked.entries.insert(Some(w));
blocked.waker_count += 1;
if blocked.waker_count == 1 {
self.is_empty.store(false, Ordering::SeqCst);
}
key
}
fn reregister(&self, key: usize, cx: &Context<'_>) {
let mut blocked = self.blocked.lock();
let was_none = blocked.entries[key].is_none();
let w = cx.waker().clone();
blocked.entries[key] = Some(w);
if was_none {
blocked.waker_count += 1;
if blocked.waker_count == 1 {
self.is_empty.store(false, Ordering::SeqCst);
}
}
}
fn unregister(&self, key: usize, completed: bool) {
let mut blocked = self.blocked.lock();
let mut removed = false;
match blocked.entries.remove(key) {
Some(_) => removed = true,
None => {
if !completed {
if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() {
if let Some(w) = opt_waker.take() {
w.wake();
removed = true;
}
}
}
}
}
if removed {
blocked.waker_count -= 1;
if blocked.waker_count == 0 {
self.is_empty.store(true, Ordering::SeqCst);
}
}
}
#[inline]
fn notify_one(&self) {
if !self.is_empty.load(Ordering::SeqCst) {
let mut blocked = self.blocked.lock();
if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() {
if let Some(w) = opt_waker.take() {
w.wake();
blocked.waker_count -= 1;
if blocked.waker_count == 0 {
self.is_empty.store(true, Ordering::SeqCst);
}
}
}
}
}
#[inline]
fn notify_all(&self) {
if !self.is_empty.load(Ordering::SeqCst) {
let mut blocked = self.blocked.lock();
for (_, opt_waker) in blocked.entries.iter_mut() {
if let Some(w) = opt_waker.take() {
w.wake();
}
}
blocked.waker_count = 0;
self.is_empty.store(true, Ordering::SeqCst);
}
}
}
struct Spinlock<T> {
flag: AtomicBool,
value: UnsafeCell<T>,
}
impl<T> Spinlock<T> {
fn new(value: T) -> Spinlock<T> {
Spinlock {
flag: AtomicBool::new(false),
value: UnsafeCell::new(value),
}
}
fn lock(&self) -> SpinlockGuard<'_, T> {
let backoff = Backoff::new();
while self.flag.swap(true, Ordering::Acquire) {
backoff.snooze();
}
SpinlockGuard { parent: self }
}
}
struct SpinlockGuard<'a, T> {
parent: &'a Spinlock<T>,
}
impl<'a, T> Drop for SpinlockGuard<'a, T> {
fn drop(&mut self) {
self.parent.flag.store(false, Ordering::Release);
}
}
impl<'a, T> Deref for SpinlockGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.parent.value.get() }
}
}
impl<'a, T> DerefMut for SpinlockGuard<'a, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.parent.value.get() }
}
}