read_write_store/
header.rs

1#[cfg(debug_assertions)]
2use std::thread;
3use std::time::Instant;
4
5use crate::timeout::{BlockResult, TimedOut};
6use crate::util::sync::atomic::{AtomicU64, Ordering};
7use crate::util::sync::park::{Park, ParkChoice, ParkResult};
8use crate::Timeout;
9
10pub const RESERVED_ID: u32 = u32::MAX;
11
12/// The maximum number of read locks which can be held concurrently.
13///
14/// Acquiring more than this number of read locks simultaneously will panic.
15pub const MAX_CONCURRENT_READS: u32 = (1 << 31) - 2;
16
17/// An unsafe read-write lock with ID matching. As well as the lock state, a header stores an ID for
18/// the data it is guarding. Many operations on the header take an ID which is compared to the
19/// stored ID to determine whether to proceed with the operation.
20pub struct Header {
21    // The layout of the header's state is as follows:
22    // * The most significant bit is the thread notification bit. This is set when no threads are
23    //   blocking on the header and is unset when there are threads blocking. The behavior of this
24    //   bit is the opposite what you might expect because the most significant 32 bits being set
25    //   represents a special state.
26    // * The next 31 most significant bits store the number of readers or is all ones if the header
27    //   is write locked. If the header is unlocked, these bits will be zero.
28    // * There are never any threads blocking on an unlocked header, so it is invalid for the thread
29    //   notification bit to be unset (indicating one or more waiting threads) and for the reader
30    //   bits to be all unset. This bit pattern instead represents the header being in an unoccupied
31    //   state.
32    // * The lower 32 bits store the ID if the header is occupied, or store the next ID if the
33    //   header is unoccupied.
34    state: Park<AtomicU64>,
35}
36
37impl Header {
38    pub fn new() -> Self {
39        let state = Self::unoccupied_bits(0);
40
41        debug_assert!(state == 0, "initial state was not zeroed");
42
43        Self {
44            state: Park::new(AtomicU64::new(state)),
45        }
46    }
47
48    /// Locks the header for reading if it's ID matches and it is not write locked. If the IDs match
49    /// the header must be occupied.
50    pub unsafe fn lock_read(&self, id: u32, timeout: Timeout) -> BlockResult<bool> {
51        debug_assert!(id != RESERVED_ID, "attempted to read lock the reserved ID");
52
53        let current = self.state.load(Ordering::Relaxed);
54        self.lock_read_with_current(id, current, timeout)
55    }
56
57    unsafe fn lock_read_with_current(
58        &self,
59        id: u32,
60        current: u64,
61        timeout: Timeout,
62    ) -> BlockResult<bool> {
63        if Self::id_from_bits(current) != id {
64            return Ok(false);
65        }
66
67        if Self::is_write_locked(current) {
68            return self.lock_read_slow(id, timeout);
69        }
70
71        match self.compare_exchange_weak(current, Self::increment_readers(current)) {
72            Ok(_) => Ok(true),
73            Err(actual) => self.lock_read_with_current(id, actual, timeout),
74        }
75    }
76
77    #[inline(never)]
78    unsafe fn lock_read_slow(&self, id: u32, timeout: Timeout) -> BlockResult<bool> {
79        enum Response {
80            Matched(u64),
81            Mismatch,
82        }
83
84        let timeout_optional = match timeout {
85            Timeout::DontBlock => return BlockResult::Err(TimedOut),
86            Timeout::BlockIndefinitely => None,
87            Timeout::BlockUntil(deadline) => Some(deadline),
88        };
89
90        let result = self.block(timeout_optional, || {
91            let current = self.state.load(Ordering::Relaxed);
92
93            if Self::id_from_bits(current) != id {
94                return BlockChoice::DontBlock(Response::Mismatch);
95            }
96
97            if Self::is_write_locked(current) {
98                return BlockChoice::Block(current);
99            }
100
101            BlockChoice::DontBlock(Response::Matched(current))
102        });
103
104        match result {
105            Ok(Response::Matched(current)) => {
106                match self.compare_exchange_weak(current, Self::increment_readers(current)) {
107                    Ok(_) => Ok(true),
108                    Err(actual) => self.lock_read_with_current(id, actual, timeout),
109                }
110            }
111            Ok(Response::Mismatch) => Ok(false),
112            Err(err) => Err(err),
113        }
114    }
115
116    /// Read unlocks the header. The header must be currently read locked and the ID must match the
117    /// header's current ID.
118    pub unsafe fn unlock_read(&self, id: u32) {
119        let current = self.state.load(Ordering::Relaxed);
120        self.unlock_read_with_current(id, current)
121    }
122
123    unsafe fn unlock_read_with_current(&self, id: u32, current: u64) {
124        debug_assert!(
125            Self::readers_from_bits(current) != 0,
126            "attempted to read unlock already unlocked header"
127        );
128
129        debug_assert!(
130            !Self::is_write_locked(current),
131            "attempted to read unlock write locked header"
132        );
133
134        debug_assert!(
135            Self::id_from_bits(current) == id,
136            "attempted to read unlock with ID 0x{:x} but it was actually 0x{:x}",
137            id,
138            Self::id_from_bits(current)
139        );
140
141        let must_unpark =
142            Self::has_thread_blocking(current) && Self::readers_from_bits(current) == 1;
143
144        let new = if must_unpark {
145            Self::unmark_thread_blocking(Self::decrement_readers(current))
146        } else {
147            Self::decrement_readers(current)
148        };
149
150        match self.compare_exchange_weak(current, new) {
151            Ok(_) => {
152                if must_unpark {
153                    Park::unpark(&self.state)
154                }
155            }
156            Err(actual) => self.unlock_read_with_current(id, actual),
157        }
158    }
159
160    /// Write locks the header if its ID matches and it is not locked. If the IDs match the header
161    /// must be occupied.
162    pub unsafe fn lock_write(&self, id: u32, timeout: Timeout) -> BlockResult<bool> {
163        debug_assert!(id != RESERVED_ID, "attempted to write lock the reserved ID");
164
165        self.transition(
166            Self::occupied_unlocked_bits(id),
167            Self::write_locked_bits(id),
168            timeout,
169        )
170    }
171
172    /// Write unlocks the header. The header must be currently write locked and the ID must match
173    /// this header's current ID.
174    pub unsafe fn unlock_write(&self, id: u32) {
175        let new = Self::occupied_unlocked_bits(id);
176        let old = self.state.swap(new, Ordering::AcqRel);
177
178        debug_assert!(
179            Self::id_from_bits(old) == id,
180            "attempted to write unlock with ID 0x{:x} but it was actually 0x{:x}",
181            id,
182            Self::id_from_bits(old)
183        );
184
185        debug_assert!(
186            Self::is_write_locked(old),
187            "attempted to write unlock header that was not write locked"
188        );
189
190        if Self::has_thread_blocking(old) {
191            Park::unpark(&self.state)
192        }
193    }
194
195    /// Moves the header from an unoccupied state into an occupied one, returning the ID of the
196    /// newly occupied header. The header must be in an unoccupied state.
197    pub unsafe fn occupy(&self) -> u32 {
198        let old = self
199            .state
200            .fetch_or(Self::thread_notification_mask(), Ordering::AcqRel);
201
202        debug_assert!(
203            !Self::is_occupied(old),
204            "attempted to occupy occupied header"
205        );
206
207        debug_assert!(
208            Self::id_from_bits(old) != RESERVED_ID,
209            "attempted to occupy header with the reserved ID"
210        );
211
212        Self::id_from_bits(old)
213    }
214
215    /// Increments the ID of the header and moves it into the unoccupied state if the ID matches. If
216    /// the IDs match, the header must be occupied.
217    pub unsafe fn remove(&self, id: u32, timeout: Timeout) -> BlockResult<RemoveResult> {
218        debug_assert!(id != RESERVED_ID, "attempted to remove the reserved ID");
219
220        let next_id = id + 1;
221
222        let matched = self.transition(
223            Self::occupied_unlocked_bits(id),
224            Self::unoccupied_bits(next_id),
225            timeout,
226        )?;
227
228        if matched {
229            Ok(RemoveResult::Matched {
230                may_reuse: next_id != RESERVED_ID,
231            })
232        } else {
233            Ok(RemoveResult::DidntMatch)
234        }
235    }
236
237    /// Write unlocks the header, increments the ID and moves it into the unoccupied state. The
238    /// header's ID must match and it must be in the write locked state. Returns whether the header
239    /// can be reused.
240    pub unsafe fn remove_locked(&self, id: u32) -> bool {
241        let next_id = id + 1;
242
243        let new = Self::unoccupied_bits(next_id);
244        let old = self.state.swap(new, Ordering::AcqRel);
245
246        debug_assert!(
247            Self::id_from_bits(old) == id,
248            "attempted to write unlock with ID 0x{:x} but it was actually 0x{:x}",
249            id,
250            Self::id_from_bits(old)
251        );
252
253        debug_assert!(
254            Self::is_write_locked(old),
255            "attempted to write unlock header that was not write locked"
256        );
257
258        if Self::has_thread_blocking(old) {
259            Park::unpark(&self.state)
260        }
261
262        next_id != RESERVED_ID
263    }
264
265    /// Sets the state of the header to the new state if it is currently in the expected state. If
266    /// the expected value did not match, the thread will block until it does. If the ID of the
267    /// actual state is different to the ID of the expected state, this will fail.
268    unsafe fn transition(&self, expected: u64, new: u64, timeout: Timeout) -> BlockResult<bool> {
269        match self.compare_exchange_weak(expected, new) {
270            Ok(_) => Ok(true),
271            Err(actual) => {
272                if Self::id_from_bits(actual) == Self::id_from_bits(expected) {
273                    if Self::readers_from_bits(actual) > 0 {
274                        self.transition_slow(expected, new, timeout)
275                    } else {
276                        self.transition(expected, new, timeout)
277                    }
278                } else {
279                    Ok(false)
280                }
281            }
282        }
283    }
284
285    #[inline(never)]
286    unsafe fn transition_slow(
287        &self,
288        expected: u64,
289        new: u64,
290        timeout: Timeout,
291    ) -> BlockResult<bool> {
292        let timeout = match timeout {
293            Timeout::DontBlock => return BlockResult::Err(TimedOut),
294            Timeout::BlockIndefinitely => None,
295            Timeout::BlockUntil(deadline) => Some(deadline),
296        };
297
298        self.block(timeout, move || {
299            match self.compare_exchange(expected, new) {
300                Ok(_) => BlockChoice::DontBlock(true),
301                Err(actual) => {
302                    if Self::id_from_bits(actual) == Self::id_from_bits(expected) {
303                        BlockChoice::Block(actual)
304                    } else {
305                        BlockChoice::DontBlock(false)
306                    }
307                }
308            }
309        })
310    }
311
312    /// Performs an operation atomically with respect to unparking which may either return a final
313    /// result or decide to block. If the operation decides to block, it must return an expected
314    /// value for the current state of the header. This will be used to set the thread notification
315    /// bit of the header with a CAS operation. If the CAS fails, the operation will be run again
316    /// until it succeeds. If blocking is successful, upon wakeup the entire process will be run
317    /// again until the operation decides not to block.
318    unsafe fn block<T, F>(&self, timeout: Option<Instant>, f: F) -> BlockResult<T>
319    where
320        F: Fn() -> BlockChoice<T>,
321    {
322        match Park::park(&self.state, timeout, || {
323            self.block_result_to_park_result(&f)
324        }) {
325            ParkResult::Waited => self.block(timeout, f),
326            ParkResult::TimedOut => Err(TimedOut),
327            ParkResult::DidntPark(result) => Ok(result),
328        }
329    }
330
331    fn block_result_to_park_result<T, F>(&self, f: &F) -> ParkChoice<T>
332    where
333        F: Fn() -> BlockChoice<T>,
334    {
335        match f() {
336            BlockChoice::Block(expected_state) => {
337                let new_state = Self::mark_thread_blocking(expected_state);
338
339                if self.compare_exchange(expected_state, new_state).is_ok() {
340                    ParkChoice::Park
341                } else {
342                    self.block_result_to_park_result(f)
343                }
344            }
345            BlockChoice::DontBlock(result) => ParkChoice::DontPark(result),
346        }
347    }
348
349    /// Determines whether or not the header is tracking an element.
350    pub fn needs_drop(&mut self) -> bool {
351        Self::is_occupied(self.state.load_directly())
352    }
353
354    /// Determines the ID the header is currently tracking.
355    pub fn id(&mut self) -> u32 {
356        let state = self.state.load_directly();
357        Self::id_from_bits(state)
358    }
359
360    pub fn id_if_occupied(&mut self) -> Option<u32> {
361        let state = self.state.load_directly();
362
363        if Self::is_occupied(state) {
364            Some(Self::id_from_bits(state))
365        } else {
366            None
367        }
368    }
369
370    /// Puts the header into the unoccupied state, returning the header's ID if it was occupied.
371    pub fn reset(&mut self) -> Option<u32> {
372        let state = self.state.load_directly();
373
374        debug_assert!(
375            Self::readers_from_bits(state) == 0,
376            "header had readers (0x{:x}) when being reset",
377            Self::readers_from_bits(state),
378        );
379
380        if Self::is_occupied(state) {
381            let id = Self::id_from_bits(state);
382
383            debug_assert!(
384                !Self::has_thread_blocking(state),
385                "header had thread blocking when being reset"
386            );
387
388            self.state.store_directly(Self::unoccupied_bits(id));
389
390            Some(id)
391        } else {
392            None
393        }
394    }
395
396    fn compare_exchange(&self, expected: u64, new: u64) -> Result<u64, u64> {
397        self.state
398            .compare_exchange(expected, new, Ordering::Release, Ordering::Relaxed)
399    }
400
401    fn compare_exchange_weak(&self, expected: u64, new: u64) -> Result<u64, u64> {
402        self.state
403            .compare_exchange_weak(expected, new, Ordering::Release, Ordering::Relaxed)
404    }
405
406    fn unoccupied_bits(id: u32) -> u64 {
407        id as u64
408    }
409
410    fn occupied_unlocked_bits(id: u32) -> u64 {
411        Self::thread_notification_mask() | Self::unoccupied_bits(id)
412    }
413
414    fn thread_notification_mask() -> u64 {
415        1u64 << 63
416    }
417
418    fn is_occupied(state: u64) -> bool {
419        state >> 32 != 0
420    }
421
422    fn write_locked_bits(id: u32) -> u64 {
423        (id as u64) | ((u32::MAX as u64) << 32)
424    }
425
426    fn id_from_bits(bits: u64) -> u32 {
427        bits as u32
428    }
429
430    fn readers_from_bits(bits: u64) -> u32 {
431        (bits >> 32) as u32 & !(1u32 << 31)
432    }
433
434    fn is_write_locked(bits: u64) -> bool {
435        Self::readers_from_bits(bits) == !(1u32 << 31)
436    }
437
438    fn has_thread_blocking(bits: u64) -> bool {
439        debug_assert!(
440            Self::is_occupied(bits),
441            "cannot check thread blocking status when unoccupied"
442        );
443
444        bits & Self::thread_notification_mask() == 0
445    }
446
447    fn mark_thread_blocking(bits: u64) -> u64 {
448        debug_assert!(
449            Self::readers_from_bits(bits) > 0,
450            "cannot block when unlocked"
451        );
452
453        debug_assert!(
454            Self::id_from_bits(bits) != RESERVED_ID,
455            "cannot block on the reserved ID"
456        );
457
458        bits & !Self::thread_notification_mask()
459    }
460
461    fn unmark_thread_blocking(bits: u64) -> u64 {
462        bits | Self::thread_notification_mask()
463    }
464
465    fn increment_readers(bits: u64) -> u64 {
466        if Self::readers_from_bits(bits) == MAX_CONCURRENT_READS {
467            Self::too_many_readers();
468        }
469
470        debug_assert!(
471            !Self::is_write_locked(bits),
472            "cannot add reader when write locked"
473        );
474
475        debug_assert!(
476            Self::id_from_bits(bits) != RESERVED_ID,
477            "cannot lock when empty"
478        );
479
480        bits + (1 << 32)
481    }
482
483    #[inline(never)]
484    fn too_many_readers() -> ! {
485        panic!("too many concurrent readers on RwStore element")
486    }
487
488    fn decrement_readers(bits: u64) -> u64 {
489        debug_assert!(Self::readers_from_bits(bits) != 0, "no readers to remove");
490        bits - (1 << 32)
491    }
492}
493
494#[cfg(debug_assertions)]
495impl Drop for Header {
496    fn drop(&mut self) {
497        if !thread::panicking() {
498            let state = self.state.load_directly();
499
500            debug_assert!(
501                Self::readers_from_bits(state) == 0,
502                "header had readers (0x{:x}) when being dropped",
503                Self::readers_from_bits(state),
504            );
505
506            debug_assert!(
507                !Self::is_occupied(state) || !Self::has_thread_blocking(state),
508                "header had thread blocking when being dropped"
509            );
510        }
511    }
512}
513
514#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
515pub enum RemoveResult {
516    Matched { may_reuse: bool },
517    DidntMatch,
518}
519
520enum BlockChoice<T> {
521    Block(u64),
522    DontBlock(T),
523}
524
525#[cfg(test)]
526mod test {
527    use crate::header::{Header, RemoveResult};
528    use crate::timeout::TimedOut;
529    use crate::timeout::Timeout::DontBlock;
530
531    #[test]
532    fn reset_initially_returns_none() {
533        let mut header = Header::new();
534        assert_eq!(header.reset(), None);
535    }
536
537    #[test]
538    fn reset_returns_the_tracked_id() {
539        unsafe {
540            let mut header = Header::new();
541            let id = header.occupy();
542
543            assert_eq!(header.reset(), Some(id));
544        }
545
546        unsafe {
547            let mut header = Header::new();
548            let id = header.occupy();
549            header.remove(id, DontBlock).unwrap();
550            let id = header.occupy();
551
552            assert_eq!(header.reset(), Some(id));
553        }
554    }
555
556    #[test]
557    fn reset_returns_none_after_double_invocation() {
558        unsafe {
559            let mut header = Header::new();
560            header.occupy();
561
562            header.reset();
563            assert_eq!(header.reset(), None);
564        }
565    }
566
567    #[test]
568    fn needs_drop_is_false_initially() {
569        let mut header = Header::new();
570        assert!(!header.needs_drop());
571    }
572
573    #[test]
574    fn needs_drop_is_true_after_occupation() {
575        unsafe {
576            let mut header = Header::new();
577            header.occupy();
578
579            assert!(header.needs_drop());
580        }
581    }
582
583    #[test]
584    fn needs_drop_is_false_after_removal() {
585        unsafe {
586            let mut header = Header::new();
587            let id = header.occupy();
588
589            header.remove(id, DontBlock).unwrap();
590            assert!(!header.needs_drop());
591        }
592    }
593
594    #[test]
595    fn needs_drop_is_false_after_locked_removal() {
596        unsafe {
597            let mut header = Header::new();
598            let id = header.occupy();
599
600            header.lock_write(id, DontBlock).unwrap();
601            header.remove_locked(id);
602
603            assert!(!header.needs_drop());
604        }
605    }
606
607    #[test]
608    fn lock_read_succeeds_when_id_matches() {
609        unsafe {
610            let header = Header::new();
611            let id = header.occupy();
612
613            assert_eq!(header.lock_read(id, DontBlock), Ok(true));
614            header.unlock_read(id);
615        }
616    }
617
618    #[test]
619    fn lock_write_succeeds_when_id_matches() {
620        unsafe {
621            let header = Header::new();
622            let id = header.occupy();
623
624            assert_eq!(header.lock_write(id, DontBlock), Ok(true));
625            header.unlock_write(id);
626        }
627    }
628
629    #[test]
630    fn lock_read_fails_when_id_doesnt_match() {
631        unsafe {
632            let header = Header::new();
633            let id = header.occupy();
634
635            assert_eq!(header.lock_read(id + 1, DontBlock), Ok(false));
636        }
637    }
638
639    #[test]
640    fn lock_write_fails_when_id_doesnt_match() {
641        unsafe {
642            let header = Header::new();
643            let id = header.occupy();
644
645            assert_eq!(header.lock_write(id + 1, DontBlock), Ok(false));
646        }
647    }
648
649    #[test]
650    fn double_read_lock_succeeds() {
651        unsafe {
652            let header = Header::new();
653            let id = header.occupy();
654
655            header.lock_read(id, DontBlock).unwrap();
656            assert_eq!(header.lock_read(id, DontBlock), Ok(true));
657            header.unlock_read(id);
658            header.unlock_read(id);
659        }
660    }
661
662    #[test]
663    fn remove_succeeds_when_id_matches() {
664        unsafe {
665            let header = Header::new();
666            let id = header.occupy();
667
668            assert_eq!(
669                header.remove(id, DontBlock),
670                Ok(RemoveResult::Matched { may_reuse: true })
671            );
672        }
673    }
674
675    #[test]
676    fn remove_fails_when_id_doesnt_match() {
677        unsafe {
678            let header = Header::new();
679            let id = header.occupy();
680
681            assert_eq!(
682                header.remove(id + 1, DontBlock),
683                Ok(RemoveResult::DidntMatch)
684            );
685        }
686    }
687
688    #[test]
689    fn remove_fails_before_occupation() {
690        unsafe {
691            let header = Header::new();
692            assert_eq!(header.remove(42, DontBlock), Ok(RemoveResult::DidntMatch));
693        }
694    }
695
696    #[test]
697    fn remove_fails_after_double_invocation() {
698        unsafe {
699            let header = Header::new();
700            let id = header.occupy();
701
702            header.remove(id, DontBlock).unwrap();
703            assert_eq!(header.remove(id, DontBlock), Ok(RemoveResult::DidntMatch));
704        }
705    }
706
707    #[test]
708    fn cannot_lock_read_when_locking_write() {
709        unsafe {
710            let header = Header::new();
711            let id = header.occupy();
712
713            header.lock_write(id, DontBlock).unwrap();
714            assert_eq!(header.lock_read(id, DontBlock), Err(TimedOut));
715            header.unlock_write(id);
716        }
717    }
718
719    #[test]
720    fn cannot_lock_write_when_locking_read() {
721        unsafe {
722            let header = Header::new();
723            let id = header.occupy();
724
725            header.lock_read(id, DontBlock).unwrap();
726            assert_eq!(header.lock_write(id, DontBlock), Err(TimedOut));
727            header.unlock_read(id);
728        }
729    }
730
731    #[test]
732    fn cannot_remove_when_locking_read() {
733        unsafe {
734            let header = Header::new();
735            let id = header.occupy();
736
737            header.lock_read(id, DontBlock).unwrap();
738            assert_eq!(header.remove(id, DontBlock), Err(TimedOut));
739            header.unlock_read(id);
740        }
741    }
742
743    #[test]
744    fn cannot_remove_when_locking_write() {
745        unsafe {
746            let header = Header::new();
747            let id = header.occupy();
748
749            header.lock_write(id, DontBlock).unwrap();
750            assert_eq!(header.remove(id, DontBlock), Err(TimedOut));
751            header.unlock_write(id);
752        }
753    }
754}