run_down/
rundown_ref.rs

1// Copyright 2019 Brian Gianforcaro
2
3use crate::{flags::to_flags, flags::RundownFlags, guard::RundownGuard};
4use lazy_init::Lazy;
5use rsevents::{Awaitable, ManualResetEvent, State};
6use std::{result::Result, sync::atomic::AtomicU64, sync::atomic::Ordering};
7
8/// The set of errors returned by methods in the run-down crate.
9#[derive(Debug, PartialEq)]
10pub enum RundownError {
11    /// Rundown is already in progress on this shared object.
12    RundownInProgress,
13}
14
15/// Tracks the status of run-down protection for an object.
16/// The type would be embedded in the object needing run-down protection.
17#[derive(Default)]
18pub struct RundownRef {
19    /// The reference count used to track the threads that currently have
20    /// outstanding run-down protection request being tracked by this object.
21    ///
22    /// The reference count holds two parts, the actual count in the lower bits
23    /// and the flags bit in the most significant bit of the u64. The flags and
24    /// reference count interpretation logic is encapsulated in the RundownFlags
25    /// type. It has the logic to correctly mask and fetch the required bits.
26    ///
27    /// We need to bit-pack the flags with the reference count, as we need a single
28    /// atomic type that we can use to implement the interlocked operations which
29    /// provide the thread safety guaranteed by this type.
30    ref_count: AtomicU64,
31
32    /// The event used to signal the thread waiting for rundown that
33    /// rundown is now complete.
34    ///
35    /// The event is lazy initialized to avoid allocating the event
36    /// unless there is an active reference count when rundown starts.
37    event: Lazy<ManualResetEvent>,
38}
39
40/// Common atomic ordering option for all of our compare exchange, loads and stores.
41const ORDERING_VAL: Ordering = Ordering::SeqCst;
42
43impl RundownRef {
44    /// Initializes a new [`RundownRef`].
45    #[inline]
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Re-initialize this instance so it can be used again.
51    /// It is only valid to call re_init once the object is
52    /// completely run-down, via the wait_for_rundown method.
53    ///
54    /// # Important
55    ///
56    /// The moment this method returns, new rundown protection
57    /// requests can succeed. You must perform all re-initialization
58    /// of the shared object the run-down protection is guarding
59    /// before you call this method.
60    pub fn re_init(&self) {
61        let current = self.load_flags();
62
63        // Validate that the object in the correct state.
64        //
65        // TODO: Ideally we should have another bit to represent
66        // rundown being complete vs run-down in progress. It would
67        // give us a more clear state transition.
68        //
69        if current.is_pre_rundown() || current.is_ref_active() {
70            panic!("Attempt to re-init before rundown is complete");
71        }
72
73        // Reset the event if it was previously lazily created so it
74        // can be used again in the future. If the event doesn't exist
75        // yet, then there is nothing to do.
76        if let Some(event) = self.event.get() {
77            event.reset();
78        }
79
80        // Zero the reference count to make the object ready for use.
81        //
82        // Note: Once this store completes then new instances of run-down
83        // protection will be able to be acquired immediately. All
84        // validation and re-initialization needs to occur before this point.
85        self.ref_count.store(0, ORDERING_VAL);
86    }
87
88    /// Attempts to acquire rundown protection on this [`RundownRef`],
89    /// returns the [`RundownGuard`] which holds the reference count,
90    /// or returns an error if the object is already being rundown.
91    pub fn try_acquire(&self) -> Result<RundownGuard<'_>, RundownError> {
92        let mut current = self.load_flags();
93
94        loop {
95            if current.is_rundown_in_progress() {
96                return Err(RundownError::RundownInProgress);
97            }
98
99            let new_bits_with_ref = current.add_ref();
100
101            match self.compare_exchange(current.bits(), new_bits_with_ref) {
102                Ok(_) => return Ok(RundownGuard::new(self)),
103                Err(new_current) => current = to_flags(new_current),
104            }
105        }
106    }
107
108    /// Release previously acquired rundown protection.
109    pub fn release(&self) {
110        let mut current = self.load_flags();
111
112        loop {
113            let bits_with_decrement = current.dec_ref();
114
115            match self.compare_exchange(current.bits(), bits_with_decrement) {
116                Ok(_) => {
117                    current = to_flags(bits_with_decrement);
118                    break;
119                }
120                Err(new_current) => current = to_flags(new_current),
121            }
122        }
123
124        if current.is_ref_zero() && current.is_rundown_in_progress() {
125            let event = self.event.get().expect("Must have been set");
126            event.set();
127        }
128    }
129
130    /// Blocks thread execution until there are no outstanding reference
131    /// counts taken on the [`RundownRef`], and the internal representation
132    /// has been marked with [`RundownFlags::RUNDOWN_IN_PROGRESS`] to signal
133    /// that no other thread can safely acquire a reference count afterwards.
134    ///
135    /// # Important
136    ///
137    /// - This method is not thread safe, it must only be called by one thread.
138    ///
139    /// - This method is however idempotent, it can be called multiple times.
140    ///
141    pub fn wait_for_rundown(&self) {
142        let mut current = self.load_flags();
143
144        loop {
145            // If there are outstanding protection reference-counts
146            // then create the event. At this point it appears that
147            // other threads need to release their protection for
148            // this thread to complete the rundown.
149            if current.is_ref_active() {
150                self.event
151                    .get_or_create(|| ManualResetEvent::new(State::Unset));
152            }
153
154            // Turn on the rundown bit to inform all other threads
155            // that rundown is in progress.
156            let bits_with_rundown = current.set_rundown_in_progress();
157
158            match self.compare_exchange(current.bits(), bits_with_rundown) {
159                Ok(_) => {
160                    current = to_flags(bits_with_rundown);
161                    break;
162                }
163                Err(new_current) => current = to_flags(new_current),
164            }
165        }
166
167        if current.is_ref_active() {
168            let event = self.event.get().expect("Must have been set");
169            event.wait();
170        }
171    }
172
173    /// Load the current flags atomically, for use in the start of all
174    /// atomic compare and exchange loops in this implementation..
175    #[inline]
176    fn load_flags(&self) -> RundownFlags {
177        to_flags(self.ref_count.load(ORDERING_VAL))
178    }
179
180    /// Readability wrapper around atomic compare exchange.
181    #[inline]
182    fn compare_exchange(&self, current: u64, new: u64) -> Result<u64, u64> {
183        self.ref_count
184            .compare_exchange(current, new, ORDERING_VAL, ORDERING_VAL)
185    }
186}
187
188#[cfg(test)]
189use std::sync::Arc;
190#[cfg(test)]
191use std::thread;
192
193//-------------------------------------------------------------------
194// Test: test_wait_when_protected
195//
196// Description:
197//  Test that wait_for_rundown correctly run-down protection fails
198//
199// Notes:
200//  This test needs access to the reference count directly to work.
201//
202#[test]
203#[allow(clippy::result_unwrap_used)]
204fn test_wait_when_protected() {
205    let rundown = Arc::new(RundownRef::new());
206
207    // Acquire protection.
208    let guard = rundown.try_acquire().unwrap();
209
210    // Launch a thread to wait for rundown.
211    let rundown_clone = Arc::clone(&rundown);
212    let waiter = thread::spawn(move || {
213        rundown_clone.wait_for_rundown();
214    });
215
216    // Spin until the rundown bit is set, one set we know
217    // that the waiter is going to wait and the signal that
218    // the drop below will send.
219    while rundown.load_flags().is_pre_rundown() {
220        thread::yield_now();
221    }
222
223    // Release protection, the waiter should be signaled.
224    std::mem::drop(guard);
225
226    waiter.join().unwrap();
227
228    // Verify re-init works after the event is used.
229    // TODO: Split out into an independent test.
230    rundown.re_init();
231}