1use rand::Rng;
16use std::time::{Duration, Instant};
17use tokio::{sync::RwLock, time::sleep};
18use tracing::info;
19
20#[derive(Debug, Default)]
21pub struct LRWMutex {
22 id: RwLock<String>,
23 source: RwLock<String>,
24 is_write: RwLock<bool>,
25 refrence: RwLock<usize>,
26}
27
28impl LRWMutex {
29 pub async fn lock(&self) -> bool {
30 let is_write = true;
31 let id = self.id.read().await.clone();
32 let source = self.source.read().await.clone();
33 let timeout = Duration::from_secs(10000);
34 self.look_loop(
35 &id, &source, &timeout, is_write,
37 )
38 .await
39 }
40
41 pub async fn get_lock(&self, id: &str, source: &str, timeout: &Duration) -> bool {
42 let is_write = true;
43 self.look_loop(id, source, timeout, is_write).await
44 }
45
46 pub async fn r_lock(&self) -> bool {
47 let is_write: bool = false;
48 let id = self.id.read().await.clone();
49 let source = self.source.read().await.clone();
50 let timeout = Duration::from_secs(10000);
51 self.look_loop(
52 &id, &source, &timeout, is_write,
54 )
55 .await
56 }
57
58 pub async fn get_r_lock(&self, id: &str, source: &str, timeout: &Duration) -> bool {
59 let is_write = false;
60 self.look_loop(id, source, timeout, is_write).await
61 }
62
63 async fn inner_lock(&self, id: &str, source: &str, is_write: bool) -> bool {
64 *self.id.write().await = id.to_string();
65 *self.source.write().await = source.to_string();
66
67 let mut locked = false;
68 if is_write {
69 if *self.refrence.read().await == 0 && !*self.is_write.read().await {
70 *self.refrence.write().await = 1;
71 *self.is_write.write().await = true;
72 locked = true;
73 }
74 } else if !*self.is_write.read().await {
75 *self.refrence.write().await += 1;
76 locked = true;
77 }
78
79 locked
80 }
81
82 async fn look_loop(&self, id: &str, source: &str, timeout: &Duration, is_write: bool) -> bool {
83 let start = Instant::now();
84 loop {
85 if self.inner_lock(id, source, is_write).await {
86 return true;
87 } else {
88 if Instant::now().duration_since(start) > *timeout {
89 return false;
90 }
91 let sleep_time: u64;
92 {
93 let mut rng = rand::rng();
94 sleep_time = rng.random_range(10..=50);
95 }
96 sleep(Duration::from_millis(sleep_time)).await;
97 }
98 }
99 }
100
101 pub async fn un_lock(&self) {
102 let is_write = true;
103 if !self.unlock(is_write).await {
104 info!("Trying to un_lock() while no Lock() is active")
105 }
106 }
107
108 pub async fn un_r_lock(&self) {
109 let is_write = false;
110 if !self.unlock(is_write).await {
111 info!("Trying to un_r_lock() while no Lock() is active")
112 }
113 }
114
115 async fn unlock(&self, is_write: bool) -> bool {
116 let mut unlocked = false;
117 if is_write {
118 if *self.is_write.read().await && *self.refrence.read().await == 1 {
119 *self.refrence.write().await = 0;
120 *self.is_write.write().await = false;
121 unlocked = true;
122 }
123 } else if !*self.is_write.read().await && *self.refrence.read().await > 0 {
124 *self.refrence.write().await -= 1;
125 unlocked = true;
126 }
127
128 unlocked
129 }
130
131 pub async fn force_un_lock(&self) {
132 *self.refrence.write().await = 0;
133 *self.is_write.write().await = false;
134 }
135}
136
137#[cfg(test)]
138mod test {
139 use std::{sync::Arc, time::Duration};
140
141 use std::io::Result;
142 use tokio::time::sleep;
143
144 use crate::lrwmutex::LRWMutex;
145
146 #[tokio::test]
147 async fn test_lock_unlock() -> Result<()> {
148 let l_rw_lock = LRWMutex::default();
149 let id = "foo";
150 let source = "dandan";
151 let timeout = Duration::from_secs(5);
152 assert!(l_rw_lock.get_lock(id, source, &timeout).await);
153 l_rw_lock.un_lock().await;
154
155 l_rw_lock.lock().await;
156
157 assert!(!l_rw_lock.get_r_lock(id, source, &timeout).await);
158 l_rw_lock.un_lock().await;
159 assert!(l_rw_lock.get_r_lock(id, source, &timeout).await);
160
161 Ok(())
162 }
163
164 #[tokio::test]
165 async fn multi_thread_test() -> Result<()> {
166 let l_rw_lock = Arc::new(LRWMutex::default());
167 let id = "foo";
168 let source = "dandan";
169
170 let one_fn = async {
171 let one = Arc::clone(&l_rw_lock);
172 let timeout = Duration::from_secs(1);
173 assert!(one.get_lock(id, source, &timeout).await);
174 sleep(Duration::from_secs(5)).await;
175 l_rw_lock.un_lock().await;
176 };
177
178 let two_fn = async {
179 let two = Arc::clone(&l_rw_lock);
180 let timeout = Duration::from_secs(2);
181 assert!(!two.get_r_lock(id, source, &timeout).await);
182 sleep(Duration::from_secs(5)).await;
183 assert!(two.get_r_lock(id, source, &timeout).await);
184 two.un_r_lock().await;
185 };
186
187 tokio::join!(one_fn, two_fn);
188
189 Ok(())
190 }
191}