use loom::sync::Arc;
use loom::sync::atomic::{AtomicU32, Ordering};
use loom::thread;
use std::marker::PhantomData;
const LOCK_BIT: u32 = 1 << 0;
const INSERTING_BIT: u32 = 1 << 1;
const SPLITTING_BIT: u32 = 1 << 2;
const DIRTY_MASK: u32 = INSERTING_BIT | SPLITTING_BIT;
const VINSERT_LOWBIT: u32 = 1 << 3;
const VSPLIT_LOWBIT: u32 = 1 << 9;
const UNUSED1_BIT: u32 = 1 << 28;
const ROOT_BIT: u32 = 1 << 30;
const ISLEAF_BIT: u32 = 1 << 31;
const SPLIT_UNLOCK_MASK: u32 = !(ROOT_BIT | UNUSED1_BIT | (VSPLIT_LOWBIT - 1));
const UNLOCK_MASK: u32 = !(UNUSED1_BIT | (VINSERT_LOWBIT - 1));
struct LoomNodeVersion {
value: AtomicU32,
}
struct LoomLockGuard<'a> {
version: &'a LoomNodeVersion,
locked_value: u32,
_marker: PhantomData<*mut ()>,
}
impl Drop for LoomLockGuard<'_> {
fn drop(&mut self) {
let new_value: u32 = if self.locked_value & SPLITTING_BIT != 0 {
(self.locked_value + VSPLIT_LOWBIT) & SPLIT_UNLOCK_MASK
} else {
(self.locked_value + ((self.locked_value & INSERTING_BIT) << 2)) & UNLOCK_MASK
};
self.version.value.store(new_value, Ordering::Release);
}
}
impl LoomLockGuard<'_> {
fn mark_insert(&mut self) {
let value: u32 = self.version.value.load(Ordering::Relaxed);
self.version
.value
.store(value | INSERTING_BIT, Ordering::Release);
loom::sync::atomic::fence(Ordering::Acquire);
self.locked_value |= INSERTING_BIT;
}
}
impl LoomNodeVersion {
fn new(is_leaf: bool) -> Self {
let initial = if is_leaf { ISLEAF_BIT } else { 0 };
Self {
value: AtomicU32::new(initial),
}
}
fn is_locked(&self) -> bool {
(self.value.load(Ordering::Relaxed) & LOCK_BIT) != 0
}
fn stable(&self) -> u32 {
loop {
let value: u32 = self.value.load(Ordering::Relaxed);
if (value & DIRTY_MASK) == 0 {
loom::sync::atomic::fence(Ordering::Acquire);
return value;
}
loom::thread::yield_now();
}
}
fn lock(&self) -> LoomLockGuard<'_> {
loop {
let value: u32 = self.value.load(Ordering::Relaxed);
if (value & (LOCK_BIT | DIRTY_MASK)) != 0 {
loom::thread::yield_now();
continue;
}
let locked = value | LOCK_BIT;
match self.value.compare_exchange_weak(
value,
locked,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => {
return LoomLockGuard {
version: self,
locked_value: locked,
_marker: PhantomData,
};
}
Err(_) => {
loom::thread::yield_now();
continue;
}
}
}
}
fn try_lock(&self) -> Option<LoomLockGuard<'_>> {
let value: u32 = self.value.load(Ordering::Relaxed);
if (value & (LOCK_BIT | DIRTY_MASK)) != 0 {
return None;
}
let locked = value | LOCK_BIT;
match self
.value
.compare_exchange(value, locked, Ordering::Acquire, Ordering::Relaxed)
{
Ok(_) => Some(LoomLockGuard {
version: self,
locked_value: locked,
_marker: PhantomData,
}),
Err(_) => None,
}
}
fn has_changed(&self, old: u32) -> bool {
(old ^ self.value.load(Ordering::Acquire)) > (LOCK_BIT | INSERTING_BIT)
}
}
#[test]
fn test_loom_mutual_exclusion() {
loom::model(|| {
let version = Arc::new(LoomNodeVersion::new(true));
let counter = Arc::new(AtomicU32::new(0));
let v1 = Arc::clone(&version);
let c1 = Arc::clone(&counter);
let t1 = thread::spawn(move || {
let _guard = v1.lock();
let val: u32 = c1.load(Ordering::Relaxed);
c1.store(val + 1, Ordering::Relaxed);
});
let v2 = Arc::clone(&version);
let c2 = Arc::clone(&counter);
let t2 = thread::spawn(move || {
let _guard = v2.lock();
let val: u32 = c2.load(Ordering::Relaxed);
c2.store(val + 1, Ordering::Relaxed);
});
t1.join().unwrap();
t2.join().unwrap();
assert_eq!(counter.load(Ordering::Relaxed), 2);
});
}
#[test]
fn test_loom_try_lock_fails_when_held() {
loom::model(|| {
let version = Arc::new(LoomNodeVersion::new(true));
let v1 = Arc::clone(&version);
let t1 = thread::spawn(move || {
let _guard = v1.lock();
loom::thread::yield_now();
});
let v2 = Arc::clone(&version);
let t2 = thread::spawn(move || {
let _result = v2.try_lock();
});
t1.join().unwrap();
t2.join().unwrap();
assert!(!version.is_locked());
});
}
#[test]
fn test_loom_version_visibility() {
loom::model(|| {
let version = Arc::new(LoomNodeVersion::new(true));
let initial = version.stable();
let v1 = Arc::clone(&version);
let t1 = thread::spawn(move || {
let mut guard = v1.lock();
guard.mark_insert();
});
t1.join().unwrap();
assert!(version.has_changed(initial));
});
}
#[test]
fn test_loom_stable_waits_for_dirty() {
loom::model(|| {
let version = Arc::new(LoomNodeVersion::new(true));
let v1 = Arc::clone(&version);
let t1 = thread::spawn(move || {
let mut guard = v1.lock();
guard.mark_insert();
});
let v2 = Arc::clone(&version);
let t2 = thread::spawn(move || {
let stable = v2.stable();
assert_eq!(stable & DIRTY_MASK, 0);
});
t1.join().unwrap();
t2.join().unwrap();
});
}