#[cfg(all(not(feature = "std"), target_has_atomic = "ptr"))]
use core::cell::UnsafeCell;
#[cfg(all(not(feature = "std"), target_has_atomic = "ptr"))]
use core::mem::MaybeUninit;
pub struct OnceCache<T: Copy> {
#[cfg(feature = "std")]
inner: std::sync::OnceLock<T>,
#[cfg(all(not(feature = "std"), target_has_atomic = "ptr"))]
state: core::sync::atomic::AtomicU8,
#[cfg(all(not(feature = "std"), target_has_atomic = "ptr"))]
value: UnsafeCell<MaybeUninit<T>>,
#[cfg(all(not(feature = "std"), not(target_has_atomic = "ptr")))]
_marker: core::marker::PhantomData<*const T>,
}
#[cfg(all(not(feature = "std"), target_has_atomic = "ptr"))]
#[allow(unsafe_code)]
unsafe impl<T: Copy + Sync> Sync for OnceCache<T> {}
#[cfg(all(not(feature = "std"), not(target_has_atomic = "ptr")))]
#[allow(unsafe_code)]
unsafe impl<T: Copy + Sync> Sync for OnceCache<T> {}
impl<T: Copy> OnceCache<T> {
#[cfg(all(not(feature = "std"), target_has_atomic = "ptr"))]
const UNINIT: u8 = 0;
#[cfg(all(not(feature = "std"), target_has_atomic = "ptr"))]
const INITING: u8 = 1;
#[cfg(all(not(feature = "std"), target_has_atomic = "ptr"))]
const READY: u8 = 2;
#[must_use]
pub const fn new() -> Self {
Self {
#[cfg(feature = "std")]
inner: std::sync::OnceLock::new(),
#[cfg(all(not(feature = "std"), target_has_atomic = "ptr"))]
state: core::sync::atomic::AtomicU8::new(0),
#[cfg(all(not(feature = "std"), target_has_atomic = "ptr"))]
value: UnsafeCell::new(MaybeUninit::uninit()),
#[cfg(all(not(feature = "std"), not(target_has_atomic = "ptr")))]
_marker: core::marker::PhantomData,
}
}
#[inline]
pub fn get_or_init(&self, f: impl FnOnce() -> T) -> T {
#[cfg(feature = "std")]
{
*self.inner.get_or_init(f)
}
#[cfg(all(not(feature = "std"), target_has_atomic = "ptr"))]
{
use core::sync::atomic::Ordering;
let state = self.state.load(Ordering::Acquire);
if state == Self::READY {
#[allow(unsafe_code)]
return unsafe { (*self.value.get()).assume_init() };
}
if state == Self::UNINIT
&& self
.state
.compare_exchange(Self::UNINIT, Self::INITING, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
let value = f();
#[allow(unsafe_code)]
unsafe {
(*self.value.get()).write(value);
}
self.state.store(Self::READY, Ordering::Release);
return value;
}
while self.state.load(Ordering::Acquire) != Self::READY {
core::hint::spin_loop();
}
#[allow(unsafe_code)]
unsafe {
(*self.value.get()).assume_init()
}
}
#[cfg(all(not(feature = "std"), not(target_has_atomic = "ptr")))]
{
f()
}
}
}
impl<T: Copy> Default for OnceCache<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_once_cache_basic() {
static CACHE: OnceCache<u64> = OnceCache::new();
let mut call_count = 0;
let value = CACHE.get_or_init(|| {
call_count += 1;
42u64
});
assert_eq!(value, 42);
let value2 = CACHE.get_or_init(|| {
call_count += 1;
99u64
});
assert_eq!(value2, 42);
#[cfg(any(feature = "std", target_has_atomic = "ptr"))]
assert_eq!(call_count, 1);
}
#[test]
fn test_once_cache_default() {
let cache: OnceCache<u32> = OnceCache::default();
let value = cache.get_or_init(|| 123);
assert_eq!(value, 123);
}
#[cfg(feature = "std")]
#[allow(clippy::std_instead_of_core, clippy::std_instead_of_alloc)]
mod threading_tests {
use std::{
sync::atomic::{AtomicUsize, Ordering},
thread,
vec::Vec,
};
use super::*;
#[test]
fn test_once_cache_concurrent_init() {
static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
static CACHE: OnceCache<u64> = OnceCache::new();
let handles: Vec<thread::JoinHandle<()>> = (0..10)
.map(|_| {
thread::spawn(|| {
for _ in 0..100 {
let value = CACHE.get_or_init(|| {
CALL_COUNT.fetch_add(1, Ordering::SeqCst);
42u64
});
assert_eq!(value, 42);
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 1);
}
}
}