#![doc(html_root_url = "https://docs.rs/async-double-checked-cell/0.1.0")]
#![warn(missing_debug_implementations)]
use std::cell::UnsafeCell;
use std::future::Future;
use std::panic::RefUnwindSafe;
use std::sync::atomic::{AtomicBool, Ordering};
use futures_util::future::ready;
use futures_util::FutureExt;
use futures_util::lock::Mutex;
use unreachable::UncheckedOptionExt;
use void::ResultVoidExt;
#[derive(Debug)]
pub struct DoubleCheckedCell<T> {
value: UnsafeCell<Option<T>>,
initialized: AtomicBool,
lock: Mutex<()>,
}
impl<T> Default for DoubleCheckedCell<T> {
fn default() -> DoubleCheckedCell<T> {
DoubleCheckedCell::new()
}
}
impl<T> DoubleCheckedCell<T> {
pub fn new() -> DoubleCheckedCell<T> {
DoubleCheckedCell {
value: UnsafeCell::new(None),
initialized: AtomicBool::new(false),
lock: Mutex::new(()),
}
}
pub async fn get(&self) -> Option<&T> {
self.get_or_try_init(ready(Err(()))).await.ok()
}
pub async fn get_or_init<Fut>(&self, init: Fut) -> &T
where
Fut: Future<Output = T>
{
self.get_or_try_init(init.map(Ok)).await.void_unwrap()
}
pub async fn get_or_try_init<Fut, E>(&self, init: Fut) -> Result<&T, E>
where
Fut: Future<Output = Result<T, E>>
{
if !self.initialized.load(Ordering::Acquire) {
let _lock = self.lock.lock().await;
if !self.initialized.load(Ordering::Relaxed) {
{
let result = init.await?;
let value = unsafe { &mut *self.value.get() }; value.replace(result);
}
self.initialized.store(true, Ordering::Release);
}
}
let value = unsafe { &*self.value.get() };
Ok(unsafe { value.as_ref().unchecked_unwrap() })
}
pub fn into_inner(self) -> Option<T> {
#[allow(unused_unsafe)]
unsafe { self.value.into_inner() }
}
}
impl<T> From<T> for DoubleCheckedCell<T> {
fn from(t: T) -> DoubleCheckedCell<T> {
DoubleCheckedCell {
value: UnsafeCell::new(Some(t)),
initialized: AtomicBool::new(true),
lock: Mutex::new(()),
}
}
}
unsafe impl<T: Send + Sync> Sync for DoubleCheckedCell<T> {}
impl<T> RefUnwindSafe for DoubleCheckedCell<T> {}
#[cfg(test)]
mod tests {
use std::rc::Rc;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use futures_util::future::join_all;
use super::*;
#[tokio::test]
async fn test_drop() {
let rc = Rc::new(true);
assert_eq!(Rc::strong_count(&rc), 1);
{
let cell = DoubleCheckedCell::new();
cell.get_or_init(ready(rc.clone())).await;
assert_eq!(Rc::strong_count(&rc), 2);
}
assert_eq!(Rc::strong_count(&rc), 1);
}
#[tokio::test(threaded_scheduler)]
async fn test_threading() {
let n = Arc::new(AtomicUsize::new(0));
let cell = Arc::new(DoubleCheckedCell::new());
let join_handles = (0..1000).map(|_| {
let n = n.clone();
let cell = cell.clone();
tokio::task::spawn(async move {
let value = cell.get_or_init(async {
n.fetch_add(1, Ordering::Relaxed);
true
}).await;
assert!(*value);
})
}).collect::<Vec<_>>();
join_all(join_handles).await;
assert_eq!(n.load(Ordering::SeqCst), 1);
}
#[test]
fn test_sync_send() {
fn assert_sync<T: Sync>(_: T) {}
fn assert_send<T: Send>(_: T) {}
assert_sync(DoubleCheckedCell::<usize>::new());
assert_send(DoubleCheckedCell::<usize>::new());
let cell = DoubleCheckedCell::<usize>::new();
assert_send(cell.get_or_init(async { 1 }));
}
struct _AssertObjectSafe(Box<DoubleCheckedCell<usize>>);
}