1pub mod error;
117pub mod execs;
118pub mod options;
119
120use crate::error::Error;
121use crate::error::Error::IdNotFound;
122use crate::execs::*;
123use crate::options::Options;
124
125use anyhow::Result;
126use redis::aio::{ConnectionManager, ConnectionManagerConfig};
127use tokio::sync::oneshot;
128use tokio::time::sleep;
129use tokio::{select, spawn};
130
131#[derive(Clone)]
132pub struct Locker {
133 client: redis::Client,
134 conn_manager: ConnectionManager,
135}
136
137impl Locker {
138 pub async fn from_redis_url(url: &str) -> Result<Self> {
139 let client = redis::Client::open(url)?;
140 let cfg = ConnectionManagerConfig::default().set_max_delay(2000);
141 let async_conn_manager = ConnectionManager::new_with_config(client.clone(), cfg).await?;
142 Ok(Self {
143 client,
144 conn_manager: async_conn_manager,
145 })
146 }
147
148 pub async fn acquire(&mut self, lock_key: &str) -> Result<Lock> {
149 self.acquire_with_options(&Options::default(), lock_key)
150 .await
151 }
152
153 pub async fn acquire_with_options(&mut self, opts: &Options, lock_key: &str) -> Result<Lock> {
154 let lock_id = lock(
155 &mut self.conn_manager,
156 lock_key,
157 opts.ttl,
158 opts.retry,
159 opts.timeout,
160 )
161 .await?;
162
163 let mut conn = self.conn_manager.clone();
164 let opts = opts.clone();
165 let lock_key_c1 = lock_key.to_owned();
166 let lock_id_c1 = lock_id.clone();
167 let (stop_tx, mut stop_rx) = oneshot::channel();
168
169 spawn(async move {
170 loop {
171 select! {
172 _ = &mut stop_rx => break,
173 _ = sleep(opts.extend) => {
174 if let Err(e) = extend(
175 &mut conn,
176 &lock_key_c1,
177 &lock_id_c1,
178 opts.ttl,
179 )
180 .await
181 {
182 if let Some(e) = e.downcast_ref::<Error>() && matches!(e, IdNotFound) {
183 break;
184 }
185 } else {
186 }
187 },
188 }
189 }
190 });
191
192 let cli = self.client.clone();
193 let lock_key_c2 = lock_key.to_owned();
194 let lock_id_c2 = lock_id.clone();
195
196 Ok(Lock {
197 release_fn: Some(Box::new(move || -> Result<()> {
198 let _ = stop_tx.send(());
199 let mut conn = cli.get_connection()?;
200 unlock_sync(&mut conn, &lock_key_c2, &lock_id_c2)
201 })),
202 })
203 }
204}
205
206pub struct Lock {
207 pub release_fn: Option<Box<dyn FnOnce() -> Result<()> + Send + 'static>>,
208}
209
210impl Lock {
211 pub fn release(mut self) -> Result<()> {
212 self.call_release()
213 }
214
215 fn call_release(&mut self) -> Result<()> {
216 match self.release_fn.take() {
217 Some(release_fn) => release_fn(),
218 None => Ok(()),
219 }
220 }
221}
222
223impl Drop for Lock {
224 fn drop(&mut self) {
225 let _ = self.call_release();
226 }
227}
228
229#[cfg(test)]
230mod test {
231 use super::*;
232 use std::time::Duration;
233
234 #[tokio::test]
235 async fn test_lock_exclusive() {
236 let mut locker = Locker::from_redis_url("redis://127.0.0.1:6379/0")
237 .await
238 .unwrap();
239 let lock_key = String::from("test:test_lock_exclusive_key");
240
241 let r = locker.acquire(&lock_key).await;
242 assert!(r.is_ok(), "Should acquire a lock");
243
244 match locker.acquire(&lock_key).await.err() {
245 None => assert!(false, "Should get an error when acquiring another lock"),
246 Some(e) => {
247 assert_eq!(
248 e.downcast_ref::<Error>().unwrap(),
249 &Error::Timeout,
250 "Should get a timed out error when acquiring another lock"
251 )
252 }
253 }
254
255 assert!(r.unwrap().release().is_ok(), "Should release a lock");
256
257 assert!(
258 locker.acquire(&lock_key).await.is_ok(),
259 "Should acquire a lock after another lock is released"
260 );
261 }
262
263 #[tokio::test]
264 async fn test_lock_drop() {
265 let mut locker = Locker::from_redis_url("redis://127.0.0.1:6379/0")
266 .await
267 .unwrap();
268 let lock_key = "test:test_lock_drop_key";
269
270 {
271 let r = locker.acquire(&lock_key).await;
272 assert!(r.is_ok(), "Should acquire a lock in a scope");
273
274 match locker.acquire(&lock_key).await.err() {
275 None => assert!(
276 false,
277 "Should get an error when acquiring another lock in a scope"
278 ),
279 Some(e) => {
280 assert_eq!(
281 e.downcast_ref::<Error>().unwrap(),
282 &Error::Timeout,
283 "Should get an timed out error when acquiring another lock in a scope"
284 );
285 }
286 }
287 }
288
289 assert!(
290 locker.acquire(&lock_key).await.is_ok(),
291 "Should acquire a lock out of the prev scope"
292 );
293 }
294
295 #[tokio::test]
296 async fn test_lock_passive_release() {
297 let mut locker = Locker::from_redis_url("redis://127.0.0.1:6379/0")
298 .await
299 .unwrap();
300 let lock_key = "test:test_lock_passive_release_key";
301
302 let opts = Options::new()
303 .ttl(Duration::from_secs(2))
304 .extend(Duration::from_secs(3));
305 let r = locker.acquire_with_options(&opts, &lock_key).await;
306 assert!(
307 r.is_ok(),
308 "Should acquire a lock with customized lifetime and extend_interval, extend_interval greater than lifetime"
309 );
310
311 sleep(Duration::from_secs(3)).await;
312 assert!(
313 locker.acquire(&lock_key).await.is_ok(),
314 "Should passively release a lock when the lifetime is reached"
315 );
316 }
317
318 #[tokio::test]
319 async fn test_lock_extend() {
320 let mut locker = Locker::from_redis_url("redis://127.0.0.1:6379/0")
321 .await
322 .unwrap();
323 let lock_key = "test:test_lock_extend_key";
324 let opts = Options::new()
325 .ttl(Duration::from_secs(3))
326 .extend(Duration::from_secs(1));
327 let r = locker.acquire_with_options(&opts, &lock_key).await;
328 assert!(
329 r.is_ok(),
330 "Should acquire a lock with customized lifetime and extend_interval, extend_interval smaller than lifetime"
331 );
332
333 sleep(Duration::from_secs(5)).await;
334 match locker.acquire(&lock_key).await.err() {
335 None => assert!(false, "Should extend lock lifetime automatically"),
336 Some(e) => {
337 assert_eq!(e.downcast_ref::<Error>().unwrap(), &Error::Timeout)
338 }
339 }
340 }
341}