use async_lock::RwLock;
use std::{future::Future, sync::Arc};
use std::{
pin::Pin,
task::{Context, Poll},
};
#[derive(Debug)]
pub(crate) struct AsyncLazy<T> {
value: RwLock<Option<Arc<T>>>,
}
impl<T> AsyncLazy<T> {
pub(crate) fn new() -> Self {
Self {
value: RwLock::new(None),
}
}
pub(crate) async fn get_or_init<F, Fut>(&self, init: F) -> Arc<T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = T>,
{
{
let guard = self.value.read().await;
if let Some(ref value) = *guard {
return value.clone();
}
}
let mut guard = self.value.write().await;
if let Some(ref value) = *guard {
return value.clone();
}
let value = Arc::new(init().await);
*guard = Some(value.clone());
value
}
pub(crate) fn try_get(&self) -> Option<Arc<T>> {
self.value.try_read().and_then(|guard| guard.clone())
}
#[allow(dead_code)] pub(crate) async fn get(&self) -> Arc<T> {
loop {
{
let guard = self.value.read().await;
if let Some(ref value) = *guard {
return value.clone();
}
}
YieldOnce(false).await;
}
}
}
#[allow(dead_code)] struct YieldOnce(bool);
impl Future for YieldOnce {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.0 {
Poll::Ready(())
} else {
self.0 = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
impl<T> Default for AsyncLazy<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Notify;
#[tokio::test]
async fn initializes_once() {
let lazy = AsyncLazy::new();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let value = lazy
.get_or_init(|| async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
42
})
.await;
assert_eq!(*value, 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
let counter_clone = counter.clone();
let value2 = lazy
.get_or_init(|| async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
100 })
.await;
assert_eq!(*value2, 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn concurrent_access_single_init() {
let lazy = Arc::new(AsyncLazy::new());
let counter = Arc::new(AtomicUsize::new(0));
let gate = Arc::new(Notify::new());
let mut handles = vec![];
for _ in 0..10 {
let lazy_clone = lazy.clone();
let counter_clone = counter.clone();
let gate_clone = gate.clone();
handles.push(tokio::spawn(async move {
lazy_clone
.get_or_init(|| async move {
counter_clone.fetch_add(1, Ordering::SeqCst);
gate_clone.notified().await;
"initialized"
})
.await
}));
}
tokio::task::yield_now().await;
gate.notify_one();
for handle in handles {
let result = handle.await.unwrap();
assert_eq!(*result, "initialized");
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn try_get_returns_none_before_init() {
let lazy: AsyncLazy<i32> = AsyncLazy::new();
assert!(lazy.try_get().is_none());
}
#[tokio::test]
async fn try_get_returns_value_after_init() {
let lazy = AsyncLazy::new();
lazy.get_or_init(|| async { 42 }).await;
assert_eq!(*lazy.try_get().unwrap(), 42);
}
#[tokio::test]
async fn get_waits_for_initialization() {
let lazy = Arc::new(AsyncLazy::new());
let lazy_clone = lazy.clone();
let gate = Arc::new(Notify::new());
let gate_clone = gate.clone();
let handle = tokio::spawn(async move {
lazy_clone
.get_or_init(|| async move {
gate_clone.notified().await;
42
})
.await
});
tokio::task::yield_now().await;
let lazy_for_get = lazy.clone();
let get_handle = tokio::spawn(async move { lazy_for_get.get().await });
tokio::task::yield_now().await;
gate.notify_one();
let value = get_handle.await.unwrap();
assert_eq!(*value, 42);
handle.await.unwrap();
}
#[tokio::test]
async fn get_waits_when_called_before_init_starts() {
let lazy = Arc::new(AsyncLazy::new());
let lazy_for_get = lazy.clone();
let lazy_for_init = lazy.clone();
let get_handle = tokio::spawn(async move { lazy_for_get.get().await });
tokio::task::yield_now().await;
let init_handle =
tokio::spawn(async move { lazy_for_init.get_or_init(|| async { 99 }).await });
let get_result = get_handle.await.unwrap();
let init_result = init_handle.await.unwrap();
assert_eq!(*get_result, 99);
assert_eq!(*init_result, 99);
}
}