Skip to main content

reinhardt_tasks/
locking.rs

1//! Distributed task locking mechanism
2//!
3//! This module provides locking primitives for distributed task systems,
4//! preventing multiple workers from executing the same task simultaneously.
5
6use crate::{TaskId, TaskResult};
7use async_trait::async_trait;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13/// Opaque token returned by a successful lock acquisition.
14///
15/// The token proves lock ownership and must be presented when releasing
16/// or extending the lock. This prevents workers from accidentally
17/// releasing locks they do not own.
18///
19/// # Examples
20///
21/// ```rust
22/// use reinhardt_tasks::LockToken;
23///
24/// let token = LockToken::generate();
25/// assert!(!token.as_str().is_empty());
26/// ```
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct LockToken(String);
29
30impl LockToken {
31	/// Generate a new unique lock token
32	pub fn generate() -> Self {
33		Self(Uuid::new_v4().to_string())
34	}
35
36	/// Get the string representation of the token
37	pub fn as_str(&self) -> &str {
38		&self.0
39	}
40}
41
42/// Distributed lock trait for task synchronization
43///
44/// # Examples
45///
46/// ```rust,no_run
47/// use reinhardt_tasks::{TaskLock, TaskId, LockToken};
48/// use async_trait::async_trait;
49/// use std::time::Duration;
50///
51/// struct MyLock;
52///
53/// #[async_trait]
54/// impl TaskLock for MyLock {
55///     async fn acquire(&self, task_id: TaskId, ttl: Duration) -> reinhardt_tasks::TaskResult<Option<LockToken>> {
56///         // Acquire lock implementation
57///         Ok(Some(LockToken::generate()))
58///     }
59///
60///     async fn release(&self, task_id: TaskId, token: &LockToken) -> reinhardt_tasks::TaskResult<bool> {
61///         // Release lock implementation
62///         Ok(true)
63///     }
64///
65///     async fn is_locked(&self, task_id: TaskId) -> reinhardt_tasks::TaskResult<bool> {
66///         // Check lock status
67///         Ok(false)
68///     }
69/// }
70/// ```
71#[async_trait]
72pub trait TaskLock: Send + Sync {
73	/// Acquire a lock for a task
74	///
75	/// Returns `Some(LockToken)` if lock was acquired, `None` if already locked
76	/// by another worker.
77	async fn acquire(&self, task_id: TaskId, ttl: Duration) -> TaskResult<Option<LockToken>>;
78
79	/// Release a lock for a task
80	///
81	/// Returns `true` if the lock was released, `false` if the token does not
82	/// match (i.e. the caller does not own the lock).
83	async fn release(&self, task_id: TaskId, token: &LockToken) -> TaskResult<bool>;
84
85	/// Check if a task is locked
86	async fn is_locked(&self, task_id: TaskId) -> TaskResult<bool>;
87
88	/// Extend the TTL of an existing lock
89	///
90	/// Implementors should override this with a backend-specific atomic operation
91	/// to avoid race conditions where another worker could steal the lock between
92	/// release and re-acquire.
93	async fn extend(&self, task_id: TaskId, token: &LockToken, ttl: Duration) -> TaskResult<bool> {
94		// Default: check-then-release-then-acquire is non-atomic.
95		// Concrete implementations should override with atomic operations.
96		if self.is_locked(task_id).await? {
97			let released = self.release(task_id, token).await?;
98			if !released {
99				// Token did not match — caller does not own the lock
100				return Ok(false);
101			}
102			self.acquire(task_id, ttl).await.map(|t| t.is_some())
103		} else {
104			Ok(false)
105		}
106	}
107}
108
109/// In-memory task lock for single-process testing
110///
111/// # Examples
112///
113/// ```rust
114/// use reinhardt_tasks::{MemoryTaskLock, TaskLock, TaskId};
115/// use std::time::Duration;
116///
117/// # async fn example() -> reinhardt_tasks::TaskResult<()> {
118/// let lock = MemoryTaskLock::new();
119/// let task_id = TaskId::new();
120///
121/// // Acquire lock
122/// let token = lock.acquire(task_id, Duration::from_secs(60)).await?;
123/// assert!(token.is_some());
124///
125/// // Check if locked
126/// let is_locked = lock.is_locked(task_id).await?;
127/// assert!(is_locked);
128///
129/// // Release lock
130/// let released = lock.release(task_id, &token.unwrap()).await?;
131/// assert!(released);
132/// # Ok(())
133/// # }
134/// ```
135pub struct MemoryTaskLock {
136	/// Map of task ID to (expiry timestamp in ms, token string)
137	locks: Arc<RwLock<std::collections::HashMap<TaskId, (i128, String)>>>,
138}
139
140impl MemoryTaskLock {
141	/// Create a new in-memory task lock
142	///
143	/// # Examples
144	///
145	/// ```rust
146	/// use reinhardt_tasks::MemoryTaskLock;
147	///
148	/// let lock = MemoryTaskLock::new();
149	/// ```
150	pub fn new() -> Self {
151		Self {
152			locks: Arc::new(RwLock::new(std::collections::HashMap::new())),
153		}
154	}
155
156	/// Clean up expired locks
157	async fn cleanup_expired(&self) {
158		let mut locks = self.locks.write().await;
159		let now = chrono::Utc::now().timestamp_millis() as i128;
160		locks.retain(|_, (expiry, _)| *expiry > now);
161	}
162}
163
164impl Default for MemoryTaskLock {
165	fn default() -> Self {
166		Self::new()
167	}
168}
169
170#[async_trait]
171impl TaskLock for MemoryTaskLock {
172	async fn acquire(&self, task_id: TaskId, ttl: Duration) -> TaskResult<Option<LockToken>> {
173		// Zero TTL would create a lock that expires immediately, causing
174		// inconsistency between acquire (returns Some) and is_locked (returns false).
175		if ttl.is_zero() {
176			return Ok(None);
177		}
178
179		self.cleanup_expired().await;
180
181		let mut locks = self.locks.write().await;
182		let now = chrono::Utc::now().timestamp_millis() as i128;
183		let expiry = now + ttl.as_millis() as i128;
184
185		if let Some(&(existing_expiry, _)) = locks.get(&task_id)
186			&& existing_expiry > now
187		{
188			return Ok(None);
189		}
190
191		let token = LockToken::generate();
192		locks.insert(task_id, (expiry, token.as_str().to_string()));
193		Ok(Some(token))
194	}
195
196	async fn release(&self, task_id: TaskId, token: &LockToken) -> TaskResult<bool> {
197		let mut locks = self.locks.write().await;
198		if let Some((_, stored_token)) = locks.get(&task_id)
199			&& stored_token == token.as_str()
200		{
201			locks.remove(&task_id);
202			return Ok(true);
203		}
204		Ok(false)
205	}
206
207	async fn is_locked(&self, task_id: TaskId) -> TaskResult<bool> {
208		self.cleanup_expired().await;
209
210		let locks = self.locks.read().await;
211		let now = chrono::Utc::now().timestamp_millis() as i128;
212
213		Ok(locks
214			.get(&task_id)
215			.map(|(expiry, _)| *expiry > now)
216			.unwrap_or(false))
217	}
218
219	/// Atomically extend the TTL of an existing lock.
220	///
221	/// Unlike the default trait implementation which releases then re-acquires,
222	/// this holds the write lock throughout the operation to prevent another
223	/// worker from stealing the lock in between.
224	async fn extend(&self, task_id: TaskId, token: &LockToken, ttl: Duration) -> TaskResult<bool> {
225		let mut locks = self.locks.write().await;
226		let now = chrono::Utc::now().timestamp_millis() as i128;
227
228		if let Some((expiry, stored_token)) = locks.get_mut(&task_id)
229			&& *expiry > now
230			&& stored_token.as_str() == token.as_str()
231		{
232			// Lock is still valid and owned by caller; atomically update its expiry
233			*expiry = now + ttl.as_millis() as i128;
234			return Ok(true);
235		}
236
237		Ok(false)
238	}
239}
240
241#[cfg(feature = "redis-backend")]
242/// Redis-based distributed task lock
243///
244/// Uses atomic `SET key value PX ms NX` for lock acquisition and Lua scripts
245/// for ownership-verified release and extension.
246///
247/// # Examples
248///
249/// ```no_run
250/// use reinhardt_tasks::{RedisTaskLock, TaskLock, TaskId};
251/// use std::time::Duration;
252///
253/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
254/// let lock = RedisTaskLock::new("redis://127.0.0.1/").await?;
255/// let task_id = TaskId::new();
256///
257/// // Acquire distributed lock
258/// let token = lock.acquire(task_id, Duration::from_secs(30)).await?;
259/// if let Some(token) = token {
260///     // Execute task
261///     // ...
262///     lock.release(task_id, &token).await?;
263/// }
264/// # Ok(())
265/// # }
266/// ```
267pub struct RedisTaskLock {
268	connection: Arc<redis::aio::ConnectionManager>,
269	key_prefix: String,
270}
271
272#[cfg(feature = "redis-backend")]
273impl RedisTaskLock {
274	/// Create a new Redis-based task lock
275	///
276	/// # Examples
277	///
278	/// ```no_run
279	/// use reinhardt_tasks::RedisTaskLock;
280	///
281	/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
282	/// let lock = RedisTaskLock::new("redis://localhost/").await?;
283	/// # Ok(())
284	/// # }
285	/// ```
286	pub async fn new(redis_url: &str) -> Result<Self, redis::RedisError> {
287		let client = redis::Client::open(redis_url)?;
288		let connection = redis::aio::ConnectionManager::new(client).await?;
289
290		Ok(Self {
291			connection: Arc::new(connection),
292			key_prefix: "reinhardt:locks:".to_string(),
293		})
294	}
295
296	/// Create a Redis task lock with custom key prefix
297	///
298	/// # Examples
299	///
300	/// ```no_run
301	/// use reinhardt_tasks::RedisTaskLock;
302	///
303	/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
304	/// let lock = RedisTaskLock::with_prefix(
305	///     "redis://localhost/",
306	///     "myapp:locks:".to_string()
307	/// ).await?;
308	/// # Ok(())
309	/// # }
310	/// ```
311	pub async fn with_prefix(
312		redis_url: &str,
313		key_prefix: String,
314	) -> Result<Self, redis::RedisError> {
315		let client = redis::Client::open(redis_url)?;
316		let connection = redis::aio::ConnectionManager::new(client).await?;
317
318		Ok(Self {
319			connection: Arc::new(connection),
320			key_prefix,
321		})
322	}
323
324	fn lock_key(&self, task_id: TaskId) -> String {
325		format!("{}task:{}", self.key_prefix, task_id)
326	}
327}
328
329#[cfg(feature = "redis-backend")]
330/// Convert a `Duration` to milliseconds as `i64`, rejecting zero and overflow.
331///
332/// Zero TTL is invalid because Redis `PX 0` causes an error and a zero-duration
333/// lock is semantically meaningless. Overflow is possible because
334/// `Duration::as_millis()` returns `u128` but Redis expects `i64`.
335fn validate_ttl_ms(ttl: Duration) -> TaskResult<i64> {
336	use crate::TaskError;
337
338	if ttl.is_zero() {
339		return Err(TaskError::ExecutionFailed(
340			"TTL must be greater than zero".to_string(),
341		));
342	}
343
344	i64::try_from(ttl.as_millis()).map_err(|_| {
345		TaskError::ExecutionFailed(format!(
346			"TTL overflow: {} ms exceeds i64::MAX",
347			ttl.as_millis()
348		))
349	})
350}
351
352#[cfg(feature = "redis-backend")]
353#[async_trait]
354impl TaskLock for RedisTaskLock {
355	async fn acquire(&self, task_id: TaskId, ttl: Duration) -> TaskResult<Option<LockToken>> {
356		use crate::TaskError;
357
358		let ttl_ms = validate_ttl_ms(ttl)?;
359		let mut conn = (*self.connection).clone();
360		let key = self.lock_key(task_id);
361		let token = LockToken::generate();
362
363		// Atomic SET key value PX ms NX
364		let result: Result<Option<String>, redis::RedisError> = redis::cmd("SET")
365			.arg(&key)
366			.arg(token.as_str())
367			.arg("PX")
368			.arg(ttl_ms)
369			.arg("NX")
370			.query_async(&mut conn)
371			.await;
372
373		match result {
374			Ok(Some(_)) => Ok(Some(token)),
375			Ok(None) => Ok(None),
376			Err(e) => Err(TaskError::ExecutionFailed(format!(
377				"Failed to acquire lock: {}",
378				e
379			))),
380		}
381	}
382
383	async fn release(&self, task_id: TaskId, token: &LockToken) -> TaskResult<bool> {
384		use crate::TaskError;
385
386		let mut conn = (*self.connection).clone();
387		let key = self.lock_key(task_id);
388
389		// Lua script: compare token, delete only if matching
390		let script = redis::Script::new(
391			"if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end",
392		);
393
394		let result: Result<i32, redis::RedisError> = script
395			.key(&key)
396			.arg(token.as_str())
397			.invoke_async(&mut conn)
398			.await;
399
400		match result {
401			Ok(1) => Ok(true),
402			Ok(_) => Ok(false),
403			Err(e) => Err(TaskError::ExecutionFailed(format!(
404				"Failed to release lock: {}",
405				e
406			))),
407		}
408	}
409
410	async fn is_locked(&self, task_id: TaskId) -> TaskResult<bool> {
411		use crate::TaskError;
412		use redis::AsyncCommands;
413
414		let mut conn = (*self.connection).clone();
415		let key = self.lock_key(task_id);
416
417		let result: Result<bool, redis::RedisError> = conn.exists(&key).await;
418
419		result.map_err(|e| TaskError::ExecutionFailed(format!("Failed to check lock: {}", e)))
420	}
421
422	/// Atomically extend the TTL using a Lua script with millisecond precision.
423	///
424	/// Verifies ownership before extending, preventing unauthorized extensions.
425	async fn extend(&self, task_id: TaskId, token: &LockToken, ttl: Duration) -> TaskResult<bool> {
426		use crate::TaskError;
427
428		let ttl_ms = validate_ttl_ms(ttl)?;
429		let mut conn = (*self.connection).clone();
430		let key = self.lock_key(task_id);
431
432		// Lua script: compare token, pexpire only if matching
433		let script = redis::Script::new(
434			"if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('pexpire', KEYS[1], ARGV[2]) else return 0 end",
435		);
436
437		let result: Result<i32, redis::RedisError> = script
438			.key(&key)
439			.arg(token.as_str())
440			.arg(ttl_ms)
441			.invoke_async(&mut conn)
442			.await;
443
444		match result {
445			Ok(1) => Ok(true),
446			Ok(_) => Ok(false),
447			Err(e) => Err(TaskError::ExecutionFailed(format!(
448				"Failed to extend lock: {}",
449				e
450			))),
451		}
452	}
453}
454
455#[cfg(test)]
456mod tests {
457	use super::*;
458	use rstest::rstest;
459	use std::time::Duration;
460
461	#[rstest]
462	#[tokio::test]
463	async fn test_memory_lock_acquire() {
464		// Arrange
465		let lock = MemoryTaskLock::new();
466		let task_id = TaskId::new();
467
468		// Act
469		let token = lock
470			.acquire(task_id, Duration::from_secs(60))
471			.await
472			.unwrap();
473
474		// Assert
475		assert!(token.is_some());
476	}
477
478	#[rstest]
479	#[tokio::test]
480	async fn test_memory_lock_already_locked() {
481		// Arrange
482		let lock = MemoryTaskLock::new();
483		let task_id = TaskId::new();
484		lock.acquire(task_id, Duration::from_secs(60))
485			.await
486			.unwrap();
487
488		// Act
489		let token = lock
490			.acquire(task_id, Duration::from_secs(60))
491			.await
492			.unwrap();
493
494		// Assert
495		assert!(token.is_none());
496	}
497
498	#[rstest]
499	#[tokio::test]
500	async fn test_memory_lock_release() {
501		// Arrange
502		let lock = MemoryTaskLock::new();
503		let task_id = TaskId::new();
504		let token = lock
505			.acquire(task_id, Duration::from_secs(60))
506			.await
507			.unwrap()
508			.unwrap();
509
510		// Act
511		let released = lock.release(task_id, &token).await.unwrap();
512
513		// Assert
514		assert!(released);
515		let is_locked = lock.is_locked(task_id).await.unwrap();
516		assert!(!is_locked);
517	}
518
519	#[rstest]
520	#[tokio::test]
521	async fn test_memory_lock_release_wrong_token() {
522		// Arrange
523		let lock = MemoryTaskLock::new();
524		let task_id = TaskId::new();
525		lock.acquire(task_id, Duration::from_secs(60))
526			.await
527			.unwrap();
528		let wrong_token = LockToken::generate();
529
530		// Act
531		let released = lock.release(task_id, &wrong_token).await.unwrap();
532
533		// Assert - release must fail with wrong token
534		assert!(!released);
535		let is_locked = lock.is_locked(task_id).await.unwrap();
536		assert!(is_locked);
537	}
538
539	#[rstest]
540	#[tokio::test]
541	async fn test_memory_lock_expiry() {
542		// Arrange
543		let lock = MemoryTaskLock::new();
544		let task_id = TaskId::new();
545		lock.acquire(task_id, Duration::from_millis(50))
546			.await
547			.unwrap();
548
549		// Act
550		tokio::time::sleep(Duration::from_millis(100)).await;
551
552		// Assert
553		let is_locked = lock.is_locked(task_id).await.unwrap();
554		assert!(!is_locked);
555	}
556
557	#[rstest]
558	#[tokio::test]
559	async fn test_memory_lock_extend() {
560		// Arrange
561		let lock = MemoryTaskLock::new();
562		let task_id = TaskId::new();
563		let token = lock
564			.acquire(task_id, Duration::from_secs(60))
565			.await
566			.unwrap()
567			.unwrap();
568
569		// Act
570		let extended = lock
571			.extend(task_id, &token, Duration::from_secs(120))
572			.await
573			.unwrap();
574
575		// Assert
576		assert!(extended);
577		let is_locked = lock.is_locked(task_id).await.unwrap();
578		assert!(is_locked);
579	}
580
581	#[rstest]
582	#[tokio::test]
583	async fn test_memory_lock_extend_returns_false_for_unlocked_task() {
584		// Arrange
585		let lock = MemoryTaskLock::new();
586		let task_id = TaskId::new();
587		let token = LockToken::generate();
588
589		// Act - extend without acquiring first
590		let extended = lock
591			.extend(task_id, &token, Duration::from_secs(120))
592			.await
593			.unwrap();
594
595		// Assert
596		assert!(!extended);
597	}
598
599	#[rstest]
600	#[tokio::test]
601	async fn test_memory_lock_extend_returns_false_for_expired_lock() {
602		// Arrange
603		let lock = MemoryTaskLock::new();
604		let task_id = TaskId::new();
605		let token = lock
606			.acquire(task_id, Duration::from_millis(50))
607			.await
608			.unwrap()
609			.unwrap();
610		tokio::time::sleep(Duration::from_millis(100)).await;
611
612		// Act - extend an expired lock
613		let extended = lock
614			.extend(task_id, &token, Duration::from_secs(120))
615			.await
616			.unwrap();
617
618		// Assert
619		assert!(!extended);
620	}
621
622	#[rstest]
623	#[tokio::test]
624	async fn test_memory_lock_extend_returns_false_for_wrong_token() {
625		// Arrange
626		let lock = MemoryTaskLock::new();
627		let task_id = TaskId::new();
628		lock.acquire(task_id, Duration::from_secs(60))
629			.await
630			.unwrap();
631		let wrong_token = LockToken::generate();
632
633		// Act - extend with wrong token
634		let extended = lock
635			.extend(task_id, &wrong_token, Duration::from_secs(120))
636			.await
637			.unwrap();
638
639		// Assert
640		assert!(!extended);
641	}
642
643	#[rstest]
644	#[tokio::test]
645	async fn test_memory_lock_extend_is_atomic() {
646		// Arrange - verify that extend does not release the lock at any point
647		let lock = Arc::new(MemoryTaskLock::new());
648		let task_id = TaskId::new();
649		let token = lock
650			.acquire(task_id, Duration::from_millis(200))
651			.await
652			.unwrap()
653			.unwrap();
654
655		// Act - extend the lock
656		let extended = lock
657			.extend(task_id, &token, Duration::from_secs(60))
658			.await
659			.unwrap();
660
661		// Assert - lock should still be held and not have been released
662		assert!(extended);
663		// A second acquire should fail because the lock was never released
664		let second_acquire = lock
665			.acquire(task_id, Duration::from_secs(60))
666			.await
667			.unwrap();
668		assert!(second_acquire.is_none());
669	}
670}