1use crate::{TaskId, TaskResult};
7use async_trait::async_trait;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct LockToken(String);
29
30impl LockToken {
31 pub fn generate() -> Self {
33 Self(Uuid::new_v4().to_string())
34 }
35
36 pub fn as_str(&self) -> &str {
38 &self.0
39 }
40}
41
42#[async_trait]
72pub trait TaskLock: Send + Sync {
73 async fn acquire(&self, task_id: TaskId, ttl: Duration) -> TaskResult<Option<LockToken>>;
78
79 async fn release(&self, task_id: TaskId, token: &LockToken) -> TaskResult<bool>;
84
85 async fn is_locked(&self, task_id: TaskId) -> TaskResult<bool>;
87
88 async fn extend(&self, task_id: TaskId, token: &LockToken, ttl: Duration) -> TaskResult<bool> {
94 if self.is_locked(task_id).await? {
97 let released = self.release(task_id, token).await?;
98 if !released {
99 return Ok(false);
101 }
102 self.acquire(task_id, ttl).await.map(|t| t.is_some())
103 } else {
104 Ok(false)
105 }
106 }
107}
108
109pub struct MemoryTaskLock {
136 locks: Arc<RwLock<std::collections::HashMap<TaskId, (i128, String)>>>,
138}
139
140impl MemoryTaskLock {
141 pub fn new() -> Self {
151 Self {
152 locks: Arc::new(RwLock::new(std::collections::HashMap::new())),
153 }
154 }
155
156 async fn cleanup_expired(&self) {
158 let mut locks = self.locks.write().await;
159 let now = chrono::Utc::now().timestamp_millis() as i128;
160 locks.retain(|_, (expiry, _)| *expiry > now);
161 }
162}
163
164impl Default for MemoryTaskLock {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170#[async_trait]
171impl TaskLock for MemoryTaskLock {
172 async fn acquire(&self, task_id: TaskId, ttl: Duration) -> TaskResult<Option<LockToken>> {
173 if ttl.is_zero() {
176 return Ok(None);
177 }
178
179 self.cleanup_expired().await;
180
181 let mut locks = self.locks.write().await;
182 let now = chrono::Utc::now().timestamp_millis() as i128;
183 let expiry = now + ttl.as_millis() as i128;
184
185 if let Some(&(existing_expiry, _)) = locks.get(&task_id)
186 && existing_expiry > now
187 {
188 return Ok(None);
189 }
190
191 let token = LockToken::generate();
192 locks.insert(task_id, (expiry, token.as_str().to_string()));
193 Ok(Some(token))
194 }
195
196 async fn release(&self, task_id: TaskId, token: &LockToken) -> TaskResult<bool> {
197 let mut locks = self.locks.write().await;
198 if let Some((_, stored_token)) = locks.get(&task_id)
199 && stored_token == token.as_str()
200 {
201 locks.remove(&task_id);
202 return Ok(true);
203 }
204 Ok(false)
205 }
206
207 async fn is_locked(&self, task_id: TaskId) -> TaskResult<bool> {
208 self.cleanup_expired().await;
209
210 let locks = self.locks.read().await;
211 let now = chrono::Utc::now().timestamp_millis() as i128;
212
213 Ok(locks
214 .get(&task_id)
215 .map(|(expiry, _)| *expiry > now)
216 .unwrap_or(false))
217 }
218
219 async fn extend(&self, task_id: TaskId, token: &LockToken, ttl: Duration) -> TaskResult<bool> {
225 let mut locks = self.locks.write().await;
226 let now = chrono::Utc::now().timestamp_millis() as i128;
227
228 if let Some((expiry, stored_token)) = locks.get_mut(&task_id)
229 && *expiry > now
230 && stored_token.as_str() == token.as_str()
231 {
232 *expiry = now + ttl.as_millis() as i128;
234 return Ok(true);
235 }
236
237 Ok(false)
238 }
239}
240
241#[cfg(feature = "redis-backend")]
242pub struct RedisTaskLock {
268 connection: Arc<redis::aio::ConnectionManager>,
269 key_prefix: String,
270}
271
272#[cfg(feature = "redis-backend")]
273impl RedisTaskLock {
274 pub async fn new(redis_url: &str) -> Result<Self, redis::RedisError> {
287 let client = redis::Client::open(redis_url)?;
288 let connection = redis::aio::ConnectionManager::new(client).await?;
289
290 Ok(Self {
291 connection: Arc::new(connection),
292 key_prefix: "reinhardt:locks:".to_string(),
293 })
294 }
295
296 pub async fn with_prefix(
312 redis_url: &str,
313 key_prefix: String,
314 ) -> Result<Self, redis::RedisError> {
315 let client = redis::Client::open(redis_url)?;
316 let connection = redis::aio::ConnectionManager::new(client).await?;
317
318 Ok(Self {
319 connection: Arc::new(connection),
320 key_prefix,
321 })
322 }
323
324 fn lock_key(&self, task_id: TaskId) -> String {
325 format!("{}task:{}", self.key_prefix, task_id)
326 }
327}
328
329#[cfg(feature = "redis-backend")]
330fn validate_ttl_ms(ttl: Duration) -> TaskResult<i64> {
336 use crate::TaskError;
337
338 if ttl.is_zero() {
339 return Err(TaskError::ExecutionFailed(
340 "TTL must be greater than zero".to_string(),
341 ));
342 }
343
344 i64::try_from(ttl.as_millis()).map_err(|_| {
345 TaskError::ExecutionFailed(format!(
346 "TTL overflow: {} ms exceeds i64::MAX",
347 ttl.as_millis()
348 ))
349 })
350}
351
352#[cfg(feature = "redis-backend")]
353#[async_trait]
354impl TaskLock for RedisTaskLock {
355 async fn acquire(&self, task_id: TaskId, ttl: Duration) -> TaskResult<Option<LockToken>> {
356 use crate::TaskError;
357
358 let ttl_ms = validate_ttl_ms(ttl)?;
359 let mut conn = (*self.connection).clone();
360 let key = self.lock_key(task_id);
361 let token = LockToken::generate();
362
363 let result: Result<Option<String>, redis::RedisError> = redis::cmd("SET")
365 .arg(&key)
366 .arg(token.as_str())
367 .arg("PX")
368 .arg(ttl_ms)
369 .arg("NX")
370 .query_async(&mut conn)
371 .await;
372
373 match result {
374 Ok(Some(_)) => Ok(Some(token)),
375 Ok(None) => Ok(None),
376 Err(e) => Err(TaskError::ExecutionFailed(format!(
377 "Failed to acquire lock: {}",
378 e
379 ))),
380 }
381 }
382
383 async fn release(&self, task_id: TaskId, token: &LockToken) -> TaskResult<bool> {
384 use crate::TaskError;
385
386 let mut conn = (*self.connection).clone();
387 let key = self.lock_key(task_id);
388
389 let script = redis::Script::new(
391 "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end",
392 );
393
394 let result: Result<i32, redis::RedisError> = script
395 .key(&key)
396 .arg(token.as_str())
397 .invoke_async(&mut conn)
398 .await;
399
400 match result {
401 Ok(1) => Ok(true),
402 Ok(_) => Ok(false),
403 Err(e) => Err(TaskError::ExecutionFailed(format!(
404 "Failed to release lock: {}",
405 e
406 ))),
407 }
408 }
409
410 async fn is_locked(&self, task_id: TaskId) -> TaskResult<bool> {
411 use crate::TaskError;
412 use redis::AsyncCommands;
413
414 let mut conn = (*self.connection).clone();
415 let key = self.lock_key(task_id);
416
417 let result: Result<bool, redis::RedisError> = conn.exists(&key).await;
418
419 result.map_err(|e| TaskError::ExecutionFailed(format!("Failed to check lock: {}", e)))
420 }
421
422 async fn extend(&self, task_id: TaskId, token: &LockToken, ttl: Duration) -> TaskResult<bool> {
426 use crate::TaskError;
427
428 let ttl_ms = validate_ttl_ms(ttl)?;
429 let mut conn = (*self.connection).clone();
430 let key = self.lock_key(task_id);
431
432 let script = redis::Script::new(
434 "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('pexpire', KEYS[1], ARGV[2]) else return 0 end",
435 );
436
437 let result: Result<i32, redis::RedisError> = script
438 .key(&key)
439 .arg(token.as_str())
440 .arg(ttl_ms)
441 .invoke_async(&mut conn)
442 .await;
443
444 match result {
445 Ok(1) => Ok(true),
446 Ok(_) => Ok(false),
447 Err(e) => Err(TaskError::ExecutionFailed(format!(
448 "Failed to extend lock: {}",
449 e
450 ))),
451 }
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458 use rstest::rstest;
459 use std::time::Duration;
460
461 #[rstest]
462 #[tokio::test]
463 async fn test_memory_lock_acquire() {
464 let lock = MemoryTaskLock::new();
466 let task_id = TaskId::new();
467
468 let token = lock
470 .acquire(task_id, Duration::from_secs(60))
471 .await
472 .unwrap();
473
474 assert!(token.is_some());
476 }
477
478 #[rstest]
479 #[tokio::test]
480 async fn test_memory_lock_already_locked() {
481 let lock = MemoryTaskLock::new();
483 let task_id = TaskId::new();
484 lock.acquire(task_id, Duration::from_secs(60))
485 .await
486 .unwrap();
487
488 let token = lock
490 .acquire(task_id, Duration::from_secs(60))
491 .await
492 .unwrap();
493
494 assert!(token.is_none());
496 }
497
498 #[rstest]
499 #[tokio::test]
500 async fn test_memory_lock_release() {
501 let lock = MemoryTaskLock::new();
503 let task_id = TaskId::new();
504 let token = lock
505 .acquire(task_id, Duration::from_secs(60))
506 .await
507 .unwrap()
508 .unwrap();
509
510 let released = lock.release(task_id, &token).await.unwrap();
512
513 assert!(released);
515 let is_locked = lock.is_locked(task_id).await.unwrap();
516 assert!(!is_locked);
517 }
518
519 #[rstest]
520 #[tokio::test]
521 async fn test_memory_lock_release_wrong_token() {
522 let lock = MemoryTaskLock::new();
524 let task_id = TaskId::new();
525 lock.acquire(task_id, Duration::from_secs(60))
526 .await
527 .unwrap();
528 let wrong_token = LockToken::generate();
529
530 let released = lock.release(task_id, &wrong_token).await.unwrap();
532
533 assert!(!released);
535 let is_locked = lock.is_locked(task_id).await.unwrap();
536 assert!(is_locked);
537 }
538
539 #[rstest]
540 #[tokio::test]
541 async fn test_memory_lock_expiry() {
542 let lock = MemoryTaskLock::new();
544 let task_id = TaskId::new();
545 lock.acquire(task_id, Duration::from_millis(50))
546 .await
547 .unwrap();
548
549 tokio::time::sleep(Duration::from_millis(100)).await;
551
552 let is_locked = lock.is_locked(task_id).await.unwrap();
554 assert!(!is_locked);
555 }
556
557 #[rstest]
558 #[tokio::test]
559 async fn test_memory_lock_extend() {
560 let lock = MemoryTaskLock::new();
562 let task_id = TaskId::new();
563 let token = lock
564 .acquire(task_id, Duration::from_secs(60))
565 .await
566 .unwrap()
567 .unwrap();
568
569 let extended = lock
571 .extend(task_id, &token, Duration::from_secs(120))
572 .await
573 .unwrap();
574
575 assert!(extended);
577 let is_locked = lock.is_locked(task_id).await.unwrap();
578 assert!(is_locked);
579 }
580
581 #[rstest]
582 #[tokio::test]
583 async fn test_memory_lock_extend_returns_false_for_unlocked_task() {
584 let lock = MemoryTaskLock::new();
586 let task_id = TaskId::new();
587 let token = LockToken::generate();
588
589 let extended = lock
591 .extend(task_id, &token, Duration::from_secs(120))
592 .await
593 .unwrap();
594
595 assert!(!extended);
597 }
598
599 #[rstest]
600 #[tokio::test]
601 async fn test_memory_lock_extend_returns_false_for_expired_lock() {
602 let lock = MemoryTaskLock::new();
604 let task_id = TaskId::new();
605 let token = lock
606 .acquire(task_id, Duration::from_millis(50))
607 .await
608 .unwrap()
609 .unwrap();
610 tokio::time::sleep(Duration::from_millis(100)).await;
611
612 let extended = lock
614 .extend(task_id, &token, Duration::from_secs(120))
615 .await
616 .unwrap();
617
618 assert!(!extended);
620 }
621
622 #[rstest]
623 #[tokio::test]
624 async fn test_memory_lock_extend_returns_false_for_wrong_token() {
625 let lock = MemoryTaskLock::new();
627 let task_id = TaskId::new();
628 lock.acquire(task_id, Duration::from_secs(60))
629 .await
630 .unwrap();
631 let wrong_token = LockToken::generate();
632
633 let extended = lock
635 .extend(task_id, &wrong_token, Duration::from_secs(120))
636 .await
637 .unwrap();
638
639 assert!(!extended);
641 }
642
643 #[rstest]
644 #[tokio::test]
645 async fn test_memory_lock_extend_is_atomic() {
646 let lock = Arc::new(MemoryTaskLock::new());
648 let task_id = TaskId::new();
649 let token = lock
650 .acquire(task_id, Duration::from_millis(200))
651 .await
652 .unwrap()
653 .unwrap();
654
655 let extended = lock
657 .extend(task_id, &token, Duration::from_secs(60))
658 .await
659 .unwrap();
660
661 assert!(extended);
663 let second_acquire = lock
665 .acquire(task_id, Duration::from_secs(60))
666 .await
667 .unwrap();
668 assert!(second_acquire.is_none());
669 }
670}