use parking_lot::RwLock;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::Notify;
use tracing::warn;
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum Error {
#[error("BroadcastOnce dropped")]
Dropped,
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug)]
pub struct BroadcastOnce<T>
where
T: Send + Sync,
{
shared: Arc<Shared<T>>,
}
impl<T> Default for BroadcastOnce<T>
where
T: Send + Sync,
{
fn default() -> Self {
Self {
shared: Arc::new(Shared {
data: Default::default(),
notify: Default::default(),
}),
}
}
}
#[derive(Debug)]
struct Shared<T> {
data: RwLock<Option<Result<T>>>,
notify: Notify,
}
impl<T: Clone + Send + Sync> BroadcastOnce<T> {
pub fn receiver(&self) -> BroadcastOnceReceiver<T> {
BroadcastOnceReceiver {
shared: Arc::clone(&self.shared),
}
}
pub fn broadcast(self, r: T) {
let mut locked = self.shared.data.write();
assert!(locked.is_none(), "double publish");
*locked = Some(Ok(r));
self.shared.notify.notify_waiters();
}
}
impl<T> Drop for BroadcastOnce<T>
where
T: Send + Sync,
{
fn drop(&mut self) {
let mut data = self.shared.data.write();
if data.is_none() {
warn!("BroadcastOnce dropped without producing");
*data = Some(Err(Error::Dropped));
self.shared.notify.notify_waiters();
}
}
}
#[derive(Debug, Clone)]
pub struct BroadcastOnceReceiver<T> {
shared: Arc<Shared<T>>,
}
impl<T: Clone + Send + Sync> BroadcastOnceReceiver<T> {
pub fn peek(&self) -> Option<Result<T>> {
self.shared.data.read().clone()
}
pub async fn receive(&self) -> Result<T> {
let notified = self.shared.notify.notified();
if let Some(v) = self.peek() {
return v;
}
notified.await;
self.peek().expect("just got notified")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_broadcast_once() {
let broadcast: BroadcastOnce<usize> = Default::default();
broadcast.broadcast(1);
let broadcast: BroadcastOnce<usize> = Default::default();
let receiver = broadcast.receiver();
assert!(receiver.peek().is_none());
tokio::time::timeout(Duration::from_millis(1), receiver.receive())
.await
.unwrap_err();
broadcast.broadcast(2);
assert_eq!(receiver.peek().unwrap(), Ok(2));
assert_eq!(receiver.peek().unwrap(), Ok(2)); assert_eq!(receiver.receive().await, Ok(2));
assert_eq!(receiver.receive().await, Ok(2));
let broadcast: BroadcastOnce<usize> = Default::default();
let r1 = broadcast.receiver();
let r2 = broadcast.receiver();
broadcast.broadcast(4);
assert_eq!(r1.receive().await, Ok(4));
assert_eq!(r2.receive().await, Ok(4));
let broadcast: BroadcastOnce<usize> = Default::default();
let r1 = broadcast.receiver();
let r2 = broadcast.receiver();
assert!(r1.peek().is_none());
std::mem::drop(broadcast);
assert_eq!(r1.receive().await, Err(Error::Dropped));
assert_eq!(r2.receive().await, Err(Error::Dropped));
}
}