a_mutex/
local.rs

1use std::{collections::HashMap, hash::Hash, marker::PhantomData};
2
3use crate::{DerefLt, Guard};
4
5use super::{GuardLt, MutexProvider, Result};
6use async_trait::async_trait;
7use std::sync::Arc;
8use tokio::sync::{Mutex, RwLock};
9
10#[derive(Debug)]
11pub struct LocalMutexProvider<T, K> {
12    map: tokio::sync::RwLock<HashMap<K, Arc<Mutex<Option<T>>>>>,
13}
14
15impl<T, K> LocalMutexProvider<T, K> {
16    pub fn new() -> LocalMutexProvider<T, K> {
17        LocalMutexProvider {
18            map: RwLock::new(HashMap::new()),
19        }
20    }
21}
22
23pub struct LocalMutex<T> {
24    mutex: Arc<Mutex<Option<T>>>,
25}
26
27#[async_trait]
28impl<T> super::Mutex<T> for LocalMutex<T>
29where
30    T: Send + Sync + 'static,
31{
32    type Guard = LocalGuardCtor<T>;
33    async fn lock(&self) -> Result<LocalGuard<'_, T>> {
34        let guard = self.mutex.lock().await;
35        Ok(LocalGuard { guard })
36    }
37}
38
39pub struct LocalGuardCtor<T>(PhantomData<T>);
40
41impl<'a, T> GuardLt<'a, T> for LocalGuardCtor<T>
42where
43    T: Send + Sync + 'static,
44{
45    type Guard = LocalGuard<'a, T>;
46}
47
48pub struct LocalGuard<'a, T> {
49    guard: tokio::sync::MutexGuard<'a, Option<T>>,
50}
51
52pub struct LocalDerefCtor<T>(PhantomData<T>);
53
54impl<'a, T> DerefLt<'a, T> for LocalDerefCtor<T>
55where
56    T: Send + Sync + 'static,
57{
58    type Deref = &'a Option<T>;
59}
60
61#[async_trait]
62impl<'a, T> Guard<T> for LocalGuard<'a, T>
63where
64    T: Send + Sync + 'static,
65{
66    type D = LocalDerefCtor<T>;
67    async fn store(&mut self, data: T) -> Result<()> {
68        *self.guard = Some(data);
69        Ok(())
70    }
71
72    async fn load<'s>(&'s self) -> Result<&'s Option<T>> {
73        Ok(&*self.guard)
74    }
75
76    async fn clear(&mut self) -> Result<()> {
77        *self.guard = None;
78        Ok(())
79    }
80}
81
82#[async_trait]
83impl<T, K> MutexProvider<T, K> for LocalMutexProvider<T, K>
84where
85    T: Send + Sync + 'static,
86    K: Hash + Eq + Send + Sync,
87{
88    type Mutex = LocalMutex<T>;
89    async fn get(&self, key: K) -> Result<Self::Mutex> {
90        let mutex = {
91            let map_readguard = self.map.read().await;
92            if let Some(lock) = map_readguard.get(&key) {
93                lock.clone()
94            } else {
95                drop(map_readguard);
96                let mutex = Arc::new(Mutex::new(None));
97                let mut writeguard = self.map.write().await;
98                writeguard.insert(key, mutex.clone());
99                mutex
100            }
101        };
102        Ok(LocalMutex { mutex })
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use crate::spec::{check_empty, check_val};
109
110    use super::LocalMutexProvider;
111
112    #[tokio::test]
113    async fn test() {
114        check_empty(LocalMutexProvider::new()).await;
115        check_val(LocalMutexProvider::new()).await;
116    }
117}