use std::time::Duration;
use tokio::time::sleep;
use zeph_common::SessionId;
use zeph_db::{DbPool, query, query_scalar, sql};
use crate::error::MemoryError;
const LOCK_TTL_SECS: i64 = 120;
const MAX_RETRIES: u32 = 3;
const BASE_BACKOFF_MS: u64 = 50;
pub struct EntityLockManager {
pool: DbPool,
session_id: SessionId,
}
impl EntityLockManager {
#[must_use]
pub fn new(pool: DbPool, session_id: impl Into<SessionId>) -> Self {
Self {
pool,
session_id: session_id.into(),
}
}
pub async fn try_acquire(&self, entity_name: &str) -> Result<bool, MemoryError> {
for attempt in 0..=MAX_RETRIES {
match self.try_acquire_once(entity_name).await? {
true => return Ok(true),
false if attempt == MAX_RETRIES => return Ok(false),
false => {
let backoff_ms = BASE_BACKOFF_MS * (1u64 << attempt);
sleep(Duration::from_millis(backoff_ms)).await;
}
}
}
Ok(false)
}
async fn try_acquire_once(&self, entity_name: &str) -> Result<bool, MemoryError> {
let acquired: bool = query_scalar(sql!(
"INSERT INTO entity_advisory_locks (entity_name, session_id, acquired_at, expires_at)
VALUES (?, ?, datetime('now'), datetime('now', ? || ' seconds'))
ON CONFLICT(entity_name) DO UPDATE SET
session_id = excluded.session_id,
acquired_at = excluded.acquired_at,
expires_at = excluded.expires_at
WHERE
-- reclaim if expired
entity_advisory_locks.expires_at < datetime('now')
OR
-- refresh if same session
entity_advisory_locks.session_id = excluded.session_id
RETURNING (session_id = ?) AS acquired"
))
.bind(entity_name)
.bind(self.session_id.as_str())
.bind(LOCK_TTL_SECS.to_string())
.bind(self.session_id.as_str())
.fetch_optional(self.pool())
.await?
.unwrap_or(false);
Ok(acquired)
}
pub async fn extend_lock(
&self,
entity_name: &str,
extra_secs: i64,
) -> Result<bool, MemoryError> {
let affected = query(sql!(
"UPDATE entity_advisory_locks
SET expires_at = datetime(expires_at, ? || ' seconds')
WHERE entity_name = ? AND session_id = ?"
))
.bind(extra_secs.to_string())
.bind(entity_name)
.bind(self.session_id.as_str())
.execute(self.pool())
.await?
.rows_affected();
Ok(affected > 0)
}
pub async fn release(&self, entity_name: &str) -> Result<(), MemoryError> {
query(sql!(
"DELETE FROM entity_advisory_locks
WHERE entity_name = ? AND session_id = ?"
))
.bind(entity_name)
.bind(self.session_id.as_str())
.execute(self.pool())
.await?;
Ok(())
}
pub async fn release_all(&self) -> Result<(), MemoryError> {
query(sql!(
"DELETE FROM entity_advisory_locks WHERE session_id = ?"
))
.bind(self.session_id.as_str())
.execute(self.pool())
.await?;
Ok(())
}
fn pool(&self) -> &DbPool {
&self.pool
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::DbStore;
async fn make_lock_manager(session_id: &str) -> EntityLockManager {
let store = DbStore::with_pool_size(":memory:", 1)
.await
.expect("in-memory store");
EntityLockManager::new(store.pool().clone(), session_id)
}
async fn make_shared_managers(
session_a: &str,
session_b: &str,
) -> (EntityLockManager, EntityLockManager) {
let store = DbStore::with_pool_size(":memory:", 2)
.await
.expect("in-memory store");
let pool = store.pool().clone();
(
EntityLockManager::new(pool.clone(), session_a),
EntityLockManager::new(pool, session_b),
)
}
#[tokio::test]
async fn try_acquire_succeeds_on_first_call() {
let mgr = make_lock_manager("session-a").await;
let acquired = mgr.try_acquire("entity::Foo").await.expect("try_acquire");
assert!(acquired);
}
#[tokio::test]
async fn try_acquire_same_session_refresh_succeeds() {
let mgr = make_lock_manager("session-a").await;
assert!(mgr.try_acquire("entity::Foo").await.expect("first"));
assert!(mgr.try_acquire("entity::Foo").await.expect("second"));
}
#[tokio::test]
async fn try_acquire_fails_when_held_by_different_session() {
let (a, b) = make_shared_managers("session-a", "session-b").await;
assert!(a.try_acquire("entity::Foo").await.expect("a acquires"));
let acquired = b.try_acquire("entity::Foo").await.expect("b tries");
assert!(
!acquired,
"session-b should not acquire a lock held by session-a"
);
}
#[tokio::test]
async fn expired_lock_is_reclaimed_by_new_session() {
let store = DbStore::with_pool_size(":memory:", 2)
.await
.expect("in-memory store");
let pool = store.pool().clone();
let b = EntityLockManager::new(pool.clone(), "session-b");
zeph_db::query(zeph_db::sql!(
"INSERT INTO entity_advisory_locks (entity_name, session_id, acquired_at, expires_at)
VALUES ('entity::Bar', 'session-a', datetime('now', '-200 seconds'), datetime('now', '-80 seconds'))"
))
.execute(&pool)
.await
.expect("insert expired lock");
let acquired = b.try_acquire("entity::Bar").await.expect("try_acquire");
assert!(acquired, "session-b should reclaim an expired lock");
}
#[tokio::test]
async fn release_clears_the_lock() {
let (a, b) = make_shared_managers("session-a", "session-b").await;
a.try_acquire("entity::Baz").await.expect("acquire");
a.release("entity::Baz").await.expect("release");
let acquired = b.try_acquire("entity::Baz").await.expect("b reacquire");
assert!(acquired);
}
#[tokio::test]
async fn release_is_noop_for_wrong_session() {
let (a, b) = make_shared_managers("session-a", "session-b").await;
assert!(a.try_acquire("entity::Qux").await.expect("a acquires"));
b.release("entity::Qux").await.expect("release noop");
let acquired = b.try_acquire("entity::Qux").await.expect("b tries");
assert!(!acquired);
}
#[tokio::test]
async fn release_all_removes_all_session_locks() {
let mgr = make_lock_manager("session-a").await;
mgr.try_acquire("entity::One").await.expect("one");
mgr.try_acquire("entity::Two").await.expect("two");
mgr.release_all().await.expect("release_all");
assert!(mgr.try_acquire("entity::One").await.expect("re-one"));
assert!(mgr.try_acquire("entity::Two").await.expect("re-two"));
}
#[tokio::test]
async fn extend_lock_returns_true_for_owner() {
let mgr = make_lock_manager("session-a").await;
mgr.try_acquire("entity::Ext").await.expect("acquire");
let extended = mgr.extend_lock("entity::Ext", 60).await.expect("extend");
assert!(extended);
}
#[tokio::test]
async fn extend_lock_returns_false_for_non_owner() {
let (a, b) = make_shared_managers("session-a", "session-b").await;
a.try_acquire("entity::Ext2").await.expect("a acquires");
let extended = b.extend_lock("entity::Ext2", 60).await.expect("b extend");
assert!(
!extended,
"non-owner session should not be able to extend lock"
);
}
}