kaspa_utils/sync/
rwlock.rs

1use super::semaphore::Semaphore;
2use std::sync::Arc;
3
4/// Readers-first Reader-writer Lock. If the lock is acquired by readers, then additional readers
5/// will always be able to acquire the lock as well even if a writer is already in the queue. Note
6/// that this makes it safe to make recursive read calls.
7///
8/// We currently only use this lock over an empty tuple, however it can easily contain data by
9/// using `UnsafeCell<T>` and passing it to the various guards with or without mutable access
10pub struct RfRwLock {
11    // The low-level "non-fair" semaphore used to prioritize readers
12    ll_sem: Semaphore,
13}
14
15impl Default for RfRwLock {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl RfRwLock {
22    pub fn new() -> Self {
23        Self { ll_sem: Semaphore::new(Semaphore::MAX_PERMITS) }
24    }
25
26    pub async fn read(&self) -> RfRwLockReadGuard<'_> {
27        self.ll_sem.acquire(1).await;
28        RfRwLockReadGuard(self)
29    }
30
31    pub fn blocking_read(&self) -> RfRwLockReadGuard<'_> {
32        self.ll_sem.blocking_acquire(1);
33        RfRwLockReadGuard(self)
34    }
35
36    pub async fn read_owned(self: Arc<Self>) -> RfRwLockOwnedReadGuard {
37        self.ll_sem.acquire(1).await;
38        RfRwLockOwnedReadGuard(self)
39    }
40
41    pub async fn write(&self) -> RfRwLockWriteGuard<'_> {
42        // Writes acquire all possible permits, hence they ensure exclusiveness. On the other hand, this allows
43        // late readers to get in front of them since readers request only a single permit and the semaphore is
44        // non-fair
45        self.ll_sem.acquire(Semaphore::MAX_PERMITS).await;
46        RfRwLockWriteGuard(self)
47    }
48
49    pub fn blocking_write(&self) -> RfRwLockWriteGuard<'_> {
50        self.ll_sem.blocking_acquire(Semaphore::MAX_PERMITS);
51        RfRwLockWriteGuard(self)
52    }
53
54    pub async fn write_owned(self: Arc<Self>) -> RfRwLockOwnedWriteGuard {
55        self.ll_sem.acquire(Semaphore::MAX_PERMITS).await;
56        RfRwLockOwnedWriteGuard(self)
57    }
58
59    fn release_read(&self) {
60        self.ll_sem.release(1);
61    }
62
63    fn release_write(&self) {
64        self.ll_sem.release(Semaphore::MAX_PERMITS);
65    }
66
67    fn blocking_yield_writer(&self) {
68        self.ll_sem.blocking_yield(Semaphore::MAX_PERMITS);
69    }
70}
71
72pub struct RfRwLockReadGuard<'a>(&'a RfRwLock);
73
74impl Drop for RfRwLockReadGuard<'_> {
75    fn drop(&mut self) {
76        self.0.release_read();
77    }
78}
79
80pub struct RfRwLockOwnedReadGuard(Arc<RfRwLock>);
81
82impl Drop for RfRwLockOwnedReadGuard {
83    fn drop(&mut self) {
84        self.0.release_read();
85    }
86}
87
88pub struct RfRwLockWriteGuard<'a>(&'a RfRwLock);
89
90impl Drop for RfRwLockWriteGuard<'_> {
91    fn drop(&mut self) {
92        self.0.release_write();
93    }
94}
95
96impl RfRwLockWriteGuard<'_> {
97    /// Releases and recaptures the write lock. Makes sure that other pending readers/writers get a
98    /// chance to capture the lock before this thread does so.
99    pub fn blocking_yield(&mut self) {
100        self.0.blocking_yield_writer();
101    }
102}
103
104pub struct RfRwLockOwnedWriteGuard(Arc<RfRwLock>);
105
106impl Drop for RfRwLockOwnedWriteGuard {
107    fn drop(&mut self) {
108        self.0.release_write();
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use std::{
116        sync::atomic::{AtomicBool, Ordering::SeqCst},
117        time::Duration,
118    };
119    use tokio::{sync::oneshot, time::sleep, time::timeout};
120
121    const ACQUIRE_TIMEOUT: Duration = Duration::from_secs(5);
122
123    #[tokio::test]
124    async fn test_writer_reentrance() {
125        for i in 0..16 {
126            let l = Arc::new(RfRwLock::new());
127            let (tx, rx) = oneshot::channel();
128            let l_clone = l.clone();
129            let h = std::thread::spawn(move || {
130                let mut write = l_clone.blocking_write();
131                tx.send(()).unwrap();
132                for _ in 0..10 {
133                    std::thread::sleep(Duration::from_millis(2));
134                    write.blocking_yield();
135                }
136            });
137            rx.await.unwrap();
138            // Make sure the reader acquires the lock during writer yields. We give the test a few chances to acquire
139            // in order to make sure it passes also in slow CI environments where the OS thread-scheduler might take its time
140            let read = timeout(Duration::from_millis(18), l.read()).await.unwrap_or_else(|_| panic!("failed at iteration {i}"));
141            drop(read);
142            timeout(Duration::from_millis(100), tokio::task::spawn_blocking(move || h.join())).await.unwrap().unwrap().unwrap();
143        }
144    }
145
146    #[tokio::test]
147    async fn test_readers_preferred() {
148        let l = Arc::new(RfRwLock::new());
149        let read1 = l.read().await;
150        let read2 = l.read().await;
151        let read3 = l.read().await;
152
153        let (tx, rx) = oneshot::channel();
154        let (tx_back, rx_back) = oneshot::channel();
155        let l_clone = l.clone();
156        let h = tokio::spawn(async move {
157            let fut = l_clone.write();
158            tx.send(()).unwrap();
159            let _write = fut.await;
160            println!("writer acquired");
161            rx_back.await.unwrap();
162            println!("releasing writer");
163        });
164
165        // Wait for the writer to request writing before registering more readers
166        rx.await.unwrap();
167
168        let read4 = timeout(ACQUIRE_TIMEOUT, l.read()).await.unwrap();
169        let read5 = timeout(ACQUIRE_TIMEOUT, l.read()).await.unwrap();
170
171        drop(read1);
172        drop(read2);
173        drop(read3);
174        drop(read4);
175        drop(read5);
176        println!("dropped all readers");
177
178        let f = Arc::new(AtomicBool::new(false));
179        let f_clone = f.clone();
180        let l_clone = l.clone();
181        tokio::spawn(async move {
182            let _read = l_clone.read().await;
183            assert!(f_clone.load(SeqCst), "reader acquired before writer release");
184            println!("late reader acquired");
185        });
186
187        sleep(Duration::from_secs(1)).await;
188        f.store(true, SeqCst);
189        tx_back.send(()).unwrap();
190        timeout(ACQUIRE_TIMEOUT, h).await.unwrap().unwrap();
191    }
192}