use std::{
sync::Arc,
sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering},
thread,
};
use arc_swap::ArcSwapOption;
use crate::{
error::{RecvError, SendError},
waiter::{
RecvWaiter, RecvWaiterGuard, RecvWaiterList, SelectWaiter, UNSELECTED,
abort_select_waiters, drain_select_waiters, new_recv_waiter_list, push_select_waiter,
wake_all_recv_waiters, wake_select_all,
},
};
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let version = Arc::new(AtomicUsize::new(0));
let value = Arc::new(ArcSwapOption::empty());
let recv_waiters = new_recv_waiter_list();
let select_waiters = Arc::new(AtomicPtr::new(std::ptr::null_mut()));
let receiver_count = Arc::new(AtomicUsize::new(1));
log_debug!("watch::channel: created chan={:p}", Arc::as_ptr(&version));
(
Sender {
version: Arc::clone(&version),
value: Arc::clone(&value),
recv_waiters: Arc::clone(&recv_waiters),
select_waiters: Arc::clone(&select_waiters),
receiver_count: Arc::clone(&receiver_count),
},
Receiver {
version,
value,
recv_waiters,
select_waiters,
wait_version: Arc::new(AtomicUsize::new(0)),
wait_armed: Arc::new(AtomicBool::new(true)),
receiver_count,
},
)
}
pub struct Ref<T> {
snapshot: Option<T>,
}
impl<T> std::ops::Deref for Ref<T> {
type Target = Option<T>;
fn deref(&self) -> &Self::Target {
&self.snapshot
}
}
pub struct Sender<T> {
version: Arc<AtomicUsize>,
value: Arc<ArcSwapOption<T>>,
recv_waiters: RecvWaiterList,
select_waiters: Arc<AtomicPtr<SelectWaiter>>,
receiver_count: Arc<AtomicUsize>,
}
impl<T> Sender<T> {
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
if self.receiver_count.load(Ordering::Acquire) == 0 {
return Err(SendError(value));
}
#[cfg(feature = "debug-logs")]
let chan_id = Arc::as_ptr(&self.version);
#[cfg(feature = "debug-logs")]
let old_version = self.version.load(Ordering::SeqCst);
self.value.store(Some(Arc::new(value)));
#[cfg_attr(not(feature = "debug-logs"), allow(unused_variables))]
let new_version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
#[cfg(feature = "debug-logs")]
log_debug!(
"watch::send: chan={:p}, version={} -> {}",
chan_id,
old_version,
new_version,
);
wake_all_recv_waiters(&self.recv_waiters, UNSELECTED);
wake_select_all(&self.select_waiters);
Ok(())
}
pub fn mark_changed(&self) -> Result<(), SendError<()>> {
if self.receiver_count.load(Ordering::Acquire) == 0 {
return Err(SendError(()));
}
#[cfg(feature = "debug-logs")]
let chan_id = Arc::as_ptr(&self.version);
#[cfg(feature = "debug-logs")]
let old_version = self.version.load(Ordering::SeqCst);
#[cfg_attr(not(feature = "debug-logs"), allow(unused_variables))]
let new_version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
#[cfg(feature = "debug-logs")]
log_debug!(
"watch::mark_changed: chan={:p}, version={} -> {}",
chan_id,
old_version,
new_version,
);
wake_all_recv_waiters(&self.recv_waiters, UNSELECTED);
wake_select_all(&self.select_waiters);
Ok(())
}
pub fn is_closed(&self) -> bool {
self.receiver_count.load(Ordering::Acquire) == 0
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Sender {
version: Arc::clone(&self.version),
value: Arc::clone(&self.value),
recv_waiters: Arc::clone(&self.recv_waiters),
select_waiters: Arc::clone(&self.select_waiters),
receiver_count: Arc::clone(&self.receiver_count),
}
}
}
impl<T: Send + 'static> crate::SelectableSender for Sender<T> {
type Input = T;
fn is_ready(&self) -> bool {
true
}
fn register_select(&self, _case_id: usize, _selected: Arc<AtomicUsize>) {}
fn abort_select(&self, _selected: &Arc<AtomicUsize>) {}
fn complete_send(&self, value: T) -> Result<(), crate::SendError<T>> {
self.send(value)
}
}
pub struct Receiver<T> {
version: Arc<AtomicUsize>,
value: Arc<ArcSwapOption<T>>,
recv_waiters: RecvWaiterList,
select_waiters: Arc<AtomicPtr<SelectWaiter>>,
wait_version: Arc<AtomicUsize>,
wait_armed: Arc<AtomicBool>,
receiver_count: Arc<AtomicUsize>,
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.receiver_count.fetch_add(1, Ordering::Relaxed);
Receiver {
version: Arc::clone(&self.version),
value: Arc::clone(&self.value),
recv_waiters: Arc::clone(&self.recv_waiters),
select_waiters: Arc::clone(&self.select_waiters),
wait_version: Arc::clone(&self.wait_version),
wait_armed: Arc::clone(&self.wait_armed),
receiver_count: Arc::clone(&self.receiver_count),
}
}
}
impl<T> Receiver<T> {
fn snapshot_wait_version(&self) -> Result<usize, RecvError> {
if Arc::strong_count(&self.version) == 1 {
self.wait_armed.store(false, Ordering::SeqCst);
return Err(RecvError::Disconnected);
}
let v = self.version.load(Ordering::SeqCst);
self.wait_version.store(v, Ordering::SeqCst);
self.wait_armed.store(true, Ordering::SeqCst);
Ok(v)
}
fn await_change_from(&self, baseline: usize) -> Result<usize, RecvError> {
#[cfg(feature = "debug-logs")]
let state_id = Arc::as_ptr(&self.version);
let sel = Arc::new(AtomicUsize::new(UNSELECTED));
loop {
let cur = self.version.load(Ordering::SeqCst);
if cur != baseline {
self.wait_version.store(cur, Ordering::SeqCst);
self.wait_armed.store(false, Ordering::SeqCst);
#[cfg(feature = "debug-logs")]
log_debug!(
"watch::await_change_from: chan={:p}, observed version={} -> {}",
state_id,
baseline,
cur
);
return Ok(cur);
}
if Arc::strong_count(&self.version) == 1 {
self.wait_armed.store(false, Ordering::SeqCst);
#[cfg(feature = "debug-logs")]
log_debug!(
"watch::await_change_from: chan={:p}, disconnected while waiting",
state_id
);
return Err(RecvError::Disconnected);
}
let waiter = RecvWaiter::new(0, Arc::clone(&sel));
let _guard = RecvWaiterGuard::register(waiter, &self.recv_waiters);
#[cfg(feature = "debug-logs")]
log_debug!(
"watch::await_change_from: chan={:p}, waiting on version={}",
state_id,
baseline
);
thread::park_timeout(std::time::Duration::from_secs(1));
}
}
pub fn borrow_arc(&self) -> Option<Arc<T>> {
self.value.load_full()
}
pub fn borrow(&self) -> Ref<T>
where
T: Clone,
{
#[cfg(feature = "debug-logs")]
let state_id = Arc::as_ptr(&self.version);
#[cfg(feature = "debug-logs")]
let version = self.version.load(Ordering::SeqCst);
let snapshot = self.borrow_arc().as_deref().cloned();
#[cfg(feature = "debug-logs")]
log_debug!("watch::borrow: chan={:p}, version={}", state_id, version);
Ref { snapshot }
}
pub fn changed(&self) -> Result<usize, RecvError> {
#[cfg(feature = "debug-logs")]
let state_id = Arc::as_ptr(&self.version);
let current_version = self.snapshot_wait_version()?;
#[cfg(feature = "debug-logs")]
log_debug!(
"watch::changed: chan={:p}, current_version={}",
state_id,
current_version
);
self.await_change_from(current_version)
}
pub(crate) fn is_ready(&self) -> bool {
if Arc::strong_count(&self.version) == 1 {
return true;
}
let cur = self.version.load(Ordering::SeqCst);
if !self.wait_armed.load(Ordering::SeqCst) {
self.wait_version.store(cur, Ordering::SeqCst);
self.wait_armed.store(true, Ordering::SeqCst);
return false;
}
cur != self.wait_version.load(Ordering::SeqCst)
}
pub(crate) fn register_select(&self, case_id: usize, selected: Arc<AtomicUsize>) {
log_trace!(
"watch::register_select: chan={:p}, case_id={}",
Arc::as_ptr(&self.version),
case_id
);
let cur = self.version.load(Ordering::SeqCst);
if !self.wait_armed.load(Ordering::SeqCst) {
self.wait_version.store(cur, Ordering::SeqCst);
self.wait_armed.store(true, Ordering::SeqCst);
}
if Arc::strong_count(&self.version) == 1 || cur != self.wait_version.load(Ordering::SeqCst)
{
return;
}
let ptr = SelectWaiter::alloc(case_id, selected);
push_select_waiter(ptr, &self.select_waiters);
}
pub(crate) fn abort_select(&self, selected: &Arc<AtomicUsize>) {
log_trace!("watch::abort_select: chan={:p}", Arc::as_ptr(&self.version));
abort_select_waiters(&self.select_waiters, selected);
}
pub fn complete_changed(&self) -> Result<usize, RecvError> {
let baseline = if self.wait_armed.load(Ordering::SeqCst) {
self.wait_version.load(Ordering::SeqCst)
} else {
self.snapshot_wait_version()?
};
self.await_change_from(baseline)
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
if self.receiver_count.fetch_sub(1, Ordering::AcqRel) == 1 {
drain_select_waiters(&self.select_waiters);
}
}
}
impl<T> crate::SelectableReceiver for Receiver<T> {
type Output = usize;
fn is_ready(&self) -> bool {
self.is_ready()
}
fn register_select(
&self,
case_id: usize,
selected: std::sync::Arc<std::sync::atomic::AtomicUsize>,
) {
self.register_select(case_id, selected)
}
fn abort_select(&self, selected: &std::sync::Arc<std::sync::atomic::AtomicUsize>) {
self.abort_select(selected)
}
fn complete(&self) -> Result<Self::Output, crate::RecvError> {
self.complete_changed()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
#[test]
fn watch_channel_broadcasts_changes() {
let (tx, rx) = channel::<i32>();
assert!(rx.borrow().is_none());
tx.send(42).unwrap();
assert_eq!(*rx.borrow(), Some(42));
let rx2 = rx.clone();
tx.send(100).unwrap();
assert_eq!(*rx.borrow(), Some(100));
assert_eq!(*rx2.borrow(), Some(100));
}
#[test]
fn watch_changed_waits_for_update() {
let (tx, rx) = channel::<i32>();
tx.send(1).unwrap();
let handle = thread::spawn(move || {
assert_eq!(rx.changed().unwrap(), 2);
rx.borrow().unwrap()
});
thread::sleep(Duration::from_millis(10));
tx.send(2).unwrap();
assert_eq!(handle.join().unwrap(), 2);
}
#[test]
fn select_state_is_shared_across_clones() {
let (tx, rx) = channel::<i32>();
tx.send(1).unwrap();
let clone = rx.clone();
assert!(rx.is_ready());
assert!(clone.is_ready());
assert_eq!(rx.complete_changed(), Ok(1));
assert!(!clone.is_ready());
tx.send(2).unwrap();
assert!(rx.is_ready());
assert!(clone.is_ready());
}
#[test]
fn send_returns_err_when_no_receivers() {
let (tx, rx) = channel::<i32>();
assert!(!tx.is_closed());
drop(rx);
assert!(tx.is_closed());
assert_eq!(tx.send(42), Err(SendError(42)));
}
#[test]
fn mark_changed_wakes_without_updating_value() {
let (tx, rx) = channel::<&str>();
tx.send("initial").unwrap();
let rx2 = rx.clone();
let handle = thread::spawn(move || {
rx2.changed().unwrap();
rx2.borrow_arc().map(|a| *a)
});
thread::sleep(Duration::from_millis(10));
tx.mark_changed().unwrap();
assert_eq!(handle.join().unwrap(), Some("initial"));
assert_eq!(*rx.borrow(), Some("initial"));
drop(rx);
assert_eq!(tx.mark_changed(), Err(SendError(())));
}
#[test]
fn multiple_receivers_all_wake_on_change() {
let (tx, rx1) = channel::<i32>();
tx.send(0).unwrap();
let rx2 = rx1.clone();
let rx3 = rx1.clone();
let h1 = thread::spawn(move || rx1.changed().unwrap());
let h2 = thread::spawn(move || rx2.changed().unwrap());
let h3 = thread::spawn(move || rx3.changed().unwrap());
thread::sleep(Duration::from_millis(20));
tx.send(1).unwrap();
assert_eq!(h1.join().unwrap(), 2);
assert_eq!(h2.join().unwrap(), 2);
assert_eq!(h3.join().unwrap(), 2);
}
#[test]
fn changed_returns_err_when_sender_drops() {
let (tx, rx) = channel::<i32>();
tx.send(1).unwrap();
let handle = thread::spawn(move || rx.changed());
thread::sleep(Duration::from_millis(20));
drop(tx);
assert_eq!(
handle.join().unwrap(),
Err(crate::error::RecvError::Disconnected)
);
}
#[test]
fn borrow_reflects_latest_send_immediately() {
let (tx, rx) = channel::<i32>();
assert!(rx.borrow_arc().is_none());
tx.send(7).unwrap();
assert_eq!(*rx.borrow_arc().unwrap(), 7);
tx.send(8).unwrap();
assert_eq!(*rx.borrow_arc().unwrap(), 8);
}
#[test]
fn rapid_sends_borrow_shows_latest() {
let (tx, rx) = channel::<u32>();
for i in 0..100u32 {
tx.send(i).unwrap();
}
assert_eq!(*rx.borrow_arc().unwrap(), 99);
}
#[test]
fn select_recv_arm_fires_on_change() {
use crate::select;
let (tx, rx) = channel::<i32>();
thread::spawn(move || {
thread::sleep(Duration::from_millis(15));
tx.send(1).unwrap();
});
select! {
recv(rx) -> ver => assert_eq!(ver, Ok(1)),
default(Duration::from_millis(200)) => panic!("timeout"),
}
}
#[test]
fn select_recv_arm_fires_for_initial_value() {
use crate::select;
let (tx, rx) = channel::<i32>();
tx.send(42).unwrap();
select! {
recv(rx) -> ver => {
assert_eq!(ver, Ok(1));
assert_eq!(*rx.borrow_arc().unwrap(), 42);
},
default(Duration::from_millis(50)) => panic!("initial value not delivered"),
}
}
}