use std::{
fmt,
hash::Hash,
ops,
sync::{Arc, RwLock, RwLockWriteGuard},
};
use readlock::{SharedReadGuard, SharedReadLock};
use crate::{state::ObservableState, ObservableReadGuard, Subscriber};
#[derive(Debug)]
pub struct Observable<T> {
state: Arc<RwLock<ObservableState<T>>>,
_num_clones: Arc<()>,
}
impl<T> Observable<T> {
#[must_use]
pub fn new(value: T) -> Self {
Self {
state: Arc::new(RwLock::new(ObservableState::new(value))),
_num_clones: Arc::new(()),
}
}
pub fn subscribe(&self) -> Subscriber<T> {
let version = self.state.read().unwrap().version();
Subscriber::new(SharedReadLock::from_inner(Arc::clone(&self.state)), version)
}
pub fn subscribe_reset(&self) -> Subscriber<T> {
Subscriber::new(SharedReadLock::from_inner(Arc::clone(&self.state)), 0)
}
pub fn get(&self) -> T
where
T: Clone,
{
self.state.read().unwrap().get().clone()
}
pub fn read(&self) -> ObservableReadGuard<'_, T> {
ObservableReadGuard::new(SharedReadGuard::from_inner(self.state.read().unwrap()))
}
pub fn write(&self) -> ObservableWriteGuard<'_, T> {
ObservableWriteGuard::new(self.state.write().unwrap())
}
pub fn set(&self, value: T) -> T {
self.state.write().unwrap().set(value)
}
pub fn set_if_not_eq(&self, value: T) -> Option<T>
where
T: PartialEq,
{
self.state.write().unwrap().set_if_not_eq(value)
}
pub fn set_if_hash_not_eq(&self, value: T) -> Option<T>
where
T: Hash,
{
self.state.write().unwrap().set_if_hash_not_eq(value)
}
pub fn take(&self) -> T
where
T: Default,
{
self.set(T::default())
}
pub fn update(&self, f: impl FnOnce(&mut T)) {
self.state.write().unwrap().update(f);
}
pub fn update_if(&self, f: impl FnOnce(&mut T) -> bool) {
self.state.write().unwrap().update_if(f);
}
#[must_use]
pub fn observable_count(&self) -> usize {
Arc::strong_count(&self._num_clones)
}
#[must_use]
pub fn subscriber_count(&self) -> usize {
self.ref_count() - self.observable_count()
}
#[must_use]
pub fn ref_count(&self) -> usize {
Arc::strong_count(&self.state)
}
}
impl<T> Clone for Observable<T> {
fn clone(&self) -> Self {
Self { state: self.state.clone(), _num_clones: self._num_clones.clone() }
}
}
impl<T: Default> Default for Observable<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T> Drop for Observable<T> {
fn drop(&mut self) {
if Arc::strong_count(&self._num_clones) == 1 {
self.state.write().unwrap().close();
}
}
}
#[must_use]
#[clippy::has_significant_drop]
pub struct ObservableWriteGuard<'a, T> {
inner: RwLockWriteGuard<'a, ObservableState<T>>,
}
impl<'a, T> ObservableWriteGuard<'a, T> {
fn new(inner: RwLockWriteGuard<'a, ObservableState<T>>) -> Self {
Self { inner }
}
pub fn set(this: &mut Self, value: T) -> T {
this.inner.set(value)
}
pub fn set_if_not_eq(this: &mut Self, value: T) -> Option<T>
where
T: PartialEq,
{
this.inner.set_if_not_eq(value)
}
pub fn set_if_hash_not_eq(this: &mut Self, value: T) -> Option<T>
where
T: Hash,
{
this.inner.set_if_hash_not_eq(value)
}
pub fn take(this: &mut Self) -> T
where
T: Default,
{
Self::set(this, T::default())
}
pub fn update(this: &mut Self, f: impl FnOnce(&mut T)) {
this.inner.update(f);
}
pub fn update_if(this: &mut Self, f: impl FnOnce(&mut T) -> bool) {
this.inner.update_if(f);
}
}
impl<T: fmt::Debug> fmt::Debug for ObservableWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.inner.fmt(f)
}
}
impl<T> ops::Deref for ObservableWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.get()
}
}