use crate::{LatchContext, LatchError};
use noxu_sync::RwLock;
use std::cell::Cell;
use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::thread;
thread_local! {
static READ_HOLD_COUNT: Cell<u32> = const { Cell::new(0) };
}
fn increment_read_hold() {
READ_HOLD_COUNT.with(|c| c.set(c.get().saturating_add(1)));
}
fn decrement_read_hold() {
READ_HOLD_COUNT.with(|c| c.set(c.get().saturating_sub(1)));
}
fn read_hold_count() -> u32 {
READ_HOLD_COUNT.with(|c| c.get())
}
pub struct SharedLatch {
context: LatchContext,
exclusive_only: bool,
inner: RwLock<()>,
exclusive_owner: AtomicU64,
}
impl SharedLatch {
pub fn new(context: LatchContext, exclusive_only: bool) -> Self {
SharedLatch {
context,
exclusive_only,
inner: RwLock::new(()),
exclusive_owner: AtomicU64::new(0),
}
}
pub fn named(name: impl Into<String>, exclusive_only: bool) -> Self {
Self::new(LatchContext::new(name), exclusive_only)
}
pub fn is_exclusive_only(&self) -> bool {
self.exclusive_only
}
pub fn acquire_exclusive(
&self,
) -> Result<SharedLatchWriteGuard<'_>, LatchError> {
let current = thread_id();
if self.exclusive_owner.load(Ordering::Relaxed) == current {
panic!(
"Latch already held exclusively: {} (thread {:?})",
self.context.name,
thread::current().name()
);
}
if read_hold_count() > 0 {
panic!(
"Deadlock: thread holds read lock and requested write lock on latch {}",
self.context.name
);
}
let timeout = self.context.timeout;
let guard = self.inner.try_write_for(timeout).ok_or_else(|| {
LatchError::Timeout(format!(
"Latch acquisition timed out after {}ms: {}",
timeout.as_millis(),
self.context.name
))
})?;
self.exclusive_owner.store(current, Ordering::Relaxed);
Ok(SharedLatchWriteGuard { latch: self, _guard: guard })
}
pub fn try_acquire_exclusive(&self) -> Option<SharedLatchWriteGuard<'_>> {
let current = thread_id();
if self.exclusive_owner.load(Ordering::Relaxed) == current {
panic!(
"Latch already held exclusively: {} (thread {:?})",
self.context.name,
thread::current().name()
);
}
self.inner.try_write().map(|guard| {
self.exclusive_owner.store(current, Ordering::Relaxed);
SharedLatchWriteGuard { latch: self, _guard: guard }
})
}
pub fn acquire_shared(&self) -> Result<SharedLatchGuard<'_>, LatchError> {
if self.exclusive_only {
Ok(SharedLatchGuard::Write(self.acquire_exclusive()?))
} else {
if read_hold_count() > 0 {
panic!(
"Latch already held in shared mode: {} (thread {:?})",
self.context.name,
thread::current().name()
);
}
let timeout = self.context.timeout;
let guard = self.inner.try_read_for(timeout).ok_or_else(|| {
LatchError::Timeout(format!(
"Latch acquisition timed out after {}ms: {}",
timeout.as_millis(),
self.context.name
))
})?;
increment_read_hold();
Ok(SharedLatchGuard::Read(SharedLatchReadGuard { _guard: guard }))
}
}
pub fn is_exclusive_owner(&self) -> bool {
self.exclusive_owner.load(Ordering::Relaxed) == thread_id()
}
pub fn context(&self) -> &LatchContext {
&self.context
}
}
impl fmt::Debug for SharedLatch {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"SharedLatch({}, exclusive_only={})",
self.context.name, self.exclusive_only
)
}
}
pub enum SharedLatchGuard<'a> {
Read(SharedLatchReadGuard<'a>),
Write(SharedLatchWriteGuard<'a>),
}
pub struct SharedLatchReadGuard<'a> {
_guard: noxu_sync::RwLockReadGuard<'a, ()>,
}
impl Drop for SharedLatchReadGuard<'_> {
fn drop(&mut self) {
decrement_read_hold();
}
}
pub struct SharedLatchWriteGuard<'a> {
latch: &'a SharedLatch,
_guard: noxu_sync::RwLockWriteGuard<'a, ()>,
}
impl Drop for SharedLatchWriteGuard<'_> {
fn drop(&mut self) {
self.latch.exclusive_owner.store(0, Ordering::Relaxed);
}
}
fn thread_id() -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
thread::current().id().hash(&mut hasher);
hasher.finish() | 1
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_shared_access() {
let latch = Arc::new(SharedLatch::named("test", false));
let _guard1 = latch.acquire_shared().expect("acquire_shared");
let latch2 = latch.clone();
let handle = std::thread::spawn(move || {
let _guard = latch2.acquire_shared().expect("acquire_shared");
true
});
assert!(handle.join().unwrap());
}
#[test]
fn test_exclusive_blocks_shared() {
let latch = Arc::new(SharedLatch::named("test", false));
let _guard = latch.acquire_exclusive().expect("acquire_exclusive");
assert!(latch.is_exclusive_owner());
let latch2 = latch.clone();
let handle = std::thread::spawn(move || {
latch2.try_acquire_exclusive().is_none()
});
assert!(handle.join().unwrap());
}
#[test]
fn test_exclusive_only_mode() {
let latch = SharedLatch::named("bin-latch", true);
assert!(latch.is_exclusive_only());
let guard = latch.acquire_shared().expect("acquire_shared");
match guard {
SharedLatchGuard::Write(_) => {} SharedLatchGuard::Read(_) => {
panic!("Expected write guard in exclusive-only mode")
}
}
}
#[test]
#[should_panic(expected = "Latch already held")]
fn test_reentrant_exclusive_panics() {
let latch = SharedLatch::named("test", false);
let _guard = latch.acquire_exclusive().expect("first acquire");
let _ = latch.acquire_exclusive(); }
#[test]
#[should_panic(expected = "Deadlock")]
fn test_read_to_write_upgrade_panics() {
let latch = SharedLatch::named("test-upgrade", false);
let _rguard = latch.acquire_shared().expect("acquire_shared");
let _ = latch.acquire_exclusive();
}
#[test]
fn test_exclusive_acquire_timeout() {
use std::time::Duration;
let ctx = crate::LatchContext::with_timeout(
"test-timeout",
Duration::from_millis(50),
);
let latch = Arc::new(SharedLatch::new(ctx, false));
let latch2 = latch.clone();
let barrier = Arc::new(std::sync::Barrier::new(2));
let barrier2 = barrier.clone();
let handle = std::thread::spawn(move || {
let _g =
latch2.acquire_exclusive().expect("acquire in spawned thread");
barrier2.wait(); std::thread::sleep(Duration::from_millis(200));
});
barrier.wait(); let result = latch.acquire_exclusive();
assert!(result.is_err(), "expected latch timeout error, got Ok");
let _ = handle.join();
}
#[test]
fn test_shared_acquire_timeout() {
use std::time::Duration;
let ctx = crate::LatchContext::with_timeout(
"test-timeout-r",
Duration::from_millis(50),
);
let latch = Arc::new(SharedLatch::new(ctx, false));
let latch2 = latch.clone();
let barrier = Arc::new(std::sync::Barrier::new(2));
let barrier2 = barrier.clone();
let handle = std::thread::spawn(move || {
let _g =
latch2.acquire_exclusive().expect("acquire in spawned thread");
barrier2.wait();
std::thread::sleep(Duration::from_millis(200));
});
barrier.wait();
let result = latch.acquire_shared();
assert!(result.is_err(), "expected latch timeout error, got Ok");
let _ = handle.join();
}
#[test]
fn test_is_not_exclusive_owner_when_not_held() {
let latch = SharedLatch::named("test-owner", false);
assert!(!latch.is_exclusive_owner());
}
#[test]
fn test_is_exclusive_owner_only_in_owning_thread() {
let latch = Arc::new(SharedLatch::named("test-owner-thread", false));
let _guard = latch.acquire_exclusive().expect("acquire_exclusive");
assert!(latch.is_exclusive_owner());
let latch2 = latch.clone();
let handle = std::thread::spawn(move || {
assert!(
!latch2.is_exclusive_owner(),
"non-owner should not be owner"
);
});
handle.join().unwrap();
}
#[test]
fn test_exclusive_owner_cleared_after_drop() {
let latch = SharedLatch::named("test-drop", false);
{
let _guard = latch.acquire_exclusive().expect("acquire_exclusive");
assert!(latch.is_exclusive_owner());
}
assert!(!latch.is_exclusive_owner());
}
#[test]
fn test_context_fields() {
use std::time::Duration;
let ctx = crate::LatchContext::with_timeout(
"ctx-test",
Duration::from_secs(3),
);
let latch = SharedLatch::new(ctx, false);
assert_eq!(latch.context().name, "ctx-test");
assert_eq!(latch.context().timeout, Duration::from_secs(3));
}
#[test]
fn test_debug_format() {
let latch = SharedLatch::named("debug-test", true);
let s = format!("{:?}", latch);
assert!(s.contains("debug-test"));
assert!(s.contains("exclusive_only=true"));
}
#[test]
fn test_try_acquire_exclusive_blocks_shared() {
let latch = Arc::new(SharedLatch::named("try-excl-blocks", false));
let guard = latch.try_acquire_exclusive();
assert!(guard.is_some());
assert!(latch.is_exclusive_owner());
let latch2 = latch.clone();
let handle = std::thread::spawn(move || {
latch2.try_acquire_exclusive().is_none()
});
assert!(handle.join().unwrap());
drop(guard);
assert!(!latch.is_exclusive_owner());
}
#[test]
fn test_concurrent_exclusive_serializes() {
use std::sync::atomic::{AtomicUsize, Ordering};
let latch = Arc::new(SharedLatch::named("concurrent-serial", false));
let counter = Arc::new(AtomicUsize::new(0));
let concurrent = Arc::new(AtomicUsize::new(0));
let violations = Arc::new(AtomicUsize::new(0));
let threads: Vec<_> = (0..4)
.map(|_| {
let latch = latch.clone();
let counter = counter.clone();
let concurrent = concurrent.clone();
let violations = violations.clone();
std::thread::spawn(move || {
for _ in 0..25 {
let _guard = latch
.acquire_exclusive()
.expect("acquire_exclusive");
let prev = concurrent.fetch_add(1, Ordering::SeqCst);
if prev != 0 {
violations.fetch_add(1, Ordering::SeqCst);
}
counter.fetch_add(1, Ordering::SeqCst);
concurrent.fetch_sub(1, Ordering::SeqCst);
}
})
})
.collect();
for t in threads {
t.join().unwrap();
}
assert_eq!(counter.load(Ordering::SeqCst), 100);
assert_eq!(
violations.load(Ordering::SeqCst),
0,
"mutual exclusion violated"
);
}
#[test]
fn test_shared_reacquire_panics() {
let result = std::panic::catch_unwind(|| {
let latch = SharedLatch::named("noxu-shared-reacquire", false);
let _g1 = latch.acquire_shared().expect("first acquire_shared");
let _ = latch.acquire_shared();
});
assert!(result.is_err(), "reentrant shared acquire should panic");
}
#[test]
fn test_read_to_write_upgrade_panics_while_shared() {
let result = std::panic::catch_unwind(|| {
let latch = SharedLatch::named("rwupgrade", false);
let _rg = latch.acquire_shared().expect("acquire_shared"); let _ = latch.acquire_exclusive(); });
assert!(result.is_err(), "read-to-write upgrade should panic");
}
#[test]
fn test_shared_release_not_held_exclusive_path() {
let latch = SharedLatch::named("noxu-not-held", false);
assert!(!latch.is_exclusive_owner());
}
#[test]
fn test_multiple_readers_concurrent() {
let latch = Arc::new(SharedLatch::named("noxu-multi-read", false));
let ready = Arc::new((
noxu_sync::Mutex::new(0usize),
noxu_sync::Condvar::new(),
));
let mut handles = Vec::new();
for _ in 0..4 {
let latch2 = latch.clone();
let ready2 = ready.clone();
let h = std::thread::spawn(move || {
let _g = latch2.acquire_shared().expect("acquire_shared");
{
let (m, cv) = &*ready2;
let mut g = m.lock();
*g += 1;
cv.notify_all();
}
std::thread::sleep(std::time::Duration::from_millis(20));
});
handles.push(h);
}
{
let (m, cv) = &*ready;
let mut g = m.lock();
while *g < 4 {
cv.wait(&mut g);
}
}
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_exclusive_blocks_then_shared_granted() {
let latch =
Arc::new(SharedLatch::named("noxu-excl-blocks-shared", false));
let g = latch.acquire_exclusive().expect("acquire_exclusive");
assert!(latch.is_exclusive_owner());
let latch2 = latch.clone();
let acquired = Arc::new(std::sync::atomic::AtomicBool::new(false));
let acquired2 = acquired.clone();
let h = std::thread::spawn(move || {
let _sg = latch2.acquire_shared().expect("acquire_shared");
acquired2.store(true, std::sync::atomic::Ordering::SeqCst);
});
std::thread::sleep(std::time::Duration::from_millis(30));
assert!(!acquired.load(std::sync::atomic::Ordering::SeqCst));
drop(g); h.join().unwrap();
assert!(acquired.load(std::sync::atomic::Ordering::SeqCst));
}
#[test]
fn test_try_acquire_exclusive_no_wait() {
let latch = Arc::new(SharedLatch::named("noxu-try-excl", false));
let barrier = Arc::new(std::sync::Barrier::new(2));
let latch2 = latch.clone();
let barrier2 = barrier.clone();
let h = std::thread::spawn(move || {
let _g = latch2.acquire_exclusive().expect("acquire_exclusive");
barrier2.wait();
std::thread::sleep(std::time::Duration::from_millis(100));
});
barrier.wait();
let r = latch.try_acquire_exclusive();
assert!(r.is_none(), "try_acquire_exclusive should fail while held");
h.join().unwrap();
let r2 = latch.try_acquire_exclusive();
assert!(
r2.is_some(),
"try_acquire_exclusive should succeed after release"
);
drop(r2);
}
#[test]
fn test_exclusive_only_mode_serializes() {
use std::sync::atomic::{AtomicUsize, Ordering};
let latch = Arc::new(SharedLatch::named("noxu-excl-only", true));
let counter = Arc::new(AtomicUsize::new(0));
let concurrent = Arc::new(AtomicUsize::new(0));
let violations = Arc::new(AtomicUsize::new(0));
let threads: Vec<_> = (0..4)
.map(|_| {
let latch = latch.clone();
let counter = counter.clone();
let concurrent = concurrent.clone();
let violations = violations.clone();
std::thread::spawn(move || {
for _ in 0..10 {
let _g =
latch.acquire_shared().expect("acquire_shared"); let prev = concurrent.fetch_add(1, Ordering::SeqCst);
if prev != 0 {
violations.fetch_add(1, Ordering::SeqCst);
}
counter.fetch_add(1, Ordering::SeqCst);
concurrent.fetch_sub(1, Ordering::SeqCst);
}
})
})
.collect();
for t in threads {
t.join().unwrap();
}
assert_eq!(counter.load(Ordering::SeqCst), 40);
assert_eq!(
violations.load(Ordering::SeqCst),
0,
"exclusive-only must serialize"
);
}
}