use std::hint::spin_loop;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
const WRITER: u32 = u32::MAX;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LatchMode {
Optimistic,
Shared,
Exclusive,
}
#[derive(Debug)]
pub struct HybridLatch {
counter: AtomicU32,
version: AtomicU64,
}
impl Default for HybridLatch {
fn default() -> Self {
Self::new()
}
}
impl HybridLatch {
#[must_use]
pub const fn new() -> Self {
Self {
counter: AtomicU32::new(0),
version: AtomicU64::new(0),
}
}
#[must_use]
pub fn acquire_optimistic(&self) -> u64 {
loop {
let v = self.version.load(Ordering::Acquire);
if self.counter.load(Ordering::Acquire) != WRITER {
return v;
}
spin_loop();
}
}
#[must_use]
pub fn validate(&self, snapshot: u64) -> bool {
if self.counter.load(Ordering::Acquire) == WRITER {
return false;
}
self.version.load(Ordering::Acquire) == snapshot
}
pub fn acquire_shared(&self) {
loop {
let cur = self.counter.load(Ordering::Relaxed);
if cur == WRITER || cur >= WRITER - 1 {
spin_loop();
continue;
}
if self
.counter
.compare_exchange_weak(cur, cur + 1, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return;
}
}
}
pub fn release_shared(&self) {
self.counter.fetch_sub(1, Ordering::Release);
}
pub fn acquire_exclusive(&self) {
loop {
if self
.counter
.compare_exchange_weak(0, WRITER, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return;
}
spin_loop();
}
}
pub fn release_exclusive(&self) {
self.version.fetch_add(1, Ordering::Release);
self.counter.store(0, Ordering::Release);
}
#[must_use]
pub fn try_upgrade(&self) -> bool {
self.counter
.compare_exchange(1, WRITER, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GuardState {
Unlocked,
Optimistic,
Shared,
Exclusive,
}
#[derive(Debug)]
pub struct Guard<'a> {
latch: &'a HybridLatch,
state: GuardState,
snapshot: u64,
}
impl<'a> Guard<'a> {
#[must_use]
pub fn optimistic(latch: &'a HybridLatch) -> Self {
Self {
latch,
state: GuardState::Optimistic,
snapshot: latch.acquire_optimistic(),
}
}
#[must_use]
pub fn shared(latch: &'a HybridLatch) -> Self {
latch.acquire_shared();
Self {
latch,
state: GuardState::Shared,
snapshot: 0,
}
}
#[must_use]
pub fn exclusive(latch: &'a HybridLatch) -> Self {
latch.acquire_exclusive();
Self {
latch,
state: GuardState::Exclusive,
snapshot: 0,
}
}
#[must_use]
pub fn state(&self) -> GuardState {
self.state
}
#[must_use]
pub fn validate(&self) -> bool {
match self.state {
GuardState::Optimistic => self.latch.validate(self.snapshot),
GuardState::Shared | GuardState::Exclusive => true,
GuardState::Unlocked => false,
}
}
pub fn upgrade_to_shared(&mut self) -> bool {
debug_assert_eq!(self.state, GuardState::Optimistic);
self.latch.acquire_shared();
if self.latch.version.load(Ordering::Acquire) != self.snapshot {
self.latch.release_shared();
return false;
}
self.state = GuardState::Shared;
true
}
pub fn upgrade_to_exclusive(&mut self) {
match self.state {
GuardState::Optimistic => {
self.latch.acquire_exclusive();
self.state = GuardState::Exclusive;
}
GuardState::Shared => {
if self.latch.try_upgrade() {
self.state = GuardState::Exclusive;
} else {
self.latch.release_shared();
self.latch.acquire_exclusive();
self.state = GuardState::Exclusive;
}
}
GuardState::Exclusive | GuardState::Unlocked => {}
}
}
pub fn release(&mut self) {
match self.state {
GuardState::Optimistic | GuardState::Unlocked => {}
GuardState::Shared => self.latch.release_shared(),
GuardState::Exclusive => self.latch.release_exclusive(),
}
self.state = GuardState::Unlocked;
}
}
impl Drop for Guard<'_> {
fn drop(&mut self) {
self.release();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exclusive_bumps_version_on_release() {
let l = HybridLatch::new();
let v0 = l.version.load(Ordering::Relaxed);
{
let _g = Guard::exclusive(&l);
}
let v1 = l.version.load(Ordering::Relaxed);
assert_eq!(v1, v0 + 1);
}
#[test]
fn optimistic_validates_while_idle_invalidates_after_exclusive() {
let l = HybridLatch::new();
let g = Guard::optimistic(&l);
assert!(g.validate());
{
let _w = Guard::exclusive(&l);
}
assert!(!g.validate());
}
#[test]
fn shared_lock_counts_up_and_down() {
let l = HybridLatch::new();
let g1 = Guard::shared(&l);
let g2 = Guard::shared(&l);
assert_eq!(l.counter.load(Ordering::Relaxed), 2);
drop(g1);
drop(g2);
assert_eq!(l.counter.load(Ordering::Relaxed), 0);
}
#[test]
fn try_upgrade_succeeds_when_sole_reader() {
let l = HybridLatch::new();
l.acquire_shared();
assert!(l.try_upgrade());
assert_eq!(l.counter.load(Ordering::Relaxed), WRITER);
l.release_exclusive();
}
#[test]
fn try_upgrade_fails_when_contended() {
let l = HybridLatch::new();
l.acquire_shared();
l.acquire_shared();
assert!(!l.try_upgrade());
l.release_shared();
l.release_shared();
}
#[test]
fn concurrent_readers_writer_never_tear() {
use std::sync::Arc;
use std::thread;
let latch = Arc::new(HybridLatch::new());
let counter = Arc::new(AtomicU64::new(0));
let wrong = Arc::new(AtomicU64::new(0));
let mut handles = vec![];
for _ in 0..4 {
let l = latch.clone();
let c = counter.clone();
let w = wrong.clone();
handles.push(thread::spawn(move || {
for _ in 0..500 {
loop {
let g = Guard::optimistic(&l);
let seen = c.load(Ordering::Relaxed);
if g.validate() {
let seen2 = c.load(Ordering::Relaxed);
if g.validate() && seen != seen2 {
w.fetch_add(1, Ordering::Relaxed);
}
break;
}
}
}
}));
}
let l = latch.clone();
let c = counter.clone();
let writer = thread::spawn(move || {
for _ in 0..200 {
let _g = Guard::exclusive(&l);
let cur = c.load(Ordering::Relaxed);
spin_loop();
c.store(cur + 1, Ordering::Relaxed);
}
});
for h in handles {
h.join().unwrap();
}
writer.join().unwrap();
assert_eq!(wrong.load(Ordering::Relaxed), 0);
assert_eq!(counter.load(Ordering::Relaxed), 200);
}
}