1#[cfg(debug_assertions)]
2use std::thread;
3use std::time::Instant;
4
5use crate::timeout::{BlockResult, TimedOut};
6use crate::util::sync::atomic::{AtomicU64, Ordering};
7use crate::util::sync::park::{Park, ParkChoice, ParkResult};
8use crate::Timeout;
9
10pub const RESERVED_ID: u32 = u32::MAX;
11
12pub const MAX_CONCURRENT_READS: u32 = (1 << 31) - 2;
16
17pub struct Header {
21 state: Park<AtomicU64>,
35}
36
37impl Header {
38 pub fn new() -> Self {
39 let state = Self::unoccupied_bits(0);
40
41 debug_assert!(state == 0, "initial state was not zeroed");
42
43 Self {
44 state: Park::new(AtomicU64::new(state)),
45 }
46 }
47
48 pub unsafe fn lock_read(&self, id: u32, timeout: Timeout) -> BlockResult<bool> {
51 debug_assert!(id != RESERVED_ID, "attempted to read lock the reserved ID");
52
53 let current = self.state.load(Ordering::Relaxed);
54 self.lock_read_with_current(id, current, timeout)
55 }
56
57 unsafe fn lock_read_with_current(
58 &self,
59 id: u32,
60 current: u64,
61 timeout: Timeout,
62 ) -> BlockResult<bool> {
63 if Self::id_from_bits(current) != id {
64 return Ok(false);
65 }
66
67 if Self::is_write_locked(current) {
68 return self.lock_read_slow(id, timeout);
69 }
70
71 match self.compare_exchange_weak(current, Self::increment_readers(current)) {
72 Ok(_) => Ok(true),
73 Err(actual) => self.lock_read_with_current(id, actual, timeout),
74 }
75 }
76
77 #[inline(never)]
78 unsafe fn lock_read_slow(&self, id: u32, timeout: Timeout) -> BlockResult<bool> {
79 enum Response {
80 Matched(u64),
81 Mismatch,
82 }
83
84 let timeout_optional = match timeout {
85 Timeout::DontBlock => return BlockResult::Err(TimedOut),
86 Timeout::BlockIndefinitely => None,
87 Timeout::BlockUntil(deadline) => Some(deadline),
88 };
89
90 let result = self.block(timeout_optional, || {
91 let current = self.state.load(Ordering::Relaxed);
92
93 if Self::id_from_bits(current) != id {
94 return BlockChoice::DontBlock(Response::Mismatch);
95 }
96
97 if Self::is_write_locked(current) {
98 return BlockChoice::Block(current);
99 }
100
101 BlockChoice::DontBlock(Response::Matched(current))
102 });
103
104 match result {
105 Ok(Response::Matched(current)) => {
106 match self.compare_exchange_weak(current, Self::increment_readers(current)) {
107 Ok(_) => Ok(true),
108 Err(actual) => self.lock_read_with_current(id, actual, timeout),
109 }
110 }
111 Ok(Response::Mismatch) => Ok(false),
112 Err(err) => Err(err),
113 }
114 }
115
116 pub unsafe fn unlock_read(&self, id: u32) {
119 let current = self.state.load(Ordering::Relaxed);
120 self.unlock_read_with_current(id, current)
121 }
122
123 unsafe fn unlock_read_with_current(&self, id: u32, current: u64) {
124 debug_assert!(
125 Self::readers_from_bits(current) != 0,
126 "attempted to read unlock already unlocked header"
127 );
128
129 debug_assert!(
130 !Self::is_write_locked(current),
131 "attempted to read unlock write locked header"
132 );
133
134 debug_assert!(
135 Self::id_from_bits(current) == id,
136 "attempted to read unlock with ID 0x{:x} but it was actually 0x{:x}",
137 id,
138 Self::id_from_bits(current)
139 );
140
141 let must_unpark =
142 Self::has_thread_blocking(current) && Self::readers_from_bits(current) == 1;
143
144 let new = if must_unpark {
145 Self::unmark_thread_blocking(Self::decrement_readers(current))
146 } else {
147 Self::decrement_readers(current)
148 };
149
150 match self.compare_exchange_weak(current, new) {
151 Ok(_) => {
152 if must_unpark {
153 Park::unpark(&self.state)
154 }
155 }
156 Err(actual) => self.unlock_read_with_current(id, actual),
157 }
158 }
159
160 pub unsafe fn lock_write(&self, id: u32, timeout: Timeout) -> BlockResult<bool> {
163 debug_assert!(id != RESERVED_ID, "attempted to write lock the reserved ID");
164
165 self.transition(
166 Self::occupied_unlocked_bits(id),
167 Self::write_locked_bits(id),
168 timeout,
169 )
170 }
171
172 pub unsafe fn unlock_write(&self, id: u32) {
175 let new = Self::occupied_unlocked_bits(id);
176 let old = self.state.swap(new, Ordering::AcqRel);
177
178 debug_assert!(
179 Self::id_from_bits(old) == id,
180 "attempted to write unlock with ID 0x{:x} but it was actually 0x{:x}",
181 id,
182 Self::id_from_bits(old)
183 );
184
185 debug_assert!(
186 Self::is_write_locked(old),
187 "attempted to write unlock header that was not write locked"
188 );
189
190 if Self::has_thread_blocking(old) {
191 Park::unpark(&self.state)
192 }
193 }
194
195 pub unsafe fn occupy(&self) -> u32 {
198 let old = self
199 .state
200 .fetch_or(Self::thread_notification_mask(), Ordering::AcqRel);
201
202 debug_assert!(
203 !Self::is_occupied(old),
204 "attempted to occupy occupied header"
205 );
206
207 debug_assert!(
208 Self::id_from_bits(old) != RESERVED_ID,
209 "attempted to occupy header with the reserved ID"
210 );
211
212 Self::id_from_bits(old)
213 }
214
215 pub unsafe fn remove(&self, id: u32, timeout: Timeout) -> BlockResult<RemoveResult> {
218 debug_assert!(id != RESERVED_ID, "attempted to remove the reserved ID");
219
220 let next_id = id + 1;
221
222 let matched = self.transition(
223 Self::occupied_unlocked_bits(id),
224 Self::unoccupied_bits(next_id),
225 timeout,
226 )?;
227
228 if matched {
229 Ok(RemoveResult::Matched {
230 may_reuse: next_id != RESERVED_ID,
231 })
232 } else {
233 Ok(RemoveResult::DidntMatch)
234 }
235 }
236
237 pub unsafe fn remove_locked(&self, id: u32) -> bool {
241 let next_id = id + 1;
242
243 let new = Self::unoccupied_bits(next_id);
244 let old = self.state.swap(new, Ordering::AcqRel);
245
246 debug_assert!(
247 Self::id_from_bits(old) == id,
248 "attempted to write unlock with ID 0x{:x} but it was actually 0x{:x}",
249 id,
250 Self::id_from_bits(old)
251 );
252
253 debug_assert!(
254 Self::is_write_locked(old),
255 "attempted to write unlock header that was not write locked"
256 );
257
258 if Self::has_thread_blocking(old) {
259 Park::unpark(&self.state)
260 }
261
262 next_id != RESERVED_ID
263 }
264
265 unsafe fn transition(&self, expected: u64, new: u64, timeout: Timeout) -> BlockResult<bool> {
269 match self.compare_exchange_weak(expected, new) {
270 Ok(_) => Ok(true),
271 Err(actual) => {
272 if Self::id_from_bits(actual) == Self::id_from_bits(expected) {
273 if Self::readers_from_bits(actual) > 0 {
274 self.transition_slow(expected, new, timeout)
275 } else {
276 self.transition(expected, new, timeout)
277 }
278 } else {
279 Ok(false)
280 }
281 }
282 }
283 }
284
285 #[inline(never)]
286 unsafe fn transition_slow(
287 &self,
288 expected: u64,
289 new: u64,
290 timeout: Timeout,
291 ) -> BlockResult<bool> {
292 let timeout = match timeout {
293 Timeout::DontBlock => return BlockResult::Err(TimedOut),
294 Timeout::BlockIndefinitely => None,
295 Timeout::BlockUntil(deadline) => Some(deadline),
296 };
297
298 self.block(timeout, move || {
299 match self.compare_exchange(expected, new) {
300 Ok(_) => BlockChoice::DontBlock(true),
301 Err(actual) => {
302 if Self::id_from_bits(actual) == Self::id_from_bits(expected) {
303 BlockChoice::Block(actual)
304 } else {
305 BlockChoice::DontBlock(false)
306 }
307 }
308 }
309 })
310 }
311
312 unsafe fn block<T, F>(&self, timeout: Option<Instant>, f: F) -> BlockResult<T>
319 where
320 F: Fn() -> BlockChoice<T>,
321 {
322 match Park::park(&self.state, timeout, || {
323 self.block_result_to_park_result(&f)
324 }) {
325 ParkResult::Waited => self.block(timeout, f),
326 ParkResult::TimedOut => Err(TimedOut),
327 ParkResult::DidntPark(result) => Ok(result),
328 }
329 }
330
331 fn block_result_to_park_result<T, F>(&self, f: &F) -> ParkChoice<T>
332 where
333 F: Fn() -> BlockChoice<T>,
334 {
335 match f() {
336 BlockChoice::Block(expected_state) => {
337 let new_state = Self::mark_thread_blocking(expected_state);
338
339 if self.compare_exchange(expected_state, new_state).is_ok() {
340 ParkChoice::Park
341 } else {
342 self.block_result_to_park_result(f)
343 }
344 }
345 BlockChoice::DontBlock(result) => ParkChoice::DontPark(result),
346 }
347 }
348
349 pub fn needs_drop(&mut self) -> bool {
351 Self::is_occupied(self.state.load_directly())
352 }
353
354 pub fn id(&mut self) -> u32 {
356 let state = self.state.load_directly();
357 Self::id_from_bits(state)
358 }
359
360 pub fn id_if_occupied(&mut self) -> Option<u32> {
361 let state = self.state.load_directly();
362
363 if Self::is_occupied(state) {
364 Some(Self::id_from_bits(state))
365 } else {
366 None
367 }
368 }
369
370 pub fn reset(&mut self) -> Option<u32> {
372 let state = self.state.load_directly();
373
374 debug_assert!(
375 Self::readers_from_bits(state) == 0,
376 "header had readers (0x{:x}) when being reset",
377 Self::readers_from_bits(state),
378 );
379
380 if Self::is_occupied(state) {
381 let id = Self::id_from_bits(state);
382
383 debug_assert!(
384 !Self::has_thread_blocking(state),
385 "header had thread blocking when being reset"
386 );
387
388 self.state.store_directly(Self::unoccupied_bits(id));
389
390 Some(id)
391 } else {
392 None
393 }
394 }
395
396 fn compare_exchange(&self, expected: u64, new: u64) -> Result<u64, u64> {
397 self.state
398 .compare_exchange(expected, new, Ordering::Release, Ordering::Relaxed)
399 }
400
401 fn compare_exchange_weak(&self, expected: u64, new: u64) -> Result<u64, u64> {
402 self.state
403 .compare_exchange_weak(expected, new, Ordering::Release, Ordering::Relaxed)
404 }
405
406 fn unoccupied_bits(id: u32) -> u64 {
407 id as u64
408 }
409
410 fn occupied_unlocked_bits(id: u32) -> u64 {
411 Self::thread_notification_mask() | Self::unoccupied_bits(id)
412 }
413
414 fn thread_notification_mask() -> u64 {
415 1u64 << 63
416 }
417
418 fn is_occupied(state: u64) -> bool {
419 state >> 32 != 0
420 }
421
422 fn write_locked_bits(id: u32) -> u64 {
423 (id as u64) | ((u32::MAX as u64) << 32)
424 }
425
426 fn id_from_bits(bits: u64) -> u32 {
427 bits as u32
428 }
429
430 fn readers_from_bits(bits: u64) -> u32 {
431 (bits >> 32) as u32 & !(1u32 << 31)
432 }
433
434 fn is_write_locked(bits: u64) -> bool {
435 Self::readers_from_bits(bits) == !(1u32 << 31)
436 }
437
438 fn has_thread_blocking(bits: u64) -> bool {
439 debug_assert!(
440 Self::is_occupied(bits),
441 "cannot check thread blocking status when unoccupied"
442 );
443
444 bits & Self::thread_notification_mask() == 0
445 }
446
447 fn mark_thread_blocking(bits: u64) -> u64 {
448 debug_assert!(
449 Self::readers_from_bits(bits) > 0,
450 "cannot block when unlocked"
451 );
452
453 debug_assert!(
454 Self::id_from_bits(bits) != RESERVED_ID,
455 "cannot block on the reserved ID"
456 );
457
458 bits & !Self::thread_notification_mask()
459 }
460
461 fn unmark_thread_blocking(bits: u64) -> u64 {
462 bits | Self::thread_notification_mask()
463 }
464
465 fn increment_readers(bits: u64) -> u64 {
466 if Self::readers_from_bits(bits) == MAX_CONCURRENT_READS {
467 Self::too_many_readers();
468 }
469
470 debug_assert!(
471 !Self::is_write_locked(bits),
472 "cannot add reader when write locked"
473 );
474
475 debug_assert!(
476 Self::id_from_bits(bits) != RESERVED_ID,
477 "cannot lock when empty"
478 );
479
480 bits + (1 << 32)
481 }
482
483 #[inline(never)]
484 fn too_many_readers() -> ! {
485 panic!("too many concurrent readers on RwStore element")
486 }
487
488 fn decrement_readers(bits: u64) -> u64 {
489 debug_assert!(Self::readers_from_bits(bits) != 0, "no readers to remove");
490 bits - (1 << 32)
491 }
492}
493
494#[cfg(debug_assertions)]
495impl Drop for Header {
496 fn drop(&mut self) {
497 if !thread::panicking() {
498 let state = self.state.load_directly();
499
500 debug_assert!(
501 Self::readers_from_bits(state) == 0,
502 "header had readers (0x{:x}) when being dropped",
503 Self::readers_from_bits(state),
504 );
505
506 debug_assert!(
507 !Self::is_occupied(state) || !Self::has_thread_blocking(state),
508 "header had thread blocking when being dropped"
509 );
510 }
511 }
512}
513
514#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
515pub enum RemoveResult {
516 Matched { may_reuse: bool },
517 DidntMatch,
518}
519
520enum BlockChoice<T> {
521 Block(u64),
522 DontBlock(T),
523}
524
525#[cfg(test)]
526mod test {
527 use crate::header::{Header, RemoveResult};
528 use crate::timeout::TimedOut;
529 use crate::timeout::Timeout::DontBlock;
530
531 #[test]
532 fn reset_initially_returns_none() {
533 let mut header = Header::new();
534 assert_eq!(header.reset(), None);
535 }
536
537 #[test]
538 fn reset_returns_the_tracked_id() {
539 unsafe {
540 let mut header = Header::new();
541 let id = header.occupy();
542
543 assert_eq!(header.reset(), Some(id));
544 }
545
546 unsafe {
547 let mut header = Header::new();
548 let id = header.occupy();
549 header.remove(id, DontBlock).unwrap();
550 let id = header.occupy();
551
552 assert_eq!(header.reset(), Some(id));
553 }
554 }
555
556 #[test]
557 fn reset_returns_none_after_double_invocation() {
558 unsafe {
559 let mut header = Header::new();
560 header.occupy();
561
562 header.reset();
563 assert_eq!(header.reset(), None);
564 }
565 }
566
567 #[test]
568 fn needs_drop_is_false_initially() {
569 let mut header = Header::new();
570 assert!(!header.needs_drop());
571 }
572
573 #[test]
574 fn needs_drop_is_true_after_occupation() {
575 unsafe {
576 let mut header = Header::new();
577 header.occupy();
578
579 assert!(header.needs_drop());
580 }
581 }
582
583 #[test]
584 fn needs_drop_is_false_after_removal() {
585 unsafe {
586 let mut header = Header::new();
587 let id = header.occupy();
588
589 header.remove(id, DontBlock).unwrap();
590 assert!(!header.needs_drop());
591 }
592 }
593
594 #[test]
595 fn needs_drop_is_false_after_locked_removal() {
596 unsafe {
597 let mut header = Header::new();
598 let id = header.occupy();
599
600 header.lock_write(id, DontBlock).unwrap();
601 header.remove_locked(id);
602
603 assert!(!header.needs_drop());
604 }
605 }
606
607 #[test]
608 fn lock_read_succeeds_when_id_matches() {
609 unsafe {
610 let header = Header::new();
611 let id = header.occupy();
612
613 assert_eq!(header.lock_read(id, DontBlock), Ok(true));
614 header.unlock_read(id);
615 }
616 }
617
618 #[test]
619 fn lock_write_succeeds_when_id_matches() {
620 unsafe {
621 let header = Header::new();
622 let id = header.occupy();
623
624 assert_eq!(header.lock_write(id, DontBlock), Ok(true));
625 header.unlock_write(id);
626 }
627 }
628
629 #[test]
630 fn lock_read_fails_when_id_doesnt_match() {
631 unsafe {
632 let header = Header::new();
633 let id = header.occupy();
634
635 assert_eq!(header.lock_read(id + 1, DontBlock), Ok(false));
636 }
637 }
638
639 #[test]
640 fn lock_write_fails_when_id_doesnt_match() {
641 unsafe {
642 let header = Header::new();
643 let id = header.occupy();
644
645 assert_eq!(header.lock_write(id + 1, DontBlock), Ok(false));
646 }
647 }
648
649 #[test]
650 fn double_read_lock_succeeds() {
651 unsafe {
652 let header = Header::new();
653 let id = header.occupy();
654
655 header.lock_read(id, DontBlock).unwrap();
656 assert_eq!(header.lock_read(id, DontBlock), Ok(true));
657 header.unlock_read(id);
658 header.unlock_read(id);
659 }
660 }
661
662 #[test]
663 fn remove_succeeds_when_id_matches() {
664 unsafe {
665 let header = Header::new();
666 let id = header.occupy();
667
668 assert_eq!(
669 header.remove(id, DontBlock),
670 Ok(RemoveResult::Matched { may_reuse: true })
671 );
672 }
673 }
674
675 #[test]
676 fn remove_fails_when_id_doesnt_match() {
677 unsafe {
678 let header = Header::new();
679 let id = header.occupy();
680
681 assert_eq!(
682 header.remove(id + 1, DontBlock),
683 Ok(RemoveResult::DidntMatch)
684 );
685 }
686 }
687
688 #[test]
689 fn remove_fails_before_occupation() {
690 unsafe {
691 let header = Header::new();
692 assert_eq!(header.remove(42, DontBlock), Ok(RemoveResult::DidntMatch));
693 }
694 }
695
696 #[test]
697 fn remove_fails_after_double_invocation() {
698 unsafe {
699 let header = Header::new();
700 let id = header.occupy();
701
702 header.remove(id, DontBlock).unwrap();
703 assert_eq!(header.remove(id, DontBlock), Ok(RemoveResult::DidntMatch));
704 }
705 }
706
707 #[test]
708 fn cannot_lock_read_when_locking_write() {
709 unsafe {
710 let header = Header::new();
711 let id = header.occupy();
712
713 header.lock_write(id, DontBlock).unwrap();
714 assert_eq!(header.lock_read(id, DontBlock), Err(TimedOut));
715 header.unlock_write(id);
716 }
717 }
718
719 #[test]
720 fn cannot_lock_write_when_locking_read() {
721 unsafe {
722 let header = Header::new();
723 let id = header.occupy();
724
725 header.lock_read(id, DontBlock).unwrap();
726 assert_eq!(header.lock_write(id, DontBlock), Err(TimedOut));
727 header.unlock_read(id);
728 }
729 }
730
731 #[test]
732 fn cannot_remove_when_locking_read() {
733 unsafe {
734 let header = Header::new();
735 let id = header.occupy();
736
737 header.lock_read(id, DontBlock).unwrap();
738 assert_eq!(header.remove(id, DontBlock), Err(TimedOut));
739 header.unlock_read(id);
740 }
741 }
742
743 #[test]
744 fn cannot_remove_when_locking_write() {
745 unsafe {
746 let header = Header::new();
747 let id = header.occupy();
748
749 header.lock_write(id, DontBlock).unwrap();
750 assert_eq!(header.remove(id, DontBlock), Err(TimedOut));
751 header.unlock_write(id);
752 }
753 }
754}