1use super::*;
2
3#[derive(Default)]
4struct RefCounter {
5 count: usize,
6 wakers: Vec<(*mut PollState, Waker)>,
7}
8
9unsafe impl Send for RefCounter {}
10unsafe impl Sync for RefCounter {}
11
12#[derive(Default)]
13pub struct RwLock<T> {
14 mutex: Mutex<T>,
15 readers: Map<T, RefCounter>,
16}
17
18impl<T: Eq + Hash + Copy + Unpin> RwLock<T> {
19 #[inline]
20 pub fn new() -> RwLock<T> {
21 RwLock {
22 mutex: Mutex::new(),
23 readers: new_map(),
24 }
25 }
26
27 pub async fn read(&self, num: T) -> ReadGuard<'_, T> {
28 self.mutex.until_unlocked(num).await;
29 self.readers
30 .lock()
31 .unwrap()
32 .entry(num)
33 .or_insert_with(RefCounter::default)
34 .count += 1;
35
36 ReadGuard {
37 num,
38 readers: &self.readers,
39 }
40 }
41
42 pub async fn write(&self, num: T) -> WriteGuard<'_, T> {
43 let guard = self.mutex.lock(num).await;
44 UntilAllReaderDropped {
45 num,
46 readers: &self.readers,
47 state: PollState::Init.into(),
48 }
49 .await;
50 guard
51 }
52}
53
54pub struct ReadGuard<'a, T: Eq + Hash + Copy> {
55 num: T,
56 readers: &'a Map<T, RefCounter>,
57}
58
59impl<'a, T: Eq + Hash + Copy> Drop for ReadGuard<'a, T> {
63 fn drop(&mut self) {
64 unsafe {
65 let mut map = self.readers.lock().unwrap();
66 let mut rc = map.get_mut(&self.num).unwrap_unchecked();
67 rc.count -= 1;
68 if rc.count == 0 {
69 let rc = map.remove(&self.num).unwrap_unchecked();
70 for (state, waker) in rc.wakers {
71 *state = PollState::Ready;
72 waker.wake();
73 }
74 }
75 }
76 }
77}
78
79struct UntilAllReaderDropped<'a, T> {
80 num: T,
81 state: UnsafeCell<PollState>,
82 readers: &'a Map<T, RefCounter>,
83}
84
85impl<T: Unpin + Hash + Eq + Copy> Future for UntilAllReaderDropped<'_, T> {
86 type Output = ();
87 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
88 let this = self.get_mut();
89 ret_fut!(*this.state.get(), {
90 let mut map = this.readers.lock().unwrap();
91 match map.get_mut(&this.num) {
92 Some(rc) => rc.wakers.push(( this.state.get(), cx.waker().clone())),
93 None => return Poll::Ready(()),
94 }
95 });
96 }
97}