use std::collections::BTreeSet;
use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::{watch, Notify};
use tokio::task::JoinHandle;
use tokio_stream::wrappers::WatchStream;
use tokio_stream::Stream;
pub struct SubscriberGuard {
count: Arc<AtomicUsize>,
}
impl Drop for SubscriberGuard {
fn drop(&mut self) {
self.count.fetch_sub(1, Ordering::Relaxed);
}
}
#[derive(Clone)]
pub struct PowerBroadcast<T: Clone + Default + Send + Sync + 'static> {
rx: watch::Receiver<T>,
subscriber_count: Arc<AtomicUsize>,
wake: Arc<Notify>,
valid_ids: Arc<BTreeSet<usize>>,
}
impl<T: Clone + Default + Send + Sync + 'static> PowerBroadcast<T> {
pub fn add_subscriber(&self) -> SubscriberGuard {
self.subscriber_count.fetch_add(1, Ordering::Relaxed);
self.wake.notify_one();
SubscriberGuard {
count: self.subscriber_count.clone(),
}
}
pub async fn wait_for_fresh(&self) -> Option<T> {
let mut rx = self.rx.clone();
rx.borrow_and_update();
if rx.changed().await.is_ok() {
Some(rx.borrow().clone())
} else {
None
}
}
pub fn stream(&self) -> impl Stream<Item = T> {
WatchStream::from_changes(self.rx.clone())
}
pub fn validate_ids(&self, ids: &[usize]) -> Result<(), Vec<usize>> {
let unknown: Vec<usize> = ids
.iter()
.filter(|id| !self.valid_ids.contains(id))
.copied()
.collect();
if unknown.is_empty() {
Ok(())
} else {
Err(unknown)
}
}
pub fn valid_ids(&self) -> &BTreeSet<usize> {
&self.valid_ids
}
}
pub struct PowerPoller<T: Clone + Default + Send + Sync + 'static> {
broadcast: PowerBroadcast<T>,
_handle: JoinHandle<()>,
}
impl<T: Clone + Default + Send + Sync + 'static> PowerPoller<T> {
pub fn start<F, Fut>(valid_ids: BTreeSet<usize>, spawn_task: F) -> Self
where
F: FnOnce(watch::Sender<T>, Arc<AtomicUsize>, Arc<Notify>) -> Fut,
Fut: Future<Output = ()> + Send + 'static,
{
let (tx, rx) = watch::channel(T::default());
let subscriber_count = Arc::new(AtomicUsize::new(0));
let wake = Arc::new(Notify::new());
let handle = tokio::spawn(spawn_task(tx, subscriber_count.clone(), wake.clone()));
Self {
broadcast: PowerBroadcast {
rx,
subscriber_count,
wake,
valid_ids: Arc::new(valid_ids),
},
_handle: handle,
}
}
pub fn broadcast(&self) -> PowerBroadcast<T> {
self.broadcast.clone()
}
}