page_lock/
rw_lock.rs

1use super::*;
2
3#[derive(Default)]
4struct RefCounter {
5    count: usize,
6    wakers: Vec<(*mut PollState, Waker)>,
7}
8
9unsafe impl Send for RefCounter {}
10unsafe impl Sync for RefCounter {}
11
12#[derive(Default)]
13pub struct RwLock<T> {
14    mutex: Mutex<T>,
15    readers: Map<T, RefCounter>,
16}
17
18impl<T: Eq + Hash + Copy + Unpin> RwLock<T> {
19    #[inline]
20    pub fn new() -> RwLock<T> {
21        RwLock {
22            mutex: Mutex::new(),
23            readers: new_map(),
24        }
25    }
26
27    pub async fn read(&self, num: T) -> ReadGuard<'_, T> {
28        self.mutex.until_unlocked(num).await;
29        self.readers
30            .lock()
31            .unwrap()
32            .entry(num)
33            .or_insert_with(RefCounter::default)
34            .count += 1;
35
36        ReadGuard {
37            num,
38            readers: &self.readers,
39        }
40    }
41
42    pub async fn write(&self, num: T) -> WriteGuard<'_, T> {
43        let guard = self.mutex.lock(num).await;
44        UntilAllReaderDropped {
45            num,
46            readers: &self.readers,
47            state: PollState::Init.into(),
48        }
49        .await;
50        guard
51    }
52}
53
54pub struct ReadGuard<'a, T: Eq + Hash + Copy> {
55    num: T,
56    readers: &'a Map<T, RefCounter>,
57}
58
59// impl<T: ?Sized> !Send for ReadGuard<'_, T> {}
60// impl<T: ?Sized + Sync> Sync for ReadGuard<'_, T> {}
61
62impl<'a, T: Eq + Hash + Copy> Drop for ReadGuard<'a, T> {
63    fn drop(&mut self) {
64        unsafe {
65            let mut map = self.readers.lock().unwrap();
66            let mut rc = map.get_mut(&self.num).unwrap_unchecked();
67            rc.count -= 1;
68            if rc.count == 0 {
69                let rc = map.remove(&self.num).unwrap_unchecked();
70                for (state, waker) in rc.wakers {
71                    *state = PollState::Ready;
72                    waker.wake();
73                }
74            }
75        }
76    }
77}
78
79struct UntilAllReaderDropped<'a, T> {
80    num: T,
81    state: UnsafeCell<PollState>,
82    readers: &'a Map<T, RefCounter>,
83}
84
85impl<T: Unpin + Hash + Eq + Copy> Future for UntilAllReaderDropped<'_, T> {
86    type Output = ();
87    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
88        let this = self.get_mut();
89        ret_fut!(*this.state.get(), {
90            let mut map = this.readers.lock().unwrap();
91            match map.get_mut(&this.num) {
92                Some(rc) => rc.wakers.push(( this.state.get(), cx.waker().clone())),
93                None => return Poll::Ready(()),
94            }
95        });
96    }
97}