use std::io;
use std::sync::Arc;
use std::time::{Duration, Instant};
use futures::future::join_all;
use rand::{rng, Rng, RngCore};
use redis::aio::MultiplexedConnection;
use redis::Value::Okay;
use redis::{Client, IntoConnectionInfo, RedisError, RedisResult, Value};
use crate::resource::{LockResource, ToLockResource};
const DEFAULT_RETRY_COUNT: u32 = 3;
const DEFAULT_RETRY_DELAY: Duration = Duration::from_millis(200);
const CLOCK_DRIFT_FACTOR: f32 = 0.01;
const UNLOCK_SCRIPT: &str = r#"
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
else
return 0
end
"#;
const EXTEND_SCRIPT: &str = r#"
if redis.call("get", KEYS[1]) ~= ARGV[1] then
return 0
else
if redis.call("set", KEYS[1], ARGV[1], "PX", ARGV[2]) ~= nil then
return 1
else
return 0
end
end
"#;
#[derive(Debug, thiserror::Error)]
pub enum LockError {
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("Redis error: {0}")]
Redis(#[from] redis::RedisError),
#[error("Resource is unavailable")]
Unavailable,
#[error("TTL exceeded")]
TtlExceeded,
#[error("TTL too large")]
TtlTooLarge,
#[error("Redis connection failed for all servers")]
RedisConnectionFailed,
#[error("Redis connection failed.")]
RedisFailedToEstablishConnection,
#[error("Redis key mismatch: expected value does not match actual value")]
RedisKeyMismatch,
#[error("Redis key not found")]
RedisKeyNotFound,
#[error("A mutex was poisoned")]
MutexPoisoned,
}
type Mutex<T> = tokio::sync::Mutex<T>;
type MutexGuard<'a, K> = tokio::sync::MutexGuard<'a, K>;
#[derive(Debug, Clone)]
pub struct LockManager {
lock_manager_inner: Arc<Mutex<LockManagerInner>>,
retry_count: u32,
retry_delay: Duration,
}
#[derive(Debug, Clone)]
struct LockManagerInner {
pub servers: Vec<RestorableConnection>,
}
impl LockManagerInner {
fn get_quorum(&self) -> u32 {
(self.servers.len() as u32) / 2 + 1
}
}
#[derive(Debug, Clone)]
struct RestorableConnection {
client: Client,
con: Arc<Mutex<Option<MultiplexedConnection>>>,
}
impl RestorableConnection {
pub fn new(client: Client) -> Self {
Self {
client,
con: Arc::new(tokio::sync::Mutex::new(None)),
}
}
pub async fn get_connection(&mut self) -> Result<MultiplexedConnection, LockError> {
let mut lock = self.con.lock().await;
if lock.is_none() {
*lock = Some(
self.client
.get_multiplexed_async_connection()
.await
.map_err(LockError::Redis)?,
);
}
match (*lock).clone() {
Some(conn) => Ok(conn),
None => Err(LockError::RedisFailedToEstablishConnection),
}
}
pub async fn recover(&mut self, error: RedisError) -> Result<(), LockError> {
if !error.is_unrecoverable_error() {
Ok(())
} else {
let mut lock = self.con.lock().await;
*lock = Some(
self.client
.get_multiplexed_async_connection()
.await
.map_err(LockError::Redis)?,
);
Ok(())
}
}
}
impl RestorableConnection {
async fn lock(&mut self, resource: &LockResource<'_>, val: &[u8], ttl: usize) -> bool {
let mut con = match self.get_connection().await {
Err(_) => return false,
Ok(val) => val,
};
let result: RedisResult<Value> = redis::cmd("SET")
.arg(resource)
.arg(val)
.arg("NX")
.arg("PX")
.arg(ttl)
.query_async(&mut con)
.await;
match result {
Ok(Okay) => true,
Ok(_) => false,
Err(e) => {
let _ = self.recover(e).await;
false
}
}
}
async fn extend(&mut self, resource: &LockResource<'_>, val: &[u8], ttl: usize) -> bool {
let mut con = match self.get_connection().await {
Err(_) => return false,
Ok(val) => val,
};
let script = redis::Script::new(EXTEND_SCRIPT);
let result: RedisResult<i32> = script
.key(resource)
.arg(val)
.arg(ttl)
.invoke_async(&mut con)
.await;
match result {
Ok(val) => val == 1,
Err(e) => {
let _ = self.recover(e).await;
false
}
}
}
async fn unlock(&mut self, resource: impl ToLockResource<'_>, val: &[u8]) -> bool {
let resource = resource.to_lock_resource();
let mut con = match self.get_connection().await {
Err(_) => return false,
Ok(val) => val,
};
let script = redis::Script::new(UNLOCK_SCRIPT);
let result: RedisResult<i32> = script.key(resource).arg(val).invoke_async(&mut con).await;
match result {
Ok(val) => val == 1,
Err(e) => {
let _ = self.recover(e).await;
false
}
}
}
async fn query(&mut self, resource: &[u8]) -> RedisResult<Option<Vec<u8>>> {
let mut con = match self.get_connection().await {
Ok(con) => con,
Err(_e) => return Ok(None),
};
let result: RedisResult<Option<Vec<u8>>> =
redis::cmd("GET").arg(resource).query_async(&mut con).await;
result
}
}
#[derive(Debug)]
pub struct Lock {
pub resource: Vec<u8>,
pub val: Vec<u8>,
pub validity_time: usize,
pub lock_manager: LockManager,
}
#[derive(Debug)]
pub struct LockGuard {
pub lock: Lock,
}
enum Operation {
Lock,
Extend,
}
#[cfg(not(feature = "tokio-comp"))]
impl Drop for LockGuard {
fn drop(&mut self) {
futures::executor::block_on(self.lock.lock_manager.unlock(&self.lock));
}
}
impl LockManager {
pub fn new<T: IntoConnectionInfo>(uris: Vec<T>) -> LockManager {
let servers: Vec<Client> = uris
.into_iter()
.map(|uri| Client::open(uri).unwrap())
.collect();
Self::from_clients(servers)
}
pub fn from_clients(clients: Vec<Client>) -> LockManager {
let clients: Vec<RestorableConnection> =
clients.into_iter().map(RestorableConnection::new).collect();
LockManager {
lock_manager_inner: Arc::new(Mutex::new(LockManagerInner { servers: clients })),
retry_count: DEFAULT_RETRY_COUNT,
retry_delay: DEFAULT_RETRY_DELAY,
}
}
pub fn get_unique_lock_id(&self) -> io::Result<Vec<u8>> {
let mut buf = [0u8; 20];
rng().fill_bytes(&mut buf);
Ok(buf.to_vec())
}
pub fn set_retry(&mut self, count: u32, delay: Duration) {
self.retry_count = count;
self.retry_delay = delay;
}
async fn lock_inner(&self) -> MutexGuard<'_, LockManagerInner> {
self.lock_manager_inner.lock().await
}
async fn exec_or_retry(
&self,
resource: impl ToLockResource<'_>,
value: &[u8],
ttl: usize,
function: Operation,
) -> Result<Lock, LockError> {
let mut current_try = 1;
let resource = &resource.to_lock_resource();
loop {
let start_time = Instant::now();
let l = self.lock_inner().await;
let mut servers = l.servers.clone();
drop(l);
let n = match function {
Operation::Lock => {
join_all(servers.iter_mut().map(|c| c.lock(resource, value, ttl))).await
}
Operation::Extend => {
join_all(servers.iter_mut().map(|c| c.extend(resource, value, ttl))).await
}
}
.into_iter()
.fold(0, |count, locked| if locked { count + 1 } else { count });
let drift = (ttl as f32 * CLOCK_DRIFT_FACTOR) as usize + 2;
let elapsed = start_time.elapsed();
let elapsed_ms =
elapsed.as_secs() as usize * 1000 + elapsed.subsec_nanos() as usize / 1_000_000;
if ttl <= drift + elapsed_ms {
return Err(LockError::TtlExceeded);
}
let validity_time = ttl
- drift
- elapsed.as_secs() as usize * 1000
- elapsed.subsec_nanos() as usize / 1_000_000;
let l = self.lock_inner().await;
if n >= l.get_quorum() && validity_time > 0 {
return Ok(Lock {
lock_manager: self.clone(),
resource: resource.to_vec(),
val: value.to_vec(),
validity_time,
});
}
let mut servers = l.servers.clone();
drop(l);
join_all(
servers
.iter_mut()
.map(|client| client.unlock(resource, value)),
)
.await;
if current_try < self.retry_count {
current_try += 1;
let retry_delay: u64 = self
.retry_delay
.as_millis()
.try_into()
.map_err(|_| LockError::TtlTooLarge)?;
let n = rng().random_range(0..retry_delay);
tokio::time::sleep(Duration::from_millis(n)).await
} else {
break;
}
}
Err(LockError::Unavailable)
}
pub async fn query_redis_for_key_value(
&self,
resource: &[u8],
) -> Result<Option<Vec<u8>>, LockError> {
let l = self.lock_inner().await;
let mut servers = l.servers.clone();
drop(l);
let results = join_all(servers.iter_mut().map(|c| c.query(resource))).await;
if let Some(value) = results.into_iter().find_map(Result::ok) {
return Ok(value);
}
Err(LockError::RedisConnectionFailed) }
pub async fn unlock(&self, lock: &Lock) {
let l = self.lock_inner().await;
let mut servers = l.servers.clone();
drop(l);
join_all(
servers
.iter_mut()
.map(|client| client.unlock(&*lock.resource, &lock.val)),
)
.await;
}
pub async fn lock(
&self,
resource: impl ToLockResource<'_>,
ttl: Duration,
) -> Result<Lock, LockError> {
let resource = resource.to_lock_resource();
let val = self.get_unique_lock_id().map_err(LockError::Io)?;
let ttl = ttl
.as_millis()
.try_into()
.map_err(|_| LockError::TtlTooLarge)?;
self.exec_or_retry(&resource, &val.clone(), ttl, Operation::Lock)
.await
}
#[cfg(feature = "async-std-comp")]
pub async fn acquire(
&self,
resource: impl ToLockResource<'_>,
ttl: Duration,
) -> Result<LockGuard, LockError> {
let lock = self.acquire_no_guard(resource, ttl).await?;
Ok(LockGuard { lock })
}
pub async fn acquire_no_guard(
&self,
resource: impl ToLockResource<'_>,
ttl: Duration,
) -> Result<Lock, LockError> {
let resource = &resource.to_lock_resource();
loop {
match self.lock(resource, ttl).await {
Ok(lock) => return Ok(lock),
Err(LockError::TtlTooLarge) => return Err(LockError::TtlTooLarge),
Err(_) => continue,
}
}
}
pub async fn extend(&self, lock: &Lock, ttl: Duration) -> Result<Lock, LockError> {
let ttl = ttl
.as_millis()
.try_into()
.map_err(|_| LockError::TtlTooLarge)?;
self.exec_or_retry(&*lock.resource, &lock.val, ttl, Operation::Extend)
.await
}
pub async fn is_freed(&self, lock: &Lock) -> Result<bool, LockError> {
match self.query_redis_for_key_value(&lock.resource).await? {
Some(val) => {
if val != lock.val {
Err(LockError::RedisKeyMismatch)
} else {
Ok(false) }
}
None => Err(LockError::RedisKeyNotFound), }
}
#[cfg(feature = "tokio-comp")]
pub async fn using<R>(
&self,
resource: &[u8],
ttl: Duration,
routine: impl AsyncFnOnce() -> R,
) -> Result<R, LockError> {
let mut lock = self.acquire_no_guard(resource, ttl).await?;
let mut threshold = lock.validity_time as u64 - 500;
let routine = routine();
futures::pin_mut!(routine);
loop {
match tokio::time::timeout(Duration::from_millis(threshold), &mut routine).await {
Ok(result) => {
self.unlock(&lock).await;
return Ok(result);
}
Err(_) => {
lock = self.extend(&lock, ttl).await?;
threshold = lock.validity_time as u64 - 500;
}
}
}
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use testcontainers::{
core::{IntoContainerPort, WaitFor},
runners::AsyncRunner,
ContainerAsync, GenericImage,
};
use tokio::time::Duration;
use super::*;
type Containers = Vec<ContainerAsync<GenericImage>>;
async fn create_clients() -> (Containers, Vec<String>) {
let mut containers = Vec::new();
let mut addresses = Vec::new();
for _ in 1..=3 {
let container = GenericImage::new("redis", "7")
.with_exposed_port(6379.tcp())
.with_wait_for(WaitFor::message_on_stdout("Ready to accept connections"))
.start()
.await
.expect("Failed to start Redis container");
let port = container
.get_host_port_ipv4(6379)
.await
.expect("Failed to get port");
let address = format!("redis://localhost:{}", port);
containers.push(container);
addresses.push(address);
}
ensure_redis_readiness(&addresses)
.await
.expect("Redis instances are not ready");
(containers, addresses)
}
async fn ensure_redis_readiness(
addresses: &[String],
) -> Result<(), Box<dyn std::error::Error>> {
for address in addresses {
let client = Client::open(address.as_str())?;
let mut retries = 120;
while retries > 0 {
match client.get_multiplexed_async_connection().await {
Ok(mut con) => match redis::cmd("PING").query_async::<String>(&mut con).await {
Ok(response) => {
eprintln!("Redis {} is ready: {}", address, response);
break; }
Err(e) => {
eprintln!("Redis {} is not ready: {:?}", address, e);
}
},
Err(e) => eprintln!("Failed to connect to Redis {}: {:?}", address, e),
}
retries -= 1;
tokio::time::sleep(Duration::from_secs(1)).await;
}
if retries == 0 {
return Err(format!("Redis {} did not become ready after retries", address).into());
}
}
Ok(())
}
fn is_normal<T: Sized + Send + Sync + Unpin>() {}
#[test]
fn test_is_normal() {
is_normal::<LockManager>();
is_normal::<LockError>();
is_normal::<Lock>();
is_normal::<LockGuard>();
}
#[tokio::test]
async fn test_lock_get_unique_id() -> Result<()> {
let rl = LockManager::new(Vec::<String>::new());
assert_eq!(rl.get_unique_lock_id()?.len(), 20);
Ok(())
}
#[tokio::test]
async fn test_lock_get_unique_id_uniqueness() -> Result<()> {
let rl = LockManager::new(Vec::<String>::new());
let id1 = rl.get_unique_lock_id()?;
let id2 = rl.get_unique_lock_id()?;
assert_eq!(20, id1.len());
assert_eq!(20, id2.len());
assert_ne!(id1, id2);
Ok(())
}
#[tokio::test]
async fn test_lock_valid_instance() {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let l = rl.lock_inner().await;
assert_eq!(3, l.servers.len());
assert_eq!(2, l.get_quorum());
}
#[tokio::test]
async fn test_lock_direct_unlock_fails() -> Result<()> {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let key = rl.get_unique_lock_id()?;
let val = rl.get_unique_lock_id()?;
let mut l = rl.lock_inner().await;
assert!(!l.servers[0].unlock(&key, &val).await);
Ok(())
}
#[tokio::test]
async fn test_lock_direct_unlock_succeeds() -> Result<()> {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let key = rl.get_unique_lock_id()?;
let val = rl.get_unique_lock_id()?;
let mut l = rl.lock_inner().await;
let mut con = l.servers[0].get_connection().await?;
redis::cmd("SET")
.arg(&*key)
.arg(&*val)
.exec_async(&mut con)
.await?;
assert!(l.servers[0].unlock(&key, &val).await);
Ok(())
}
#[tokio::test]
async fn test_lock_direct_lock_succeeds() -> Result<()> {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let key = rl.get_unique_lock_id()?;
let resource = key.to_lock_resource();
let val = rl.get_unique_lock_id()?;
let mut l = rl.lock_inner().await;
let mut con = l.servers[0].get_connection().await?;
redis::cmd("DEL").arg(&*key).exec_async(&mut con).await?;
assert!(l.servers[0].lock(&resource, &val, 10_000).await);
Ok(())
}
#[tokio::test]
async fn test_lock_unlock() -> Result<()> {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let key = rl.get_unique_lock_id()?;
let val = rl.get_unique_lock_id()?;
let mut l = rl.lock_inner().await;
let mut con = l.servers[0].get_connection().await?;
drop(l);
let _: () = redis::cmd("SET")
.arg(&*key)
.arg(&*val)
.query_async(&mut con)
.await?;
let lock = Lock {
lock_manager: rl.clone(),
resource: key,
val,
validity_time: 0,
};
rl.unlock(&lock).await;
Ok(())
}
#[tokio::test]
async fn test_lock_lock() -> Result<()> {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let key = rl.get_unique_lock_id()?;
match rl.lock(&key, Duration::from_millis(10_000)).await {
Ok(lock) => {
assert_eq!(key, lock.resource);
assert_eq!(20, lock.val.len());
assert!(
lock.validity_time > 0,
"validity time: {}",
lock.validity_time
);
}
Err(e) => panic!("{:?}", e),
}
Ok(())
}
#[tokio::test]
async fn test_lock_lock_unlock() -> Result<()> {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let rl2 = LockManager::new(addresses.clone());
let key = rl.get_unique_lock_id()?;
let lock = rl.lock(&key, Duration::from_millis(10_000)).await.unwrap();
assert!(
lock.validity_time > 0,
"validity time: {}",
lock.validity_time
);
if let Ok(_l) = rl2.lock(&key, Duration::from_millis(10_000)).await {
panic!("Lock acquired, even though it should be locked")
}
rl.unlock(&lock).await;
match rl2.lock(&key, Duration::from_millis(10_000)).await {
Ok(l) => assert!(l.validity_time > 0),
Err(_) => panic!("Lock couldn't be acquired"),
}
Ok(())
}
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
#[tokio::test]
async fn test_lock_lock_unlock_raii() -> Result<()> {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let rl2 = LockManager::new(addresses.clone());
let key = rl.get_unique_lock_id()?;
async {
let lock_guard = rl
.acquire(&key, Duration::from_millis(10_000))
.await
.unwrap();
let lock = &lock_guard.lock;
assert!(
lock.validity_time > 0,
"validity time: {}",
lock.validity_time
);
if let Ok(_l) = rl2.lock(&key, Duration::from_millis(10_000)).await {
panic!("Lock acquired, even though it should be locked")
}
}
.await;
match rl2.lock(&key, Duration::from_millis(10_000)).await {
Ok(l) => assert!(l.validity_time > 0),
Err(_) => panic!("Lock couldn't be acquired"),
}
Ok(())
}
#[cfg(feature = "tokio-comp")]
#[tokio::test]
async fn test_lock_raii_does_not_unlock_with_tokio_enabled() -> Result<()> {
let (_containers, addresses) = create_clients().await;
let rl1 = LockManager::new(addresses.clone());
let rl2 = LockManager::new(addresses.clone());
let key = rl1.get_unique_lock_id()?;
async {
let lock_guard = rl1
.acquire(&key, Duration::from_millis(10_000))
.await
.expect("LockManage rl1 should be able to acquire lock");
let lock = &lock_guard.lock;
assert!(
lock.validity_time > 0,
"validity time: {}",
lock.validity_time
);
let mut retries = 5;
let mut redis_key_verified = false;
while retries > 0 {
match rl1.query_redis_for_key_value(&key).await {
Ok(Some(redis_val)) if redis_val == lock.val => {
redis_key_verified = true;
break;
}
Ok(Some(redis_val)) => {
println!(
"Redis key value mismatch. Expected: {:?}, Found: {:?}. Retrying...",
lock.val, redis_val
);
}
Ok(None) => println!("Redis key not found. Retrying..."),
Err(e) => println!("Failed to query Redis key: {:?}. Retrying...", e),
}
retries -= 1;
tokio::time::sleep(Duration::from_millis(1000)).await;
}
if let Ok(_l) = rl2.lock(&key, Duration::from_millis(10_000)).await {
panic!("Lock acquired, even though it should be locked")
}
assert!(redis_key_verified);
}
.await;
if let Ok(_) = rl2.lock(&key, Duration::from_millis(10_000)).await {
panic!("Lock couldn't be acquired");
}
Ok(())
}
#[cfg(feature = "async-std-comp")]
#[tokio::test]
async fn test_lock_extend_lock() -> Result<()> {
let (_containers, addresses) = create_clients().await;
let rl1 = LockManager::new(addresses.clone());
let rl2 = LockManager::new(addresses.clone());
let key = rl1.get_unique_lock_id()?;
async {
let lock1 = rl1
.acquire(&key, Duration::from_millis(10_000))
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
rl1.extend(&lock1.lock, Duration::from_millis(10_000))
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
match rl2.lock(&key, Duration::from_millis(10_000)).await {
Ok(_) => panic!("Expected an error when extending the lock but didn't receive one"),
Err(e) => match e {
LockError::Unavailable => (),
_ => panic!("Unexpected error when extending lock"),
},
}
}
.await;
Ok(())
}
#[cfg(feature = "async-std-comp")]
#[tokio::test]
async fn test_lock_extend_lock_releases() -> Result<()> {
let (_containers, addresses) = create_clients().await;
let rl1 = LockManager::new(addresses.clone());
let rl2 = LockManager::new(addresses.clone());
let key = rl1.get_unique_lock_id()?;
async {
let lock1 = rl1.acquire(&key, Duration::from_millis(500)).await.unwrap();
rl1.extend(&lock1.lock, Duration::from_millis(500))
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
match rl2.lock(&key, Duration::from_millis(10_000)).await {
Err(_) => {
panic!("Unexpected error when trying to claim free lock after extend expired")
}
_ => (),
}
match rl1.extend(&lock1.lock, Duration::from_millis(10_000)).await {
Ok(_) => panic!("Did not expect OK() when re-extending rl1"),
Err(e) => match e {
LockError::Unavailable => (),
_ => panic!("Expected lockError::Unavailable when re-extending rl1"),
},
}
}
.await;
Ok(())
}
#[tokio::test]
async fn test_lock_with_short_ttl_and_retries() -> Result<()> {
let (_containers, addresses) = create_clients().await;
let mut rl = LockManager::new(addresses.clone());
rl.set_retry(10, Duration::from_millis(10));
let key = rl.get_unique_lock_id()?;
let ttl = Duration::from_millis(1);
let lock_result = rl.lock(&key, ttl).await;
match lock_result {
Err(LockError::TtlExceeded) => (), _ => panic!("Expected LockError::TtlExceeded, but got {:?}", lock_result),
}
Ok(())
}
#[tokio::test]
async fn test_lock_ttl_duration_conversion_error() {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let key = rl.get_unique_lock_id().unwrap();
let ttl = Duration::from_secs(u64::MAX);
match rl.lock(&key, ttl).await {
Ok(_) => panic!("Expected LockError::TtlTooLarge"),
Err(_) => (), }
}
#[tokio::test]
#[cfg(feature = "tokio-comp")]
async fn test_lock_send_lock_manager() {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let lock = rl
.lock(b"resource", std::time::Duration::from_millis(10_000))
.await
.unwrap();
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
tx.send(("some info", lock, rl)).await.unwrap();
let j = tokio::spawn(async move {
if let Some((_entry, lock, rl)) = rx.recv().await {
rl.unlock(&lock).await;
}
});
let _ = j.await;
}
#[tokio::test]
#[cfg(feature = "tokio-comp")]
async fn test_lock_state_in_multiple_threads() {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let lock1 = rl
.lock(b"resource_1", std::time::Duration::from_millis(10_000))
.await
.unwrap();
let lock1 = Arc::new(lock1);
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
tx.send(("some info", lock1.clone(), rl.clone()))
.await
.unwrap();
let j = tokio::spawn(async move {
if let Some((_entry, lock1, rl)) = rx.recv().await {
rl.unlock(&lock1).await;
}
});
let _ = j.await;
match rl.is_freed(&lock1).await {
Ok(freed) => assert!(freed, "Lock should be freed after unlock"),
Err(LockError::RedisKeyNotFound) => {
assert!(true, "RedisKeyNotFound is expected if key is missing")
}
Err(e) => panic!("Unexpected error: {:?}", e),
};
let lock2 = rl
.lock(b"resource_2", std::time::Duration::from_millis(10_000))
.await
.unwrap();
rl.unlock(&lock2).await;
match rl.is_freed(&lock2).await {
Ok(freed) => assert!(freed, "Lock should be freed after unlock"),
Err(LockError::RedisKeyNotFound) => {
assert!(true, "RedisKeyNotFound is expected if key is missing")
}
Err(e) => panic!("Unexpected error: {:?}", e),
};
}
#[tokio::test]
async fn test_redis_value_matches_lock_value() {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let lock = rl
.lock(b"resource_1", std::time::Duration::from_millis(10_000))
.await
.unwrap();
let mut l = rl.lock_inner().await;
let mut con = l.servers[0].get_connection().await.unwrap();
let redis_val: Option<Vec<u8>> = redis::cmd("GET")
.arg(&lock.resource)
.query_async(&mut con)
.await
.unwrap();
eprintln!(
"Debug: Expected value in Redis: {:?}, Actual value in Redis: {:?}",
Some(lock.val.as_slice()),
redis_val.as_deref()
);
assert_eq!(
redis_val.as_deref(),
Some(lock.val.as_slice()),
"Redis value should match lock value"
);
}
#[tokio::test]
async fn test_is_not_freed_after_lock() {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let lock = rl
.lock(b"resource_1", std::time::Duration::from_millis(10_000))
.await
.unwrap();
match rl.is_freed(&lock).await {
Ok(freed) => assert!(!freed, "Lock should not be freed after it is acquired"),
Err(LockError::RedisKeyMismatch) => {
panic!("Redis key mismatch should not occur for a valid lock")
}
Err(LockError::RedisKeyNotFound) => {
panic!("Redis key not found should not occur for a valid lock")
}
Err(e) => panic!("Unexpected error: {:?}", e),
};
}
#[tokio::test]
async fn test_is_freed_after_manual_unlock() {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let lock = rl
.lock(b"resource_2", std::time::Duration::from_millis(10_000))
.await
.unwrap();
rl.unlock(&lock).await;
match rl.is_freed(&lock).await {
Ok(freed) => assert!(freed, "Lock should be freed after unlock"),
Err(LockError::RedisKeyNotFound) => {
assert!(true, "RedisKeyNotFound is expected if key is missing")
}
Err(e) => panic!("Unexpected error: {:?}", e),
};
}
#[tokio::test]
async fn test_is_freed_when_key_missing_in_redis() {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let lock = rl
.lock(b"resource_3", std::time::Duration::from_millis(10_000))
.await
.unwrap();
let mut l = rl.lock_inner().await;
let mut con = l.servers[0].get_connection().await.unwrap();
drop(l);
redis::cmd("DEL")
.arg(&lock.resource)
.query_async::<()>(&mut con)
.await
.unwrap();
match rl.is_freed(&lock).await {
Ok(freed) => assert!(
freed,
"Lock should be marked as freed when key is missing in Redis"
),
Err(LockError::RedisKeyNotFound) => assert!(
true,
"RedisKeyNotFound is expected when key is missing in Redis"
),
Err(e) => panic!("Unexpected error: {:?}", e),
};
}
#[tokio::test]
async fn test_is_freed_handles_redis_connection_failure() {
let (_containers, _) = create_clients().await;
let rl = LockManager::new(Vec::<String>::new());
let lock_result = rl
.lock(b"resource_4", std::time::Duration::from_millis(10_000))
.await;
match lock_result {
Ok(lock) => {
match rl.is_freed(&lock).await {
Ok(freed) => panic!("Expected failure due to Redis connection, but got Ok with freed status: {}", freed),
Err(LockError::RedisConnectionFailed) => assert!(true, "Expected RedisConnectionFailed when all Redis connections fail"),
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
Err(LockError::Unavailable) => {
assert!(true);
}
Err(e) => panic!("Unexpected error while acquiring lock: {:?}", e),
}
}
#[tokio::test]
async fn test_redis_connection_failed() {
let (_containers, _) = create_clients().await;
let rl = LockManager::new(Vec::<String>::new());
let lock_result = rl
.lock(b"resource_5", std::time::Duration::from_millis(10_000))
.await;
match lock_result {
Ok(lock) => match rl.is_freed(&lock).await {
Err(LockError::RedisConnectionFailed) => assert!(
true,
"Expected RedisConnectionFailed when all Redis connections fail"
),
Ok(_) => panic!("Expected RedisConnectionFailed, but got Ok"),
Err(e) => panic!("Unexpected error: {:?}", e),
},
Err(LockError::Unavailable) => {
assert!(true);
}
Err(e) => panic!("Unexpected error while acquiring lock: {:?}", e),
}
}
#[tokio::test]
async fn test_redis_key_mismatch() {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let lock = rl
.lock(b"resource_6", std::time::Duration::from_millis(10_000))
.await
.unwrap();
let mut l = rl.lock_inner().await;
let mut con = l.servers[0].get_connection().await.unwrap();
drop(l);
let different_value: Vec<u8> = vec![1, 2, 3, 4, 5]; redis::cmd("SET")
.arg(&lock.resource)
.arg(different_value)
.query_async::<()>(&mut con)
.await
.unwrap();
match rl.is_freed(&lock).await {
Err(LockError::RedisKeyMismatch) => assert!(
true,
"Expected RedisKeyMismatch when key value does not match the lock value"
),
Ok(_) => panic!("Expected RedisKeyMismatch, but got Ok"),
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
#[tokio::test]
async fn test_redis_key_not_found() {
let (_containers, addresses) = create_clients().await;
let rl = LockManager::new(addresses.clone());
let lock = rl
.lock(b"resource_7", std::time::Duration::from_millis(10_000))
.await
.unwrap();
let mut l = rl.lock_inner().await;
let mut con = l.servers[0].get_connection().await.unwrap();
drop(l);
redis::cmd("DEL")
.arg(&lock.resource)
.query_async::<()>(&mut con)
.await
.unwrap();
match rl.is_freed(&lock).await {
Err(LockError::RedisKeyNotFound) => assert!(
true,
"Expected RedisKeyNotFound when key is missing in Redis"
),
Ok(_) => panic!("Expected RedisKeyNotFound, but got Ok"),
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
#[tokio::test]
async fn test_lock_manager_from_clients_valid_instance() {
let (_containers, addresses) = create_clients().await;
let clients: Vec<Client> = addresses
.iter()
.map(|uri| Client::open(uri.as_str()).unwrap())
.collect();
let lock_manager = LockManager::from_clients(clients);
let l = lock_manager.lock_inner().await;
assert_eq!(l.servers.len(), 3);
assert_eq!(l.get_quorum(), 2);
}
#[tokio::test]
async fn test_lock_manager_from_clients_partial_quorum() {
let (_containers, addresses) = create_clients().await;
let mut clients: Vec<Client> = addresses
.iter()
.map(|uri| Client::open(uri.as_str()).unwrap())
.collect();
clients.pop();
let lock_manager = LockManager::from_clients(clients);
let l = lock_manager.lock_inner().await;
assert_eq!(l.servers.len(), 2);
assert_eq!(l.get_quorum(), 2); }
}