use std::{
hash::{Hash, Hasher},
mem, ops,
pin::Pin,
task::{Context, Poll},
};
use futures_core::Stream;
use tokio::sync::broadcast::{self, Sender};
use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
#[derive(Debug)]
pub struct Observable<T> {
value: T,
sender: Sender<T>,
}
impl<T: Clone + Send + 'static> Observable<T> {
pub fn new(value: T) -> Self {
let (sender, _) = broadcast::channel(1);
Self { value, sender }
}
pub fn subscribe(this: &Self) -> Subscriber<T> {
let rx = this.sender.subscribe();
Subscriber::new(BroadcastStream::new(rx))
}
pub fn get(this: &Self) -> &T {
&this.value
}
pub fn set(this: &mut Self, value: T) {
Self::replace(this, value);
}
pub fn replace(this: &mut Self, value: T) -> T {
let result = mem::replace(&mut this.value, value);
Self::broadcast_update(this);
result
}
pub fn update(this: &mut Self, f: impl FnOnce(&mut T)) {
f(&mut this.value);
Self::broadcast_update(this);
}
pub fn update_eq(this: &mut Self, f: impl FnOnce(&mut T))
where
T: PartialEq,
{
let prev = this.value.clone();
f(&mut this.value);
if this.value != prev {
Self::broadcast_update(this);
}
}
pub fn update_hash(this: &mut Self, f: impl FnOnce(&mut T))
where
T: Hash,
{
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
this.value.hash(&mut hasher);
let prev_hash = hasher.finish();
f(&mut this.value);
let mut hasher = DefaultHasher::new();
this.value.hash(&mut hasher);
let new_hash = hasher.finish();
if prev_hash != new_hash {
Self::broadcast_update(this);
}
}
fn broadcast_update(this: &Self) {
if this.sender.receiver_count() != 0 {
let _num_receivers = this.sender.send(this.value.clone()).unwrap_or(0);
#[cfg(feature = "tracing")]
tracing::debug!("New observable value broadcast to {_num_receivers} receivers");
}
}
}
impl<T> ops::Deref for Observable<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
#[derive(Debug)]
pub struct Subscriber<T> {
inner: BroadcastStream<T>,
}
impl<T> Subscriber<T> {
fn new(inner: BroadcastStream<T>) -> Self {
Self { inner }
}
}
impl<T: Clone + Send + 'static> Stream for Subscriber<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
let poll = match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(value))) => Poll::Ready(Some(value)),
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(_)))) => continue,
Poll::Pending => Poll::Pending,
};
return poll;
}
}
}