use std::{
borrow::{Borrow, BorrowMut},
cell::UnsafeCell,
ops::{Deref, DerefMut},
time::{Duration, Instant},
};
use crate::{AsyncEvent, CanceledError, TimeoutError};
pub struct AsyncLock<T> {
value: UnsafeCell<T>,
unlocked: AsyncEvent,
}
static_assertions::const_assert!(impls::impls!(AsyncLock<()>: !Send & !Sync));
impl<T> AsyncLock<T> {
pub fn new(value: T) -> Self {
let unlocked = AsyncEvent::new();
unlocked.set();
Self {
value: UnsafeCell::new(value),
unlocked,
}
}
pub fn try_lock(&self) -> Option<AsyncLockRef<'_, T>> {
if self.unlocked.is_set() {
self.unlocked.reset();
Some(AsyncLockRef { parent: self })
} else {
None
}
}
pub async fn lock_with_deadline(
&self,
deadline: Option<Instant>,
) -> Result<AsyncLockRef<'_, T>, TimeoutError> {
self.unlocked.wait_with_deadline(deadline).await?;
self.try_lock().ok_or(TimeoutError::Timeout)
}
pub async fn lock_with_timeout(
&self,
timeout: Option<Duration>,
) -> Result<AsyncLockRef<'_, T>, TimeoutError> {
let deadline = timeout.map(|timeout| crate::clock_now() + timeout);
self.lock_with_deadline(deadline).await
}
pub async fn lock(&self) -> Result<AsyncLockRef<'_, T>, CanceledError> {
while !self.unlocked.is_set() {
self.unlocked.wait().await?;
}
self.unlocked.reset();
Ok(AsyncLockRef { parent: self })
}
}
impl<T: Default> Default for AsyncLock<T> {
fn default() -> Self {
AsyncLock::new(T::default())
}
}
pub struct AsyncLockRef<'a, T> {
parent: &'a AsyncLock<T>,
}
impl<T> Deref for AsyncLockRef<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.parent.value.get() }
}
}
impl<T> DerefMut for AsyncLockRef<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.parent.value.get() }
}
}
impl<T> Drop for AsyncLockRef<'_, T> {
fn drop(&mut self) {
self.parent.unlocked.set_wake_one();
}
}
impl<T> Borrow<T> for AsyncLockRef<'_, T> {
fn borrow(&self) -> &T {
self.deref()
}
}
impl<T> BorrowMut<T> for AsyncLockRef<'_, T> {
fn borrow_mut(&mut self) -> &mut T {
self.deref_mut()
}
}
#[cfg(test)]
mod test {
use std::{rc::Rc, time::Duration};
use crate::{
AsyncEvent, AsyncLock,
operations::{self, spawn_task},
};
#[crate::test]
async fn async_lock_test() {
let l: Rc<AsyncLock<usize>> = Default::default();
let mut l_ref = l.lock().await.unwrap();
let task = {
let l = l.clone();
operations::spawn_task(async move {
let mut l_ref = l.lock().await.unwrap();
*l_ref += 1;
})
};
for _ in 0..100 {
operations::yield_io().await;
}
*l_ref = 100;
drop(l_ref);
task.await.unwrap();
assert_eq!(*l.lock().await.unwrap(), 101);
}
#[crate::test]
async fn async_timeout_test() {
let l = Rc::new(AsyncLock::new(0));
let l2 = l.clone();
let ready = Rc::new(AsyncEvent::new());
let ready2 = ready.clone();
let done = Rc::new(AsyncEvent::new());
let done2 = done.clone();
let other = spawn_task(async move {
let guard = l2.lock().await.unwrap();
ready2.set();
done2.wait().await.unwrap();
drop(guard);
});
ready.wait().await.unwrap();
let wait = l.lock_with_timeout(Some(Duration::from_millis(1))).await;
assert!(wait.is_err());
done.set();
other.await.unwrap();
let wait = l
.lock_with_timeout(Some(Duration::from_millis(1)))
.await
.unwrap();
assert_eq!(*wait, 0);
}
#[crate::test]
async fn lock_from_parallel_futures() {
let l1 = Rc::new(AsyncLock::new(0));
let l2 = l1.clone();
let l3 = l2.clone();
let fut1 = async move {
let mut guard = l1.lock().await.unwrap();
operations::sleep(std::time::Duration::from_millis(100))
.await
.unwrap();
*guard += 2;
drop(guard);
};
let fut2 = async move {
let mut guard = l2.lock().await.unwrap();
*guard += 1;
drop(guard);
};
futures::join!(fut1, fut2);
assert_eq!(*l3.lock().await.unwrap(), 3);
}
}