use std::{
sync::Arc,
sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering::*},
thread,
};
use arc_swap::ArcSwapOption;
use crate::{
error::{RecvError, SendError},
waiter::{
RecvWaiter, RecvWaiterGuard, RecvWaiterList, SelectWaiter, UNSELECTED,
abort_select_waiters, drain_select_waiters, new_recv_waiter_list, push_select_waiter,
wake_all_recv_waiters, wake_select_all,
},
};
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let version = Arc::new(AtomicUsize::new(0));
let value = Arc::new(ArcSwapOption::empty());
let recv_waiters = new_recv_waiter_list();
let select_waiters = Arc::new(AtomicPtr::new(std::ptr::null_mut()));
let receiver_count = Arc::new(AtomicUsize::new(1));
let sender_count = Arc::new(AtomicUsize::new(1));
log_debug!("watch::channel: created chan={:p}", Arc::as_ptr(&version));
(
Sender {
version: Arc::clone(&version),
value: Arc::clone(&value),
recv_waiters: Arc::clone(&recv_waiters),
select_waiters: Arc::clone(&select_waiters),
receiver_count: Arc::clone(&receiver_count),
sender_count: Arc::clone(&sender_count),
},
Receiver {
version,
value,
recv_waiters,
select_waiters,
cursor_version: AtomicUsize::new(0),
cursor_armed: AtomicBool::new(false),
receiver_count,
sender_count,
},
)
}
pub struct Sender<T> {
version: Arc<AtomicUsize>,
value: Arc<ArcSwapOption<T>>,
recv_waiters: RecvWaiterList,
select_waiters: Arc<AtomicPtr<SelectWaiter>>,
receiver_count: Arc<AtomicUsize>,
sender_count: Arc<AtomicUsize>,
}
impl<T> Sender<T> {
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
if self.receiver_count.load(Acquire) == 0 {
return Err(SendError(value));
}
self.value.store(Some(Arc::new(value)));
self.bump_version_and_notify();
Ok(())
}
pub fn mark_changed(&self) -> Result<(), SendError<()>> {
if self.receiver_count.load(Acquire) == 0 {
return Err(SendError(()));
}
self.bump_version_and_notify();
Ok(())
}
pub fn is_closed(&self) -> bool {
self.receiver_count.load(Acquire) == 0
}
fn bump_version_and_notify(&self) {
#[cfg(feature = "debug-logs")]
let chan_id = Arc::as_ptr(&self.version);
#[cfg(feature = "debug-logs")]
let old_version = self.version.load(Relaxed);
#[cfg_attr(not(feature = "debug-logs"), allow(unused_variables))]
let new_version = self.version.fetch_add(1, Release) + 1;
#[cfg(feature = "debug-logs")]
log_debug!(
"watch::bump_version: chan={:p}, version={} -> {}",
chan_id,
old_version,
new_version,
);
self.wake_all();
}
fn wake_all(&self) {
wake_all_recv_waiters(&self.recv_waiters);
wake_select_all(&self.select_waiters);
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let prev = self.sender_count.fetch_sub(1, AcqRel);
if prev == 1 {
self.wake_all();
}
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.sender_count.fetch_add(1, Relaxed);
Sender {
version: Arc::clone(&self.version),
value: Arc::clone(&self.value),
recv_waiters: Arc::clone(&self.recv_waiters),
select_waiters: Arc::clone(&self.select_waiters),
receiver_count: Arc::clone(&self.receiver_count),
sender_count: Arc::clone(&self.sender_count),
}
}
}
impl<T: Send + 'static> crate::SelectableSender for Sender<T> {
type Input = T;
fn is_ready(&self) -> bool {
true
}
fn register_select(&self, _case_id: usize, _selected: Arc<AtomicUsize>) {}
fn abort_select(&self, _selected: &Arc<AtomicUsize>) {}
fn complete_send(&self, value: T) -> Result<(), crate::SendError<T>> {
self.send(value)
}
}
pub struct Receiver<T> {
version: Arc<AtomicUsize>,
value: Arc<ArcSwapOption<T>>,
recv_waiters: RecvWaiterList,
select_waiters: Arc<AtomicPtr<SelectWaiter>>,
cursor_version: AtomicUsize,
cursor_armed: AtomicBool,
receiver_count: Arc<AtomicUsize>,
sender_count: Arc<AtomicUsize>,
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.receiver_count.fetch_add(1, Relaxed);
Receiver {
version: Arc::clone(&self.version),
value: Arc::clone(&self.value),
recv_waiters: Arc::clone(&self.recv_waiters),
select_waiters: Arc::clone(&self.select_waiters),
cursor_version: AtomicUsize::new(self.cursor_version.load(Relaxed)),
cursor_armed: AtomicBool::new(self.cursor_armed.load(Relaxed)),
receiver_count: Arc::clone(&self.receiver_count),
sender_count: Arc::clone(&self.sender_count),
}
}
}
impl<T> Receiver<T> {
fn arm_cursor(&self) -> usize {
let cur = self.version.load(Acquire);
if !self.cursor_armed.load(Relaxed) {
self.cursor_version.store(cur, Relaxed);
self.cursor_armed.store(true, Relaxed);
}
cur
}
fn await_change_from(&self, baseline: usize) -> Result<usize, RecvError> {
#[cfg(feature = "debug-logs")]
let state_id = Arc::as_ptr(&self.version);
let sel = Arc::new(AtomicUsize::new(UNSELECTED));
loop {
let cur = self.version.load(Acquire);
if cur != baseline {
self.cursor_version.store(cur, Relaxed);
self.cursor_armed.store(false, Relaxed);
#[cfg(feature = "debug-logs")]
log_debug!(
"watch::await_change_from: chan={:p}, observed version={} -> {}",
state_id,
baseline,
cur
);
return Ok(cur);
}
if self.sender_count.load(Acquire) == 0 {
self.cursor_armed.store(false, Relaxed);
#[cfg(feature = "debug-logs")]
log_debug!(
"watch::await_change_from: chan={:p}, disconnected while waiting",
state_id
);
return Err(RecvError::Disconnected);
}
let waiter = RecvWaiter::new(0, Arc::clone(&sel));
let _guard = RecvWaiterGuard::register(waiter, &self.recv_waiters);
#[cfg(feature = "debug-logs")]
log_debug!(
"watch::await_change_from: chan={:p}, waiting on version={}",
state_id,
baseline
);
if self.version.load(Acquire) != baseline || self.sender_count.load(Acquire) == 0 {
continue;
}
thread::park();
}
}
pub fn borrow_arc(&self) -> Option<Arc<T>> {
self.value.load_full()
}
pub fn borrow(&self) -> Option<T>
where
T: Clone,
{
#[cfg(feature = "debug-logs")]
let state_id = Arc::as_ptr(&self.version);
#[cfg(feature = "debug-logs")]
let version = self.version.load(Relaxed);
let snapshot = self.borrow_arc().as_deref().cloned();
#[cfg(feature = "debug-logs")]
log_debug!("watch::borrow: chan={:p}, version={}", state_id, version);
snapshot
}
pub fn is_closed(&self) -> bool {
self.sender_count.load(Acquire) == 0
}
pub fn changed(&self) -> Result<(), RecvError> {
#[cfg(feature = "debug-logs")]
let state_id = Arc::as_ptr(&self.version);
if self.sender_count.load(Acquire) == 0 {
self.cursor_armed.store(false, Relaxed);
return Err(RecvError::Disconnected);
}
let current_version = self.arm_cursor();
#[cfg(feature = "debug-logs")]
log_debug!(
"watch::changed: chan={:p}, current_version={}",
state_id,
current_version
);
self.await_change_from(current_version).map(|_| ())
}
pub(crate) fn is_ready(&self) -> bool {
if self.sender_count.load(Acquire) == 0 {
return true;
}
let cur = self.arm_cursor();
cur != self.cursor_version.load(Relaxed)
}
pub(crate) fn register_select(&self, case_id: usize, selected: Arc<AtomicUsize>) {
log_debug!(
"watch::register_select: chan={:p}, case_id={}",
Arc::as_ptr(&self.version),
case_id
);
let cur = self.arm_cursor();
if self.sender_count.load(Acquire) == 0 || cur != self.cursor_version.load(Relaxed) {
return;
}
let ptr = SelectWaiter::alloc(case_id, selected);
push_select_waiter(ptr, &self.select_waiters);
}
pub(crate) fn abort_select(&self, selected: &Arc<AtomicUsize>) {
log_debug!("watch::abort_select: chan={:p}", Arc::as_ptr(&self.version));
abort_select_waiters(&self.select_waiters, selected);
}
pub(crate) fn complete_recv(&self) -> Result<(), RecvError> {
let baseline = self.cursor_version.load(Relaxed);
self.await_change_from(baseline).map(|_| ())
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
if self.receiver_count.fetch_sub(1, AcqRel) == 1 {
drain_select_waiters(&self.select_waiters);
}
}
}
impl_selectable_receiver!([T] Receiver<T>, ());
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
#[test]
fn watch_channel_broadcasts_changes() {
let (tx, rx) = channel::<i32>();
assert!(rx.borrow().is_none());
tx.send(42).unwrap();
assert_eq!(rx.borrow(), Some(42));
let rx2 = rx.clone();
tx.send(100).unwrap();
assert_eq!(rx.borrow(), Some(100));
assert_eq!(rx2.borrow(), Some(100));
}
#[test]
fn watch_changed_waits_for_update() {
let (tx, rx) = channel::<i32>();
tx.send(1).unwrap();
let handle = thread::spawn(move || {
rx.changed().unwrap();
rx.borrow().unwrap()
});
thread::sleep(Duration::from_millis(10));
tx.send(2).unwrap();
assert_eq!(handle.join().unwrap(), 2);
}
#[test]
fn clones_have_independent_cursors() {
let (tx, rx) = channel::<i32>();
tx.send(1).unwrap();
let clone = rx.clone();
assert!(!rx.is_ready()); assert!(!clone.is_ready());
tx.send(2).unwrap();
assert!(rx.is_ready());
assert!(clone.is_ready());
assert_eq!(rx.complete_recv(), Ok(()));
assert!(!rx.is_ready()); assert!(clone.is_ready());
tx.send(3).unwrap();
assert!(rx.is_ready());
assert_eq!(rx.complete_recv(), Ok(()));
}
#[test]
fn send_returns_err_when_no_receivers() {
let (tx, rx) = channel::<i32>();
assert!(!tx.is_closed());
drop(rx);
assert!(tx.is_closed());
assert_eq!(tx.send(42), Err(SendError(42)));
}
#[test]
fn mark_changed_returns_err_when_no_receivers_no_value() {
let (tx, rx) = channel::<i32>();
drop(rx);
assert_eq!(tx.mark_changed(), Err(SendError(())));
}
#[test]
fn mark_changed_wakes_without_updating_value() {
let (tx, rx) = channel::<&str>();
tx.send("initial").unwrap();
let rx2 = rx.clone();
let handle = thread::spawn(move || {
rx2.changed().unwrap();
rx2.borrow_arc().map(|a| *a)
});
thread::sleep(Duration::from_millis(10));
tx.mark_changed().unwrap();
assert_eq!(handle.join().unwrap(), Some("initial"));
assert_eq!(rx.borrow(), Some("initial"));
drop(rx);
assert_eq!(tx.mark_changed(), Err(SendError(())));
}
#[test]
fn multiple_receivers_all_wake_on_change() {
let (tx, rx1) = channel::<i32>();
tx.send(0).unwrap();
let rx2 = rx1.clone();
let rx3 = rx1.clone();
let h1 = thread::spawn(move || rx1.changed());
let h2 = thread::spawn(move || rx2.changed());
let h3 = thread::spawn(move || rx3.changed());
thread::sleep(Duration::from_millis(20));
tx.send(1).unwrap();
assert_eq!(h1.join().unwrap(), Ok(()));
assert_eq!(h2.join().unwrap(), Ok(()));
assert_eq!(h3.join().unwrap(), Ok(()));
}
#[test]
fn changed_returns_err_when_sender_drops() {
let (tx, rx) = channel::<i32>();
tx.send(1).unwrap();
let handle = thread::spawn(move || rx.changed());
thread::sleep(Duration::from_millis(20));
drop(tx);
assert_eq!(
handle.join().unwrap(),
Err(crate::error::RecvError::Disconnected)
);
}
#[test]
fn multi_receiver_all_detect_sender_disconnect() {
const N: usize = 4;
let (tx, rx) = channel::<i32>();
tx.send(0).unwrap();
let mut handles = Vec::new();
for _ in 0..N {
let rx_clone = rx.clone();
handles.push(thread::spawn(move || rx_clone.changed()));
}
thread::sleep(Duration::from_millis(20)); drop(tx);
for h in handles {
assert_eq!(
h.join().unwrap(),
Err(crate::error::RecvError::Disconnected)
);
}
}
#[test]
fn borrow_reflects_latest_send_immediately() {
let (tx, rx) = channel::<i32>();
assert!(rx.borrow().is_none());
tx.send(7).unwrap();
assert_eq!(rx.borrow(), Some(7));
tx.send(8).unwrap();
assert_eq!(rx.borrow(), Some(8));
}
#[test]
fn borrow_arc_reflects_latest_send_immediately() {
let (tx, rx) = channel::<i32>();
assert!(rx.borrow_arc().is_none());
tx.send(7).unwrap();
assert_eq!(*rx.borrow_arc().unwrap(), 7);
tx.send(8).unwrap();
assert_eq!(*rx.borrow_arc().unwrap(), 8);
}
#[test]
fn is_closed_reflects_sender_liveness() {
let (tx, rx) = channel::<i32>();
assert!(!rx.is_closed());
let tx2 = tx.clone();
drop(tx);
assert!(!rx.is_closed()); drop(tx2);
assert!(rx.is_closed());
}
#[test]
fn rapid_sends_borrow_shows_latest() {
let (tx, rx) = channel::<u32>();
for i in 0..100u32 {
tx.send(i).unwrap();
}
assert_eq!(rx.borrow(), Some(99));
}
#[test]
fn select_recv_arm_fires_on_change() {
use crate::select;
let (tx, rx) = channel::<i32>();
tx.send(1).unwrap();
thread::spawn(move || {
thread::sleep(Duration::from_millis(15));
tx.send(2).unwrap();
});
select! {
recv(rx) -> res => assert_eq!(res, Ok(())),
default(Duration::from_millis(200)) => panic!("timeout"),
}
}
#[test]
fn multiple_senders_last_write_wins() {
let (tx1, rx) = channel::<i32>();
let tx2 = tx1.clone();
tx1.send(10).unwrap();
tx2.send(20).unwrap();
assert_eq!(rx.borrow(), Some(20));
}
#[test]
fn sender_clone_extends_liveness() {
let (tx1, rx) = channel::<i32>();
let tx2 = tx1.clone();
drop(tx1);
assert!(!rx.is_closed());
assert!(!tx2.is_closed());
tx2.send(42).unwrap();
assert_eq!(rx.borrow(), Some(42));
drop(tx2);
assert!(rx.is_closed());
assert_eq!(rx.changed(), Err(crate::error::RecvError::Disconnected));
}
}