#![doc = include_str!("../README.md")]
#![cfg_attr(docsrs, feature(doc_cfg))]
pub mod payload;
use payload::AtomicPrimitive;
#[doc(inline)]
pub use payload::Payload;
use std::{
future::Future,
pin::Pin,
sync::{
Arc,
atomic::{AtomicU8, AtomicU32, Ordering},
},
task::{Context, Poll},
};
use tokio::sync::{Notify, futures::OwnedNotified};
pub struct SoftCycleListener<'a, T: Payload> {
notify: OwnedNotified,
controller: &'a SoftCycleController<T>,
}
impl<T: Payload> Future for SoftCycleListener<'_, T> {
type Output = Result<T, ()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let notify_pin = unsafe { self.as_mut().map_unchecked_mut(|s| &mut s.notify) };
match notify_pin.poll(cx) {
Poll::Pending => {}
Poll::Ready(()) => return Poll::Ready(Ok(self.controller.read_payload())),
}
if self.controller.is_notified() {
return Poll::Ready(Ok(self.controller.read_payload()));
}
Poll::Pending
}
}
const STATUS_NOT_NOTIFIED: u8 = 0;
const STATUS_STORING_PAYLOAD: u8 = 1;
const STATUS_NOTIFIED: u8 = 2;
const STATUS_CLEARING: u8 = 3;
pub struct SoftCycleController<T: Payload = ()> {
notify: Arc<Notify>,
next_notify_sequence: AtomicU32,
status: AtomicU8,
payload: <T as Payload>::UnderlyingAtomic,
}
impl<T: Payload> SoftCycleController<T> {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
notify: Arc::new(Notify::new()),
next_notify_sequence: AtomicU32::new(0),
status: AtomicU8::new(STATUS_NOT_NOTIFIED),
payload: <T as Payload>::UnderlyingAtomic::new_default(),
}
}
#[must_use = "Caller must check if the operation was successful"]
pub fn try_notify(&self, payload: T) -> Result<u32, T> {
match self.status.compare_exchange(
STATUS_NOT_NOTIFIED,
STATUS_STORING_PAYLOAD,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => {
let sequence_number = self.next_notify_sequence.fetch_add(1, Ordering::AcqRel);
self.payload.store(payload.into());
self.status.store(STATUS_NOTIFIED, Ordering::Release);
self.notify.notify_waiters();
Ok(sequence_number)
}
Err(_) => Err(payload),
}
}
#[allow(clippy::result_unit_err)]
pub fn try_clear(&self) -> Result<u32, ()> {
match self.status.compare_exchange(
STATUS_NOTIFIED,
STATUS_CLEARING,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => {
let sequence_number = self
.next_notify_sequence
.load(Ordering::Acquire)
.saturating_sub(1);
self.status.store(STATUS_NOT_NOTIFIED, Ordering::Release);
Ok(sequence_number)
}
Err(_) => Err(()),
}
}
#[must_use = "Caller must await the listener to receive the signal"]
pub fn listener<'a>(&'a self) -> SoftCycleListener<'a, T> {
SoftCycleListener {
notify: self.notify.clone().notified_owned(),
controller: self,
}
}
fn is_notified(&self) -> bool {
self.status.load(Ordering::Acquire) == STATUS_NOTIFIED
}
fn read_payload(&self) -> T {
let inner = self.payload.load();
T::from(inner)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use std::time::Instant;
use super::*;
#[tokio::test]
async fn guarantee_a_try_notify_sn_from_zero() {
let ctrl = SoftCycleController::<u32>::new();
assert_eq!(ctrl.try_notify(10), Ok(0));
assert_eq!(ctrl.try_clear(), Ok(0));
assert_eq!(ctrl.try_notify(20), Ok(1));
assert_eq!(ctrl.try_clear(), Ok(1));
assert_eq!(ctrl.try_notify(30), Ok(2));
}
#[tokio::test]
async fn guarantee_a_try_clear_fails_when_not_notified() {
let ctrl = SoftCycleController::<u32>::new();
assert_eq!(ctrl.try_clear(), Err(()));
assert_eq!(ctrl.try_notify(1), Ok(0));
assert_eq!(ctrl.try_clear(), Ok(0));
assert_eq!(ctrl.try_clear(), Err(()));
}
#[tokio::test]
async fn guarantee_a_sn_sequence_notify_clear_interleaved() {
let ctrl = SoftCycleController::<u32>::new();
assert_eq!(ctrl.try_notify(100), Ok(0));
assert_eq!(ctrl.try_clear(), Ok(0));
assert_eq!(ctrl.try_notify(200), Ok(1));
assert_eq!(ctrl.try_clear(), Ok(1));
assert_eq!(ctrl.try_notify(300), Ok(2));
}
#[tokio::test]
async fn guarantee_b_try_clear_nonblocking_many_listeners() {
let ctrl = Arc::new(SoftCycleController::<u32>::new());
let mut listener_handles = Vec::new();
for _ in 0..100 {
let c = ctrl.clone();
listener_handles.push(tokio::spawn(async move { c.listener().await }));
}
assert_eq!(ctrl.try_notify(1), Ok(0));
let deadline = Duration::from_millis(100);
let clear_done = tokio::time::timeout(deadline, async {
let _ = ctrl.try_clear();
});
clear_done.await.expect("try_clear must not block");
assert_eq!(ctrl.try_clear(), Err(()));
ctrl.try_notify(2).ok();
let mut got = 0;
for h in listener_handles {
if let Ok(Ok(v)) = tokio::time::timeout(Duration::from_secs(2), h).await {
assert!(v == Ok(1) || v == Ok(2));
got += 1;
}
}
assert!(got > 0, "at least one listener should get a value");
}
#[tokio::test]
async fn guarantee_b_try_notify_nonblocking_many_listeners() {
let ctrl = Arc::new(SoftCycleController::<u32>::new());
for _ in 0..50 {
let c = ctrl.clone();
tokio::spawn(async move {
let _ = c.listener().await;
});
}
tokio::time::sleep(Duration::from_millis(20)).await;
let start = Instant::now();
let res = ctrl.try_notify(1);
assert!(res.is_ok(), "try_notify must succeed");
assert!(
start.elapsed() < Duration::from_millis(50),
"try_notify must not block"
);
}
#[tokio::test]
async fn guarantee_c_listener_created_while_notified_completes() {
let ctrl = Arc::new(SoftCycleController::<u32>::new());
assert_eq!(ctrl.try_notify(42), Ok(0));
let v = ctrl.listener().await;
assert_eq!(v, Ok(42));
}
#[tokio::test]
async fn guarantee_c_listener_created_before_notify_completes_after_notify() {
let ctrl = Arc::new(SoftCycleController::<u32>::new());
let c = ctrl.clone();
let listener_task = tokio::spawn(async move { c.listener().await });
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(ctrl.try_notify(7), Ok(0));
let r = tokio::time::timeout(Duration::from_secs(1), listener_task)
.await
.expect("listener must complete within timeout")
.expect("task must not panic");
assert_eq!(r, Ok(7));
}
#[tokio::test]
async fn guarantee_d_listener_returns_one_payload_after_multi_round() {
let ctrl = Arc::new(SoftCycleController::<u32>::new());
let c = ctrl.clone();
let listener_task = tokio::spawn(async move { c.listener().await });
assert!(ctrl.try_notify(1).is_ok());
let _ = ctrl.try_clear();
assert!(ctrl.try_notify(2).is_ok());
let r = listener_task.await.unwrap();
let allowed = [Ok(1), Ok(2)];
assert!(
allowed.contains(&r),
"listener must return one of the payloads, got {:?}",
r
);
}
#[tokio::test]
async fn concurrent_multi_round_collects_subset_of_payloads() {
let ctrl = Arc::new(SoftCycleController::<u32>::new());
let mut seen = Vec::new();
let reader = ctrl.clone();
let reader_handle = tokio::spawn(async move {
for _ in 0..20 {
let v = reader.listener().await;
if let Ok(x) = v {
seen.push(x);
}
}
seen
});
for i in 0..10u32 {
assert!(ctrl.try_notify(i).is_ok());
tokio::time::sleep(Duration::from_millis(2)).await;
let _ = ctrl.try_clear();
tokio::time::sleep(Duration::from_millis(2)).await;
}
ctrl.try_notify(99).ok();
let collected = reader_handle.await.unwrap();
assert!(!collected.is_empty());
assert!(collected.iter().all(|&x| (0..10).contains(&x) || x == 99));
}
#[tokio::test]
async fn stress_many_cycles_and_listeners() {
let ctrl = Arc::new(SoftCycleController::<u32>::new());
let writer = ctrl.clone();
let writer_handle = tokio::spawn(async move {
for i in 1..=400u32 {
let _ = writer.try_clear();
if writer.try_notify(i).is_ok() {
tokio::time::sleep(Duration::from_millis(30)).await;
} else {
panic!("notify failed");
}
}
});
let reader = ctrl.clone();
let reader_handle = tokio::spawn(async move {
for _ in 0..3000 {
if let Ok(v) = reader.listener().await {
assert!(0 < v && v <= 400);
tokio::time::sleep(Duration::from_millis(3)).await;
}
}
});
let _ = tokio::join!(writer_handle, reader_handle);
}
#[tokio::test]
async fn regression_concurrent_try_notify_try_clear_sequence_numbers_unique_and_consistent() {
let ctrl = Arc::new(SoftCycleController::<u32>::new());
let mut notify_seqs: Vec<u32> = Vec::new();
let mut clear_seqs: Vec<u32> = Vec::new();
let mut handles = Vec::new();
for _ in 0..8 {
let c = ctrl.clone();
let h = tokio::spawn(async move {
let mut my_notify = Vec::new();
let mut my_clear = Vec::new();
for i in 0..20u32 {
if let Ok(seq) = c.try_notify(i) {
my_notify.push(seq);
}
if c.try_clear().map(|seq| my_clear.push(seq)).is_err() {}
}
(my_notify, my_clear)
});
handles.push(h);
}
for h in handles {
let (n, cl) = h.await.unwrap();
notify_seqs.extend(n);
clear_seqs.extend(cl);
}
notify_seqs.sort_unstable();
clear_seqs.sort_unstable();
let n = notify_seqs.len();
let unique: std::collections::HashSet<u32> = notify_seqs.iter().copied().collect();
assert_eq!(unique.len(), n, "every try_notify Ok(seq) must be unique");
for seq in 0..n as u32 {
assert!(
notify_seqs.contains(&seq),
"sequence numbers must be contiguous from 0, missing {}",
seq
);
}
for &cleared in &clear_seqs {
assert!(
notify_seqs.contains(&cleared),
"try_clear returned seq {} which was not returned by try_notify",
cleared
);
}
}
#[tokio::test]
async fn regression_concurrent_try_clear_returns_notification_sequence() {
let ctrl = Arc::new(SoftCycleController::<u32>::new());
let ctrl2 = ctrl.clone();
let notifier = tokio::spawn(async move {
for i in 0u32..50 {
if ctrl2.try_notify(100 + i).is_ok() {
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
});
let clearer = tokio::spawn(async move {
let mut cleared = Vec::new();
for _ in 0..60 {
if let Ok(seq) = ctrl.try_clear() {
cleared.push(seq);
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
cleared
});
let _ = notifier.await;
let cleared = clearer.await.unwrap();
for &s in &cleared {
assert!(s < 50, "cleared seq {} must be from a prior notify", s);
}
}
}
#[cfg(feature = "global_instance")]
#[cfg_attr(docsrs, doc(cfg(feature = "global_instance")))]
mod global;
#[cfg(feature = "global_instance")]
#[cfg_attr(docsrs, doc(cfg(feature = "global_instance")))]
pub use global::*;