1use crate::{LatchContext, LatchError};
11use noxu_sync::RwLock;
12use std::cell::Cell;
13use std::fmt;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::thread;
16
17thread_local! {
21 static READ_HOLD_COUNT: Cell<u32> = const { Cell::new(0) };
22}
23
24fn increment_read_hold() {
25 READ_HOLD_COUNT.with(|c| c.set(c.get().saturating_add(1)));
26}
27
28fn decrement_read_hold() {
29 READ_HOLD_COUNT.with(|c| c.set(c.get().saturating_sub(1)));
30}
31
32fn read_hold_count() -> u32 {
33 READ_HOLD_COUNT.with(|c| c.get())
34}
35
36pub struct SharedLatch {
42 context: LatchContext,
43 exclusive_only: bool,
44 inner: RwLock<()>,
45 exclusive_owner: AtomicU64,
47}
48
49impl SharedLatch {
50 pub fn new(context: LatchContext, exclusive_only: bool) -> Self {
52 SharedLatch {
53 context,
54 exclusive_only,
55 inner: RwLock::new(()),
56 exclusive_owner: AtomicU64::new(0),
57 }
58 }
59
60 pub fn named(name: impl Into<String>, exclusive_only: bool) -> Self {
62 Self::new(LatchContext::new(name), exclusive_only)
63 }
64
65 pub fn is_exclusive_only(&self) -> bool {
67 self.exclusive_only
68 }
69
70 pub fn acquire_exclusive(
81 &self,
82 ) -> Result<SharedLatchWriteGuard<'_>, LatchError> {
83 let current = thread_id();
84 if self.exclusive_owner.load(Ordering::Relaxed) == current {
85 panic!(
86 "Latch already held exclusively: {} (thread {:?})",
87 self.context.name,
88 thread::current().name()
89 );
90 }
91
92 if read_hold_count() > 0 {
96 panic!(
97 "Deadlock: thread holds read lock and requested write lock on latch {}",
98 self.context.name
99 );
100 }
101
102 let timeout = self.context.timeout;
103 let guard = self.inner.try_write_for(timeout).ok_or_else(|| {
104 LatchError::Timeout(format!(
105 "Latch acquisition timed out after {}ms: {}",
106 timeout.as_millis(),
107 self.context.name
108 ))
109 })?;
110 self.exclusive_owner.store(current, Ordering::Relaxed);
111 Ok(SharedLatchWriteGuard { latch: self, _guard: guard })
112 }
113
114 pub fn try_acquire_exclusive(&self) -> Option<SharedLatchWriteGuard<'_>> {
118 let current = thread_id();
119 if self.exclusive_owner.load(Ordering::Relaxed) == current {
120 panic!(
121 "Latch already held exclusively: {} (thread {:?})",
122 self.context.name,
123 thread::current().name()
124 );
125 }
126
127 self.inner.try_write().map(|guard| {
128 self.exclusive_owner.store(current, Ordering::Relaxed);
129 SharedLatchWriteGuard { latch: self, _guard: guard }
130 })
131 }
132
133 pub fn acquire_shared(&self) -> Result<SharedLatchGuard<'_>, LatchError> {
146 if self.exclusive_only {
147 Ok(SharedLatchGuard::Write(self.acquire_exclusive()?))
148 } else {
149 if read_hold_count() > 0 {
153 panic!(
154 "Latch already held in shared mode: {} (thread {:?})",
155 self.context.name,
156 thread::current().name()
157 );
158 }
159
160 let timeout = self.context.timeout;
161 let guard = self.inner.try_read_for(timeout).ok_or_else(|| {
162 LatchError::Timeout(format!(
163 "Latch acquisition timed out after {}ms: {}",
164 timeout.as_millis(),
165 self.context.name
166 ))
167 })?;
168 increment_read_hold();
169 Ok(SharedLatchGuard::Read(SharedLatchReadGuard { _guard: guard }))
170 }
171 }
172
173 pub fn is_exclusive_owner(&self) -> bool {
175 self.exclusive_owner.load(Ordering::Relaxed) == thread_id()
176 }
177
178 pub fn context(&self) -> &LatchContext {
180 &self.context
181 }
182}
183
184impl fmt::Debug for SharedLatch {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 write!(
187 f,
188 "SharedLatch({}, exclusive_only={})",
189 self.context.name, self.exclusive_only
190 )
191 }
192}
193
194pub enum SharedLatchGuard<'a> {
196 Read(SharedLatchReadGuard<'a>),
197 Write(SharedLatchWriteGuard<'a>),
198}
199
200pub struct SharedLatchReadGuard<'a> {
202 _guard: noxu_sync::RwLockReadGuard<'a, ()>,
203}
204
205impl Drop for SharedLatchReadGuard<'_> {
206 fn drop(&mut self) {
207 decrement_read_hold();
210 }
211}
212
213pub struct SharedLatchWriteGuard<'a> {
215 latch: &'a SharedLatch,
216 _guard: noxu_sync::RwLockWriteGuard<'a, ()>,
217}
218
219impl Drop for SharedLatchWriteGuard<'_> {
220 fn drop(&mut self) {
221 self.latch.exclusive_owner.store(0, Ordering::Relaxed);
222 }
223}
224
225fn thread_id() -> u64 {
227 use std::hash::{Hash, Hasher};
228 let mut hasher = std::collections::hash_map::DefaultHasher::new();
229 thread::current().id().hash(&mut hasher);
230 hasher.finish() | 1
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use std::sync::Arc;
240
241 #[test]
242 fn test_shared_access() {
243 let latch = Arc::new(SharedLatch::named("test", false));
244
245 let _guard1 = latch.acquire_shared().expect("acquire_shared");
247 let latch2 = latch.clone();
248 let handle = std::thread::spawn(move || {
249 let _guard = latch2.acquire_shared().expect("acquire_shared");
250 true
251 });
252 assert!(handle.join().unwrap());
253 }
254
255 #[test]
256 fn test_exclusive_blocks_shared() {
257 let latch = Arc::new(SharedLatch::named("test", false));
258 let _guard = latch.acquire_exclusive().expect("acquire_exclusive");
259 assert!(latch.is_exclusive_owner());
260
261 let latch2 = latch.clone();
263 let handle = std::thread::spawn(move || {
264 latch2.try_acquire_exclusive().is_none()
265 });
266 assert!(handle.join().unwrap());
267 }
268
269 #[test]
270 fn test_exclusive_only_mode() {
271 let latch = SharedLatch::named("bin-latch", true);
272 assert!(latch.is_exclusive_only());
273
274 let guard = latch.acquire_shared().expect("acquire_shared");
276 match guard {
277 SharedLatchGuard::Write(_) => {} SharedLatchGuard::Read(_) => {
279 panic!("Expected write guard in exclusive-only mode")
280 }
281 }
282 }
283
284 #[test]
285 #[should_panic(expected = "Latch already held")]
286 fn test_reentrant_exclusive_panics() {
287 let latch = SharedLatch::named("test", false);
288 let _guard = latch.acquire_exclusive().expect("first acquire");
289 let _ = latch.acquire_exclusive(); }
291
292 #[test]
293 #[should_panic(expected = "Deadlock")]
294 fn test_read_to_write_upgrade_panics() {
295 let latch = SharedLatch::named("test-upgrade", false);
298 let _rguard = latch.acquire_shared().expect("acquire_shared");
299 let _ = latch.acquire_exclusive();
301 }
302
303 #[test]
304 fn test_exclusive_acquire_timeout() {
305 use std::time::Duration;
306 let ctx = crate::LatchContext::with_timeout(
308 "test-timeout",
309 Duration::from_millis(50),
310 );
311 let latch = Arc::new(SharedLatch::new(ctx, false));
312
313 let latch2 = latch.clone();
315 let barrier = Arc::new(std::sync::Barrier::new(2));
316 let barrier2 = barrier.clone();
317 let handle = std::thread::spawn(move || {
318 let _g =
319 latch2.acquire_exclusive().expect("acquire in spawned thread");
320 barrier2.wait(); std::thread::sleep(Duration::from_millis(200));
322 });
323
324 barrier.wait(); let result = latch.acquire_exclusive();
327 assert!(result.is_err(), "expected latch timeout error, got Ok");
328 let _ = handle.join();
329 }
330
331 #[test]
332 fn test_shared_acquire_timeout() {
333 use std::time::Duration;
334 let ctx = crate::LatchContext::with_timeout(
335 "test-timeout-r",
336 Duration::from_millis(50),
337 );
338 let latch = Arc::new(SharedLatch::new(ctx, false));
339
340 let latch2 = latch.clone();
341 let barrier = Arc::new(std::sync::Barrier::new(2));
342 let barrier2 = barrier.clone();
343 let handle = std::thread::spawn(move || {
344 let _g =
345 latch2.acquire_exclusive().expect("acquire in spawned thread");
346 barrier2.wait();
347 std::thread::sleep(Duration::from_millis(200));
348 });
349
350 barrier.wait();
351 let result = latch.acquire_shared();
353 assert!(result.is_err(), "expected latch timeout error, got Ok");
354 let _ = handle.join();
355 }
356
357 #[test]
358 fn test_is_not_exclusive_owner_when_not_held() {
359 let latch = SharedLatch::named("test-owner", false);
360 assert!(!latch.is_exclusive_owner());
361 }
362
363 #[test]
364 fn test_is_exclusive_owner_only_in_owning_thread() {
365 let latch = Arc::new(SharedLatch::named("test-owner-thread", false));
366 let _guard = latch.acquire_exclusive().expect("acquire_exclusive");
367 assert!(latch.is_exclusive_owner());
368
369 let latch2 = latch.clone();
370 let handle = std::thread::spawn(move || {
371 assert!(
372 !latch2.is_exclusive_owner(),
373 "non-owner should not be owner"
374 );
375 });
376 handle.join().unwrap();
377 }
378
379 #[test]
380 fn test_exclusive_owner_cleared_after_drop() {
381 let latch = SharedLatch::named("test-drop", false);
382 {
383 let _guard = latch.acquire_exclusive().expect("acquire_exclusive");
384 assert!(latch.is_exclusive_owner());
385 }
386 assert!(!latch.is_exclusive_owner());
387 }
388
389 #[test]
390 fn test_context_fields() {
391 use std::time::Duration;
392 let ctx = crate::LatchContext::with_timeout(
393 "ctx-test",
394 Duration::from_secs(3),
395 );
396 let latch = SharedLatch::new(ctx, false);
397 assert_eq!(latch.context().name, "ctx-test");
398 assert_eq!(latch.context().timeout, Duration::from_secs(3));
399 }
400
401 #[test]
402 fn test_debug_format() {
403 let latch = SharedLatch::named("debug-test", true);
404 let s = format!("{:?}", latch);
405 assert!(s.contains("debug-test"));
406 assert!(s.contains("exclusive_only=true"));
407 }
408
409 #[test]
410 fn test_try_acquire_exclusive_blocks_shared() {
411 let latch = Arc::new(SharedLatch::named("try-excl-blocks", false));
412 let guard = latch.try_acquire_exclusive();
413 assert!(guard.is_some());
414 assert!(latch.is_exclusive_owner());
415
416 let latch2 = latch.clone();
418 let handle = std::thread::spawn(move || {
419 latch2.try_acquire_exclusive().is_none()
420 });
421 assert!(handle.join().unwrap());
422 drop(guard);
423 assert!(!latch.is_exclusive_owner());
424 }
425
426 #[test]
427 fn test_concurrent_exclusive_serializes() {
428 use std::sync::atomic::{AtomicUsize, Ordering};
429 let latch = Arc::new(SharedLatch::named("concurrent-serial", false));
430 let counter = Arc::new(AtomicUsize::new(0));
431 let concurrent = Arc::new(AtomicUsize::new(0));
432 let violations = Arc::new(AtomicUsize::new(0));
433
434 let threads: Vec<_> = (0..4)
435 .map(|_| {
436 let latch = latch.clone();
437 let counter = counter.clone();
438 let concurrent = concurrent.clone();
439 let violations = violations.clone();
440 std::thread::spawn(move || {
441 for _ in 0..25 {
442 let _guard = latch
443 .acquire_exclusive()
444 .expect("acquire_exclusive");
445 let prev = concurrent.fetch_add(1, Ordering::SeqCst);
446 if prev != 0 {
447 violations.fetch_add(1, Ordering::SeqCst);
448 }
449 counter.fetch_add(1, Ordering::SeqCst);
450 concurrent.fetch_sub(1, Ordering::SeqCst);
451 }
452 })
453 })
454 .collect();
455
456 for t in threads {
457 t.join().unwrap();
458 }
459 assert_eq!(counter.load(Ordering::SeqCst), 100);
460 assert_eq!(
461 violations.load(Ordering::SeqCst),
462 0,
463 "mutual exclusion violated"
464 );
465 }
466
467 #[test]
473 fn test_shared_reacquire_panics() {
474 let result = std::panic::catch_unwind(|| {
475 let latch = SharedLatch::named("noxu-shared-reacquire", false);
476 let _g1 = latch.acquire_shared().expect("first acquire_shared");
477 let _ = latch.acquire_shared();
479 });
480 assert!(result.is_err(), "reentrant shared acquire should panic");
481 }
482
483 #[test]
485 fn test_read_to_write_upgrade_panics_while_shared() {
486 let result = std::panic::catch_unwind(|| {
487 let latch = SharedLatch::named("rwupgrade", false);
488 let _rg = latch.acquire_shared().expect("acquire_shared"); let _ = latch.acquire_exclusive(); });
491 assert!(result.is_err(), "read-to-write upgrade should panic");
492 }
493
494 #[test]
496 fn test_shared_release_not_held_exclusive_path() {
497 let latch = SharedLatch::named("noxu-not-held", false);
498 assert!(!latch.is_exclusive_owner());
500 }
501
502 #[test]
504 fn test_multiple_readers_concurrent() {
505 let latch = Arc::new(SharedLatch::named("noxu-multi-read", false));
506 let ready = Arc::new((
507 noxu_sync::Mutex::new(0usize),
508 noxu_sync::Condvar::new(),
509 ));
510 let mut handles = Vec::new();
511
512 for _ in 0..4 {
513 let latch2 = latch.clone();
514 let ready2 = ready.clone();
515 let h = std::thread::spawn(move || {
516 let _g = latch2.acquire_shared().expect("acquire_shared");
517 {
518 let (m, cv) = &*ready2;
519 let mut g = m.lock();
520 *g += 1;
521 cv.notify_all();
522 }
523 std::thread::sleep(std::time::Duration::from_millis(20));
525 });
526 handles.push(h);
527 }
528
529 {
531 let (m, cv) = &*ready;
532 let mut g = m.lock();
533 while *g < 4 {
534 cv.wait(&mut g);
535 }
536 }
537 for h in handles {
539 h.join().unwrap();
540 }
541 }
542
543 #[test]
545 fn test_exclusive_blocks_then_shared_granted() {
546 let latch =
547 Arc::new(SharedLatch::named("noxu-excl-blocks-shared", false));
548
549 let g = latch.acquire_exclusive().expect("acquire_exclusive");
551 assert!(latch.is_exclusive_owner());
552
553 let latch2 = latch.clone();
554 let acquired = Arc::new(std::sync::atomic::AtomicBool::new(false));
555 let acquired2 = acquired.clone();
556 let h = std::thread::spawn(move || {
557 let _sg = latch2.acquire_shared().expect("acquire_shared");
558 acquired2.store(true, std::sync::atomic::Ordering::SeqCst);
559 });
560
561 std::thread::sleep(std::time::Duration::from_millis(30));
562 assert!(!acquired.load(std::sync::atomic::Ordering::SeqCst));
564
565 drop(g); h.join().unwrap();
567 assert!(acquired.load(std::sync::atomic::Ordering::SeqCst));
568 }
569
570 #[test]
572 fn test_try_acquire_exclusive_no_wait() {
573 let latch = Arc::new(SharedLatch::named("noxu-try-excl", false));
574 let barrier = Arc::new(std::sync::Barrier::new(2));
575
576 let latch2 = latch.clone();
577 let barrier2 = barrier.clone();
578 let h = std::thread::spawn(move || {
579 let _g = latch2.acquire_exclusive().expect("acquire_exclusive");
580 barrier2.wait();
581 std::thread::sleep(std::time::Duration::from_millis(100));
582 });
583
584 barrier.wait();
585 let r = latch.try_acquire_exclusive();
587 assert!(r.is_none(), "try_acquire_exclusive should fail while held");
588 h.join().unwrap();
589
590 let r2 = latch.try_acquire_exclusive();
592 assert!(
593 r2.is_some(),
594 "try_acquire_exclusive should succeed after release"
595 );
596 drop(r2);
597 }
598
599 #[test]
602 fn test_exclusive_only_mode_serializes() {
603 use std::sync::atomic::{AtomicUsize, Ordering};
604 let latch = Arc::new(SharedLatch::named("noxu-excl-only", true));
605 let counter = Arc::new(AtomicUsize::new(0));
606 let concurrent = Arc::new(AtomicUsize::new(0));
607 let violations = Arc::new(AtomicUsize::new(0));
608
609 let threads: Vec<_> = (0..4)
610 .map(|_| {
611 let latch = latch.clone();
612 let counter = counter.clone();
613 let concurrent = concurrent.clone();
614 let violations = violations.clone();
615 std::thread::spawn(move || {
616 for _ in 0..10 {
617 let _g =
618 latch.acquire_shared().expect("acquire_shared"); let prev = concurrent.fetch_add(1, Ordering::SeqCst);
620 if prev != 0 {
621 violations.fetch_add(1, Ordering::SeqCst);
622 }
623 counter.fetch_add(1, Ordering::SeqCst);
624 concurrent.fetch_sub(1, Ordering::SeqCst);
625 }
626 })
627 })
628 .collect();
629
630 for t in threads {
631 t.join().unwrap();
632 }
633 assert_eq!(counter.load(Ordering::SeqCst), 40);
634 assert_eq!(
635 violations.load(Ordering::SeqCst),
636 0,
637 "exclusive-only must serialize"
638 );
639 }
640}