1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
use std::{
    collections::HashMap,
    future::Future,
    hash::Hash,
    pin::Pin,
    sync::RwLock,
    task::{Context, Poll, Waker},
};

type Locker<T> = RwLock<HashMap<T, Vec<Waker>>>;

pub struct UnLock<'a, T> {
    num:    T,
    state:  bool,
    locker: &'a Locker<T>,
}

impl<'a, T: Unpin + Eq + Hash> Future for UnLock<'a, T> {
    type Output = ();

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        if self.state {
            return Poll::Ready(());
        }
        let this = self.get_mut();
        this.state = true;
        this.locker.write().unwrap().get_mut(&this.num).unwrap().push(cx.waker().clone());
        Poll::Pending
    }
}

pub struct LockGuard<'a, T: Eq + Hash> {
    num:    T,
    locker: &'a Locker<T>,
}

impl<T: Eq + Hash> Drop for LockGuard<'_, T> {
    fn drop(&mut self) {
        for waker in self.locker.write().unwrap().remove(&self.num).unwrap() {
            waker.wake();
        }
    }
}

#[derive(Default)]
pub struct PageLocker<T> {
    locker: Locker<T>,
}

impl<T: Eq + Hash + Copy + Unpin> PageLocker<T> {
    pub fn new() -> Self { Self { locker: RwLock::new(HashMap::new()) } }

    #[inline(always)]
    pub fn unlock(&self, num: T) -> UnLock<T> {
        UnLock { state: self.locker.read().unwrap().get(&num).is_none(), locker: &self.locker, num }
    }

    #[inline(always)]
    pub async fn lock(&self, num: T) -> LockGuard<'_, T> {
        self.unlock(num).await;
        self.locker.write().unwrap().insert(num, Vec::new());
        LockGuard { num, locker: &self.locker }
    }
}