use crate::config::ScheduleError;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScheduleLock {
pub task_name: String,
pub owner: String,
pub acquired_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub renewal_count: u32,
}
impl ScheduleLock {
pub fn new(task_name: String, owner: String, ttl_seconds: u64) -> Self {
let now = Utc::now();
Self {
task_name,
owner,
acquired_at: now,
expires_at: now + Duration::seconds(ttl_seconds as i64),
renewal_count: 0,
}
}
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
pub fn is_owned_by(&self, owner: &str) -> bool {
self.owner == owner
}
pub fn renew(&mut self, ttl_seconds: u64) -> Result<(), ScheduleError> {
if self.is_expired() {
return Err(ScheduleError::Invalid(
"Cannot renew expired lock".to_string(),
));
}
self.expires_at = Utc::now() + Duration::seconds(ttl_seconds as i64);
self.renewal_count += 1;
Ok(())
}
pub fn ttl(&self) -> Duration {
self.expires_at - Utc::now()
}
pub fn age(&self) -> Duration {
Utc::now() - self.acquired_at
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LockManager {
pub(crate) locks: HashMap<String, ScheduleLock>,
pub(crate) default_ttl: u64,
}
impl LockManager {
pub fn new(default_ttl: u64) -> Self {
Self {
locks: HashMap::new(),
default_ttl,
}
}
pub fn try_acquire(
&mut self,
task_name: &str,
owner: &str,
ttl: Option<u64>,
) -> Result<bool, ScheduleError> {
self.cleanup_expired();
if let Some(existing_lock) = self.locks.get(task_name) {
if !existing_lock.is_expired() {
if !existing_lock.is_owned_by(owner) {
return Ok(false);
}
return Ok(true);
}
}
let ttl_seconds = ttl.unwrap_or(self.default_ttl);
let lock = ScheduleLock::new(task_name.to_string(), owner.to_string(), ttl_seconds);
self.locks.insert(task_name.to_string(), lock);
Ok(true)
}
pub fn release(&mut self, task_name: &str, owner: &str) -> Result<bool, ScheduleError> {
if let Some(lock) = self.locks.get(task_name) {
if lock.is_owned_by(owner) {
self.locks.remove(task_name);
return Ok(true);
}
}
Ok(false)
}
pub fn renew(
&mut self,
task_name: &str,
owner: &str,
ttl: Option<u64>,
) -> Result<bool, ScheduleError> {
if let Some(lock) = self.locks.get_mut(task_name) {
if lock.is_owned_by(owner) && !lock.is_expired() {
let ttl_seconds = ttl.unwrap_or(self.default_ttl);
lock.renew(ttl_seconds)?;
return Ok(true);
}
}
Ok(false)
}
pub fn is_locked(&self, task_name: &str) -> bool {
if let Some(lock) = self.locks.get(task_name) {
!lock.is_expired()
} else {
false
}
}
pub fn get_lock(&self, task_name: &str) -> Option<&ScheduleLock> {
self.locks.get(task_name)
}
pub fn cleanup_expired(&mut self) {
self.locks.retain(|_, lock| !lock.is_expired());
}
pub fn get_active_locks(&self) -> Vec<&ScheduleLock> {
self.locks
.values()
.filter(|lock| !lock.is_expired())
.collect()
}
pub fn release_all(&mut self) {
self.locks.clear();
}
}
impl Default for LockManager {
fn default() -> Self {
Self::new(300) }
}
use async_trait::async_trait;
use celers_core::lock::DistributedLockBackend;
use std::sync::Arc;
use std::time::{Duration as StdDuration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
struct LockEntry {
owner: String,
acquired_at: Instant,
ttl: StdDuration,
}
impl LockEntry {
fn is_expired(&self) -> bool {
self.acquired_at.elapsed() >= self.ttl
}
}
#[derive(Debug, Clone)]
pub struct InMemoryLockBackend {
locks: Arc<RwLock<HashMap<String, LockEntry>>>,
}
impl InMemoryLockBackend {
pub fn new() -> Self {
Self {
locks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn active_lock_count(&self) -> usize {
let locks = self.locks.read().await;
locks.values().filter(|e| !e.is_expired()).count()
}
}
impl Default for InMemoryLockBackend {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl DistributedLockBackend for InMemoryLockBackend {
async fn try_acquire(
&self,
key: &str,
owner: &str,
ttl_secs: u64,
) -> celers_core::error::Result<bool> {
let mut locks = self.locks.write().await;
locks.retain(|_, entry| !entry.is_expired());
if let Some(existing) = locks.get(key) {
if !existing.is_expired() {
if existing.owner == owner {
return Ok(true);
}
return Ok(false);
}
}
locks.insert(
key.to_string(),
LockEntry {
owner: owner.to_string(),
acquired_at: Instant::now(),
ttl: StdDuration::from_secs(ttl_secs),
},
);
Ok(true)
}
async fn release(&self, key: &str, owner: &str) -> celers_core::error::Result<bool> {
let mut locks = self.locks.write().await;
if let Some(entry) = locks.get(key) {
if entry.owner == owner {
locks.remove(key);
return Ok(true);
}
}
Ok(false)
}
async fn renew(
&self,
key: &str,
owner: &str,
ttl_secs: u64,
) -> celers_core::error::Result<bool> {
let mut locks = self.locks.write().await;
if let Some(entry) = locks.get_mut(key) {
if entry.owner == owner && !entry.is_expired() {
entry.acquired_at = Instant::now();
entry.ttl = StdDuration::from_secs(ttl_secs);
return Ok(true);
}
}
Ok(false)
}
async fn is_locked(&self, key: &str) -> celers_core::error::Result<bool> {
let locks = self.locks.read().await;
if let Some(entry) = locks.get(key) {
Ok(!entry.is_expired())
} else {
Ok(false)
}
}
async fn owner(&self, key: &str) -> celers_core::error::Result<Option<String>> {
let locks = self.locks.read().await;
if let Some(entry) = locks.get(key) {
if !entry.is_expired() {
return Ok(Some(entry.owner.clone()));
}
}
Ok(None)
}
async fn release_all(&self, owner: &str) -> celers_core::error::Result<u64> {
let mut locks = self.locks.write().await;
let before = locks.len();
locks.retain(|_, entry| entry.owner != owner);
let after = locks.len();
Ok((before - after) as u64)
}
}
#[cfg(test)]
mod in_memory_lock_tests {
use super::*;
#[tokio::test]
async fn test_acquire_and_release() {
let backend = InMemoryLockBackend::new();
let acquired = backend.try_acquire("task1", "owner1", 300).await;
assert!(acquired.is_ok());
assert!(acquired.as_ref().is_ok_and(|v| *v));
let locked = backend.is_locked("task1").await;
assert!(locked.is_ok_and(|v| v));
let owner = backend.owner("task1").await;
assert!(owner.is_ok());
assert_eq!(
owner.as_ref().ok().and_then(|o| o.as_deref()),
Some("owner1")
);
let released = backend.release("task1", "owner1").await;
assert!(released.is_ok_and(|v| v));
let locked = backend.is_locked("task1").await;
assert!(locked.is_ok_and(|v| !v));
}
#[tokio::test]
async fn test_acquire_fails_for_different_owner() {
let backend = InMemoryLockBackend::new();
let acquired = backend.try_acquire("task1", "owner1", 300).await;
assert!(acquired.is_ok_and(|v| v));
let acquired2 = backend.try_acquire("task1", "owner2", 300).await;
assert!(acquired2.is_ok_and(|v| !v));
}
#[tokio::test]
async fn test_same_owner_can_reacquire() {
let backend = InMemoryLockBackend::new();
let acquired = backend.try_acquire("task1", "owner1", 300).await;
assert!(acquired.is_ok_and(|v| v));
let acquired2 = backend.try_acquire("task1", "owner1", 300).await;
assert!(acquired2.is_ok_and(|v| v));
}
#[tokio::test]
async fn test_release_wrong_owner_fails() {
let backend = InMemoryLockBackend::new();
let _ = backend.try_acquire("task1", "owner1", 300).await;
let released = backend.release("task1", "owner2").await;
assert!(released.is_ok_and(|v| !v));
let locked = backend.is_locked("task1").await;
assert!(locked.is_ok_and(|v| v));
}
#[tokio::test]
async fn test_renew() {
let backend = InMemoryLockBackend::new();
let _ = backend.try_acquire("task1", "owner1", 300).await;
let renewed = backend.renew("task1", "owner1", 600).await;
assert!(renewed.is_ok_and(|v| v));
let renewed2 = backend.renew("task1", "owner2", 600).await;
assert!(renewed2.is_ok_and(|v| !v));
}
#[tokio::test]
async fn test_renew_nonexistent_fails() {
let backend = InMemoryLockBackend::new();
let renewed = backend.renew("nonexistent", "owner1", 600).await;
assert!(renewed.is_ok_and(|v| !v));
}
#[tokio::test]
async fn test_expiry() {
let backend = InMemoryLockBackend::new();
let _ = backend.try_acquire("task1", "owner1", 0).await;
let locked = backend.is_locked("task1").await;
assert!(locked.is_ok_and(|v| !v));
let acquired = backend.try_acquire("task1", "owner2", 300).await;
assert!(acquired.is_ok_and(|v| v));
}
#[tokio::test]
async fn test_release_all() {
let backend = InMemoryLockBackend::new();
let _ = backend.try_acquire("task1", "owner1", 300).await;
let _ = backend.try_acquire("task2", "owner1", 300).await;
let _ = backend.try_acquire("task3", "owner2", 300).await;
let count = backend.release_all("owner1").await;
assert!(count.is_ok());
assert_eq!(count.ok(), Some(2));
let locked = backend.is_locked("task3").await;
assert!(locked.is_ok_and(|v| v));
let locked1 = backend.is_locked("task1").await;
assert!(locked1.is_ok_and(|v| !v));
}
#[tokio::test]
async fn test_owner_of_nonexistent() {
let backend = InMemoryLockBackend::new();
let owner = backend.owner("nonexistent").await;
assert!(owner.is_ok());
assert_eq!(owner.ok().flatten(), None);
}
#[tokio::test]
async fn test_release_nonexistent() {
let backend = InMemoryLockBackend::new();
let released = backend.release("nonexistent", "owner1").await;
assert!(released.is_ok_and(|v| !v));
}
#[tokio::test]
async fn test_active_lock_count() {
let backend = InMemoryLockBackend::new();
let _ = backend.try_acquire("task1", "owner1", 300).await;
let _ = backend.try_acquire("task2", "owner1", 300).await;
assert_eq!(backend.active_lock_count().await, 2);
let _ = backend.release("task1", "owner1").await;
assert_eq!(backend.active_lock_count().await, 1);
}
#[tokio::test]
async fn test_multiple_keys_independent() {
let backend = InMemoryLockBackend::new();
let _ = backend.try_acquire("task_a", "owner1", 300).await;
let _ = backend.try_acquire("task_b", "owner2", 300).await;
assert!(backend.is_locked("task_a").await.is_ok_and(|v| v));
assert!(backend.is_locked("task_b").await.is_ok_and(|v| v));
let owner_a = backend.owner("task_a").await;
let owner_b = backend.owner("task_b").await;
assert_eq!(owner_a.ok().flatten().as_deref(), Some("owner1"));
assert_eq!(owner_b.ok().flatten().as_deref(), Some("owner2"));
}
}