Skip to main content

rslock/
lock.rs

1use std::io;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use futures::future::join_all;
6use rand::{rng, Rng, RngCore};
7use redis::aio::MultiplexedConnection;
8use redis::Value::Okay;
9use redis::{Client, IntoConnectionInfo, RedisError, RedisResult, Value};
10
11use crate::resource::{LockResource, ToLockResource};
12
13const DEFAULT_RETRY_COUNT: u32 = 3;
14const DEFAULT_RETRY_DELAY: Duration = Duration::from_millis(200);
15const CLOCK_DRIFT_FACTOR: f32 = 0.01;
16const UNLOCK_SCRIPT: &str = r#"
17if redis.call("GET", KEYS[1]) == ARGV[1] then
18  return redis.call("DEL", KEYS[1])
19else
20  return 0
21end
22"#;
23const EXTEND_SCRIPT: &str = r#"
24if redis.call("get", KEYS[1]) ~= ARGV[1] then
25  return 0
26else
27  if redis.call("set", KEYS[1], ARGV[1], "PX", ARGV[2]) ~= nil then
28    return 1
29  else
30    return 0
31  end
32end
33"#;
34
35#[derive(Debug, thiserror::Error)]
36pub enum LockError {
37    #[error("IO error: {0}")]
38    Io(#[from] io::Error),
39
40    #[error("Redis error: {0}")]
41    Redis(#[from] redis::RedisError),
42
43    #[error("Resource is unavailable")]
44    Unavailable,
45
46    #[error("TTL exceeded")]
47    TtlExceeded,
48
49    #[error("TTL too large")]
50    TtlTooLarge,
51
52    #[error("Redis connection failed for all servers")]
53    RedisConnectionFailed,
54
55    #[error("Redis connection failed.")]
56    RedisFailedToEstablishConnection,
57
58    #[error("Redis key mismatch: expected value does not match actual value")]
59    RedisKeyMismatch,
60
61    #[error("Redis key not found")]
62    RedisKeyNotFound,
63    #[error("A mutex was poisoned")]
64    MutexPoisoned,
65}
66
67//This is in place to make it easier to swap to just an std-async io implementaiton ??
68type Mutex<T> = tokio::sync::Mutex<T>;
69type MutexGuard<'a, K> = tokio::sync::MutexGuard<'a, K>;
70
71/// The lock manager.
72///
73/// Implements the necessary functionality to acquire and release locks
74/// and handles the Redis connections.
75#[derive(Debug, Clone)]
76pub struct LockManager {
77    lock_manager_inner: Arc<Mutex<LockManagerInner>>,
78    retry_count: u32,
79    retry_delay: Duration,
80}
81
82#[derive(Debug, Clone)]
83struct LockManagerInner {
84    /// List of all Redis clients
85    pub servers: Vec<RestorableConnection>,
86}
87
88impl LockManagerInner {
89    fn get_quorum(&self) -> u32 {
90        (self.servers.len() as u32) / 2 + 1
91    }
92}
93
94#[derive(Debug, Clone)]
95struct RestorableConnection {
96    client: Client,
97    con: Arc<Mutex<Option<MultiplexedConnection>>>,
98}
99
100impl RestorableConnection {
101    pub fn new(client: Client) -> Self {
102        Self {
103            client,
104            con: Arc::new(tokio::sync::Mutex::new(None)),
105        }
106    }
107
108    pub async fn get_connection(&mut self) -> Result<MultiplexedConnection, LockError> {
109        let mut lock = self.con.lock().await;
110        if lock.is_none() {
111            *lock = Some(
112                self.client
113                    .get_multiplexed_async_connection()
114                    .await
115                    .map_err(LockError::Redis)?,
116            );
117        }
118        match (*lock).clone() {
119            Some(conn) => Ok(conn),
120            None => Err(LockError::RedisFailedToEstablishConnection),
121        }
122    }
123
124    pub async fn recover(&mut self, error: RedisError) -> Result<(), LockError> {
125        //We need to rebuild the connection
126        if !error.is_unrecoverable_error() {
127            Ok(())
128        } else {
129            let mut lock = self.con.lock().await;
130            *lock = Some(
131                self.client
132                    .get_multiplexed_async_connection()
133                    .await
134                    .map_err(LockError::Redis)?,
135            );
136            Ok(())
137        }
138    }
139}
140
141impl RestorableConnection {
142    async fn lock(&mut self, resource: &LockResource<'_>, val: &[u8], ttl: usize) -> bool {
143        let mut con = match self.get_connection().await {
144            Err(_) => return false,
145            Ok(val) => val,
146        };
147
148        let result: RedisResult<Value> = redis::cmd("SET")
149            .arg(resource)
150            .arg(val)
151            .arg("NX")
152            .arg("PX")
153            .arg(ttl)
154            .query_async(&mut con)
155            .await;
156
157        match result {
158            Ok(Okay) => true,
159            Ok(_) => false,
160            Err(e) => {
161                //We don't have to do anything special, it's up to the caller to retry or back out
162                let _ = self.recover(e).await;
163                false
164            }
165        }
166    }
167
168    async fn extend(&mut self, resource: &LockResource<'_>, val: &[u8], ttl: usize) -> bool {
169        let mut con = match self.get_connection().await {
170            Err(_) => return false,
171            Ok(val) => val,
172        };
173        let script = redis::Script::new(EXTEND_SCRIPT);
174        let result: RedisResult<i32> = script
175            .key(resource)
176            .arg(val)
177            .arg(ttl)
178            .invoke_async(&mut con)
179            .await;
180        match result {
181            Ok(val) => val == 1,
182            Err(e) => {
183                //We don't have to do anything special, it's up to the caller to retry or back out
184                let _ = self.recover(e).await;
185                false
186            }
187        }
188    }
189
190    async fn unlock(&mut self, resource: impl ToLockResource<'_>, val: &[u8]) -> bool {
191        let resource = resource.to_lock_resource();
192        let mut con = match self.get_connection().await {
193            Err(_) => return false,
194            Ok(val) => val,
195        };
196        let script = redis::Script::new(UNLOCK_SCRIPT);
197        let result: RedisResult<i32> = script.key(resource).arg(val).invoke_async(&mut con).await;
198        match result {
199            Ok(val) => val == 1,
200            Err(e) => {
201                //We don't have to do anything special, it's up to the caller to retry or back out
202                let _ = self.recover(e).await;
203                false
204            }
205        }
206    }
207
208    async fn query(&mut self, resource: &[u8]) -> RedisResult<Option<Vec<u8>>> {
209        let mut con = match self.get_connection().await {
210            Ok(con) => con,
211            Err(_e) => return Ok(None),
212        };
213        let result: RedisResult<Option<Vec<u8>>> =
214            redis::cmd("GET").arg(resource).query_async(&mut con).await;
215        result
216    }
217}
218
219/// A distributed lock that can be acquired and released across multiple Redis instances.
220///
221/// A `Lock` represents a distributed lock in Redis.
222/// The lock is associated with a resource, identified by a unique key, and a value that identifies
223/// the lock owner. The `LockManager` is responsible for managing the acquisition, release, and extension
224/// of locks.
225#[derive(Debug)]
226pub struct Lock {
227    /// The resource to lock. Will be used as the key in Redis.
228    pub resource: Vec<u8>,
229    /// The value for this lock.
230    pub val: Vec<u8>,
231    /// Time the lock is still valid.
232    /// Should only be slightly smaller than the requested TTL.
233    pub validity_time: usize,
234    /// Used to limit the lifetime of a lock to its lock manager.
235    pub lock_manager: LockManager,
236}
237
238/// Upon dropping the guard, `LockManager::unlock` will be ran synchronously on the executor.
239///
240/// This is known to block the tokio runtime if this happens inside of the context of a tokio runtime
241/// if `tokio-comp` is enabled as a feature on this crate or the `redis` crate.
242///
243/// To eliminate this risk, if the `tokio-comp` flag is enabled, the `Drop` impl will not be compiled,
244/// meaning that dropping the `LockGuard` will be a no-op.
245/// Under this circumstance, `LockManager::unlock` can be called manually using the inner `lock` at the appropriate
246/// point to release the lock taken in `Redis`.
247#[derive(Debug)]
248pub struct LockGuard {
249    pub lock: Lock,
250}
251
252enum Operation {
253    Lock,
254    Extend,
255}
256
257/// Dropping this guard inside the context of a tokio runtime if `tokio-comp` is enabled
258/// will block the tokio runtime.
259/// Because of this, the guard is not compiled if `tokio-comp` is enabled.
260#[cfg(not(feature = "tokio-comp"))]
261impl Drop for LockGuard {
262    fn drop(&mut self) {
263        futures::executor::block_on(self.lock.lock_manager.unlock(&self.lock));
264    }
265}
266
267impl LockManager {
268    /// Create a new lock manager instance, defined by the given Redis connection uris.
269    ///
270    /// Sample URI: `"redis://127.0.0.1:6379"`
271    pub fn new<T: IntoConnectionInfo>(uris: Vec<T>) -> LockManager {
272        let servers: Vec<Client> = uris
273            .into_iter()
274            .map(|uri| Client::open(uri).unwrap())
275            .collect();
276
277        Self::from_clients(servers)
278    }
279
280    /// Create a new lock manager instance, defined by the given Redis clients.
281    /// Quorum is defined to be N/2+1, with N being the number of given Redis instances.
282    pub fn from_clients(clients: Vec<Client>) -> LockManager {
283        let clients: Vec<RestorableConnection> =
284            clients.into_iter().map(RestorableConnection::new).collect();
285        LockManager {
286            lock_manager_inner: Arc::new(Mutex::new(LockManagerInner { servers: clients })),
287            retry_count: DEFAULT_RETRY_COUNT,
288            retry_delay: DEFAULT_RETRY_DELAY,
289        }
290    }
291
292    /// Get 20 random bytes from the pseudorandom interface.
293    pub fn get_unique_lock_id(&self) -> io::Result<Vec<u8>> {
294        let mut buf = [0u8; 20];
295        rng().fill_bytes(&mut buf);
296        Ok(buf.to_vec())
297    }
298
299    /// Set retry count and retry delay.
300    ///
301    /// Retries will be delayed by a random amount of time between `0` and `retry_delay`.
302    ///
303    /// Retry count defaults to `3`.
304    /// Retry delay defaults to `200`.
305    pub fn set_retry(&mut self, count: u32, delay: Duration) {
306        self.retry_count = count;
307        self.retry_delay = delay;
308    }
309
310    async fn lock_inner(&self) -> MutexGuard<'_, LockManagerInner> {
311        self.lock_manager_inner.lock().await
312    }
313
314    // Can be used for creating or extending a lock
315    async fn exec_or_retry(
316        &self,
317        resource: impl ToLockResource<'_>,
318        value: &[u8],
319        ttl: usize,
320        function: Operation,
321    ) -> Result<Lock, LockError> {
322        let mut current_try = 1;
323        let resource = &resource.to_lock_resource();
324
325        loop {
326            let start_time = Instant::now();
327            let l = self.lock_inner().await;
328            let mut servers = l.servers.clone();
329            drop(l);
330
331            let n = match function {
332                Operation::Lock => {
333                    join_all(servers.iter_mut().map(|c| c.lock(resource, value, ttl))).await
334                }
335                Operation::Extend => {
336                    join_all(servers.iter_mut().map(|c| c.extend(resource, value, ttl))).await
337                }
338            }
339            .into_iter()
340            .fold(0, |count, locked| if locked { count + 1 } else { count });
341
342            let drift = (ttl as f32 * CLOCK_DRIFT_FACTOR) as usize + 2;
343            let elapsed = start_time.elapsed();
344            let elapsed_ms =
345                elapsed.as_secs() as usize * 1000 + elapsed.subsec_nanos() as usize / 1_000_000;
346            if ttl <= drift + elapsed_ms {
347                return Err(LockError::TtlExceeded);
348            }
349            let validity_time = ttl
350                - drift
351                - elapsed.as_secs() as usize * 1000
352                - elapsed.subsec_nanos() as usize / 1_000_000;
353
354            let l = self.lock_inner().await;
355            if n >= l.get_quorum() && validity_time > 0 {
356                return Ok(Lock {
357                    lock_manager: self.clone(),
358                    resource: resource.to_vec(),
359                    val: value.to_vec(),
360                    validity_time,
361                });
362            }
363
364            let mut servers = l.servers.clone();
365            drop(l);
366            join_all(
367                servers
368                    .iter_mut()
369                    .map(|client| client.unlock(resource, value)),
370            )
371            .await;
372
373            // only sleep here if we have any retries left
374            if current_try < self.retry_count {
375                current_try += 1;
376
377                let retry_delay: u64 = self
378                    .retry_delay
379                    .as_millis()
380                    .try_into()
381                    .map_err(|_| LockError::TtlTooLarge)?;
382
383                let n = rng().random_range(0..retry_delay);
384
385                tokio::time::sleep(Duration::from_millis(n)).await
386            } else {
387                break;
388            }
389        }
390
391        Err(LockError::Unavailable)
392    }
393
394    // Query Redis for a key's value and keep trying each server until a successful result is returned
395    pub async fn query_redis_for_key_value(
396        &self,
397        resource: &[u8],
398    ) -> Result<Option<Vec<u8>>, LockError> {
399        let l = self.lock_inner().await;
400        let mut servers = l.servers.clone();
401        drop(l);
402        let results = join_all(servers.iter_mut().map(|c| c.query(resource))).await;
403
404        if let Some(value) = results.into_iter().find_map(Result::ok) {
405            return Ok(value);
406        }
407        Err(LockError::RedisConnectionFailed) // All servers failed
408    }
409
410    /// Unlock the given lock.
411    ///
412    /// Unlock is best effort. It will simply try to contact all instances
413    /// and remove the key.
414    pub async fn unlock(&self, lock: &Lock) {
415        let l = self.lock_inner().await;
416        let mut servers = l.servers.clone();
417        drop(l);
418        join_all(
419            servers
420                .iter_mut()
421                .map(|client| client.unlock(&*lock.resource, &lock.val)),
422        )
423        .await;
424    }
425
426    /// Acquire the lock for the given resource and the requested TTL.
427    ///
428    /// If it succeeds, a `Lock` instance is returned,
429    /// including the value and the validity time
430    ///
431    /// If it fails. `None` is returned.
432    /// A user should retry after a short wait time.
433    ///
434    /// May return `LockError::TtlTooLarge` if `ttl` is too large.
435    pub async fn lock(
436        &self,
437        resource: impl ToLockResource<'_>,
438        ttl: Duration,
439    ) -> Result<Lock, LockError> {
440        let resource = resource.to_lock_resource();
441        let val = self.get_unique_lock_id().map_err(LockError::Io)?;
442        let ttl = ttl
443            .as_millis()
444            .try_into()
445            .map_err(|_| LockError::TtlTooLarge)?;
446
447        self.exec_or_retry(&resource, &val.clone(), ttl, Operation::Lock)
448            .await
449    }
450
451    /// Loops until the lock is acquired.
452    ///
453    /// The lock is placed in a guard that will unlock the lock when the guard is dropped.
454    ///
455    /// May return `LockError::TtlTooLarge` if `ttl` is too large.
456    #[cfg(feature = "async-std-comp")]
457    pub async fn acquire(
458        &self,
459        resource: impl ToLockResource<'_>,
460        ttl: Duration,
461    ) -> Result<LockGuard, LockError> {
462        let lock = self.acquire_no_guard(resource, ttl).await?;
463        Ok(LockGuard { lock })
464    }
465
466    /// Loops until the lock is acquired.
467    ///
468    /// Either lock's value must expire after the ttl has elapsed,
469    /// or `LockManager::unlock` must be called to allow other clients to lock the same resource.
470    ///
471    /// May return `LockError::TtlTooLarge` if `ttl` is too large.
472    pub async fn acquire_no_guard(
473        &self,
474        resource: impl ToLockResource<'_>,
475        ttl: Duration,
476    ) -> Result<Lock, LockError> {
477        let resource = &resource.to_lock_resource();
478        loop {
479            match self.lock(resource, ttl).await {
480                Ok(lock) => return Ok(lock),
481                Err(LockError::TtlTooLarge) => return Err(LockError::TtlTooLarge),
482                Err(_) => continue,
483            }
484        }
485    }
486
487    /// Extend the given lock by given time in milliseconds
488    pub async fn extend(&self, lock: &Lock, ttl: Duration) -> Result<Lock, LockError> {
489        let ttl = ttl
490            .as_millis()
491            .try_into()
492            .map_err(|_| LockError::TtlTooLarge)?;
493
494        self.exec_or_retry(&*lock.resource, &lock.val, ttl, Operation::Extend)
495            .await
496    }
497
498    /// Checks if the given lock has been freed (i.e., is no longer held).
499    ///
500    /// This method queries Redis to determine if the key associated with the lock
501    /// is still present and matches the value of this lock. If the key is missing
502    /// or the value does not match, the lock is considered freed.
503    ///
504    /// # Returns
505    ///
506    /// `Ok(true)` if the lock is considered freed (either because the key does not exist
507    /// or the value does not match), otherwise `Ok(false)`. Returns an error if a Redis
508    /// connection or query fails.
509    pub async fn is_freed(&self, lock: &Lock) -> Result<bool, LockError> {
510        match self.query_redis_for_key_value(&lock.resource).await? {
511            Some(val) => {
512                if val != lock.val {
513                    Err(LockError::RedisKeyMismatch)
514                } else {
515                    Ok(false) // Key is present and matches the lock value
516                }
517            }
518            None => Err(LockError::RedisKeyNotFound), // Key does not exist
519        }
520    }
521
522    #[cfg(feature = "tokio-comp")]
523    pub async fn using<R>(
524        &self,
525        resource: &[u8],
526        ttl: Duration,
527        routine: impl AsyncFnOnce() -> R,
528    ) -> Result<R, LockError> {
529        let mut lock = self.acquire_no_guard(resource, ttl).await?;
530        let mut threshold = lock.validity_time as u64 - 500;
531
532        let routine = routine();
533        futures::pin_mut!(routine);
534
535        loop {
536            match tokio::time::timeout(Duration::from_millis(threshold), &mut routine).await {
537                Ok(result) => {
538                    self.unlock(&lock).await;
539
540                    return Ok(result);
541                }
542
543                Err(_) => {
544                    lock = self.extend(&lock, ttl).await?;
545                    threshold = lock.validity_time as u64 - 500;
546                }
547            }
548        }
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use anyhow::Result;
555    use testcontainers::{
556        core::{IntoContainerPort, WaitFor},
557        runners::AsyncRunner,
558        ContainerAsync, GenericImage,
559    };
560    use tokio::time::Duration;
561
562    use super::*;
563
564    type Containers = Vec<ContainerAsync<GenericImage>>;
565
566    async fn create_clients() -> (Containers, Vec<String>) {
567        let mut containers = Vec::new();
568        let mut addresses = Vec::new();
569
570        for _ in 1..=3 {
571            let container = GenericImage::new("redis", "7")
572                .with_exposed_port(6379.tcp())
573                .with_wait_for(WaitFor::message_on_stdout("Ready to accept connections"))
574                .start()
575                .await
576                .expect("Failed to start Redis container");
577
578            let port = container
579                .get_host_port_ipv4(6379)
580                .await
581                .expect("Failed to get port");
582            let address = format!("redis://localhost:{}", port);
583
584            containers.push(container);
585            addresses.push(address);
586        }
587
588        // Ensure all Redis instances are ready
589        ensure_redis_readiness(&addresses)
590            .await
591            .expect("Redis instances are not ready");
592
593        (containers, addresses)
594    }
595
596    /// This function connects to each Redis instance and sends a `PING` command to verify its readiness.
597    /// If any Redis instance fails to respond, it retries up to 120 times with a 1000ms delay between attempts.
598    /// If readiness is not achieved after the retries, an error is returned.
599    ///
600    /// # Purpose
601    /// This function is particularly useful in CI environments and automated testing to ensure
602    /// that Redis containers or instances are fully initialized before running tests. This helps
603    /// prevent flaky tests caused by race conditions where Redis is not yet ready.
604    async fn ensure_redis_readiness(
605        addresses: &[String],
606    ) -> Result<(), Box<dyn std::error::Error>> {
607        for address in addresses {
608            let client = Client::open(address.as_str())?;
609            let mut retries = 120;
610
611            while retries > 0 {
612                match client.get_multiplexed_async_connection().await {
613                    Ok(mut con) => match redis::cmd("PING").query_async::<String>(&mut con).await {
614                        Ok(response) => {
615                            eprintln!("Redis {} is ready: {}", address, response);
616                            break; // Move to the next address
617                        }
618                        Err(e) => {
619                            eprintln!("Redis {} is not ready: {:?}", address, e);
620                        }
621                    },
622                    Err(e) => eprintln!("Failed to connect to Redis {}: {:?}", address, e),
623                }
624
625                // Decrement retries and wait before the next attempt
626                retries -= 1;
627                tokio::time::sleep(Duration::from_secs(1)).await;
628            }
629
630            if retries == 0 {
631                return Err(format!("Redis {} did not become ready after retries", address).into());
632            }
633        }
634
635        Ok(())
636    }
637
638    fn is_normal<T: Sized + Send + Sync + Unpin>() {}
639
640    // Test that the LockManager is Send + Sync
641    #[test]
642    fn test_is_normal() {
643        is_normal::<LockManager>();
644        is_normal::<LockError>();
645        is_normal::<Lock>();
646        is_normal::<LockGuard>();
647    }
648
649    #[tokio::test]
650    async fn test_lock_get_unique_id() -> Result<()> {
651        let rl = LockManager::new(Vec::<String>::new());
652        assert_eq!(rl.get_unique_lock_id()?.len(), 20);
653
654        Ok(())
655    }
656
657    #[tokio::test]
658    async fn test_lock_get_unique_id_uniqueness() -> Result<()> {
659        let rl = LockManager::new(Vec::<String>::new());
660
661        let id1 = rl.get_unique_lock_id()?;
662        let id2 = rl.get_unique_lock_id()?;
663
664        assert_eq!(20, id1.len());
665        assert_eq!(20, id2.len());
666        assert_ne!(id1, id2);
667
668        Ok(())
669    }
670
671    #[tokio::test]
672    async fn test_lock_valid_instance() {
673        let (_containers, addresses) = create_clients().await;
674
675        let rl = LockManager::new(addresses.clone());
676        let l = rl.lock_inner().await;
677
678        assert_eq!(3, l.servers.len());
679        assert_eq!(2, l.get_quorum());
680    }
681
682    #[tokio::test]
683    async fn test_lock_direct_unlock_fails() -> Result<()> {
684        let (_containers, addresses) = create_clients().await;
685
686        let rl = LockManager::new(addresses.clone());
687        let key = rl.get_unique_lock_id()?;
688
689        let val = rl.get_unique_lock_id()?;
690        let mut l = rl.lock_inner().await;
691        assert!(!l.servers[0].unlock(&key, &val).await);
692
693        Ok(())
694    }
695
696    #[tokio::test]
697    async fn test_lock_direct_unlock_succeeds() -> Result<()> {
698        let (_containers, addresses) = create_clients().await;
699
700        let rl = LockManager::new(addresses.clone());
701        let key = rl.get_unique_lock_id()?;
702
703        let val = rl.get_unique_lock_id()?;
704        let mut l = rl.lock_inner().await;
705        let mut con = l.servers[0].get_connection().await?;
706
707        redis::cmd("SET")
708            .arg(&*key)
709            .arg(&*val)
710            .exec_async(&mut con)
711            .await?;
712
713        assert!(l.servers[0].unlock(&key, &val).await);
714        Ok(())
715    }
716
717    #[tokio::test]
718    async fn test_lock_direct_lock_succeeds() -> Result<()> {
719        let (_containers, addresses) = create_clients().await;
720
721        let rl = LockManager::new(addresses.clone());
722        let key = rl.get_unique_lock_id()?;
723        let resource = key.to_lock_resource();
724
725        let val = rl.get_unique_lock_id()?;
726        let mut l = rl.lock_inner().await;
727        let mut con = l.servers[0].get_connection().await?;
728
729        redis::cmd("DEL").arg(&*key).exec_async(&mut con).await?;
730        assert!(l.servers[0].lock(&resource, &val, 10_000).await);
731        Ok(())
732    }
733
734    #[tokio::test]
735    async fn test_lock_unlock() -> Result<()> {
736        let (_containers, addresses) = create_clients().await;
737
738        let rl = LockManager::new(addresses.clone());
739        let key = rl.get_unique_lock_id()?;
740
741        let val = rl.get_unique_lock_id()?;
742        let mut l = rl.lock_inner().await;
743        let mut con = l.servers[0].get_connection().await?;
744        drop(l);
745        let _: () = redis::cmd("SET")
746            .arg(&*key)
747            .arg(&*val)
748            .query_async(&mut con)
749            .await?;
750
751        let lock = Lock {
752            lock_manager: rl.clone(),
753            resource: key,
754            val,
755            validity_time: 0,
756        };
757
758        rl.unlock(&lock).await;
759
760        Ok(())
761    }
762
763    #[tokio::test]
764    async fn test_lock_lock() -> Result<()> {
765        let (_containers, addresses) = create_clients().await;
766
767        let rl = LockManager::new(addresses.clone());
768
769        let key = rl.get_unique_lock_id()?;
770        match rl.lock(&key, Duration::from_millis(10_000)).await {
771            Ok(lock) => {
772                assert_eq!(key, lock.resource);
773                assert_eq!(20, lock.val.len());
774                assert!(
775                    lock.validity_time > 0,
776                    "validity time: {}",
777                    lock.validity_time
778                );
779            }
780            Err(e) => panic!("{:?}", e),
781        }
782
783        Ok(())
784    }
785
786    #[tokio::test]
787    async fn test_lock_lock_unlock() -> Result<()> {
788        let (_containers, addresses) = create_clients().await;
789
790        let rl = LockManager::new(addresses.clone());
791        let rl2 = LockManager::new(addresses.clone());
792
793        let key = rl.get_unique_lock_id()?;
794
795        let lock = rl.lock(&key, Duration::from_millis(10_000)).await.unwrap();
796        assert!(
797            lock.validity_time > 0,
798            "validity time: {}",
799            lock.validity_time
800        );
801
802        if let Ok(_l) = rl2.lock(&key, Duration::from_millis(10_000)).await {
803            panic!("Lock acquired, even though it should be locked")
804        }
805
806        rl.unlock(&lock).await;
807
808        match rl2.lock(&key, Duration::from_millis(10_000)).await {
809            Ok(l) => assert!(l.validity_time > 0),
810            Err(_) => panic!("Lock couldn't be acquired"),
811        }
812
813        Ok(())
814    }
815
816    #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
817    #[tokio::test]
818    async fn test_lock_lock_unlock_raii() -> Result<()> {
819        let (_containers, addresses) = create_clients().await;
820
821        let rl = LockManager::new(addresses.clone());
822        let rl2 = LockManager::new(addresses.clone());
823        let key = rl.get_unique_lock_id()?;
824
825        async {
826            let lock_guard = rl
827                .acquire(&key, Duration::from_millis(10_000))
828                .await
829                .unwrap();
830            let lock = &lock_guard.lock;
831            assert!(
832                lock.validity_time > 0,
833                "validity time: {}",
834                lock.validity_time
835            );
836
837            if let Ok(_l) = rl2.lock(&key, Duration::from_millis(10_000)).await {
838                panic!("Lock acquired, even though it should be locked")
839            }
840        }
841        .await;
842
843        match rl2.lock(&key, Duration::from_millis(10_000)).await {
844            Ok(l) => assert!(l.validity_time > 0),
845            Err(_) => panic!("Lock couldn't be acquired"),
846        }
847
848        Ok(())
849    }
850
851    #[cfg(feature = "tokio-comp")]
852    #[tokio::test]
853    async fn test_lock_raii_does_not_unlock_with_tokio_enabled() -> Result<()> {
854        let (_containers, addresses) = create_clients().await;
855
856        let rl1 = LockManager::new(addresses.clone());
857        let rl2 = LockManager::new(addresses.clone());
858        let key = rl1.get_unique_lock_id()?;
859
860        async {
861            //The acquire function is only enabled for `async-std-comp` ??
862            let lock_guard = rl1
863                .acquire(&key, Duration::from_millis(10_000))
864                .await
865                .expect("LockManage rl1 should be able to acquire lock");
866            let lock = &lock_guard.lock;
867            assert!(
868                lock.validity_time > 0,
869                "validity time: {}",
870                lock.validity_time
871            );
872
873            // Retry verifying the Redis key state up to 5 times with a 1000ms delay
874            let mut retries = 5;
875            let mut redis_key_verified = false;
876
877            while retries > 0 {
878                match rl1.query_redis_for_key_value(&key).await {
879                    Ok(Some(redis_val)) if redis_val == lock.val => {
880                        redis_key_verified = true;
881                        break;
882                    }
883                    Ok(Some(redis_val)) => {
884                        println!(
885                            "Redis key value mismatch. Expected: {:?}, Found: {:?}. Retrying...",
886                            lock.val, redis_val
887                        );
888                    }
889                    Ok(None) => println!("Redis key not found. Retrying..."),
890                    Err(e) => println!("Failed to query Redis key: {:?}. Retrying...", e),
891                }
892
893                retries -= 1;
894                tokio::time::sleep(Duration::from_millis(1000)).await;
895            }
896
897            // Acquire lock2 and assert it can't be acquired
898            if let Ok(_l) = rl2.lock(&key, Duration::from_millis(10_000)).await {
899                panic!("Lock acquired, even though it should be locked")
900            }
901
902            assert!(redis_key_verified);
903        }
904        .await;
905
906        if let Ok(_) = rl2.lock(&key, Duration::from_millis(10_000)).await {
907            panic!("Lock couldn't be acquired");
908        }
909
910        Ok(())
911    }
912
913    #[cfg(feature = "async-std-comp")]
914    #[tokio::test]
915    async fn test_lock_extend_lock() -> Result<()> {
916        let (_containers, addresses) = create_clients().await;
917
918        let rl1 = LockManager::new(addresses.clone());
919        let rl2 = LockManager::new(addresses.clone());
920
921        let key = rl1.get_unique_lock_id()?;
922
923        async {
924            let lock1 = rl1
925                .acquire(&key, Duration::from_millis(10_000))
926                .await
927                .unwrap();
928
929            // Wait half a second before locking again
930            tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
931
932            rl1.extend(&lock1.lock, Duration::from_millis(10_000))
933                .await
934                .unwrap();
935
936            // Wait another half a second to see if lock2 can unlock
937            tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
938
939            // Assert lock2 can't access after extended lock
940            match rl2.lock(&key, Duration::from_millis(10_000)).await {
941                Ok(_) => panic!("Expected an error when extending the lock but didn't receive one"),
942                Err(e) => match e {
943                    LockError::Unavailable => (),
944                    _ => panic!("Unexpected error when extending lock"),
945                },
946            }
947        }
948        .await;
949
950        Ok(())
951    }
952
953    #[cfg(feature = "async-std-comp")]
954    #[tokio::test]
955    async fn test_lock_extend_lock_releases() -> Result<()> {
956        let (_containers, addresses) = create_clients().await;
957
958        let rl1 = LockManager::new(addresses.clone());
959        let rl2 = LockManager::new(addresses.clone());
960
961        let key = rl1.get_unique_lock_id()?;
962
963        async {
964            // Create 500ms lock and immediately extend 500ms
965            let lock1 = rl1.acquire(&key, Duration::from_millis(500)).await.unwrap();
966            rl1.extend(&lock1.lock, Duration::from_millis(500))
967                .await
968                .unwrap();
969
970            // Wait one second for the lock to expire
971            tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
972
973            // Assert rl2 can lock with the key now
974            match rl2.lock(&key, Duration::from_millis(10_000)).await {
975                Err(_) => {
976                    panic!("Unexpected error when trying to claim free lock after extend expired")
977                }
978                _ => (),
979            }
980
981            // Also assert rl1 can't reuse lock1
982            match rl1.extend(&lock1.lock, Duration::from_millis(10_000)).await {
983                Ok(_) => panic!("Did not expect OK() when re-extending rl1"),
984                Err(e) => match e {
985                    LockError::Unavailable => (),
986                    _ => panic!("Expected lockError::Unavailable when re-extending rl1"),
987                },
988            }
989        }
990        .await;
991
992        Ok(())
993    }
994
995    #[tokio::test]
996    async fn test_lock_with_short_ttl_and_retries() -> Result<()> {
997        let (_containers, addresses) = create_clients().await;
998
999        let mut rl = LockManager::new(addresses.clone());
1000        // Set a high retry count to ensure retries happen
1001        rl.set_retry(10, Duration::from_millis(10)); // Retry 10 times with 10 milliseconds delay
1002
1003        let key = rl.get_unique_lock_id()?;
1004
1005        // Use a very short TTL
1006        let ttl = Duration::from_millis(1);
1007
1008        // Acquire lock
1009        let lock_result = rl.lock(&key, ttl).await;
1010
1011        // Check if the error returned is TtlExceeded
1012        match lock_result {
1013            Err(LockError::TtlExceeded) => (), // Test passes
1014            _ => panic!("Expected LockError::TtlExceeded, but got {:?}", lock_result),
1015        }
1016
1017        Ok(())
1018    }
1019
1020    #[tokio::test]
1021    async fn test_lock_ttl_duration_conversion_error() {
1022        let (_containers, addresses) = create_clients().await;
1023        let rl = LockManager::new(addresses.clone());
1024        let key = rl.get_unique_lock_id().unwrap();
1025
1026        // Too big Duration, fails - technical limit is from_millis(u64::MAX)
1027        let ttl = Duration::from_secs(u64::MAX);
1028        match rl.lock(&key, ttl).await {
1029            Ok(_) => panic!("Expected LockError::TtlTooLarge"),
1030            Err(_) => (), // Test passes
1031        }
1032    }
1033
1034    #[tokio::test]
1035    #[cfg(feature = "tokio-comp")]
1036    async fn test_lock_send_lock_manager() {
1037        let (_containers, addresses) = create_clients().await;
1038        let rl = LockManager::new(addresses.clone());
1039
1040        let lock = rl
1041            .lock(b"resource", std::time::Duration::from_millis(10_000))
1042            .await
1043            .unwrap();
1044
1045        // Send the lock and entry through the channel
1046        let (tx, mut rx) = tokio::sync::mpsc::channel(32);
1047        tx.send(("some info", lock, rl)).await.unwrap();
1048
1049        let j = tokio::spawn(async move {
1050            // Retrieve from channel and use
1051            if let Some((_entry, lock, rl)) = rx.recv().await {
1052                rl.unlock(&lock).await;
1053            }
1054        });
1055        let _ = j.await;
1056    }
1057
1058    #[tokio::test]
1059    #[cfg(feature = "tokio-comp")]
1060    async fn test_lock_state_in_multiple_threads() {
1061        let (_containers, addresses) = create_clients().await;
1062        let rl = LockManager::new(addresses.clone());
1063
1064        let lock1 = rl
1065            .lock(b"resource_1", std::time::Duration::from_millis(10_000))
1066            .await
1067            .unwrap();
1068
1069        let lock1 = Arc::new(lock1);
1070        // Send the lock and entry through the channel
1071        let (tx, mut rx) = tokio::sync::mpsc::channel(32);
1072        tx.send(("some info", lock1.clone(), rl.clone()))
1073            .await
1074            .unwrap();
1075
1076        let j = tokio::spawn(async move {
1077            // Retrieve from channel and use
1078            if let Some((_entry, lock1, rl)) = rx.recv().await {
1079                rl.unlock(&lock1).await;
1080            }
1081        });
1082        let _ = j.await;
1083
1084        match rl.is_freed(&lock1).await {
1085            Ok(freed) => assert!(freed, "Lock should be freed after unlock"),
1086            Err(LockError::RedisKeyNotFound) => {
1087                assert!(true, "RedisKeyNotFound is expected if key is missing")
1088            }
1089            Err(e) => panic!("Unexpected error: {:?}", e),
1090        };
1091
1092        let lock2 = rl
1093            .lock(b"resource_2", std::time::Duration::from_millis(10_000))
1094            .await
1095            .unwrap();
1096        rl.unlock(&lock2).await;
1097
1098        match rl.is_freed(&lock2).await {
1099            Ok(freed) => assert!(freed, "Lock should be freed after unlock"),
1100            Err(LockError::RedisKeyNotFound) => {
1101                assert!(true, "RedisKeyNotFound is expected if key is missing")
1102            }
1103            Err(e) => panic!("Unexpected error: {:?}", e),
1104        };
1105    }
1106
1107    #[tokio::test]
1108    async fn test_redis_value_matches_lock_value() {
1109        let (_containers, addresses) = create_clients().await;
1110        let rl = LockManager::new(addresses.clone());
1111
1112        let lock = rl
1113            .lock(b"resource_1", std::time::Duration::from_millis(10_000))
1114            .await
1115            .unwrap();
1116
1117        // Ensure Redis key is correctly set and matches the lock value
1118        let mut l = rl.lock_inner().await;
1119        let mut con = l.servers[0].get_connection().await.unwrap();
1120        let redis_val: Option<Vec<u8>> = redis::cmd("GET")
1121            .arg(&lock.resource)
1122            .query_async(&mut con)
1123            .await
1124            .unwrap();
1125
1126        eprintln!(
1127            "Debug: Expected value in Redis: {:?}, Actual value in Redis: {:?}",
1128            Some(lock.val.as_slice()),
1129            redis_val.as_deref()
1130        );
1131
1132        assert_eq!(
1133            redis_val.as_deref(),
1134            Some(lock.val.as_slice()),
1135            "Redis value should match lock value"
1136        );
1137    }
1138
1139    #[tokio::test]
1140    async fn test_is_not_freed_after_lock() {
1141        let (_containers, addresses) = create_clients().await;
1142        let rl = LockManager::new(addresses.clone());
1143
1144        let lock = rl
1145            .lock(b"resource_1", std::time::Duration::from_millis(10_000))
1146            .await
1147            .unwrap();
1148
1149        match rl.is_freed(&lock).await {
1150            Ok(freed) => assert!(!freed, "Lock should not be freed after it is acquired"),
1151            Err(LockError::RedisKeyMismatch) => {
1152                panic!("Redis key mismatch should not occur for a valid lock")
1153            }
1154            Err(LockError::RedisKeyNotFound) => {
1155                panic!("Redis key not found should not occur for a valid lock")
1156            }
1157            Err(e) => panic!("Unexpected error: {:?}", e),
1158        };
1159    }
1160
1161    #[tokio::test]
1162    async fn test_is_freed_after_manual_unlock() {
1163        let (_containers, addresses) = create_clients().await;
1164        let rl = LockManager::new(addresses.clone());
1165
1166        let lock = rl
1167            .lock(b"resource_2", std::time::Duration::from_millis(10_000))
1168            .await
1169            .unwrap();
1170
1171        rl.unlock(&lock).await;
1172
1173        match rl.is_freed(&lock).await {
1174            Ok(freed) => assert!(freed, "Lock should be freed after unlock"),
1175            Err(LockError::RedisKeyNotFound) => {
1176                assert!(true, "RedisKeyNotFound is expected if key is missing")
1177            }
1178            Err(e) => panic!("Unexpected error: {:?}", e),
1179        };
1180    }
1181
1182    #[tokio::test]
1183    async fn test_is_freed_when_key_missing_in_redis() {
1184        let (_containers, addresses) = create_clients().await;
1185        let rl = LockManager::new(addresses.clone());
1186
1187        let lock = rl
1188            .lock(b"resource_3", std::time::Duration::from_millis(10_000))
1189            .await
1190            .unwrap();
1191
1192        // Manually delete the key in Redis to simulate it being missing
1193        let mut l = rl.lock_inner().await;
1194        let mut con = l.servers[0].get_connection().await.unwrap();
1195        drop(l);
1196
1197        redis::cmd("DEL")
1198            .arg(&lock.resource)
1199            .query_async::<()>(&mut con)
1200            .await
1201            .unwrap();
1202
1203        match rl.is_freed(&lock).await {
1204            Ok(freed) => assert!(
1205                freed,
1206                "Lock should be marked as freed when key is missing in Redis"
1207            ),
1208            Err(LockError::RedisKeyNotFound) => assert!(
1209                true,
1210                "RedisKeyNotFound is expected when key is missing in Redis"
1211            ),
1212            Err(e) => panic!("Unexpected error: {:?}", e),
1213        };
1214    }
1215
1216    #[tokio::test]
1217    async fn test_is_freed_handles_redis_connection_failure() {
1218        let (_containers, _) = create_clients().await;
1219        let rl = LockManager::new(Vec::<String>::new()); // No Redis clients, simulate failure
1220
1221        let lock_result = rl
1222            .lock(b"resource_4", std::time::Duration::from_millis(10_000))
1223            .await;
1224
1225        match lock_result {
1226            Ok(lock) => {
1227                // Since there are no clients, any check with Redis will fail
1228                match rl.is_freed(&lock).await {
1229                    Ok(freed) => panic!("Expected failure due to Redis connection, but got Ok with freed status: {}", freed),
1230                    Err(LockError::RedisConnectionFailed) => assert!(true, "Expected RedisConnectionFailed when all Redis connections fail"),
1231                    Err(e) => panic!("Unexpected error: {:?}", e),
1232                }
1233            }
1234            Err(LockError::Unavailable) => {
1235                // Expected error, the test should pass in this scenario
1236                assert!(true);
1237            }
1238            Err(e) => panic!("Unexpected error while acquiring lock: {:?}", e),
1239        }
1240    }
1241
1242    #[tokio::test]
1243    async fn test_redis_connection_failed() {
1244        let (_containers, _) = create_clients().await;
1245        let rl = LockManager::new(Vec::<String>::new()); // No Redis clients, simulate failure
1246
1247        let lock_result = rl
1248            .lock(b"resource_5", std::time::Duration::from_millis(10_000))
1249            .await;
1250
1251        match lock_result {
1252            Ok(lock) => match rl.is_freed(&lock).await {
1253                Err(LockError::RedisConnectionFailed) => assert!(
1254                    true,
1255                    "Expected RedisConnectionFailed when all Redis connections fail"
1256                ),
1257                Ok(_) => panic!("Expected RedisConnectionFailed, but got Ok"),
1258                Err(e) => panic!("Unexpected error: {:?}", e),
1259            },
1260            Err(LockError::Unavailable) => {
1261                // Expected error, the test should pass in this scenario
1262                assert!(true);
1263            }
1264            Err(e) => panic!("Unexpected error while acquiring lock: {:?}", e),
1265        }
1266    }
1267
1268    #[tokio::test]
1269    async fn test_redis_key_mismatch() {
1270        let (_containers, addresses) = create_clients().await;
1271        let rl = LockManager::new(addresses.clone());
1272
1273        let lock = rl
1274            .lock(b"resource_6", std::time::Duration::from_millis(10_000))
1275            .await
1276            .unwrap();
1277
1278        // Set a different value for the same key to simulate a mismatch
1279        let mut l = rl.lock_inner().await;
1280        let mut con = l.servers[0].get_connection().await.unwrap();
1281        drop(l);
1282        let different_value: Vec<u8> = vec![1, 2, 3, 4, 5]; // Different value
1283        redis::cmd("SET")
1284            .arg(&lock.resource)
1285            .arg(different_value)
1286            .query_async::<()>(&mut con)
1287            .await
1288            .unwrap();
1289
1290        // Now check if is_freed identifies the mismatch correctly
1291        match rl.is_freed(&lock).await {
1292            Err(LockError::RedisKeyMismatch) => assert!(
1293                true,
1294                "Expected RedisKeyMismatch when key value does not match the lock value"
1295            ),
1296            Ok(_) => panic!("Expected RedisKeyMismatch, but got Ok"),
1297            Err(e) => panic!("Unexpected error: {:?}", e),
1298        }
1299    }
1300
1301    #[tokio::test]
1302    async fn test_redis_key_not_found() {
1303        let (_containers, addresses) = create_clients().await;
1304        let rl = LockManager::new(addresses.clone());
1305
1306        let lock = rl
1307            .lock(b"resource_7", std::time::Duration::from_millis(10_000))
1308            .await
1309            .unwrap();
1310
1311        // Manually delete the key in Redis to simulate it being missing
1312        let mut l = rl.lock_inner().await;
1313        let mut con = l.servers[0].get_connection().await.unwrap();
1314        drop(l);
1315        redis::cmd("DEL")
1316            .arg(&lock.resource)
1317            .query_async::<()>(&mut con)
1318            .await
1319            .unwrap();
1320
1321        match rl.is_freed(&lock).await {
1322            Err(LockError::RedisKeyNotFound) => assert!(
1323                true,
1324                "Expected RedisKeyNotFound when key is missing in Redis"
1325            ),
1326            Ok(_) => panic!("Expected RedisKeyNotFound, but got Ok"),
1327            Err(e) => panic!("Unexpected error: {:?}", e),
1328        }
1329    }
1330
1331    #[tokio::test]
1332    async fn test_lock_manager_from_clients_valid_instance() {
1333        let (_containers, addresses) = create_clients().await;
1334
1335        let clients: Vec<Client> = addresses
1336            .iter()
1337            .map(|uri| Client::open(uri.as_str()).unwrap())
1338            .collect();
1339
1340        let lock_manager = LockManager::from_clients(clients);
1341
1342        let l = lock_manager.lock_inner().await;
1343        assert_eq!(l.servers.len(), 3);
1344        assert_eq!(l.get_quorum(), 2);
1345    }
1346
1347    #[tokio::test]
1348    async fn test_lock_manager_from_clients_partial_quorum() {
1349        let (_containers, addresses) = create_clients().await;
1350        let mut clients: Vec<Client> = addresses
1351            .iter()
1352            .map(|uri| Client::open(uri.as_str()).unwrap())
1353            .collect();
1354
1355        // Remove one client to simulate fewer nodes
1356        clients.pop();
1357
1358        let lock_manager = LockManager::from_clients(clients);
1359
1360        let l = lock_manager.lock_inner().await;
1361        assert_eq!(l.servers.len(), 2);
1362        assert_eq!(l.get_quorum(), 2); // 2/2+1 still rounds to 2
1363    }
1364}