1use crate::inner::Inner;
2use std::fmt::Debug;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9#[derive(Debug)]
12pub struct RwLock<T: ?Sized> {
13 readers: AtomicUsize,
14 inner: Inner<T>,
15}
16
17impl<T> RwLock<T> {
18 #[inline]
20 pub const fn new(data: T) -> RwLock<T> {
21 RwLock {
22 readers: AtomicUsize::new(0),
23 inner: Inner::new(data),
24 }
25 }
26}
27
28impl<T: ?Sized> RwLock<T> {
29 #[inline]
47 pub fn write(&self) -> RwLockWriteGuardFuture<T> {
48 RwLockWriteGuardFuture {
49 mutex: &self,
50 is_realized: false,
51 }
52 }
53
54 #[inline]
73 pub fn write_owned(self: &Arc<Self>) -> RwLockWriteOwnedGuardFuture<T> {
74 RwLockWriteOwnedGuardFuture {
75 mutex: self.clone(),
76 is_realized: false,
77 }
78 }
79
80 #[inline]
98 pub fn read(&self) -> RwLockReadGuardFuture<T> {
99 RwLockReadGuardFuture {
100 mutex: &self,
101 is_realized: false,
102 }
103 }
104
105 #[inline]
124 pub fn read_owned(self: &Arc<Self>) -> RwLockReadOwnedGuardFuture<T> {
125 RwLockReadOwnedGuardFuture {
126 mutex: self.clone(),
127 is_realized: false,
128 }
129 }
130
131 #[inline]
132 fn unlock_reader(&self) {
133 if self.readers.fetch_sub(1, Ordering::Release) == 1 {
134 self.inner.unlock()
135 }
136 }
137
138 #[inline]
139 fn add_reader(&self) {
140 self.readers.fetch_add(1, Ordering::Release);
141 }
142
143 #[inline]
144 fn try_acquire_reader(&self) -> bool {
145 self.readers.load(Ordering::Acquire) > 0 || self.inner.try_acquire()
146 }
147}
148
149#[derive(Debug)]
153pub struct RwLockWriteGuard<'a, T: ?Sized> {
154 mutex: &'a RwLock<T>,
155}
156
157#[derive(Debug)]
158pub struct RwLockWriteGuardFuture<'a, T: ?Sized> {
159 mutex: &'a RwLock<T>,
160 is_realized: bool,
161}
162
163#[derive(Debug)]
168pub struct RwLockWriteOwnedGuard<T: ?Sized> {
169 mutex: Arc<RwLock<T>>,
170}
171
172#[derive(Debug)]
173pub struct RwLockWriteOwnedGuardFuture<T: ?Sized> {
174 mutex: Arc<RwLock<T>>,
175 is_realized: bool,
176}
177
178#[derive(Debug)]
182pub struct RwLockReadGuard<'a, T: ?Sized> {
183 mutex: &'a RwLock<T>,
184}
185
186#[derive(Debug)]
187pub struct RwLockReadGuardFuture<'a, T: ?Sized> {
188 mutex: &'a RwLock<T>,
189 is_realized: bool,
190}
191
192#[derive(Debug)]
197pub struct RwLockReadOwnedGuard<T: ?Sized> {
198 mutex: Arc<RwLock<T>>,
199}
200
201#[derive(Debug)]
202pub struct RwLockReadOwnedGuardFuture<T: ?Sized> {
203 mutex: Arc<RwLock<T>>,
204 is_realized: bool,
205}
206
207impl<'a, T: ?Sized> Future for RwLockWriteGuardFuture<'a, T> {
208 type Output = RwLockWriteGuard<'a, T>;
209
210 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
211 if self.mutex.inner.try_acquire() {
212 self.is_realized = true;
213 Poll::Ready(RwLockWriteGuard { mutex: self.mutex })
214 } else {
215 self.mutex.inner.store_waker(cx.waker());
216 Poll::Pending
217 }
218 }
219}
220
221impl<T: ?Sized> Future for RwLockWriteOwnedGuardFuture<T> {
222 type Output = RwLockWriteOwnedGuard<T>;
223
224 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
225 if self.mutex.inner.try_acquire() {
226 self.is_realized = true;
227 Poll::Ready(RwLockWriteOwnedGuard {
228 mutex: self.mutex.clone(),
229 })
230 } else {
231 self.mutex.inner.store_waker(cx.waker());
232 Poll::Pending
233 }
234 }
235}
236
237impl<'a, T: ?Sized> Future for RwLockReadGuardFuture<'a, T> {
238 type Output = RwLockReadGuard<'a, T>;
239
240 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
241 if self.mutex.try_acquire_reader() {
242 self.is_realized = true;
243 self.mutex.add_reader();
244 Poll::Ready(RwLockReadGuard { mutex: self.mutex })
245 } else {
246 self.mutex.inner.store_waker(cx.waker());
247 Poll::Pending
248 }
249 }
250}
251
252impl<T: ?Sized> Future for RwLockReadOwnedGuardFuture<T> {
253 type Output = RwLockReadOwnedGuard<T>;
254
255 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
256 if self.mutex.try_acquire_reader() {
257 self.is_realized = true;
258 self.mutex.add_reader();
259 Poll::Ready(RwLockReadOwnedGuard {
260 mutex: self.mutex.clone(),
261 })
262 } else {
263 self.mutex.inner.store_waker(cx.waker());
264 Poll::Pending
265 }
266 }
267}
268
269crate::impl_send_sync_rwlock!(
270 RwLock,
271 RwLockReadGuard,
272 RwLockReadOwnedGuard,
273 RwLockWriteGuard,
274 RwLockWriteOwnedGuard
275);
276
277crate::impl_deref_mut!(RwLockWriteGuard, 'a);
278crate::impl_deref_mut!(RwLockWriteOwnedGuard);
279crate::impl_deref!(RwLockReadGuard, 'a);
280crate::impl_deref!(RwLockReadOwnedGuard);
281
282crate::impl_drop_guard!(RwLockWriteGuard, 'a, unlock);
283crate::impl_drop_guard!(RwLockWriteOwnedGuard, unlock);
284crate::impl_drop_guard_self!(RwLockReadGuard, 'a, unlock_reader);
285crate::impl_drop_guard_self!(RwLockReadOwnedGuard, unlock_reader);
286
287crate::impl_drop_guard_future!(RwLockWriteGuardFuture, 'a, unlock);
288crate::impl_drop_guard_future!(RwLockWriteOwnedGuardFuture, unlock);
289crate::impl_drop_guard_future!(RwLockReadGuardFuture, 'a, unlock);
290crate::impl_drop_guard_future!(RwLockReadOwnedGuardFuture, unlock);
291
292#[cfg(test)]
293mod tests {
294 use crate::rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard, RwLockWriteOwnedGuard};
295 use futures::executor::block_on;
296 use futures::{FutureExt, StreamExt, TryStreamExt};
297 use std::ops::AddAssign;
298 use std::sync::Arc;
299 use tokio::time::{sleep, Duration};
300
301 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
302 async fn test_mutex() {
303 let c = RwLock::new(0);
304
305 futures::stream::iter(0..10000)
306 .for_each_concurrent(None, |_| async {
307 let mut co: RwLockWriteGuard<i32> = c.write().await;
308 *co += 1;
309 })
310 .await;
311
312 let co = c.write().await;
313 assert_eq!(*co, 10000)
314 }
315
316 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
317 async fn test_mutex_delay() {
318 let expected_result = 100;
319 let c = RwLock::new(0);
320
321 futures::stream::iter(0..expected_result)
322 .then(|i| c.write().map(move |co| (i, co)))
323 .for_each_concurrent(None, |(i, mut co)| async move {
324 sleep(Duration::from_millis(expected_result - i)).await;
325 *co += 1;
326 })
327 .await;
328
329 let co = c.write().await;
330 assert_eq!(*co, expected_result)
331 }
332
333 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
334 async fn test_owned_mutex() {
335 let c = Arc::new(RwLock::new(0));
336
337 futures::stream::iter(0..10000)
338 .for_each_concurrent(None, |_| async {
339 let mut co: RwLockWriteOwnedGuard<i32> = c.write_owned().await;
340 *co += 1;
341 })
342 .await;
343
344 let co = c.write_owned().await;
345 assert_eq!(*co, 10000)
346 }
347
348 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
349 async fn test_container() {
350 let c = RwLock::new(String::from("lol"));
351
352 let mut co: RwLockWriteGuard<String> = c.write().await;
353 co.add_assign("lol");
354
355 assert_eq!(*co, "lollol");
356 }
357
358 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
359 async fn test_timeout() {
360 let c = RwLock::new(String::from("lol"));
361
362 let co: RwLockWriteGuard<String> = c.write().await;
363
364 futures::stream::iter(0..10000i32)
365 .then(|_| tokio::time::timeout(Duration::from_nanos(1), c.write()))
366 .try_for_each_concurrent(None, |_c| futures::future::ok(()))
367 .await
368 .expect_err("timout must be");
369
370 drop(co);
371
372 let mut co: RwLockWriteGuard<String> = c.write().await;
373 co.add_assign("lol");
374
375 assert_eq!(*co, "lollol");
376 }
377
378 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
379 async fn test_concurrent_reading() {
380 let c = RwLock::new(String::from("lol"));
381
382 let co: RwLockReadGuard<String> = c.read().await;
383
384 futures::stream::iter(0..10000i32)
385 .then(|_| c.read())
386 .inspect(|c| assert_eq!(*co, **c))
387 .for_each_concurrent(None, |_c| futures::future::ready(()))
388 .await;
389
390 assert!(matches!(
391 tokio::time::timeout(Duration::from_millis(1), c.write()).await,
392 Err(_)
393 ));
394
395 let co2: RwLockReadGuard<String> = c.read().await;
396 assert_eq!(*co, *co2);
397 }
398
399 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
400 async fn test_concurrent_reading_writing() {
401 let c = RwLock::new(String::from("lol"));
402
403 let co: RwLockReadGuard<String> = c.read().await;
404 let co2: RwLockReadGuard<String> = c.read().await;
405 assert_eq!(*co, *co2);
406
407 drop(co);
408 drop(co2);
409
410 let mut co: RwLockWriteGuard<String> = c.write().await;
411
412 assert!(matches!(
413 tokio::time::timeout(Duration::from_millis(1), c.read()).await,
414 Err(_)
415 ));
416
417 *co += "lol";
418
419 drop(co);
420
421 let co: RwLockReadGuard<String> = c.read().await;
422 let co2: RwLockReadGuard<String> = c.read().await;
423 assert_eq!(*co, "lollol");
424 assert_eq!(*co, *co2);
425 }
426
427 #[test]
428 fn multithreading_test() {
429 let num = 100;
430 let mutex = Arc::new(RwLock::new(0));
431 let ths: Vec<_> = (0..num)
432 .map(|i| {
433 let mutex = mutex.clone();
434 std::thread::spawn(move || {
435 block_on(async {
436 if i % 2 == 0 {
437 let mut lock = mutex.write().await;
438 *lock += 1;
439 drop(lock)
440 } else {
441 let lock1 = mutex.read().await;
442 let lock2 = mutex.read().await;
443 assert_eq!(*lock1, *lock2);
444 drop(lock1);
445 drop(lock2);
446 }
447 })
448 })
449 })
450 .collect();
451
452 for thread in ths {
453 thread.join().unwrap();
454 }
455
456 block_on(async {
457 let lock = mutex.read().await;
458 assert_eq!(num / 2, *lock)
459 })
460 }
461}