#[cfg(feature = "redis")]
pub mod redis;
#[cfg(feature = "mongodb")]
pub mod mongo;
use async_trait::async_trait;
use klauthed_core::id::Id;
use klauthed_core::time::Duration;
use klauthed_core::time::{Clock, SystemClock, Timestamp};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::error::DataError;
pub struct LockTokenTag;
pub type LockToken = Id<LockTokenTag>;
type LockTable = Mutex<HashMap<String, (LockToken, Timestamp)>>;
#[async_trait]
pub trait LockManager: Send + Sync {
async fn acquire(&self, key: &str, ttl: Duration) -> Result<Option<LockGuard>, DataError>;
}
enum LockBackend {
InMemory(Arc<LockTable>),
#[cfg(feature = "redis")]
Redis(self::redis::RedisLockManager),
#[cfg(feature = "mongodb")]
Mongo(self::mongo::MongoLockManager),
}
pub struct LockGuard {
key: String,
token: LockToken,
backend: LockBackend,
released: bool,
}
impl LockGuard {
fn in_memory(key: String, token: LockToken, table: Arc<LockTable>) -> Self {
Self { key, token, backend: LockBackend::InMemory(table), released: false }
}
#[cfg(feature = "redis")]
pub(crate) fn redis(
key: String,
token: LockToken,
manager: self::redis::RedisLockManager,
) -> Self {
Self { key, token, backend: LockBackend::Redis(manager), released: false }
}
#[cfg(feature = "mongodb")]
pub(crate) fn mongo(
key: String,
token: LockToken,
manager: self::mongo::MongoLockManager,
) -> Self {
Self { key, token, backend: LockBackend::Mongo(manager), released: false }
}
pub fn key(&self) -> &str {
&self.key
}
pub fn token(&self) -> LockToken {
self.token
}
pub async fn release(mut self) -> Result<(), DataError> {
if self.released {
return Ok(());
}
self.released = true;
match &self.backend {
LockBackend::InMemory(table) => {
Self::release_in_memory(table, &self.key, self.token);
Ok(())
}
#[cfg(feature = "redis")]
LockBackend::Redis(manager) => {
manager.release_token(&self.key, self.token).await?;
Ok(())
}
#[cfg(feature = "mongodb")]
LockBackend::Mongo(manager) => {
manager.release_token(&self.key, self.token).await?;
Ok(())
}
}
}
fn release_in_memory(table: &LockTable, key: &str, token: LockToken) {
let mut guard = table.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some((holder, _)) = guard.get(key)
&& *holder == token
{
guard.remove(key);
}
}
}
impl Drop for LockGuard {
fn drop(&mut self) {
if self.released {
return;
}
self.released = true;
match &self.backend {
LockBackend::InMemory(table) => {
Self::release_in_memory(table, &self.key, self.token);
}
#[cfg(feature = "redis")]
LockBackend::Redis(_) => {
tracing::debug!(
key = %self.key,
"redis lock guard dropped without explicit release; relying on TTL expiry"
);
}
#[cfg(feature = "mongodb")]
LockBackend::Mongo(_) => {
tracing::debug!(
key = %self.key,
"mongodb lock guard dropped without explicit release; relying on TTL expiry"
);
}
}
}
}
pub struct InMemoryLockManager {
table: Arc<LockTable>,
clock: Arc<dyn Clock>,
}
impl InMemoryLockManager {
pub fn new(clock: Arc<dyn Clock>) -> Self {
Self { table: Arc::new(Mutex::new(HashMap::new())), clock }
}
}
impl Default for InMemoryLockManager {
fn default() -> Self {
Self::new(Arc::new(SystemClock))
}
}
#[async_trait]
impl LockManager for InMemoryLockManager {
async fn acquire(&self, key: &str, ttl: Duration) -> Result<Option<LockGuard>, DataError> {
let now = self.clock.now();
let expires_at = now
.checked_add(ttl)
.ok_or_else(|| DataError::LockHeld(format!("invalid TTL for lock '{key}'")))?;
let mut guard = self.table.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
let live_holder = guard.get(key).is_some_and(|(_, holder_expiry)| now < *holder_expiry);
if live_holder {
return Ok(None);
}
let token = LockToken::new();
guard.insert(key.to_owned(), (token, expires_at));
drop(guard);
Ok(Some(LockGuard::in_memory(key.to_owned(), token, Arc::clone(&self.table))))
}
}
#[cfg(test)]
mod tests {
use super::*;
use klauthed_core::time::FixedClock;
fn manager_with(clock: Arc<FixedClock>) -> InMemoryLockManager {
InMemoryLockManager::new(clock)
}
#[tokio::test]
async fn second_acquire_is_blocked_while_held() {
let clock = Arc::new(FixedClock::at_unix_millis(0));
let locks = manager_with(clock);
let first =
locks.acquire("k", Duration::seconds(30)).await.unwrap().expect("first acquire wins");
assert_eq!(first.key(), "k");
assert!(locks.acquire("k", Duration::seconds(30)).await.unwrap().is_none());
}
#[tokio::test]
async fn lock_releases_on_drop() {
let clock = Arc::new(FixedClock::at_unix_millis(0));
let locks = manager_with(clock);
{
let _guard = locks.acquire("k", Duration::seconds(30)).await.unwrap().unwrap();
assert!(locks.acquire("k", Duration::seconds(30)).await.unwrap().is_none());
}
assert!(locks.acquire("k", Duration::seconds(30)).await.unwrap().is_some());
}
#[tokio::test]
async fn explicit_release_frees_the_lock() {
let clock = Arc::new(FixedClock::at_unix_millis(0));
let locks = manager_with(clock);
let guard = locks.acquire("k", Duration::seconds(30)).await.unwrap().unwrap();
guard.release().await.unwrap();
assert!(locks.acquire("k", Duration::seconds(30)).await.unwrap().is_some());
}
#[tokio::test]
async fn lock_expires_after_ttl() {
let clock = Arc::new(FixedClock::at_unix_millis(0));
let locks = manager_with(Arc::clone(&clock));
let guard = locks.acquire("k", Duration::seconds(10)).await.unwrap().unwrap();
std::mem::forget(guard);
clock.advance(Duration::seconds(5));
assert!(locks.acquire("k", Duration::seconds(10)).await.unwrap().is_none());
clock.advance(Duration::seconds(6));
assert!(locks.acquire("k", Duration::seconds(10)).await.unwrap().is_some());
}
#[tokio::test]
async fn stale_guard_release_does_not_steal_new_holder() {
let clock = Arc::new(FixedClock::at_unix_millis(0));
let locks = manager_with(Arc::clone(&clock));
let stale = locks.acquire("k", Duration::seconds(10)).await.unwrap().unwrap();
clock.advance(Duration::seconds(11));
let fresh = locks.acquire("k", Duration::seconds(10)).await.unwrap().unwrap();
drop(stale);
assert!(locks.acquire("k", Duration::seconds(10)).await.unwrap().is_none());
drop(fresh);
assert!(locks.acquire("k", Duration::seconds(10)).await.unwrap().is_some());
}
}