use std::future::Future;
use std::mem::replace;
use std::ops::Deref;
use thiserror::Error;
use tokio::sync::{RwLock, RwLockReadGuard};
use tokio::task::{JoinError, JoinHandle};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RwTaskLockError {
#[error(transparent)]
JoinError(#[from] JoinError),
#[error("Attempting to access value not available due to a previously reported error.")]
CalledAfterError,
}
enum RwTaskLockState<T, E> {
Pending(JoinHandle<Result<T, E>>),
Ready(T),
Error,
}
pub struct RwTaskLockReadGuard<'a, T, E> {
guard: RwLockReadGuard<'a, RwTaskLockState<T, E>>,
}
impl<T, E> Deref for RwTaskLockReadGuard<'_, T, E> {
type Target = T;
fn deref(&self) -> &T {
match &*self.guard {
RwTaskLockState::Ready(val) => val,
_ => unreachable!("Read guard is only constructed for Ready state"),
}
}
}
pub struct RwTaskLock<T, E>
where
T: Send + Sync + 'static,
E: Send + Sync + 'static + From<RwTaskLockError>,
{
state: RwLock<RwTaskLockState<T, E>>,
}
impl<T, E> RwTaskLock<T, E>
where
T: Send + Sync + 'static,
E: Send + Sync + 'static + From<RwTaskLockError>,
{
pub fn from_value(val: T) -> Self {
Self {
state: RwLock::new(RwTaskLockState::Ready(val)),
}
}
pub fn from_task<Fut>(fut: Fut) -> Self
where
Fut: Future<Output = Result<T, E>> + Send + 'static,
{
let task = tokio::spawn(fut);
Self {
state: RwLock::new(RwTaskLockState::Pending(task)),
}
}
pub async fn read(&self) -> Result<RwTaskLockReadGuard<'_, T, E>, E> {
{
let state = self.state.read().await;
match &*state {
RwTaskLockState::Ready(_) => {
return Ok(RwTaskLockReadGuard { guard: state });
},
RwTaskLockState::Error => return Err(E::from(RwTaskLockError::CalledAfterError)),
RwTaskLockState::Pending(_) => {},
}
}
let mut state = self.state.write().await;
match replace(&mut *state, RwTaskLockState::Error) {
RwTaskLockState::Ready(v) => {
*state = RwTaskLockState::Ready(v);
},
RwTaskLockState::Error => {
return Err(E::from(RwTaskLockError::CalledAfterError));
},
RwTaskLockState::Pending(jh) => {
match jh.await.map_err(RwTaskLockError::JoinError)? {
Ok(v) => {
*state = RwTaskLockState::Ready(v);
},
Err(e) => {
*state = RwTaskLockState::Error;
return Err(e);
},
};
},
};
Ok(RwTaskLockReadGuard {
guard: state.downgrade(),
})
}
pub async fn update<Fut, Updater>(&self, updater: Updater) -> Result<(), RwTaskLockError>
where
Updater: FnOnce(T) -> Fut + Send + 'static,
Fut: Future<Output = Result<T, E>> + Send + 'static,
{
use RwTaskLockState::*;
let mut state_lg = self.state.write().await;
let state = replace(&mut *state_lg, RwTaskLockState::Error);
match state {
Pending(jh) => {
let new_task = tokio::spawn(async move {
let current = jh.await.map_err(RwTaskLockError::JoinError)??;
updater(current).await
});
*state_lg = Pending(new_task);
Ok(())
},
Ready(v) => {
*state_lg = Pending(tokio::spawn(updater(v)));
Ok(())
},
Error => {
*state_lg = Error;
Err(RwTaskLockError::CalledAfterError)
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_from_value() {
let lock: RwTaskLock<_, RwTaskLockError> = RwTaskLock::from_value(7);
let guard = lock.read().await.unwrap();
assert_eq!(*guard, 7);
let guard2 = lock.read().await.unwrap();
assert_eq!(*guard2, 7);
}
#[tokio::test]
async fn test_from_future_success() {
let lock = RwTaskLock::from_task(async {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Ok::<_, RwTaskLockError>(999)
});
let guard = lock.read().await.unwrap();
assert_eq!(*guard, 999);
let guard2 = lock.read().await.unwrap();
assert_eq!(*guard2, 999);
}
#[tokio::test]
async fn test_from_future_error() {
let lock = RwTaskLock::<u8, RwTaskLockError>::from_task(async { Err(RwTaskLockError::CalledAfterError) });
let result = lock.read().await;
assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
let result2 = lock.read().await;
assert!(matches!(result2, Err(RwTaskLockError::CalledAfterError)));
}
#[tokio::test]
async fn test_concurrent_read() {
use std::sync::Arc;
let lock = Arc::new(RwTaskLock::from_task(async {
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
Ok::<_, RwTaskLockError>("concurrent".to_string())
}));
let lock1 = lock.clone();
let lock2 = lock.clone();
let (a, b) = tokio::join!(lock1.read(), lock2.read());
assert_eq!(*a.unwrap(), "concurrent");
assert_eq!(*b.unwrap(), "concurrent");
}
#[tokio::test]
async fn test_error_then_retrieval() {
let lock = RwTaskLock::<u8, RwTaskLockError>::from_task(async { Err(RwTaskLockError::CalledAfterError) });
let _ = lock.read().await;
let result = lock.read().await;
assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
}
#[tokio::test]
async fn test_update_from_ready() {
let lock = RwTaskLock::from_value(100);
lock.update(|v| async move { Ok::<_, RwTaskLockError>(v + 1) }).await.unwrap();
let guard = lock.read().await.unwrap();
assert_eq!(*guard, 101);
}
#[tokio::test]
async fn test_update_chained_pending() {
use std::sync::Arc;
let lock = Arc::new(RwTaskLock::from_task(async {
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
Ok::<_, RwTaskLockError>(5)
}));
let lock2 = lock.clone();
lock2.update(|v| async move { Ok::<_, RwTaskLockError>(v * 3) }).await.unwrap();
let guard = lock.read().await.unwrap();
assert_eq!(*guard, 15);
}
#[tokio::test]
async fn test_update_error_state() {
let lock = RwTaskLock::<i32, RwTaskLockError>::from_task(async { Err(RwTaskLockError::CalledAfterError) });
let _ = lock.read().await;
let result = lock.update(|v| async move { Ok::<_, RwTaskLockError>(v + 1) }).await;
assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
}
#[tokio::test]
async fn test_update_to_error() {
let lock = RwTaskLock::from_value(123);
lock.update(|_v| async move { Err(RwTaskLockError::CalledAfterError) })
.await
.unwrap();
let result = lock.read().await;
assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
}
#[tokio::test]
async fn test_multiple_updates() {
let lock = RwTaskLock::from_value(1);
lock.update(|v| async move { Ok::<_, RwTaskLockError>(v + 10) }).await.unwrap();
lock.update(|v| async move { Ok::<_, RwTaskLockError>(v * 2) }).await.unwrap();
let guard = lock.read().await.unwrap();
assert_eq!(*guard, 22);
}
}