use std::{
fmt,
future::{poll_fn, Future},
hash::Hash,
ops,
pin::Pin,
sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard},
task::{Context, Poll},
};
use futures_core::Stream;
use crate::state::ObservableState;
#[derive(Debug)]
pub struct Observable<T> {
state: Arc<RwLock<ObservableState<T>>>,
_num_clones: Arc<()>,
}
impl<T> Observable<T> {
pub fn new(value: T) -> Self {
Self {
state: Arc::new(RwLock::new(ObservableState::new(value))),
_num_clones: Arc::new(()),
}
}
pub fn subscribe(&self) -> Subscriber<T> {
let version = self.state.read().unwrap().version();
Subscriber::new(Arc::clone(&self.state), version)
}
pub fn read(&self) -> ObservableReadGuard<'_, T> {
ObservableReadGuard::new(self.state.read().unwrap())
}
pub fn write(&self) -> ObservableWriteGuard<'_, T> {
ObservableWriteGuard::new(self.state.write().unwrap())
}
pub fn set(&self, value: T) {
self.state.write().unwrap().set(value);
}
pub fn set_eq(&self, value: T)
where
T: Clone + PartialEq,
{
Self::update_eq(self, |inner| {
*inner = value;
});
}
pub fn set_hash(&self, value: T)
where
T: Hash,
{
Self::update_hash(self, |inner| {
*inner = value;
});
}
pub fn replace(&self, value: T) -> T {
self.state.write().unwrap().replace(value)
}
pub fn take(&self) -> T
where
T: Default,
{
self.replace(T::default())
}
pub fn update(&self, f: impl FnOnce(&mut T)) {
self.state.write().unwrap().update(f);
}
pub fn update_eq(&self, f: impl FnOnce(&mut T))
where
T: Clone + PartialEq,
{
self.state.write().unwrap().update_eq(f);
}
pub fn update_hash(&self, f: impl FnOnce(&mut T))
where
T: Hash,
{
self.state.write().unwrap().update_hash(f);
}
}
impl<T> Clone for Observable<T> {
fn clone(&self) -> Self {
Self { state: self.state.clone(), _num_clones: self._num_clones.clone() }
}
}
impl<T: Default> Default for Observable<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T> Drop for Observable<T> {
fn drop(&mut self) {
if Arc::strong_count(&self._num_clones) == 1 {
self.state.write().unwrap().close();
}
}
}
#[clippy::has_significant_drop]
pub struct ObservableReadGuard<'a, T> {
inner: RwLockReadGuard<'a, ObservableState<T>>,
}
impl<'a, T> ObservableReadGuard<'a, T> {
fn new(inner: RwLockReadGuard<'a, ObservableState<T>>) -> Self {
Self { inner }
}
}
impl<T: fmt::Debug> fmt::Debug for ObservableReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.inner.fmt(f)
}
}
impl<T> ops::Deref for ObservableReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.get()
}
}
#[clippy::has_significant_drop]
pub struct ObservableWriteGuard<'a, T> {
inner: RwLockWriteGuard<'a, ObservableState<T>>,
}
impl<'a, T> ObservableWriteGuard<'a, T> {
fn new(inner: RwLockWriteGuard<'a, ObservableState<T>>) -> Self {
Self { inner }
}
pub fn set(this: &mut Self, value: T) {
this.inner.set(value);
}
pub fn set_eq(this: &mut Self, value: T)
where
T: Clone + PartialEq,
{
Self::update_eq(this, |inner| {
*inner = value;
});
}
pub fn set_hash(this: &mut Self, value: T)
where
T: Hash,
{
Self::update_hash(this, |inner| {
*inner = value;
});
}
pub fn replace(this: &mut Self, value: T) -> T {
this.inner.replace(value)
}
pub fn take(this: &mut Self) -> T
where
T: Default,
{
Self::replace(this, T::default())
}
pub fn update(this: &mut Self, f: impl FnOnce(&mut T)) {
this.inner.update(f);
}
pub fn update_eq(this: &mut Self, f: impl FnOnce(&mut T))
where
T: Clone + PartialEq,
{
this.inner.update_eq(f);
}
pub fn update_hash(this: &mut Self, f: impl FnOnce(&mut T))
where
T: Hash,
{
this.inner.update_hash(f);
}
}
impl<T: fmt::Debug> fmt::Debug for ObservableWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.inner.fmt(f)
}
}
impl<T> ops::Deref for ObservableWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.get()
}
}
#[derive(Debug)]
pub struct Subscriber<T> {
state: Arc<RwLock<ObservableState<T>>>,
observed_version: u64,
}
impl<T> Subscriber<T> {
pub(crate) fn new(read_lock: Arc<RwLock<ObservableState<T>>>, version: u64) -> Self {
Self { state: read_lock, observed_version: version }
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Next<T>
where
T: Clone,
{
Next::new(self)
}
pub fn next_now(&mut self) -> T
where
T: Clone,
{
let lock = self.state.read().unwrap();
self.observed_version = lock.version();
lock.get().clone()
}
pub fn get(&self) -> T
where
T: Clone,
{
self.read().clone()
}
pub async fn next_ref(&mut self) -> Option<ObservableReadGuard<'_, T>> {
poll_fn(|cx| self.poll_next_ref(cx).map(|opt| opt.map(|_| {}))).await?;
Some(self.next_ref_now())
}
pub fn next_ref_now(&mut self) -> ObservableReadGuard<'_, T> {
let lock = self.state.read().unwrap();
self.observed_version = lock.version();
ObservableReadGuard::new(lock)
}
pub fn read(&self) -> ObservableReadGuard<'_, T> {
ObservableReadGuard::new(self.state.read().unwrap())
}
fn poll_next_ref(&mut self, cx: &mut Context<'_>) -> Poll<Option<ObservableReadGuard<'_, T>>> {
let state = self.state.read().unwrap();
let version = state.version();
if version == 0 {
Poll::Ready(None)
} else if self.observed_version < version {
self.observed_version = version;
Poll::Ready(Some(ObservableReadGuard::new(state)))
} else {
state.add_waker(cx.waker().clone());
Poll::Pending
}
}
}
impl<T: Clone> Stream for Subscriber<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_next_ref(cx).map(opt_guard_to_owned)
}
}
#[derive(Debug)]
pub struct Next<'a, T> {
subscriber: &'a mut Subscriber<T>,
}
impl<'a, T> Next<'a, T> {
fn new(subscriber: &'a mut Subscriber<T>) -> Self {
Self { subscriber }
}
}
impl<'a, T: Clone> Future for Next<'a, T> {
type Output = Option<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.subscriber.poll_next_ref(cx).map(opt_guard_to_owned)
}
}
fn opt_guard_to_owned<T: Clone>(value: Option<ObservableReadGuard<'_, T>>) -> Option<T> {
value.map(|guard| guard.to_owned())
}