1use crate::inner::OrderedInner;
2use std::fmt::Debug;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9#[derive(Debug)]
15pub struct OrderedRwLock<T: ?Sized> {
16 readers: AtomicUsize,
17 inner: OrderedInner<T>,
18}
19
20impl<T> OrderedRwLock<T> {
21 #[inline]
23 pub const fn new(data: T) -> OrderedRwLock<T> {
24 OrderedRwLock {
25 readers: AtomicUsize::new(0),
26 inner: OrderedInner::new(data),
27 }
28 }
29}
30
31impl<T: ?Sized> OrderedRwLock<T> {
32 #[inline]
50 pub fn write(&self) -> OrderedRwLockWriteGuardFuture<T> {
51 OrderedRwLockWriteGuardFuture {
52 mutex: &self,
53 id: self.inner.generate_id(),
54 is_realized: false,
55 }
56 }
57
58 #[inline]
77 pub fn write_owned(self: &Arc<Self>) -> OrderedRwLockWriteOwnedGuardFuture<T> {
78 OrderedRwLockWriteOwnedGuardFuture {
79 mutex: self.clone(),
80 id: self.inner.generate_id(),
81 is_realized: false,
82 }
83 }
84
85 #[inline]
103 pub fn read(&self) -> OrderedRwLockReadGuardFuture<T> {
104 OrderedRwLockReadGuardFuture {
105 mutex: &self,
106 id: self.inner.generate_id(),
107 is_realized: false,
108 }
109 }
110
111 #[inline]
130 pub fn read_owned(self: &Arc<Self>) -> OrderedRwLockReadOwnedGuardFuture<T> {
131 OrderedRwLockReadOwnedGuardFuture {
132 mutex: self.clone(),
133 id: self.inner.generate_id(),
134 is_realized: false,
135 }
136 }
137
138 #[inline]
139 fn unlock_reader(&self) {
140 self.readers.fetch_sub(1, Ordering::Release);
141 self.inner.unlock()
142 }
143
144 #[inline]
145 fn add_reader(&self) {
146 self.readers.fetch_add(1, Ordering::Release);
147 }
148
149 #[inline]
150 pub fn try_acquire_reader(&self, id: usize) -> bool {
151 id == self.inner.current.load(Ordering::Acquire) + self.readers.load(Ordering::Acquire)
152 }
153}
154
155#[derive(Debug)]
159pub struct OrderedRwLockWriteGuard<'a, T: ?Sized> {
160 mutex: &'a OrderedRwLock<T>,
161}
162
163#[derive(Debug)]
164pub struct OrderedRwLockWriteGuardFuture<'a, T: ?Sized> {
165 mutex: &'a OrderedRwLock<T>,
166 id: usize,
167 is_realized: bool,
168}
169
170#[derive(Debug)]
175pub struct OrderedRwLockWriteOwnedGuard<T: ?Sized> {
176 mutex: Arc<OrderedRwLock<T>>,
177}
178
179#[derive(Debug)]
180pub struct OrderedRwLockWriteOwnedGuardFuture<T: ?Sized> {
181 mutex: Arc<OrderedRwLock<T>>,
182 id: usize,
183 is_realized: bool,
184}
185
186#[derive(Debug)]
190pub struct OrderedRwLockReadGuard<'a, T: ?Sized> {
191 mutex: &'a OrderedRwLock<T>,
192}
193
194#[derive(Debug)]
195pub struct OrderedRwLockReadGuardFuture<'a, T: ?Sized> {
196 mutex: &'a OrderedRwLock<T>,
197 id: usize,
198 is_realized: bool,
199}
200
201#[derive(Debug)]
206pub struct OrderedRwLockReadOwnedGuard<T: ?Sized> {
207 mutex: Arc<OrderedRwLock<T>>,
208}
209
210#[derive(Debug)]
211pub struct OrderedRwLockReadOwnedGuardFuture<T: ?Sized> {
212 mutex: Arc<OrderedRwLock<T>>,
213 id: usize,
214 is_realized: bool,
215}
216
217impl<'a, T: ?Sized> Future for OrderedRwLockWriteGuardFuture<'a, T> {
218 type Output = OrderedRwLockWriteGuard<'a, T>;
219
220 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
221 if self.mutex.inner.try_acquire(self.id) {
222 self.is_realized = true;
223 Poll::Ready(OrderedRwLockWriteGuard { mutex: self.mutex })
224 } else {
225 self.mutex.inner.store_waker(cx.waker());
226 Poll::Pending
227 }
228 }
229}
230
231impl<T: ?Sized> Future for OrderedRwLockWriteOwnedGuardFuture<T> {
232 type Output = OrderedRwLockWriteOwnedGuard<T>;
233
234 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
235 if self.mutex.inner.try_acquire(self.id) {
236 self.is_realized = true;
237 Poll::Ready(OrderedRwLockWriteOwnedGuard {
238 mutex: self.mutex.clone(),
239 })
240 } else {
241 self.mutex.inner.store_waker(cx.waker());
242 Poll::Pending
243 }
244 }
245}
246
247impl<'a, T: ?Sized> Future for OrderedRwLockReadGuardFuture<'a, T> {
248 type Output = OrderedRwLockReadGuard<'a, T>;
249
250 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
251 if self.mutex.try_acquire_reader(self.id) {
252 self.is_realized = true;
253 self.mutex.add_reader();
254 Poll::Ready(OrderedRwLockReadGuard { mutex: &self.mutex })
255 } else {
256 self.mutex.inner.store_waker(cx.waker());
257 Poll::Pending
258 }
259 }
260}
261
262impl<T: ?Sized> Future for OrderedRwLockReadOwnedGuardFuture<T> {
263 type Output = OrderedRwLockReadOwnedGuard<T>;
264
265 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
266 if self.mutex.try_acquire_reader(self.id) {
267 self.is_realized = true;
268 self.mutex.add_reader();
269 Poll::Ready(OrderedRwLockReadOwnedGuard {
270 mutex: self.mutex.clone(),
271 })
272 } else {
273 self.mutex.inner.store_waker(cx.waker());
274 Poll::Pending
275 }
276 }
277}
278
279crate::impl_send_sync_rwlock!(
280 OrderedRwLock,
281 OrderedRwLockReadGuard,
282 OrderedRwLockReadOwnedGuard,
283 OrderedRwLockWriteGuard,
284 OrderedRwLockWriteOwnedGuard
285);
286
287crate::impl_deref_mut!(OrderedRwLockWriteGuard, 'a);
288crate::impl_deref_mut!(OrderedRwLockWriteOwnedGuard);
289crate::impl_deref!(OrderedRwLockReadGuard, 'a);
290crate::impl_deref!(OrderedRwLockReadOwnedGuard);
291
292crate::impl_drop_guard!(OrderedRwLockWriteGuard, 'a, unlock);
293crate::impl_drop_guard!(OrderedRwLockWriteOwnedGuard, unlock);
294crate::impl_drop_guard_self!(OrderedRwLockReadGuard, 'a, unlock_reader);
295crate::impl_drop_guard_self!(OrderedRwLockReadOwnedGuard, unlock_reader);
296
297crate::impl_drop_guard_future!(OrderedRwLockWriteGuardFuture, 'a, unlock);
298crate::impl_drop_guard_future!(OrderedRwLockWriteOwnedGuardFuture, unlock);
299crate::impl_drop_guard_future!(OrderedRwLockReadGuardFuture, 'a, unlock);
300crate::impl_drop_guard_future!(OrderedRwLockReadOwnedGuardFuture, unlock);
301
302#[cfg(test)]
303mod tests {
304 use crate::rwlock_ordered::{
305 OrderedRwLock, OrderedRwLockReadGuard, OrderedRwLockWriteGuard,
306 OrderedRwLockWriteOwnedGuard,
307 };
308 use futures::executor::block_on;
309 use futures::{FutureExt, StreamExt, TryStreamExt};
310 use std::ops::AddAssign;
311 use std::sync::atomic::AtomicUsize;
312 use std::sync::Arc;
313 use tokio::time::{sleep, Duration};
314
315 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
316 async fn test_mutex() {
317 let c = OrderedRwLock::new(0);
318
319 futures::stream::iter(0..10000)
320 .for_each_concurrent(None, |_| async {
321 let mut co: OrderedRwLockWriteGuard<i32> = c.write().await;
322 *co += 1;
323 })
324 .await;
325
326 let co = c.write().await;
327 assert_eq!(*co, 10000)
328 }
329
330 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
331 async fn test_mutex_delay() {
332 let expected_result = 100;
333 let c = OrderedRwLock::new(0);
334
335 futures::stream::iter(0..expected_result)
336 .then(|i| c.write().map(move |co| (i, co)))
337 .for_each_concurrent(None, |(i, mut co)| async move {
338 sleep(Duration::from_millis(expected_result - i)).await;
339 *co += 1;
340 })
341 .await;
342
343 let co = c.write().await;
344 assert_eq!(*co, expected_result)
345 }
346
347 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
348 async fn test_owned_mutex() {
349 let c = Arc::new(OrderedRwLock::new(0));
350
351 futures::stream::iter(0..10000)
352 .for_each_concurrent(None, |_| async {
353 let mut co: OrderedRwLockWriteOwnedGuard<i32> = c.write_owned().await;
354 *co += 1;
355 })
356 .await;
357
358 let co = c.write_owned().await;
359 assert_eq!(*co, 10000)
360 }
361
362 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
363 async fn test_container() {
364 let c = OrderedRwLock::new(String::from("lol"));
365
366 let mut co: OrderedRwLockWriteGuard<String> = c.write().await;
367 co.add_assign("lol");
368
369 assert_eq!(*co, "lollol");
370 }
371
372 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
373 async fn test_overflow() {
374 let mut c = OrderedRwLock::new(String::from("lol"));
375
376 c.inner.state = AtomicUsize::new(usize::max_value());
377 c.inner.current = AtomicUsize::new(usize::max_value());
378
379 let mut co: OrderedRwLockWriteGuard<String> = c.write().await;
380 co.add_assign("lol");
381
382 assert_eq!(*co, "lollol");
383 }
384
385 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
386 async fn test_timeout() {
387 let c = OrderedRwLock::new(String::from("lol"));
388
389 let co: OrderedRwLockWriteGuard<String> = c.write().await;
390
391 futures::stream::iter(0..10000i32)
392 .then(|_| tokio::time::timeout(Duration::from_nanos(1), c.write()))
393 .try_for_each_concurrent(None, |_c| futures::future::ok(()))
394 .await
395 .expect_err("timout must be");
396
397 drop(co);
398
399 let mut co: OrderedRwLockWriteGuard<String> = c.write().await;
400 co.add_assign("lol");
401
402 assert_eq!(*co, "lollol");
403 }
404
405 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
406 async fn test_concurrent_reading() {
407 let c = OrderedRwLock::new(String::from("lol"));
408
409 let co: OrderedRwLockReadGuard<String> = c.read().await;
410
411 futures::stream::iter(0..10000i32)
412 .then(|_| c.read())
413 .inspect(|c| assert_eq!(*co, **c))
414 .for_each_concurrent(None, |_c| futures::future::ready(()))
415 .await;
416
417 assert!(matches!(
418 tokio::time::timeout(Duration::from_millis(1), c.write()).await,
419 Err(_)
420 ));
421
422 let co2: OrderedRwLockReadGuard<String> = c.read().await;
423 assert_eq!(*co, *co2);
424 }
425
426 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
427 async fn test_concurrent_reading_writing() {
428 let c = OrderedRwLock::new(String::from("lol"));
429
430 let co: OrderedRwLockReadGuard<String> = c.read().await;
431 let co2: OrderedRwLockReadGuard<String> = c.read().await;
432 assert_eq!(*co, *co2);
433
434 drop(co);
435 drop(co2);
436
437 let mut co: OrderedRwLockWriteGuard<String> = c.write().await;
438
439 assert!(matches!(
440 tokio::time::timeout(Duration::from_millis(1), c.read()).await,
441 Err(_)
442 ));
443
444 *co += "lol";
445
446 drop(co);
447
448 let co: OrderedRwLockReadGuard<String> = c.read().await;
449 let co2: OrderedRwLockReadGuard<String> = c.read().await;
450 assert_eq!(*co, "lollol");
451 assert_eq!(*co, *co2);
452 }
453
454 #[test]
455 fn multithreading_test() {
456 let num = 100;
457 let mutex = Arc::new(OrderedRwLock::new(0));
458 let ths: Vec<_> = (0..num)
459 .map(|i| {
460 let mutex = mutex.clone();
461 std::thread::spawn(move || {
462 block_on(async {
463 if i % 2 == 0 {
464 let mut lock = mutex.write().await;
465 *lock += 1;
466 drop(lock);
467 } else {
468 let _lock = mutex.read().await;
469 }
470 })
471 })
472 })
473 .collect();
474
475 for thread in ths {
476 thread.join().unwrap();
477 }
478
479 block_on(async {
480 let lock = mutex.read().await;
481 assert_eq!(num / 2, *lock)
482 })
483 }
484}