use crate::{LatchContext, LatchError};
use noxu_sync::Mutex;
use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::thread;
pub struct ExclusiveLatch {
context: LatchContext,
inner: Mutex<()>,
owner: AtomicU64,
}
impl ExclusiveLatch {
pub fn new(context: LatchContext) -> Self {
ExclusiveLatch {
context,
inner: Mutex::new(()),
owner: AtomicU64::new(0),
}
}
pub fn named(name: impl Into<String>) -> Self {
Self::new(LatchContext::new(name))
}
pub fn acquire(&self) -> Result<ExclusiveLatchGuard<'_>, LatchError> {
let current = thread_id();
if self.owner.load(Ordering::Relaxed) == current {
panic!(
"Latch already held: {} (thread {:?})",
self.context.name,
thread::current().name()
);
}
let timeout = self.context.timeout;
let guard = self.inner.try_lock_for(timeout).ok_or_else(|| {
LatchError::Timeout(format!(
"Latch acquisition timed out after {}ms: {}",
timeout.as_millis(),
self.context.name
))
})?;
self.owner.store(current, Ordering::Relaxed);
Ok(ExclusiveLatchGuard { latch: self, _guard: guard })
}
pub fn try_acquire(&self) -> Option<ExclusiveLatchGuard<'_>> {
let current = thread_id();
if self.owner.load(Ordering::Relaxed) == current {
panic!(
"Latch already held: {} (thread {:?})",
self.context.name,
thread::current().name()
);
}
self.inner.try_lock().map(|guard| {
self.owner.store(current, Ordering::Relaxed);
ExclusiveLatchGuard { latch: self, _guard: guard }
})
}
pub fn is_locked(&self) -> bool {
self.inner.is_locked()
}
pub fn is_owner(&self) -> bool {
self.owner.load(Ordering::Relaxed) == thread_id()
}
pub fn context(&self) -> &LatchContext {
&self.context
}
pub fn release_if_owner(&self) {
if self.is_owner() {
self.owner.store(0, Ordering::Relaxed);
unsafe { self.inner.force_unlock() };
}
}
}
impl fmt::Debug for ExclusiveLatch {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"ExclusiveLatch({}, locked={})",
self.context.name,
self.is_locked()
)
}
}
pub struct ExclusiveLatchGuard<'a> {
latch: &'a ExclusiveLatch,
_guard: noxu_sync::MutexGuard<'a, ()>,
}
impl Drop for ExclusiveLatchGuard<'_> {
fn drop(&mut self) {
self.latch.owner.store(0, Ordering::Relaxed);
}
}
fn thread_id() -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
thread::current().id().hash(&mut hasher);
hasher.finish() | 1
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_acquire_release() {
let latch = ExclusiveLatch::named("test");
assert!(!latch.is_locked());
{
let _guard = latch.acquire().expect("acquire");
assert!(latch.is_locked());
assert!(latch.is_owner());
}
assert!(!latch.is_locked());
}
#[test]
fn test_try_acquire() {
let latch = Arc::new(ExclusiveLatch::named("test"));
let guard = latch.try_acquire();
assert!(guard.is_some());
let latch2 = latch.clone();
let handle = std::thread::spawn(move || latch2.try_acquire().is_none());
assert!(handle.join().unwrap());
}
#[test]
#[should_panic(expected = "Latch already held")]
fn test_reentrant_panics() {
let latch = ExclusiveLatch::named("test");
let _guard = latch.acquire().expect("first acquire");
let _ = latch.acquire(); }
#[test]
fn test_release_if_owner() {
let latch = ExclusiveLatch::named("test");
{
let _guard = latch.acquire().expect("acquire");
assert!(latch.is_owner());
}
latch.release_if_owner();
assert!(!latch.is_locked());
}
#[test]
fn test_acquire_timeout() {
use std::time::Duration;
let ctx = crate::LatchContext::with_timeout(
"test-timeout",
Duration::from_millis(50),
);
let latch = Arc::new(ExclusiveLatch::new(ctx));
let latch2 = latch.clone();
let barrier = Arc::new(std::sync::Barrier::new(2));
let barrier2 = barrier.clone();
let handle = std::thread::spawn(move || {
let _g = latch2.acquire().expect("acquire in spawned thread");
barrier2.wait(); std::thread::sleep(Duration::from_millis(200));
});
barrier.wait(); let result = latch.acquire();
assert!(result.is_err(), "expected latch timeout error, got Ok");
let _ = handle.join();
}
#[test]
fn test_context_name_and_timeout() {
use std::time::Duration;
let ctx = crate::LatchContext::with_timeout(
"my-latch",
Duration::from_secs(1),
);
let latch = ExclusiveLatch::new(ctx);
assert_eq!(latch.context().name, "my-latch");
assert_eq!(latch.context().timeout, Duration::from_secs(1));
}
#[test]
fn test_is_not_owner_when_not_held() {
let latch = ExclusiveLatch::named("test-owner");
assert!(!latch.is_owner());
assert!(!latch.is_locked());
}
#[test]
fn test_is_owner_only_in_owning_thread() {
let latch = Arc::new(ExclusiveLatch::named("test-owner-thread"));
let _guard = latch.acquire().expect("acquire");
assert!(latch.is_owner());
let latch2 = latch.clone();
let handle = std::thread::spawn(move || {
assert!(!latch2.is_owner(), "non-owner thread should not be owner");
assert!(latch2.is_locked(), "latch should be locked");
});
handle.join().unwrap();
}
#[test]
fn test_concurrent_acquire_serializes() {
use std::sync::atomic::{AtomicUsize, Ordering};
let latch = Arc::new(ExclusiveLatch::named("serial-test"));
let counter = Arc::new(AtomicUsize::new(0));
let concurrent = Arc::new(AtomicUsize::new(0));
let violations = Arc::new(AtomicUsize::new(0));
let threads: Vec<_> = (0..4)
.map(|_| {
let latch = latch.clone();
let counter = counter.clone();
let concurrent = concurrent.clone();
let violations = violations.clone();
std::thread::spawn(move || {
for _ in 0..25 {
let _guard = latch.acquire().expect("acquire");
let prev = concurrent.fetch_add(1, Ordering::SeqCst);
if prev != 0 {
violations.fetch_add(1, Ordering::SeqCst);
}
counter.fetch_add(1, Ordering::SeqCst);
concurrent.fetch_sub(1, Ordering::SeqCst);
}
})
})
.collect();
for t in threads {
t.join().unwrap();
}
assert_eq!(counter.load(Ordering::SeqCst), 100);
assert_eq!(
violations.load(Ordering::SeqCst),
0,
"mutual exclusion violated"
);
}
#[test]
fn test_try_acquire_reentrant_panics() {
let result = std::panic::catch_unwind(|| {
let latch = ExclusiveLatch::named("try-reentrant");
let _guard = latch.acquire();
let _guard2 = latch.try_acquire();
});
assert!(result.is_err(), "expected panic on reentrant try_acquire");
}
#[test]
fn test_debug_format() {
let latch = ExclusiveLatch::named("debug-test");
let s = format!("{:?}", latch);
assert!(s.contains("debug-test"));
assert!(s.contains("locked=false"));
}
#[test]
fn test_acquire_reacquire_panics() {
let result = std::panic::catch_unwind(|| {
let latch = ExclusiveLatch::named("noxu-reacquire");
let _g1 = latch.acquire().expect("first acquire");
let _ = latch.acquire();
});
assert!(result.is_err(), "reentrant acquire should panic");
}
#[test]
fn test_release_not_held_panics() {
let latch = ExclusiveLatch::named("noxu-not-held");
assert!(!latch.is_locked());
latch.release_if_owner();
assert!(!latch.is_locked());
}
#[test]
fn test_try_acquire_no_wait() {
let latch = Arc::new(ExclusiveLatch::named("noxu-no-wait"));
let barrier = Arc::new(std::sync::Barrier::new(2));
let latch2 = latch.clone();
let barrier2 = barrier.clone();
let held = Arc::new(std::sync::atomic::AtomicBool::new(false));
let held2 = held;
let released = Arc::new(std::sync::atomic::AtomicBool::new(false));
let released2 = released.clone();
let h = std::thread::spawn(move || {
let _g = latch2.acquire();
held2.store(true, std::sync::atomic::Ordering::SeqCst);
barrier2.wait(); while !released2.load(std::sync::atomic::Ordering::SeqCst) {
std::thread::yield_now();
}
});
barrier.wait();
assert!(!latch.is_owner(), "main thread should not be owner");
let r = latch.try_acquire();
assert!(
r.is_none(),
"try_acquire should fail while other thread holds it"
);
assert!(latch.is_locked());
released.store(true, std::sync::atomic::Ordering::SeqCst);
h.join().unwrap();
let g = latch.try_acquire();
assert!(g.is_some(), "try_acquire should succeed after release");
assert!(latch.is_locked());
drop(g);
assert!(!latch.is_locked());
}
#[test]
fn test_wait_blocks_until_released() {
let latch = Arc::new(ExclusiveLatch::named("noxu-wait"));
let acquired = Arc::new(std::sync::atomic::AtomicBool::new(false));
let g = latch.acquire().expect("acquire");
assert!(latch.is_locked());
let latch2 = latch.clone();
let acquired2 = acquired.clone();
let h = std::thread::spawn(move || {
let _g2 = latch2.acquire().expect("acquire in spawned thread");
acquired2.store(true, std::sync::atomic::Ordering::SeqCst);
});
std::thread::sleep(std::time::Duration::from_millis(30));
assert!(!acquired.load(std::sync::atomic::Ordering::SeqCst));
drop(g);
h.join().unwrap();
assert!(acquired.load(std::sync::atomic::Ordering::SeqCst));
}
#[test]
fn test_multiple_waiters_sequential_grant() {
use std::sync::atomic::{AtomicUsize, Ordering};
const N: usize = 5;
let latch = Arc::new(ExclusiveLatch::named("noxu-multi-wait"));
let order = Arc::new(AtomicUsize::new(0));
let g = latch.acquire().expect("acquire");
let mut handles = Vec::new();
for i in 0..N {
let latch2 = latch.clone();
let order2 = order.clone();
let h = std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(
5 * (i as u64 + 1),
));
let _g = latch2.acquire().expect("acquire in spawned thread");
order2.fetch_add(1, Ordering::SeqCst);
});
handles.push(h);
}
std::thread::sleep(std::time::Duration::from_millis(80));
drop(g);
for h in handles {
h.join().unwrap();
}
assert_eq!(
order.load(Ordering::SeqCst),
N,
"all waiters should have been granted"
);
}
}