use crate::art::{async_timeout, future::to, stream};
use crate::channel::{unbounded, UnboundedReceiver, UnboundedSender};
use async_lock::{Mutex, MutexGuard};
use futures::{stream::FuturesOrdered, Future, FutureExt};
use std::{fmt, time::Duration};
use tracing::warn;
#[cfg(not(async_executor_impl = "tokio"))]
use async_std::prelude::StreamExt;
#[cfg(async_executor_impl = "tokio")]
use tokio_stream::StreamExt;
#[derive(Default)]
pub struct SubscribableMutex<T: ?Sized> {
subscribers: Mutex<Vec<UnboundedSender<()>>>,
mutex: Mutex<T>,
}
impl<T> SubscribableMutex<T> {
pub fn new(t: T) -> Self {
Self {
mutex: Mutex::new(t),
subscribers: Mutex::default(),
}
}
#[deprecated(note = "Consider using a different function instead")]
pub async fn lock(&self) -> MutexGuard<'_, T> {
self.mutex.lock().await
}
pub async fn notify_change_subscribers(&self) {
let mut lock = self.subscribers.lock().await;
let mut idx_to_remove = Vec::new();
for (idx, sender) in lock.iter().enumerate() {
if sender.send(()).await.is_err() {
idx_to_remove.push(idx);
}
}
for idx in idx_to_remove.into_iter().rev() {
lock.remove(idx);
}
}
pub async fn subscribe(&self) -> UnboundedReceiver<()> {
let (sender, receiver) = unbounded();
self.subscribers.lock().await.push(sender);
receiver
}
pub async fn modify<F>(&self, cb: F)
where
F: FnOnce(&mut T),
{
let mut lock = self.mutex.lock().await;
cb(&mut *lock);
drop(lock);
self.notify_change_subscribers().await;
}
pub async fn set(&self, val: T) {
let mut lock = self.mutex.lock().await;
*lock = val;
drop(lock);
self.notify_change_subscribers().await;
}
pub async fn wait_until<F>(&self, mut f: F)
where
F: FnMut(&T) -> bool,
{
let receiver = {
let lock = self.mutex.lock().await;
if f(&*lock) {
return;
}
let receiver = self.subscribe().await;
drop(lock);
receiver
};
loop {
receiver
.recv()
.await
.expect("`SubscribableMutex::wait_until` was still running when it was dropped");
let lock = self.mutex.lock().await;
if f(&*lock) {
return;
}
}
}
async fn wait_until_with_trigger_inner<'a, F>(
&self,
mut f: F,
ready_chan: futures::channel::oneshot::Sender<()>,
) where
F: FnMut(&T) -> bool + 'a,
{
let receiver = self.subscribe().await;
if ready_chan.send(()).is_err() {
warn!("unable to notify that channel is ready");
};
loop {
receiver
.recv()
.await
.expect("`SubscribableMutex::wait_until` was still running when it was dropped");
let lock = self.mutex.lock().await;
if f(&*lock) {
return;
}
drop(lock);
}
}
pub fn wait_until_with_trigger<'a, F>(
&'a self,
f: F,
) -> FuturesOrdered<impl Future<Output = ()> + 'a>
where
F: FnMut(&T) -> bool + 'a,
{
let (s, r) = futures::channel::oneshot::channel::<()>();
let mut result = FuturesOrdered::new();
let f1 = r.map(|_| ()).left_future();
let f2 = self.wait_until_with_trigger_inner(f, s).right_future();
result.push_back(f1);
result.push_back(f2);
result
}
pub fn wait_timeout_until_with_trigger<'a, F>(
&'a self,
timeout: Duration,
f: F,
) -> stream::to::Timeout<FuturesOrdered<impl Future<Output = ()> + 'a>>
where
F: FnMut(&T) -> bool + 'a,
{
self.wait_until_with_trigger(f).timeout(timeout)
}
pub async fn wait_timeout_until<F>(&self, timeout: Duration, f: F) -> to::Result<()>
where
F: FnMut(&T) -> bool,
{
async_timeout(timeout, self.wait_until(f)).await
}
}
impl<T: PartialEq> SubscribableMutex<T> {
pub async fn compare_and_set(&self, compare: T, set: T) {
let mut lock = self.mutex.lock().await;
if *lock == compare {
*lock = set;
drop(lock);
self.notify_change_subscribers().await;
}
}
}
impl<T: Clone> SubscribableMutex<T> {
pub async fn cloned(&self) -> T {
self.mutex.lock().await.clone()
}
}
impl<T: Copy> SubscribableMutex<T> {
pub async fn copied(&self) -> T {
*self.mutex.lock().await
}
}
impl<T: fmt::Debug> fmt::Debug for SubscribableMutex<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Locked;
impl fmt::Debug for Locked {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
match self.mutex.try_lock() {
None => f
.debug_struct("SubscribableMutex")
.field("data", &Locked)
.finish(),
Some(guard) => f
.debug_struct("SubscribableMutex")
.field("data", &&*guard)
.finish(),
}
}
}
#[cfg(test)]
mod tests {
use super::SubscribableMutex;
use crate::art::{async_sleep, async_spawn, async_timeout};
use std::{sync::Arc, time::Duration};
#[cfg_attr(
async_executor_impl = "tokio",
tokio::test(flavor = "multi_thread", worker_threads = 2)
)]
#[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
async fn test_wait_timeout_until() {
let mutex: Arc<SubscribableMutex<usize>> = Arc::default();
{
let mutex = Arc::clone(&mutex);
async_spawn(async move {
for i in 0..=10 {
async_sleep(Duration::from_millis(100)).await;
mutex.set(i).await;
}
});
}
let result = mutex
.wait_timeout_until(Duration::from_secs(2), |s| *s == 10)
.await;
assert_eq!(result, Ok(()));
assert_eq!(mutex.copied().await, 10);
}
#[cfg_attr(
async_executor_impl = "tokio",
tokio::test(flavor = "multi_thread", worker_threads = 2)
)]
#[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
async fn test_wait_timeout_until_fail() {
let mutex: Arc<SubscribableMutex<usize>> = Arc::default();
{
let mutex = Arc::clone(&mutex);
async_spawn(async move {
for i in 0..10 {
async_sleep(Duration::from_millis(100)).await;
mutex.set(i).await;
}
});
}
let result = mutex
.wait_timeout_until(Duration::from_secs(2), |s| *s == 10)
.await;
assert!(result.is_err());
assert_eq!(mutex.copied().await, 9);
}
#[cfg_attr(
async_executor_impl = "tokio",
tokio::test(flavor = "multi_thread", worker_threads = 2)
)]
#[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
async fn test_compare_and_set() {
let mutex = SubscribableMutex::new(5usize);
let subscriber = mutex.subscribe().await;
assert_eq!(mutex.copied().await, 5);
mutex.compare_and_set(5, 10).await;
assert_eq!(mutex.copied().await, 10);
assert!(subscriber.try_recv().is_ok());
mutex.compare_and_set(5, 20).await;
assert_eq!(mutex.copied().await, 10);
assert!(subscriber.try_recv().is_err());
}
#[cfg_attr(
async_executor_impl = "tokio",
tokio::test(flavor = "multi_thread", worker_threads = 2)
)]
#[cfg_attr(not(async_executor_impl = "tokio"), async_std::test)]
async fn test_subscriber() {
let mutex = SubscribableMutex::new(5usize);
let subscriber = mutex.subscribe().await;
assert!(subscriber.try_recv().is_err());
mutex.set(10).await;
assert_eq!(subscriber.try_recv(), Ok(()));
mutex.set(20).await;
assert_eq!(
async_timeout(Duration::from_millis(10), subscriber.recv()).await,
Ok(Ok(()))
);
assert_eq!(mutex.subscribers.lock().await.len(), 1);
drop(subscriber);
mutex.notify_change_subscribers().await;
assert_eq!(mutex.subscribers.lock().await.len(), 0);
}
}