use std::{
collections::VecDeque,
fmt,
sync::{
Arc, Condvar, Mutex,
atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering, fence},
},
time::Duration,
};
use crossbeam_queue::ArrayQueue;
use tokio::sync::Notify;
use crate::stream::{BoxStream, NotUsed, Source};
use crate::{StreamError, StreamResult};
const CHANNEL_OPEN: u8 = 0;
const CHANNEL_CLOSED: u8 = 1;
const PARK_BACKSTOP: Duration = Duration::from_millis(10);
const CONSUMER_DRAIN_BATCH: usize = 256;
const PRODUCER_WAKE_BATCH: usize = 256;
pub struct Channel<T> {
shared: Arc<ChannelShared<T>>,
local: Arc<ProducerLocal>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SendError<T> {
Closed(T),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrySendError<T> {
Full(T),
Closed(T),
}
struct ChannelShared<T> {
buffer: ArrayQueue<T>,
capacity: usize,
available_slots: AtomicUsize,
closed: AtomicU8,
in_flight_senders: AtomicUsize,
consumer_active: AtomicBool,
consumer_park: Mutex<()>,
consumer_available: Condvar,
consumer_parked: AtomicBool,
producer_waiters: Mutex<VecDeque<Arc<ProducerLocal>>>,
space_waiters: AtomicUsize,
closed_notified: Notify,
}
struct ProducerLocal {
reserved_slots: AtomicUsize,
queued: AtomicBool,
active: AtomicBool,
available: Notify,
}
struct ChannelStream<T> {
shared: Arc<ChannelShared<T>>,
pending: VecDeque<T>,
active: bool,
}
impl<T> Channel<T> {
#[must_use]
pub fn bounded(capacity: usize) -> Self {
assert!(capacity > 0, "channel capacity must be greater than zero");
Self {
shared: ChannelShared::new(capacity),
local: ProducerLocal::new(),
}
}
#[must_use]
pub fn source(&self) -> Source<T>
where
T: Send + 'static,
{
let shared = Arc::clone(&self.shared);
Source::from_materialized_factory(move |_materializer| {
let stream = ChannelShared::new_stream(Arc::clone(&shared))?;
Ok((stream, NotUsed))
})
}
pub async fn send(&self, mut value: T) -> Result<(), SendError<T>> {
loop {
match self.try_send(value) {
Ok(()) => return Ok(()),
Err(TrySendError::Closed(value)) => return Err(SendError::Closed(value)),
Err(TrySendError::Full(returned)) => {
value = returned;
}
}
self.shared.wait_for_space_or_close(&self.local).await;
}
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
self.shared.try_send_value(&self.local, value)
}
pub fn close(&self) {
self.shared.close();
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.shared.is_closed()
}
pub async fn closed(&self) {
loop {
if self.is_closed() {
return;
}
let notified = self.shared.closed_notified.notified();
let mut notified = std::pin::pin!(notified);
notified.as_mut().enable();
if self.is_closed() {
return;
}
notified.as_mut().await;
}
}
}
impl<T> Clone for Channel<T> {
fn clone(&self) -> Self {
Self {
shared: Arc::clone(&self.shared),
local: ProducerLocal::new(),
}
}
}
impl<T> Drop for Channel<T> {
fn drop(&mut self) {
self.local.active.store(false, Ordering::Release);
self.shared.cancel_waiter(&self.local);
self.local.available.notify_waiters();
let reserved = self.local.reserved_slots.swap(0, Ordering::AcqRel);
if reserved > 0 && !self.shared.is_closed() {
self.shared.release_slots(reserved);
}
}
}
impl ProducerLocal {
fn new() -> Arc<Self> {
Arc::new(Self {
reserved_slots: AtomicUsize::new(0),
queued: AtomicBool::new(false),
active: AtomicBool::new(true),
available: Notify::new(),
})
}
}
impl<T: Send + 'static> Source<T, NotUsed> {
#[must_use]
pub fn channel(capacity: usize) -> Source<T, Channel<T>> {
assert!(capacity > 0, "channel capacity must be greater than zero");
Source::from_materialized_factory(move |_materializer| {
let channel = Channel::bounded(capacity);
let stream = ChannelShared::new_stream(Arc::clone(&channel.shared))?;
Ok((stream, channel))
})
}
}
impl<T> ChannelShared<T> {
fn new(capacity: usize) -> Arc<Self> {
Arc::new(Self {
buffer: ArrayQueue::new(capacity),
capacity,
available_slots: AtomicUsize::new(capacity),
closed: AtomicU8::new(CHANNEL_OPEN),
in_flight_senders: AtomicUsize::new(0),
consumer_active: AtomicBool::new(false),
consumer_park: Mutex::new(()),
consumer_available: Condvar::new(),
consumer_parked: AtomicBool::new(false),
producer_waiters: Mutex::new(VecDeque::new()),
space_waiters: AtomicUsize::new(0),
closed_notified: Notify::new(),
})
}
fn new_stream(shared: Arc<Self>) -> StreamResult<BoxStream<T>>
where
T: Send + 'static,
{
shared.acquire_consumer()?;
Ok(Box::new(ChannelStream {
shared,
pending: VecDeque::new(),
active: true,
}))
}
fn acquire_consumer(&self) -> StreamResult<()> {
self.consumer_active
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.map(|_| ())
.map_err(|_| {
StreamError::Failed("channel source already has an active consumer".into())
})
}
fn release_consumer(&self) {
self.consumer_active.store(false, Ordering::Release);
}
fn is_closed(&self) -> bool {
self.closed.load(Ordering::Acquire) == CHANNEL_CLOSED
}
fn close(&self) -> bool {
let closed = self
.closed
.compare_exchange(
CHANNEL_OPEN,
CHANNEL_CLOSED,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok();
if closed {
self.wake_all_senders();
self.closed_notified.notify_waiters();
self.wake_consumer();
}
closed
}
fn wake_consumer(&self) {
fence(Ordering::SeqCst);
if self.consumer_parked.load(Ordering::Relaxed) {
let _guard = self
.consumer_park
.lock()
.unwrap_or_else(|poison| poison.into_inner());
self.consumer_available.notify_one();
}
}
fn wake_all_senders(&self) {
fence(Ordering::SeqCst);
let mut waiters = self
.producer_waiters
.lock()
.unwrap_or_else(|poison| poison.into_inner());
while let Some(local) = waiters.pop_front() {
if local.queued.swap(false, Ordering::AcqRel) {
self.space_waiters.fetch_sub(1, Ordering::AcqRel);
local.available.notify_waiters();
}
}
}
async fn wait_for_space_or_close(&self, local: &Arc<ProducerLocal>) {
let notified = local.available.notified();
let mut notified = std::pin::pin!(notified);
notified.as_mut().enable();
let guard = self
.producer_waiters
.lock()
.unwrap_or_else(|poison| poison.into_inner());
if local.reserved_slots.load(Ordering::Acquire) > 0
|| self.available_slots.load(Ordering::Acquire) > 0
|| self.is_closed()
{
drop(guard);
return;
}
if !local.queued.swap(true, Ordering::AcqRel) {
self.space_waiters.fetch_add(1, Ordering::AcqRel);
let mut waiters = guard;
waiters.push_back(Arc::clone(local));
drop(waiters);
} else {
drop(guard);
}
fence(Ordering::SeqCst);
if local.reserved_slots.load(Ordering::Acquire) > 0
|| self.available_slots.load(Ordering::Acquire) > 0
|| self.is_closed()
{
self.cancel_waiter(local);
return;
}
notified.as_mut().await;
self.cancel_waiter(local);
}
fn cancel_waiter(&self, local: &ProducerLocal) {
if local.queued.swap(false, Ordering::AcqRel) {
self.space_waiters.fetch_sub(1, Ordering::AcqRel);
}
}
fn try_acquire_global_slot(&self) -> bool {
let mut slots = self.available_slots.load(Ordering::Acquire);
loop {
if slots == 0 {
return false;
}
match self.available_slots.compare_exchange_weak(
slots,
slots - 1,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return true,
Err(actual) => slots = actual,
}
}
}
fn try_acquire_local_slot(&self, local: &ProducerLocal) -> bool {
let mut slots = local.reserved_slots.load(Ordering::Acquire);
loop {
if slots == 0 {
return false;
}
match local.reserved_slots.compare_exchange_weak(
slots,
slots - 1,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return true,
Err(actual) => slots = actual,
}
}
}
fn release_slots(&self, count: usize) {
if count == 0 {
return;
}
let remaining = self.grant_slots_to_waiters(count);
if remaining > 0 {
let previous = self.available_slots.fetch_add(remaining, Ordering::AcqRel);
debug_assert!(
previous + remaining <= self.capacity,
"channel available slot count exceeded capacity"
);
}
}
fn handoff_available_slots(&self) {
if self.space_waiters.load(Ordering::Acquire) == 0 {
return;
}
let mut slots = self.available_slots.load(Ordering::Acquire);
loop {
if slots == 0 {
return;
}
let claimed = slots.min(PRODUCER_WAKE_BATCH);
match self.available_slots.compare_exchange_weak(
slots,
slots - claimed,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
let remaining = self.grant_slots_to_waiters(claimed);
if remaining > 0 {
self.available_slots.fetch_add(remaining, Ordering::AcqRel);
}
return;
}
Err(actual) => slots = actual,
}
}
}
fn grant_slots_to_waiters(&self, mut slots: usize) -> usize {
if slots == 0 || self.space_waiters.load(Ordering::Acquire) == 0 {
return slots;
}
fence(Ordering::SeqCst);
let mut waiters = self
.producer_waiters
.lock()
.unwrap_or_else(|poison| poison.into_inner());
while slots > 0 {
let Some(local) = waiters.pop_front() else {
break;
};
if !local.queued.swap(false, Ordering::AcqRel) {
continue;
}
self.space_waiters.fetch_sub(1, Ordering::AcqRel);
if !local.active.load(Ordering::Acquire) || self.is_closed() {
local.available.notify_waiters();
continue;
}
let grant = slots.min(PRODUCER_WAKE_BATCH);
local.reserved_slots.fetch_add(grant, Ordering::AcqRel);
slots -= grant;
local.available.notify_waiters();
}
slots
}
fn finish_send(&self) {
let previous = self.in_flight_senders.fetch_sub(1, Ordering::AcqRel);
debug_assert!(previous > 0, "channel in-flight sender underflow");
if previous == 1 && self.is_closed() {
self.wake_consumer();
}
}
fn try_send_value(&self, local: &ProducerLocal, value: T) -> Result<(), TrySendError<T>> {
if self.is_closed() {
return Err(TrySendError::Closed(value));
}
self.in_flight_senders.fetch_add(1, Ordering::AcqRel);
if self.is_closed() {
self.finish_send();
return Err(TrySendError::Closed(value));
}
let used_local_slot = self.try_acquire_local_slot(local);
if !used_local_slot && !self.try_acquire_global_slot() {
self.finish_send();
return if self.is_closed() {
Err(TrySendError::Closed(value))
} else {
Err(TrySendError::Full(value))
};
}
match self.buffer.push(value) {
Ok(()) => {
self.finish_send();
self.wake_consumer();
Ok(())
}
Err(value) => {
if used_local_slot {
local.reserved_slots.fetch_add(1, Ordering::AcqRel);
} else {
self.release_slots(1);
}
self.finish_send();
if self.is_closed() {
Err(TrySendError::Closed(value))
} else {
Err(TrySendError::Full(value))
}
}
}
}
}
impl<T> Iterator for ChannelStream<T> {
type Item = StreamResult<T>;
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(item) = self.pending.pop_front() {
return Some(Ok(item));
}
if let Some(item) = self.drain_batch() {
return Some(Ok(item));
}
if self.shared.is_closed() {
if let Some(item) = self.drain_batch() {
return Some(Ok(item));
}
if self.shared.in_flight_senders.load(Ordering::Acquire) == 0 {
self.finish();
return None;
}
}
let shared = &*self.shared;
let guard = shared
.consumer_park
.lock()
.unwrap_or_else(|poison| poison.into_inner());
shared.consumer_parked.store(true, Ordering::Relaxed);
fence(Ordering::SeqCst);
if !shared.buffer.is_empty()
|| (shared.is_closed() && shared.in_flight_senders.load(Ordering::Acquire) == 0)
{
shared.consumer_parked.store(false, Ordering::Relaxed);
drop(guard);
continue;
}
if !shared.is_closed() && shared.space_waiters.load(Ordering::Acquire) > 0 {
shared.handoff_available_slots();
}
let (guard, _timeout) = shared
.consumer_available
.wait_timeout(guard, PARK_BACKSTOP)
.unwrap_or_else(|poison| poison.into_inner());
shared.consumer_parked.store(false, Ordering::Relaxed);
drop(guard);
}
}
}
impl<T> ChannelStream<T> {
fn drain_batch(&mut self) -> Option<T> {
let first = self.shared.buffer.pop()?;
let mut drained = 1;
while drained < CONSUMER_DRAIN_BATCH {
let Some(item) = self.shared.buffer.pop() else {
break;
};
self.pending.push_back(item);
drained += 1;
}
self.shared.release_slots(drained);
Some(first)
}
fn finish(&mut self) {
if self.active {
self.active = false;
self.shared.release_consumer();
}
}
}
impl<T> Drop for ChannelStream<T> {
fn drop(&mut self) {
if self.active {
self.shared.close();
self.shared.release_consumer();
self.active = false;
}
}
}
impl<T> fmt::Debug for Channel<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Channel")
.field("closed", &self.is_closed())
.finish_non_exhaustive()
}
}
impl<T> fmt::Display for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SendError::Closed(_) => f.write_str("channel is closed"),
}
}
}
impl<T> fmt::Display for TrySendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TrySendError::Full(_) => f.write_str("channel is full"),
TrySendError::Closed(_) => f.write_str("channel is closed"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stream::Materializer;
use futures::executor::block_on;
use std::{
collections::HashMap,
sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
},
thread,
time::{Duration, Instant},
};
#[test]
fn try_send_reports_full_and_closed() {
let channel = Channel::bounded(1);
let mut stream = materialize_channel(&channel);
assert_eq!(channel.try_send(1), Ok(()));
assert_eq!(channel.try_send(2), Err(TrySendError::Full(2)));
assert_eq!(stream.next(), Some(Ok(1)));
assert_eq!(channel.try_send(3), Ok(()));
channel.close();
assert_eq!(channel.try_send(4), Err(TrySendError::Closed(4)));
assert_eq!(stream.next(), Some(Ok(3)));
assert_eq!(stream.next(), None);
}
#[test]
fn send_many_producers_preserves_per_producer_order() {
const PRODUCERS: usize = 8;
const PER_PRODUCER: usize = 256;
let channel = Channel::bounded(32);
let stream = materialize_channel(&channel);
let consumer = thread::spawn(move || {
let mut collected = Vec::new();
for item in stream {
collected.push(item.expect("channel has no failure terminal"));
}
collected
});
let mut handles = Vec::new();
for producer in 0..PRODUCERS {
let channel = channel.clone();
handles.push(thread::spawn(move || {
for seq in 0..PER_PRODUCER {
block_on(channel.send((producer, seq))).unwrap();
}
}));
}
for handle in handles {
handle.join().unwrap();
}
channel.close();
let collected = consumer.join().unwrap();
assert_eq!(collected.len(), PRODUCERS * PER_PRODUCER);
let mut by_producer: HashMap<usize, Vec<usize>> = HashMap::new();
for (producer, seq) in collected {
by_producer.entry(producer).or_default().push(seq);
}
for producer in 0..PRODUCERS {
assert_eq!(
by_producer.remove(&producer).unwrap(),
(0..PER_PRODUCER).collect::<Vec<_>>()
);
}
}
#[test]
fn try_send_under_contention_counts_all_accepted_elements() {
const PRODUCERS: usize = 8;
const PER_PRODUCER: usize = 128;
let total = PRODUCERS * PER_PRODUCER;
let channel = Channel::bounded(total);
let mut handles = Vec::new();
for producer in 0..PRODUCERS {
let channel = channel.clone();
handles.push(thread::spawn(move || {
for seq in 0..PER_PRODUCER {
channel.try_send((producer, seq)).unwrap();
}
}));
}
for handle in handles {
handle.join().unwrap();
}
channel.close();
let stream = materialize_channel(&channel);
let mut count = 0;
for item in stream {
item.unwrap();
count += 1;
}
assert_eq!(count, total);
}
#[test]
fn send_backpressure_parks_and_resumes_on_consume() {
let channel = Channel::bounded(1);
let mut stream = materialize_channel(&channel);
block_on(channel.send(1)).unwrap();
let completed = Arc::new(AtomicBool::new(false));
let send_completed = Arc::clone(&completed);
let sender = {
let channel = channel.clone();
thread::spawn(move || {
let result = block_on(channel.send(2));
send_completed.store(true, Ordering::SeqCst);
result
})
};
wait_until(Duration::from_secs(1), || {
channel.shared.space_waiters.load(Ordering::SeqCst) >= 1
});
assert!(!completed.load(Ordering::SeqCst));
assert_eq!(stream.next(), Some(Ok(1)));
assert_eq!(sender.join().unwrap(), Ok(()));
channel.close();
assert_eq!(stream.next(), Some(Ok(2)));
assert_eq!(stream.next(), None);
}
#[test]
fn close_drains_buffer_before_completion() {
let channel = Channel::bounded(3);
let mut stream = materialize_channel(&channel);
assert_eq!(channel.try_send(1), Ok(()));
assert_eq!(channel.try_send(2), Ok(()));
assert_eq!(channel.try_send(3), Ok(()));
channel.close();
assert_eq!(stream.next(), Some(Ok(1)));
assert_eq!(stream.next(), Some(Ok(2)));
assert_eq!(stream.next(), Some(Ok(3)));
assert_eq!(stream.next(), None);
assert_eq!(stream.next(), None);
}
#[test]
fn concurrent_close_vs_send_never_loses_accepted_elements() {
const ROUNDS: usize = 20;
const PRODUCERS: usize = 8;
const PER_PRODUCER: usize = 200;
for _ in 0..ROUNDS {
let channel = Channel::bounded(4);
let stream = materialize_channel(&channel);
let consumer = thread::spawn(move || {
let mut count = 0_usize;
for item in stream {
item.unwrap();
count += 1;
}
count
});
let mut handles = Vec::new();
let started = Arc::new(AtomicUsize::new(0));
for producer in 0..PRODUCERS {
let channel = channel.clone();
let started = Arc::clone(&started);
handles.push(thread::spawn(move || {
let mut accepted = 0_usize;
started.fetch_add(1, Ordering::SeqCst);
for seq in 0..PER_PRODUCER {
if block_on(channel.send((producer, seq))).is_ok() {
accepted += 1;
} else {
break;
}
}
accepted
}));
}
wait_until(Duration::from_secs(1), || {
started.load(Ordering::SeqCst) == PRODUCERS
});
channel.close();
let accepted: usize = handles
.into_iter()
.map(|handle| handle.join().unwrap())
.sum();
let delivered = consumer.join().unwrap();
assert_eq!(delivered, accepted);
assert_eq!(
channel.try_send((usize::MAX, usize::MAX)),
Err(TrySendError::Closed((usize::MAX, usize::MAX)))
);
}
}
#[test]
fn closed_future_wakes_on_close() {
let channel = Channel::<u64>::bounded(1);
let waiting = Arc::new(AtomicBool::new(false));
let waiter_started = Arc::clone(&waiting);
let waiter = {
let channel = channel.clone();
thread::spawn(move || {
waiter_started.store(true, Ordering::SeqCst);
block_on(channel.closed());
})
};
wait_until(Duration::from_secs(1), || waiting.load(Ordering::SeqCst));
channel.close();
waiter.join().unwrap();
}
#[test]
fn consumer_drop_closes_channel_and_wakes_blocked_producers() {
let channel = Channel::bounded(1);
let stream = materialize_channel(&channel);
block_on(channel.send(1)).unwrap();
let sender = {
let channel = channel.clone();
thread::spawn(move || block_on(channel.send(2)))
};
wait_until(Duration::from_secs(1), || {
channel.shared.space_waiters.load(Ordering::SeqCst) >= 1
});
drop(stream);
assert_eq!(sender.join().unwrap(), Err(SendError::Closed(2)));
assert!(channel.is_closed());
assert_eq!(channel.try_send(3), Err(TrySendError::Closed(3)));
let replacement = Channel::bounded(1);
let mut replacement_stream = materialize_channel(&replacement);
block_on(replacement.send(10)).unwrap();
replacement.close();
assert_eq!(replacement_stream.next(), Some(Ok(10)));
assert_eq!(replacement_stream.next(), None);
}
#[test]
fn single_active_consumer_is_enforced() {
let materializer = Materializer::new();
let channel = Channel::<i32>::bounded(1);
let source = channel.source();
let (_first, _) = Arc::clone(&source.factory).create(&materializer).unwrap();
let second = Arc::clone(&source.factory).create(&materializer);
assert!(
matches!(second, Err(StreamError::Failed(message)) if message.contains("active consumer"))
);
}
#[test]
fn capacity_one_ping_pong_preserves_order() {
const ITEMS: usize = 500;
let channel = Channel::bounded(1);
let stream = materialize_channel(&channel);
let consumer = thread::spawn(move || {
let mut got = Vec::new();
for item in stream {
got.push(item.unwrap());
}
got
});
for item in 0..ITEMS {
block_on(channel.send(item)).unwrap();
}
channel.close();
assert_eq!(consumer.join().unwrap(), (0..ITEMS).collect::<Vec<_>>());
}
#[test]
#[should_panic(expected = "channel capacity must be greater than zero")]
fn zero_capacity_is_unsupported() {
let _ = Channel::<i32>::bounded(0);
}
fn materialize_channel<T: Send + 'static>(channel: &Channel<T>) -> BoxStream<T> {
let materializer = Materializer::new();
let (stream, _) = channel.source().factory.create(&materializer).unwrap();
stream
}
fn wait_until(timeout: Duration, condition: impl Fn() -> bool) {
let deadline = Instant::now() + timeout;
while Instant::now() < deadline {
if condition() {
return;
}
thread::yield_now();
thread::sleep(Duration::from_millis(1));
}
assert!(condition(), "condition was not met within {timeout:?}");
}
}