use std::{
fmt,
hash::Hash,
ops,
sync::{Arc, Weak},
};
use readlock::{SharedReadGuard, SharedReadLock};
#[cfg(feature = "async-lock")]
use readlock_tokio::{
SharedReadGuard as SharedAsyncReadGuard, SharedReadLock as SharedAsyncReadLock,
};
#[cfg(feature = "async-lock")]
use crate::AsyncLock;
use crate::{lock::Lock, state::ObservableState, ObservableReadGuard, Subscriber, SyncLock};
pub struct SharedObservable<T, L: Lock = SyncLock> {
state: Arc<L::RwLock<ObservableState<T>>>,
_num_clones: Arc<()>,
}
impl<T> SharedObservable<T> {
#[must_use]
pub fn new(value: T) -> Self {
Self::from_inner(Arc::new(std::sync::RwLock::new(ObservableState::new(value))))
}
pub fn subscribe(&self) -> Subscriber<T> {
let version = self.state.read().unwrap().version();
Subscriber::new(SharedReadLock::from_inner(Arc::clone(&self.state)), version)
}
pub fn subscribe_reset(&self) -> Subscriber<T> {
Subscriber::new(SharedReadLock::from_inner(Arc::clone(&self.state)), 0)
}
pub fn get(&self) -> T
where
T: Clone,
{
self.state.read().unwrap().get().clone()
}
pub fn read(&self) -> ObservableReadGuard<'_, T> {
ObservableReadGuard::new(SharedReadGuard::from_inner(self.state.read().unwrap()))
}
pub fn write(&self) -> ObservableWriteGuard<'_, T> {
ObservableWriteGuard::new(self.state.write().unwrap())
}
pub fn set(&self, value: T) -> T {
self.state.write().unwrap().set(value)
}
pub fn set_if_not_eq(&self, value: T) -> Option<T>
where
T: PartialEq,
{
self.state.write().unwrap().set_if_not_eq(value)
}
pub fn set_if_hash_not_eq(&self, value: T) -> Option<T>
where
T: Hash,
{
self.state.write().unwrap().set_if_hash_not_eq(value)
}
pub fn take(&self) -> T
where
T: Default,
{
self.set(T::default())
}
pub fn update(&self, f: impl FnOnce(&mut T)) {
self.state.write().unwrap().update(f);
}
pub fn update_if(&self, f: impl FnOnce(&mut T) -> bool) {
self.state.write().unwrap().update_if(f);
}
}
#[cfg(feature = "async-lock")]
impl<T: Send + Sync + 'static> SharedObservable<T, AsyncLock> {
#[must_use]
pub fn new_async(value: T) -> Self {
Self::from_inner(Arc::new(tokio::sync::RwLock::new(ObservableState::new(value))))
}
pub async fn subscribe(&self) -> Subscriber<T, AsyncLock> {
let version = self.state.read().await.version();
Subscriber::new_async(SharedAsyncReadLock::from_inner(Arc::clone(&self.state)), version)
}
pub fn subscribe_reset(&self) -> Subscriber<T, AsyncLock> {
Subscriber::new_async(SharedAsyncReadLock::from_inner(Arc::clone(&self.state)), 0)
}
pub async fn get(&self) -> T
where
T: Clone,
{
self.state.read().await.get().clone()
}
pub async fn read(&self) -> ObservableReadGuard<'_, T, AsyncLock> {
ObservableReadGuard::new(SharedAsyncReadGuard::from_inner(self.state.read().await))
}
pub async fn write(&self) -> ObservableWriteGuard<'_, T, AsyncLock> {
ObservableWriteGuard::new(self.state.write().await)
}
pub async fn set(&self, value: T) -> T {
self.state.write().await.set(value)
}
pub async fn set_if_not_eq(&self, value: T) -> Option<T>
where
T: PartialEq,
{
self.state.write().await.set_if_not_eq(value)
}
pub async fn set_if_hash_not_eq(&self, value: T) -> Option<T>
where
T: Hash,
{
self.state.write().await.set_if_hash_not_eq(value)
}
pub async fn take(&self) -> T
where
T: Default,
{
self.set(T::default()).await
}
pub async fn update(&self, f: impl FnOnce(&mut T)) {
self.state.write().await.update(f);
}
pub async fn update_if(&self, f: impl FnOnce(&mut T) -> bool) {
self.state.write().await.update_if(f);
}
}
impl<T, L: Lock> SharedObservable<T, L> {
pub(crate) fn from_inner(state: Arc<L::RwLock<ObservableState<T>>>) -> Self {
Self { state, _num_clones: Arc::new(()) }
}
#[must_use]
pub fn observable_count(&self) -> usize {
Arc::strong_count(&self._num_clones)
}
#[must_use]
pub fn subscriber_count(&self) -> usize {
self.strong_count() - self.observable_count()
}
#[must_use]
pub fn strong_count(&self) -> usize {
Arc::strong_count(&self.state)
}
#[must_use]
pub fn weak_count(&self) -> usize {
Arc::weak_count(&self.state)
}
pub fn downgrade(&self) -> WeakObservable<T, L> {
WeakObservable {
state: Arc::downgrade(&self.state),
_num_clones: Arc::downgrade(&self._num_clones),
}
}
}
impl<T, L: Lock> Clone for SharedObservable<T, L> {
fn clone(&self) -> Self {
Self { state: self.state.clone(), _num_clones: self._num_clones.clone() }
}
}
impl<T, L: Lock> fmt::Debug for SharedObservable<T, L>
where
L::RwLock<ObservableState<T>>: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SharedObservable")
.field("state", &self.state)
.field("_num_clones", &self._num_clones)
.finish()
}
}
impl<T, L> Default for SharedObservable<T, L>
where
T: Default,
L: Lock,
{
fn default() -> Self {
let rwlock = L::new_rwlock(ObservableState::new(T::default()));
Self::from_inner(Arc::new(rwlock))
}
}
impl<T, L: Lock> Drop for SharedObservable<T, L> {
fn drop(&mut self) {
if Arc::strong_count(&self._num_clones) == 1 {
L::read_noblock(&self.state).close();
}
}
}
pub struct WeakObservable<T, L: Lock = SyncLock> {
state: Weak<L::RwLock<ObservableState<T>>>,
_num_clones: Weak<()>,
}
impl<T, L: Lock> WeakObservable<T, L> {
pub fn upgrade(&self) -> Option<SharedObservable<T, L>> {
let state = Weak::upgrade(&self.state)?;
let _num_clones = Weak::upgrade(&self._num_clones)?;
Some(SharedObservable { state, _num_clones })
}
}
impl<T, L: Lock> Clone for WeakObservable<T, L> {
fn clone(&self) -> Self {
Self { state: self.state.clone(), _num_clones: self._num_clones.clone() }
}
}
impl<T, L: Lock> fmt::Debug for WeakObservable<T, L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WeakObservable").finish_non_exhaustive()
}
}
#[must_use]
#[clippy::has_significant_drop]
pub struct ObservableWriteGuard<'a, T: 'a, L: Lock = SyncLock> {
inner: L::RwLockWriteGuard<'a, ObservableState<T>>,
}
impl<'a, T: 'a, L: Lock> ObservableWriteGuard<'a, T, L> {
fn new(inner: L::RwLockWriteGuard<'a, ObservableState<T>>) -> Self {
Self { inner }
}
pub fn set(this: &mut Self, value: T) -> T {
this.inner.set(value)
}
pub fn set_if_not_eq(this: &mut Self, value: T) -> Option<T>
where
T: PartialEq,
{
this.inner.set_if_not_eq(value)
}
pub fn set_if_hash_not_eq(this: &mut Self, value: T) -> Option<T>
where
T: Hash,
{
this.inner.set_if_hash_not_eq(value)
}
pub fn take(this: &mut Self) -> T
where
T: Default,
{
Self::set(this, T::default())
}
pub fn update(this: &mut Self, f: impl FnOnce(&mut T)) {
this.inner.update(f);
}
pub fn update_if(this: &mut Self, f: impl FnOnce(&mut T) -> bool) {
this.inner.update_if(f);
}
}
impl<T: fmt::Debug> fmt::Debug for ObservableWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.inner.fmt(f)
}
}
impl<T, L: Lock> ops::Deref for ObservableWriteGuard<'_, T, L> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.get()
}
}