use alloc::boxed::Box;
use core::{
fmt,
ops::Deref,
ptr,
sync::atomic::{AtomicPtr, Ordering},
};
pub struct RacyLock<T, F = fn() -> T> {
inner: AtomicPtr<T>,
f: F,
}
#[cfg(all(loom, test))]
mod unsound_demo {
use alloc::boxed::Box;
use core::{
cell::RefCell,
ptr,
sync::atomic::{AtomicPtr, Ordering},
};
use loom::{hint, model::Builder, sync::Arc, thread};
struct BadLock<T, F: Fn() -> T> {
inner: AtomicPtr<T>,
f: F,
}
impl<T, F: Fn() -> T> BadLock<T, F> {
pub const fn new(f: F) -> Self {
Self {
inner: AtomicPtr::new(ptr::null_mut()),
f,
}
}
pub fn force(&self) -> &T {
let mut p = self.inner.load(Ordering::Acquire);
if p.is_null() {
let v = (self.f)();
p = Box::into_raw(Box::new(v));
if let Err(old) = self.inner.compare_exchange(
ptr::null_mut(),
p,
Ordering::AcqRel,
Ordering::Acquire,
) {
drop(unsafe { Box::from_raw(p) });
p = old;
}
}
unsafe { &*p }
}
}
impl<T, F: Fn() -> T> Drop for BadLock<T, F> {
fn drop(&mut self) {
let p = *self.inner.get_mut();
if !p.is_null() {
drop(unsafe { Box::from_raw(p) });
}
}
}
unsafe impl<T, F: Fn() -> T + Sync> Sync for BadLock<T, F> {}
unsafe impl<T, F: Fn() -> T + Send> Send for BadLock<T, F> {}
#[test]
#[should_panic]
fn bad_sync_loom_allows_cross_thread_refcell_borrow_mut_panic() {
let mut builder = Builder::default();
builder.max_duration = Some(std::time::Duration::from_secs(10));
builder.check(|| {
let lock = Arc::new(BadLock::new(|| RefCell::new(0u32)));
let l1 = lock.clone();
let l2 = lock.clone();
let t1 = thread::spawn(move || {
let c1 = l1.force();
let _g1 = c1.borrow_mut();
for _ in 0..100 {
hint::spin_loop();
}
});
let t2 = thread::spawn(move || {
let c2 = l2.force();
let _g2 = c2.borrow_mut();
});
let _ = t1.join();
let _ = t2.join();
});
}
}
impl<T, F> RacyLock<T, F>
where
F: Fn() -> T,
{
pub const fn new(f: F) -> Self {
Self {
inner: AtomicPtr::new(ptr::null_mut()),
f,
}
}
pub fn force(this: &RacyLock<T, F>) -> &T {
let mut ptr = this.inner.load(Ordering::Acquire);
if ptr.is_null() {
let val = (this.f)();
ptr = Box::into_raw(Box::new(val));
let exchange = this.inner.compare_exchange(
ptr::null_mut(),
ptr,
Ordering::AcqRel,
Ordering::Acquire,
);
if let Err(old) = exchange {
drop(unsafe { Box::from_raw(ptr) });
ptr = old;
}
}
unsafe { &*ptr }
}
}
impl<T: Default> Default for RacyLock<T> {
#[inline]
fn default() -> RacyLock<T> {
RacyLock::new(T::default)
}
}
impl<T, F> fmt::Debug for RacyLock<T, F>
where
T: fmt::Debug,
F: Fn() -> T,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RacyLock({:?})", self.inner.load(Ordering::Relaxed))
}
}
impl<T, F> Deref for RacyLock<T, F>
where
F: Fn() -> T,
{
type Target = T;
#[inline]
fn deref(&self) -> &T {
RacyLock::force(self)
}
}
impl<T, F> Drop for RacyLock<T, F> {
fn drop(&mut self) {
let ptr = *self.inner.get_mut();
if !ptr.is_null() {
drop(unsafe { Box::from_raw(ptr) });
}
}
}
unsafe impl<T: Send, F: Send> Send for RacyLock<T, F> {}
unsafe impl<T: Send + Sync, F: Send> Sync for RacyLock<T, F> {}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use super::*;
#[test]
fn deref_default() {
let lock: RacyLock<i32> = RacyLock::default();
assert_eq!(*lock, 0);
}
#[test]
fn deref_copy() {
let lock = RacyLock::new(|| 42);
assert_eq!(*lock, 42);
}
#[test]
fn deref_clone() {
let lock = RacyLock::new(|| Vec::from([1, 2, 3]));
let mut v = lock.clone();
v.push(4);
assert_eq!(v, Vec::from([1, 2, 3, 4]));
}
#[test]
fn deref_static() {
static VEC: RacyLock<Vec<i32>> = RacyLock::new(|| Vec::from([1, 2, 3]));
let addr = &*VEC as *const Vec<i32>;
for _ in 0..5 {
assert_eq!(*VEC, [1, 2, 3]);
assert_eq!(addr, &(*VEC) as *const Vec<i32>)
}
}
#[test]
fn type_inference() {
let _ = RacyLock::new(|| ());
}
#[test]
fn is_sync_send() {
fn assert_traits<T: Send + Sync>() {}
assert_traits::<RacyLock<Vec<i32>>>();
}
#[test]
fn is_send() {
fn assert_send<T: Send>() {}
assert_send::<RacyLock<i32>>();
}
#[test]
fn is_sync() {
fn assert_sync<T: Sync>() {}
assert_sync::<RacyLock<i32>>();
}
}