use std::{
hash::{Hash, Hasher},
mem, ops,
};
use readlock::Shared;
use tokio::sync::broadcast::{self, Sender};
use tokio_stream::wrappers::BroadcastStream;
use crate::subscriber::Subscriber;
#[derive(Debug)]
pub struct Observable<T> {
value: Shared<T>,
sender: Sender<()>,
}
impl<T> Observable<T> {
pub fn new(value: T) -> Self {
let (sender, _) = broadcast::channel(1);
Self { value: Shared::new(value), sender }
}
pub fn subscribe(this: &Self) -> Subscriber<T> {
let rx = this.sender.subscribe();
Subscriber::new(Shared::get_read_lock(&this.value), BroadcastStream::new(rx))
}
pub fn get(this: &Self) -> &T {
&this.value
}
pub fn set(this: &mut Self, value: T) {
*Shared::lock(&mut this.value) = value;
Self::broadcast_update(this);
}
pub fn set_eq(this: &mut Self, value: T)
where
T: Clone + PartialEq,
{
Self::update_eq(this, |inner| {
*inner = value;
});
}
pub fn set_hash(this: &mut Self, value: T)
where
T: Hash,
{
Self::update_hash(this, |inner| {
*inner = value;
});
}
pub fn replace(this: &mut Self, value: T) -> T {
let result = mem::replace(&mut *Shared::lock(&mut this.value), value);
Self::broadcast_update(this);
result
}
pub fn take(this: &mut Self) -> T
where
T: Default,
{
Self::replace(this, T::default())
}
pub fn update(this: &mut Self, f: impl FnOnce(&mut T)) {
f(&mut *Shared::lock(&mut this.value));
Self::broadcast_update(this);
}
pub fn update_eq(this: &mut Self, f: impl FnOnce(&mut T))
where
T: Clone + PartialEq,
{
let prev = this.value.clone();
f(&mut *Shared::lock(&mut this.value));
if *this.value != prev {
Self::broadcast_update(this);
}
}
pub fn update_hash(this: &mut Self, f: impl FnOnce(&mut T))
where
T: Hash,
{
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
this.value.hash(&mut hasher);
let prev_hash = hasher.finish();
f(&mut *Shared::lock(&mut this.value));
let mut hasher = DefaultHasher::new();
this.value.hash(&mut hasher);
let new_hash = hasher.finish();
if prev_hash != new_hash {
Self::broadcast_update(this);
}
}
fn broadcast_update(this: &Self) {
let _num_receivers = this.sender.send(()).unwrap_or(0);
#[cfg(feature = "tracing")]
if _num_receivers > 0 {
tracing::debug!("New observable value broadcast to {_num_receivers} receivers");
}
}
}
impl<T: Default> Default for Observable<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T> ops::Deref for Observable<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}