use async_trait::async_trait;
use klauthed_core::time::Duration;
use redis::aio::ConnectionManager;
use redis::{ExistenceCheck, SetExpiry, SetOptions};
use crate::error::DataError;
use crate::locks::{LockGuard, LockManager, LockToken};
const RELEASE_SCRIPT: &str = "\
if redis.call('GET', KEYS[1]) == ARGV[1] then
return redis.call('DEL', KEYS[1])
else
return 0
end";
#[derive(Clone)]
pub struct RedisLockManager {
conn: ConnectionManager,
}
impl RedisLockManager {
pub fn new(conn: ConnectionManager) -> Self {
Self { conn }
}
pub async fn release_token(&self, key: &str, token: LockToken) -> Result<bool, DataError> {
let mut conn = self.conn.clone();
let deleted: i64 = redis::Script::new(RELEASE_SCRIPT)
.key(key)
.arg(token.to_string())
.invoke_async(&mut conn)
.await?;
Ok(deleted == 1)
}
}
#[async_trait]
impl LockManager for RedisLockManager {
async fn acquire(&self, key: &str, ttl: Duration) -> Result<Option<LockGuard>, DataError> {
let ttl_ms: u64 = ttl.whole_milliseconds().try_into().map_err(|_| {
DataError::LockHeld(format!("invalid (non-positive) TTL for lock '{key}'"))
})?;
if ttl_ms == 0 {
return Err(DataError::LockHeld(format!("invalid (zero) TTL for lock '{key}'")));
}
let token = LockToken::new();
let options = SetOptions::default()
.conditional_set(ExistenceCheck::NX)
.with_expiration(SetExpiry::PX(ttl_ms));
let mut conn = self.conn.clone();
let outcome: Option<String> = redis::cmd("SET")
.arg(key)
.arg(token.to_string())
.arg(&options)
.query_async(&mut conn)
.await?;
match outcome {
Some(_) => Ok(Some(LockGuard::redis(key.to_owned(), token, self.clone()))),
None => Ok(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn live_manager() -> RedisLockManager {
let url = std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1/".to_owned());
let client = redis::Client::open(url).expect("open redis client");
let conn = ConnectionManager::new(client).await.expect("connect redis");
RedisLockManager::new(conn)
}
#[tokio::test]
#[ignore = "requires a live Redis at REDIS_URL"]
async fn acquire_blocks_until_released() {
let locks = live_manager().await;
let key = format!("klauthed:test:lock:{}", LockToken::new());
let guard =
locks.acquire(&key, Duration::seconds(30)).await.unwrap().expect("first acquire wins");
assert!(locks.acquire(&key, Duration::seconds(30)).await.unwrap().is_none());
guard.release().await.unwrap();
assert!(locks.acquire(&key, Duration::seconds(30)).await.unwrap().is_some());
}
#[tokio::test]
#[ignore = "requires a live Redis at REDIS_URL"]
async fn stale_token_release_does_not_steal() {
let locks = live_manager().await;
let key = format!("klauthed:test:lock:{}", LockToken::new());
let stale = locks.acquire(&key, Duration::milliseconds(50)).await.unwrap().unwrap();
let stale_token = stale.token();
tokio::time::sleep(std::time::Duration::from_millis(120)).await;
let _fresh = locks
.acquire(&key, Duration::seconds(30))
.await
.unwrap()
.expect("fresh acquire after expiry");
let freed = locks.release_token(&key, stale_token).await.unwrap();
assert!(!freed);
std::mem::forget(stale);
}
}