use std::sync::{
atomic::{AtomicBool, Ordering::*},
Arc,
};
use tokio::sync::Semaphore;
#[derive(Debug, Clone)]
pub struct Signal {
shared: Arc<Inner>,
}
#[derive(Debug)]
struct Inner {
semaphore: Semaphore,
triggered: AtomicBool,
}
impl Signal {
pub fn new() -> Self {
Signal {
shared: Arc::new(Inner {
semaphore: Semaphore::new(0),
triggered: AtomicBool::new(false),
}),
}
}
pub fn trigger(&self) {
let result = self
.shared
.triggered
.compare_exchange(false, true, AcqRel, Acquire);
if result.is_ok() {
self.shared.semaphore.add_permits(usize::MAX >> 3);
}
}
pub fn is_triggered(&self) -> bool {
self.shared.triggered.load(Acquire)
}
pub async fn wait(&self) {
if !self.is_triggered() {
let _ = self.shared.semaphore.acquire().await;
}
}
}
impl Default for Signal {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn signal_test() {
let signal = Signal::new();
let r#pub = tokio::task::spawn({
let signal = signal.clone();
async move {
tokio::time::sleep(Duration::from_millis(200)).await;
signal.trigger();
signal.trigger(); }
});
let fast_sub = tokio::task::spawn({
let signal = signal.clone();
async move {
signal.wait().await;
}
});
let slow_sub = tokio::task::spawn({
let signal = signal.clone();
async move {
tokio::time::sleep(Duration::from_millis(400)).await;
signal.wait().await;
}
});
let result = tokio::time::timeout(
Duration::from_millis(50000),
futures::future::join3(r#pub, fast_sub, slow_sub),
)
.await;
assert!(result.is_ok());
assert!(signal.is_triggered());
}
}