kaspa_utils/sync/
rwlock.rs1use super::semaphore::Semaphore;
2use std::sync::Arc;
3
4pub struct RfRwLock {
11 ll_sem: Semaphore,
13}
14
15impl Default for RfRwLock {
16 fn default() -> Self {
17 Self::new()
18 }
19}
20
21impl RfRwLock {
22 pub fn new() -> Self {
23 Self { ll_sem: Semaphore::new(Semaphore::MAX_PERMITS) }
24 }
25
26 pub async fn read(&self) -> RfRwLockReadGuard<'_> {
27 self.ll_sem.acquire(1).await;
28 RfRwLockReadGuard(self)
29 }
30
31 pub fn blocking_read(&self) -> RfRwLockReadGuard<'_> {
32 self.ll_sem.blocking_acquire(1);
33 RfRwLockReadGuard(self)
34 }
35
36 pub async fn read_owned(self: Arc<Self>) -> RfRwLockOwnedReadGuard {
37 self.ll_sem.acquire(1).await;
38 RfRwLockOwnedReadGuard(self)
39 }
40
41 pub async fn write(&self) -> RfRwLockWriteGuard<'_> {
42 self.ll_sem.acquire(Semaphore::MAX_PERMITS).await;
46 RfRwLockWriteGuard(self)
47 }
48
49 pub fn blocking_write(&self) -> RfRwLockWriteGuard<'_> {
50 self.ll_sem.blocking_acquire(Semaphore::MAX_PERMITS);
51 RfRwLockWriteGuard(self)
52 }
53
54 pub async fn write_owned(self: Arc<Self>) -> RfRwLockOwnedWriteGuard {
55 self.ll_sem.acquire(Semaphore::MAX_PERMITS).await;
56 RfRwLockOwnedWriteGuard(self)
57 }
58
59 fn release_read(&self) {
60 self.ll_sem.release(1);
61 }
62
63 fn release_write(&self) {
64 self.ll_sem.release(Semaphore::MAX_PERMITS);
65 }
66
67 fn blocking_yield_writer(&self) {
68 self.ll_sem.blocking_yield(Semaphore::MAX_PERMITS);
69 }
70}
71
72pub struct RfRwLockReadGuard<'a>(&'a RfRwLock);
73
74impl Drop for RfRwLockReadGuard<'_> {
75 fn drop(&mut self) {
76 self.0.release_read();
77 }
78}
79
80pub struct RfRwLockOwnedReadGuard(Arc<RfRwLock>);
81
82impl Drop for RfRwLockOwnedReadGuard {
83 fn drop(&mut self) {
84 self.0.release_read();
85 }
86}
87
88pub struct RfRwLockWriteGuard<'a>(&'a RfRwLock);
89
90impl Drop for RfRwLockWriteGuard<'_> {
91 fn drop(&mut self) {
92 self.0.release_write();
93 }
94}
95
96impl RfRwLockWriteGuard<'_> {
97 pub fn blocking_yield(&mut self) {
100 self.0.blocking_yield_writer();
101 }
102}
103
104pub struct RfRwLockOwnedWriteGuard(Arc<RfRwLock>);
105
106impl Drop for RfRwLockOwnedWriteGuard {
107 fn drop(&mut self) {
108 self.0.release_write();
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use std::{
116 sync::atomic::{AtomicBool, Ordering::SeqCst},
117 time::Duration,
118 };
119 use tokio::{sync::oneshot, time::sleep, time::timeout};
120
121 const ACQUIRE_TIMEOUT: Duration = Duration::from_secs(5);
122
123 #[tokio::test]
124 async fn test_writer_reentrance() {
125 for i in 0..16 {
126 let l = Arc::new(RfRwLock::new());
127 let (tx, rx) = oneshot::channel();
128 let l_clone = l.clone();
129 let h = std::thread::spawn(move || {
130 let mut write = l_clone.blocking_write();
131 tx.send(()).unwrap();
132 for _ in 0..10 {
133 std::thread::sleep(Duration::from_millis(2));
134 write.blocking_yield();
135 }
136 });
137 rx.await.unwrap();
138 let read = timeout(Duration::from_millis(18), l.read()).await.unwrap_or_else(|_| panic!("failed at iteration {i}"));
141 drop(read);
142 timeout(Duration::from_millis(100), tokio::task::spawn_blocking(move || h.join())).await.unwrap().unwrap().unwrap();
143 }
144 }
145
146 #[tokio::test]
147 async fn test_readers_preferred() {
148 let l = Arc::new(RfRwLock::new());
149 let read1 = l.read().await;
150 let read2 = l.read().await;
151 let read3 = l.read().await;
152
153 let (tx, rx) = oneshot::channel();
154 let (tx_back, rx_back) = oneshot::channel();
155 let l_clone = l.clone();
156 let h = tokio::spawn(async move {
157 let fut = l_clone.write();
158 tx.send(()).unwrap();
159 let _write = fut.await;
160 println!("writer acquired");
161 rx_back.await.unwrap();
162 println!("releasing writer");
163 });
164
165 rx.await.unwrap();
167
168 let read4 = timeout(ACQUIRE_TIMEOUT, l.read()).await.unwrap();
169 let read5 = timeout(ACQUIRE_TIMEOUT, l.read()).await.unwrap();
170
171 drop(read1);
172 drop(read2);
173 drop(read3);
174 drop(read4);
175 drop(read5);
176 println!("dropped all readers");
177
178 let f = Arc::new(AtomicBool::new(false));
179 let f_clone = f.clone();
180 let l_clone = l.clone();
181 tokio::spawn(async move {
182 let _read = l_clone.read().await;
183 assert!(f_clone.load(SeqCst), "reader acquired before writer release");
184 println!("late reader acquired");
185 });
186
187 sleep(Duration::from_secs(1)).await;
188 f.store(true, SeqCst);
189 tx_back.send(()).unwrap();
190 timeout(ACQUIRE_TIMEOUT, h).await.unwrap().unwrap();
191 }
192}