1use std::{collections::HashMap, hash::Hash, marker::PhantomData};
2
3use crate::{DerefLt, Guard};
4
5use super::{GuardLt, MutexProvider, Result};
6use async_trait::async_trait;
7use std::sync::Arc;
8use tokio::sync::{Mutex, RwLock};
9
10#[derive(Debug)]
11pub struct LocalMutexProvider<T, K> {
12 map: tokio::sync::RwLock<HashMap<K, Arc<Mutex<Option<T>>>>>,
13}
14
15impl<T, K> LocalMutexProvider<T, K> {
16 pub fn new() -> LocalMutexProvider<T, K> {
17 LocalMutexProvider {
18 map: RwLock::new(HashMap::new()),
19 }
20 }
21}
22
23pub struct LocalMutex<T> {
24 mutex: Arc<Mutex<Option<T>>>,
25}
26
27#[async_trait]
28impl<T> super::Mutex<T> for LocalMutex<T>
29where
30 T: Send + Sync + 'static,
31{
32 type Guard = LocalGuardCtor<T>;
33 async fn lock(&self) -> Result<LocalGuard<'_, T>> {
34 let guard = self.mutex.lock().await;
35 Ok(LocalGuard { guard })
36 }
37}
38
39pub struct LocalGuardCtor<T>(PhantomData<T>);
40
41impl<'a, T> GuardLt<'a, T> for LocalGuardCtor<T>
42where
43 T: Send + Sync + 'static,
44{
45 type Guard = LocalGuard<'a, T>;
46}
47
48pub struct LocalGuard<'a, T> {
49 guard: tokio::sync::MutexGuard<'a, Option<T>>,
50}
51
52pub struct LocalDerefCtor<T>(PhantomData<T>);
53
54impl<'a, T> DerefLt<'a, T> for LocalDerefCtor<T>
55where
56 T: Send + Sync + 'static,
57{
58 type Deref = &'a Option<T>;
59}
60
61#[async_trait]
62impl<'a, T> Guard<T> for LocalGuard<'a, T>
63where
64 T: Send + Sync + 'static,
65{
66 type D = LocalDerefCtor<T>;
67 async fn store(&mut self, data: T) -> Result<()> {
68 *self.guard = Some(data);
69 Ok(())
70 }
71
72 async fn load<'s>(&'s self) -> Result<&'s Option<T>> {
73 Ok(&*self.guard)
74 }
75
76 async fn clear(&mut self) -> Result<()> {
77 *self.guard = None;
78 Ok(())
79 }
80}
81
82#[async_trait]
83impl<T, K> MutexProvider<T, K> for LocalMutexProvider<T, K>
84where
85 T: Send + Sync + 'static,
86 K: Hash + Eq + Send + Sync,
87{
88 type Mutex = LocalMutex<T>;
89 async fn get(&self, key: K) -> Result<Self::Mutex> {
90 let mutex = {
91 let map_readguard = self.map.read().await;
92 if let Some(lock) = map_readguard.get(&key) {
93 lock.clone()
94 } else {
95 drop(map_readguard);
96 let mutex = Arc::new(Mutex::new(None));
97 let mut writeguard = self.map.write().await;
98 writeguard.insert(key, mutex.clone());
99 mutex
100 }
101 };
102 Ok(LocalMutex { mutex })
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use crate::spec::{check_empty, check_val};
109
110 use super::LocalMutexProvider;
111
112 #[tokio::test]
113 async fn test() {
114 check_empty(LocalMutexProvider::new()).await;
115 check_val(LocalMutexProvider::new()).await;
116 }
117}