use crate::{TaskId, TaskResult};
use async_trait::async_trait;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LockToken(String);
impl LockToken {
pub fn generate() -> Self {
Self(Uuid::new_v4().to_string())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[async_trait]
pub trait TaskLock: Send + Sync {
async fn acquire(&self, task_id: TaskId, ttl: Duration) -> TaskResult<Option<LockToken>>;
async fn release(&self, task_id: TaskId, token: &LockToken) -> TaskResult<bool>;
async fn is_locked(&self, task_id: TaskId) -> TaskResult<bool>;
async fn extend(&self, task_id: TaskId, token: &LockToken, ttl: Duration) -> TaskResult<bool> {
if self.is_locked(task_id).await? {
let released = self.release(task_id, token).await?;
if !released {
return Ok(false);
}
self.acquire(task_id, ttl).await.map(|t| t.is_some())
} else {
Ok(false)
}
}
}
pub struct MemoryTaskLock {
locks: Arc<RwLock<std::collections::HashMap<TaskId, (i128, String)>>>,
}
impl MemoryTaskLock {
pub fn new() -> Self {
Self {
locks: Arc::new(RwLock::new(std::collections::HashMap::new())),
}
}
async fn cleanup_expired(&self) {
let mut locks = self.locks.write().await;
let now = chrono::Utc::now().timestamp_millis() as i128;
locks.retain(|_, (expiry, _)| *expiry > now);
}
}
impl Default for MemoryTaskLock {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl TaskLock for MemoryTaskLock {
async fn acquire(&self, task_id: TaskId, ttl: Duration) -> TaskResult<Option<LockToken>> {
if ttl.is_zero() {
return Ok(None);
}
self.cleanup_expired().await;
let mut locks = self.locks.write().await;
let now = chrono::Utc::now().timestamp_millis() as i128;
let expiry = now + ttl.as_millis() as i128;
if let Some(&(existing_expiry, _)) = locks.get(&task_id)
&& existing_expiry > now
{
return Ok(None);
}
let token = LockToken::generate();
locks.insert(task_id, (expiry, token.as_str().to_string()));
Ok(Some(token))
}
async fn release(&self, task_id: TaskId, token: &LockToken) -> TaskResult<bool> {
let mut locks = self.locks.write().await;
if let Some((_, stored_token)) = locks.get(&task_id)
&& stored_token == token.as_str()
{
locks.remove(&task_id);
return Ok(true);
}
Ok(false)
}
async fn is_locked(&self, task_id: TaskId) -> TaskResult<bool> {
self.cleanup_expired().await;
let locks = self.locks.read().await;
let now = chrono::Utc::now().timestamp_millis() as i128;
Ok(locks
.get(&task_id)
.map(|(expiry, _)| *expiry > now)
.unwrap_or(false))
}
async fn extend(&self, task_id: TaskId, token: &LockToken, ttl: Duration) -> TaskResult<bool> {
let mut locks = self.locks.write().await;
let now = chrono::Utc::now().timestamp_millis() as i128;
if let Some((expiry, stored_token)) = locks.get_mut(&task_id)
&& *expiry > now
&& stored_token.as_str() == token.as_str()
{
*expiry = now + ttl.as_millis() as i128;
return Ok(true);
}
Ok(false)
}
}
#[cfg(feature = "redis-backend")]
pub struct RedisTaskLock {
connection: Arc<redis::aio::ConnectionManager>,
key_prefix: String,
}
#[cfg(feature = "redis-backend")]
impl RedisTaskLock {
pub async fn new(redis_url: &str) -> Result<Self, redis::RedisError> {
let client = redis::Client::open(redis_url)?;
let connection = redis::aio::ConnectionManager::new(client).await?;
Ok(Self {
connection: Arc::new(connection),
key_prefix: "reinhardt:locks:".to_string(),
})
}
pub async fn with_prefix(
redis_url: &str,
key_prefix: String,
) -> Result<Self, redis::RedisError> {
let client = redis::Client::open(redis_url)?;
let connection = redis::aio::ConnectionManager::new(client).await?;
Ok(Self {
connection: Arc::new(connection),
key_prefix,
})
}
fn lock_key(&self, task_id: TaskId) -> String {
format!("{}task:{}", self.key_prefix, task_id)
}
}
#[cfg(feature = "redis-backend")]
fn validate_ttl_ms(ttl: Duration) -> TaskResult<i64> {
use crate::TaskError;
if ttl.is_zero() {
return Err(TaskError::ExecutionFailed(
"TTL must be greater than zero".to_string(),
));
}
i64::try_from(ttl.as_millis()).map_err(|_| {
TaskError::ExecutionFailed(format!(
"TTL overflow: {} ms exceeds i64::MAX",
ttl.as_millis()
))
})
}
#[cfg(feature = "redis-backend")]
#[async_trait]
impl TaskLock for RedisTaskLock {
async fn acquire(&self, task_id: TaskId, ttl: Duration) -> TaskResult<Option<LockToken>> {
use crate::TaskError;
let ttl_ms = validate_ttl_ms(ttl)?;
let mut conn = (*self.connection).clone();
let key = self.lock_key(task_id);
let token = LockToken::generate();
let result: Result<Option<String>, redis::RedisError> = redis::cmd("SET")
.arg(&key)
.arg(token.as_str())
.arg("PX")
.arg(ttl_ms)
.arg("NX")
.query_async(&mut conn)
.await;
match result {
Ok(Some(_)) => Ok(Some(token)),
Ok(None) => Ok(None),
Err(e) => Err(TaskError::ExecutionFailed(format!(
"Failed to acquire lock: {}",
e
))),
}
}
async fn release(&self, task_id: TaskId, token: &LockToken) -> TaskResult<bool> {
use crate::TaskError;
let mut conn = (*self.connection).clone();
let key = self.lock_key(task_id);
let script = redis::Script::new(
"if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end",
);
let result: Result<i32, redis::RedisError> = script
.key(&key)
.arg(token.as_str())
.invoke_async(&mut conn)
.await;
match result {
Ok(1) => Ok(true),
Ok(_) => Ok(false),
Err(e) => Err(TaskError::ExecutionFailed(format!(
"Failed to release lock: {}",
e
))),
}
}
async fn is_locked(&self, task_id: TaskId) -> TaskResult<bool> {
use crate::TaskError;
use redis::AsyncCommands;
let mut conn = (*self.connection).clone();
let key = self.lock_key(task_id);
let result: Result<bool, redis::RedisError> = conn.exists(&key).await;
result.map_err(|e| TaskError::ExecutionFailed(format!("Failed to check lock: {}", e)))
}
async fn extend(&self, task_id: TaskId, token: &LockToken, ttl: Duration) -> TaskResult<bool> {
use crate::TaskError;
let ttl_ms = validate_ttl_ms(ttl)?;
let mut conn = (*self.connection).clone();
let key = self.lock_key(task_id);
let script = redis::Script::new(
"if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('pexpire', KEYS[1], ARGV[2]) else return 0 end",
);
let result: Result<i32, redis::RedisError> = script
.key(&key)
.arg(token.as_str())
.arg(ttl_ms)
.invoke_async(&mut conn)
.await;
match result {
Ok(1) => Ok(true),
Ok(_) => Ok(false),
Err(e) => Err(TaskError::ExecutionFailed(format!(
"Failed to extend lock: {}",
e
))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use std::time::Duration;
#[rstest]
#[tokio::test]
async fn test_memory_lock_acquire() {
let lock = MemoryTaskLock::new();
let task_id = TaskId::new();
let token = lock
.acquire(task_id, Duration::from_secs(60))
.await
.unwrap();
assert!(token.is_some());
}
#[rstest]
#[tokio::test]
async fn test_memory_lock_already_locked() {
let lock = MemoryTaskLock::new();
let task_id = TaskId::new();
lock.acquire(task_id, Duration::from_secs(60))
.await
.unwrap();
let token = lock
.acquire(task_id, Duration::from_secs(60))
.await
.unwrap();
assert!(token.is_none());
}
#[rstest]
#[tokio::test]
async fn test_memory_lock_release() {
let lock = MemoryTaskLock::new();
let task_id = TaskId::new();
let token = lock
.acquire(task_id, Duration::from_secs(60))
.await
.unwrap()
.unwrap();
let released = lock.release(task_id, &token).await.unwrap();
assert!(released);
let is_locked = lock.is_locked(task_id).await.unwrap();
assert!(!is_locked);
}
#[rstest]
#[tokio::test]
async fn test_memory_lock_release_wrong_token() {
let lock = MemoryTaskLock::new();
let task_id = TaskId::new();
lock.acquire(task_id, Duration::from_secs(60))
.await
.unwrap();
let wrong_token = LockToken::generate();
let released = lock.release(task_id, &wrong_token).await.unwrap();
assert!(!released);
let is_locked = lock.is_locked(task_id).await.unwrap();
assert!(is_locked);
}
#[rstest]
#[tokio::test]
async fn test_memory_lock_expiry() {
let lock = MemoryTaskLock::new();
let task_id = TaskId::new();
lock.acquire(task_id, Duration::from_millis(50))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let is_locked = lock.is_locked(task_id).await.unwrap();
assert!(!is_locked);
}
#[rstest]
#[tokio::test]
async fn test_memory_lock_extend() {
let lock = MemoryTaskLock::new();
let task_id = TaskId::new();
let token = lock
.acquire(task_id, Duration::from_secs(60))
.await
.unwrap()
.unwrap();
let extended = lock
.extend(task_id, &token, Duration::from_secs(120))
.await
.unwrap();
assert!(extended);
let is_locked = lock.is_locked(task_id).await.unwrap();
assert!(is_locked);
}
#[rstest]
#[tokio::test]
async fn test_memory_lock_extend_returns_false_for_unlocked_task() {
let lock = MemoryTaskLock::new();
let task_id = TaskId::new();
let token = LockToken::generate();
let extended = lock
.extend(task_id, &token, Duration::from_secs(120))
.await
.unwrap();
assert!(!extended);
}
#[rstest]
#[tokio::test]
async fn test_memory_lock_extend_returns_false_for_expired_lock() {
let lock = MemoryTaskLock::new();
let task_id = TaskId::new();
let token = lock
.acquire(task_id, Duration::from_millis(50))
.await
.unwrap()
.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let extended = lock
.extend(task_id, &token, Duration::from_secs(120))
.await
.unwrap();
assert!(!extended);
}
#[rstest]
#[tokio::test]
async fn test_memory_lock_extend_returns_false_for_wrong_token() {
let lock = MemoryTaskLock::new();
let task_id = TaskId::new();
lock.acquire(task_id, Duration::from_secs(60))
.await
.unwrap();
let wrong_token = LockToken::generate();
let extended = lock
.extend(task_id, &wrong_token, Duration::from_secs(120))
.await
.unwrap();
assert!(!extended);
}
#[rstest]
#[tokio::test]
async fn test_memory_lock_extend_is_atomic() {
let lock = Arc::new(MemoryTaskLock::new());
let task_id = TaskId::new();
let token = lock
.acquire(task_id, Duration::from_millis(200))
.await
.unwrap()
.unwrap();
let extended = lock
.extend(task_id, &token, Duration::from_secs(60))
.await
.unwrap();
assert!(extended);
let second_acquire = lock
.acquire(task_id, Duration::from_secs(60))
.await
.unwrap();
assert!(second_acquire.is_none());
}
}