use std::sync::Arc;
use tokio::sync::Notify;
#[derive(Clone, Default)]
pub struct Broadcast {
notify: Arc<Notify>,
}
impl Broadcast {
pub fn new() -> Self {
#[cfg(feature = "tracing")]
tracing::trace!("creating new broadcast shutdown");
Self::default()
}
pub fn subscribe(&self) -> Subscriber {
#[cfg(feature = "tracing")]
tracing::trace!("creating new broadcast subscriber");
Subscriber {
notify: self.notify.clone(),
}
}
pub fn shutdown(&self) {
#[cfg(feature = "tracing")]
tracing::info!("broadcasting shutdown to all subscribers");
self.notify.notify_waiters();
}
}
#[derive(Clone)]
pub struct Subscriber {
notify: Arc<Notify>,
}
impl Subscriber {
pub async fn recv(&self) {
#[cfg(feature = "tracing")]
tracing::debug!("subscriber waiting for broadcast shutdown");
self.notify.notified().await;
#[cfg(feature = "tracing")]
tracing::debug!("subscriber received broadcast shutdown");
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::timeout;
#[tokio::test]
async fn test_broadcast_shutdown() {
let broadcast = Broadcast::new();
let sub1 = broadcast.subscribe();
let sub2 = broadcast.subscribe();
let task1 = tokio::spawn(async move {
sub1.recv().await;
});
let task2 = tokio::spawn(async move {
sub2.recv().await;
});
tokio::time::sleep(Duration::from_millis(10)).await;
broadcast.shutdown();
assert!(timeout(Duration::from_millis(100), task1).await.is_ok());
assert!(timeout(Duration::from_millis(100), task2).await.is_ok());
}
#[tokio::test]
async fn test_late_subscriber() {
let broadcast = Broadcast::new();
broadcast.shutdown();
let sub = broadcast.subscribe();
let task = tokio::spawn(async move {
sub.recv().await;
});
assert!(timeout(Duration::from_millis(50), task).await.is_err());
}
}