Skip to main content

miden_utils_sync/
rw_lock.rs

1#[cfg(not(loom))]
2use core::{
3    hint,
4    sync::atomic::{AtomicUsize, Ordering},
5};
6
7use lock_api::RawRwLock;
8#[cfg(loom)]
9use loom::{
10    hint,
11    sync::atomic::{AtomicUsize, Ordering},
12};
13
14/// An implementation of a reader-writer lock, based on a spinlock primitive, no-std compatible
15///
16/// See [lock_api::RwLock] for usage.
17pub type RwLock<T> = lock_api::RwLock<Spinlock, T>;
18
19/// See [lock_api::RwLockReadGuard] for usage.
20pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, Spinlock, T>;
21
22/// See [lock_api::RwLockWriteGuard] for usage.
23pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, Spinlock, T>;
24
25/// The underlying raw reader-writer primitive that implements [lock_api::RawRwLock]
26///
27/// This is fundamentally a spinlock, in that blocking operations on the lock will spin until
28/// they succeed in acquiring/releasing the lock.
29///
30/// To achieve the ability to share the underlying data with multiple readers, or hold
31/// exclusive access for one writer, the lock state is based on a "locked" count, where shared
32/// access increments the count by an even number, and acquiring exclusive access relies on the
33/// use of the lowest order bit to stop further shared acquisition, and indicate that the lock
34/// is exclusively held (the difference between the two is irrelevant from the perspective of
35/// a thread attempting to acquire the lock, but internally the state uses `usize::MAX` as the
36/// "exclusively locked" sentinel).
37///
38/// This mechanism gets us the following:
39///
40/// * Whether the lock has been acquired (shared or exclusive)
41/// * Whether the lock is being exclusively acquired
42/// * How many times the lock has been acquired
43/// * Whether the acquisition(s) are exclusive or shared
44///
45/// Further implementation details, such as how we manage draining readers once an attempt to
46/// exclusively acquire the lock occurs, are described below.
47///
48/// NOTE: This is a simple implementation, meant for use in no-std environments; there are much
49/// more robust/performant implementations available when OS primitives can be used.
50pub struct Spinlock {
51    /// The state of the lock, primarily representing the acquisition count, but relying on
52    /// the distinction between even and odd values to indicate whether or not exclusive access
53    /// is being acquired.
54    state: AtomicUsize,
55    /// A counter used to wake a parked writer once the last shared lock is released during
56    /// acquisition of an exclusive lock. The actual count is not acutally important, and
57    /// simply wraps around on overflow, but what is important is that when the value changes,
58    /// the writer will wake and resume attempting to acquire the exclusive lock.
59    writer_wake_counter: AtomicUsize,
60}
61
62impl Default for Spinlock {
63    #[inline(always)]
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl Spinlock {
70    #[cfg(not(loom))]
71    pub const fn new() -> Self {
72        Self {
73            state: AtomicUsize::new(0),
74            writer_wake_counter: AtomicUsize::new(0),
75        }
76    }
77
78    #[cfg(loom)]
79    pub fn new() -> Self {
80        Self {
81            state: AtomicUsize::new(0),
82            writer_wake_counter: AtomicUsize::new(0),
83        }
84    }
85}
86
87unsafe impl RawRwLock for Spinlock {
88    #[cfg(loom)]
89    const INIT: Spinlock = unimplemented!();
90
91    #[cfg(not(loom))]
92    // This is intentional on the part of the [RawRwLock] API, basically a hack to provide
93    // initial values as static items.
94    const INIT: Spinlock = Spinlock::new();
95
96    type GuardMarker = lock_api::GuardSend;
97
98    /// The operation invoked when calling `RwLock::read`, blocks the caller until acquired
99    fn lock_shared(&self) {
100        let mut s = self.state.load(Ordering::Relaxed);
101        loop {
102            // If the exclusive bit is unset, attempt to acquire a read lock
103            if s & 1 == 0 {
104                match self.state.compare_exchange_weak(
105                    s,
106                    s + 2,
107                    Ordering::Acquire,
108                    Ordering::Relaxed,
109                ) {
110                    Ok(_) => return,
111                    // Someone else beat us to the punch and acquired a lock
112                    Err(e) => s = e,
113                }
114            }
115            // If an exclusive lock is held/being acquired, loop until the lock state changes
116            // at which point, try to acquire the lock again
117            if s & 1 == 1 {
118                loop {
119                    let next = self.state.load(Ordering::Relaxed);
120                    if s == next {
121                        hint::spin_loop();
122                        continue;
123                    } else {
124                        s = next;
125                        break;
126                    }
127                }
128            }
129        }
130    }
131
132    /// The operation invoked when calling `RwLock::try_read`, returns whether or not the
133    /// lock was acquired
134    fn try_lock_shared(&self) -> bool {
135        let s = self.state.load(Ordering::Relaxed);
136        if s & 1 == 0 {
137            self.state
138                .compare_exchange_weak(s, s + 2, Ordering::Acquire, Ordering::Relaxed)
139                .is_ok()
140        } else {
141            false
142        }
143    }
144
145    /// The operation invoked when dropping a `RwLockReadGuard`
146    unsafe fn unlock_shared(&self) {
147        if self.state.fetch_sub(2, Ordering::Release) == 3 {
148            // The lock is being exclusively acquired, and we're the last shared acquisition
149            // to be released, so wake the writer by incrementing the wake counter
150            self.writer_wake_counter.fetch_add(1, Ordering::Release);
151        }
152    }
153
154    /// The operation invoked when calling `RwLock::write`, blocks the caller until acquired
155    fn lock_exclusive(&self) {
156        let mut s = self.state.load(Ordering::Relaxed);
157        loop {
158            // Attempt to acquire the lock immediately, or complete acquistion of the lock
159            // if we're continuing the loop after acquiring the exclusive bit. If another
160            // thread acquired it first, we race to be the first thread to acquire it once
161            // released, by busy looping here.
162            if s <= 1 {
163                match self.state.compare_exchange(
164                    s,
165                    usize::MAX,
166                    Ordering::Acquire,
167                    Ordering::Relaxed,
168                ) {
169                    Ok(_) => return,
170                    Err(e) => {
171                        s = e;
172                        hint::spin_loop();
173                        continue;
174                    },
175                }
176            }
177
178            // Only shared locks have been acquired, attempt to acquire the exclusive bit,
179            // which will prevent further shared locks from being acquired. It does not
180            // in and of itself grant us exclusive access however.
181            if s & 1 == 0
182                && let Err(e) =
183                    self.state.compare_exchange(s, s + 1, Ordering::Relaxed, Ordering::Relaxed)
184            {
185                // The lock state has changed before we could acquire the exclusive bit,
186                // update our view of the lock state and try again
187                s = e;
188                continue;
189            }
190
191            // We've acquired the exclusive bit, now we need to busy wait until all shared
192            // acquisitions are released.
193            let w = self.writer_wake_counter.load(Ordering::Acquire);
194            s = self.state.load(Ordering::Relaxed);
195
196            // "Park" the thread here (by busy looping), until the release of the last shared
197            // lock, which is communicated to us by it incrementing the wake counter.
198            if s >= 2 {
199                while self.writer_wake_counter.load(Ordering::Acquire) == w {
200                    hint::spin_loop();
201                }
202                s = self.state.load(Ordering::Relaxed);
203            }
204
205            // All shared locks have been released, go back to the top and try to complete
206            // acquisition of exclusive access.
207        }
208    }
209
210    /// The operation invoked when calling `RwLock::try_write`, returns whether or not the
211    /// lock was acquired
212    fn try_lock_exclusive(&self) -> bool {
213        let s = self.state.load(Ordering::Relaxed);
214        if s <= 1 {
215            self.state
216                .compare_exchange(s, usize::MAX, Ordering::Acquire, Ordering::Relaxed)
217                .is_ok()
218        } else {
219            false
220        }
221    }
222
223    /// The operation invoked when dropping a `RwLockWriteGuard`
224    unsafe fn unlock_exclusive(&self) {
225        // Infallible, as we hold an exclusive lock
226        //
227        // Note the use of `Release` ordering here, which ensures any loads of the lock state
228        // by other threads, are ordered after this store.
229        self.state.store(0, Ordering::Release);
230        // This fetch_add isn't important for signaling purposes, however it serves a key
231        // purpose, in that it imposes a memory ordering on any loads of this field that
232        // have an `Acquire` ordering, i.e. they will read the value stored here. Without
233        // a `Release` store, loads/stores of this field could be reordered relative to
234        // each other.
235        self.writer_wake_counter.fetch_add(1, Ordering::Release);
236    }
237}
238
239#[cfg(all(loom, test))]
240mod test {
241    use alloc::vec::Vec;
242
243    use loom::{model::Builder, sync::Arc};
244
245    use super::{RwLock, Spinlock};
246
247    #[test]
248    fn test_rwlock_loom() {
249        let mut builder = Builder::default();
250        builder.max_duration = Some(std::time::Duration::from_secs(60));
251        builder.log = true;
252        builder.check(|| {
253            let raw_rwlock = Spinlock::new();
254            let n = Arc::new(RwLock::from_raw(raw_rwlock, 0usize));
255            let mut readers = Vec::new();
256            let mut writers = Vec::new();
257
258            let num_readers = 2;
259            let num_writers = 2;
260            let num_iterations = 2;
261
262            // Readers should never observe a non-zero value
263            for _ in 0..num_readers {
264                let n0 = n.clone();
265                let t = loom::thread::spawn(move || {
266                    for _ in 0..num_iterations {
267                        let guard = n0.read();
268                        assert_eq!(*guard, 0);
269                    }
270                });
271
272                readers.push(t);
273            }
274
275            // Writers should never observe a non-zero value once they've
276            // acquired the lock, and should never observe a value > 1
277            // while holding the lock
278            for _ in 0..num_writers {
279                let n0 = n.clone();
280                let t = loom::thread::spawn(move || {
281                    for _ in 0..num_iterations {
282                        let mut guard = n0.write();
283                        assert_eq!(*guard, 0);
284                        *guard += 1;
285                        assert_eq!(*guard, 1);
286                        *guard -= 1;
287                        assert_eq!(*guard, 0);
288                    }
289                });
290
291                writers.push(t);
292            }
293
294            for t in readers {
295                t.join().unwrap();
296            }
297
298            for t in writers {
299                t.join().unwrap();
300            }
301        })
302    }
303}