entelix_persistence/redis/
lock.rs1use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use async_trait::async_trait;
13use redis::Script;
14use redis::aio::ConnectionManager;
15use tokio::time::sleep;
16
17use crate::advisory_key::AdvisoryKey;
18use crate::error::{PersistenceError, PersistenceResult};
19use crate::lock::{DistributedLock, LockGuard};
20
21const POLL_INTERVAL: Duration = Duration::from_millis(50);
22
23const RELEASE_SCRIPT: &str = r#"
27if redis.call("get", KEYS[1]) == ARGV[1] then
28 return redis.call("del", KEYS[1])
29else
30 return 0
31end
32"#;
33
34const EXTEND_SCRIPT: &str = r#"
36if redis.call("get", KEYS[1]) == ARGV[1] then
37 return redis.call("pexpire", KEYS[1], ARGV[2])
38else
39 return 0
40end
41"#;
42
43pub struct RedisLock {
45 manager: Arc<ConnectionManager>,
46 release_script: Script,
47 extend_script: Script,
48}
49
50impl RedisLock {
51 pub(crate) fn new(manager: Arc<ConnectionManager>) -> Self {
52 Self {
53 manager,
54 release_script: Script::new(RELEASE_SCRIPT),
55 extend_script: Script::new(EXTEND_SCRIPT),
56 }
57 }
58}
59
60#[async_trait]
61impl DistributedLock for RedisLock {
62 async fn try_acquire(
63 &self,
64 key: &AdvisoryKey,
65 ttl: Duration,
66 ) -> PersistenceResult<Option<LockGuard>> {
67 let guard = LockGuard::new(*key);
68 let mut conn = (*self.manager).clone();
69 let ttl_ms = u64::try_from(ttl.as_millis()).unwrap_or(u64::MAX);
70 let result: Option<String> = redis::cmd("SET")
71 .arg(key.redis_key())
72 .arg(guard.token())
73 .arg("NX")
74 .arg("PX")
75 .arg(ttl_ms)
76 .query_async(&mut conn)
77 .await
78 .map_err(backend_err)?;
79 match result.as_deref() {
80 Some("OK") => Ok(Some(guard)),
81 _ => {
82 drop(guard); Ok(None)
85 }
86 }
87 }
88
89 async fn acquire(
90 &self,
91 key: &AdvisoryKey,
92 ttl: Duration,
93 deadline: Duration,
94 ) -> PersistenceResult<LockGuard> {
95 let start = Instant::now();
96 let mut attempts: u32 = 0;
97 loop {
98 attempts = attempts.saturating_add(1);
99 if let Some(guard) = self.try_acquire(key, ttl).await? {
100 return Ok(guard);
101 }
102 if start.elapsed() >= deadline {
103 return Err(PersistenceError::LockAcquireTimeout {
104 key: key.to_string(),
105 attempts,
106 });
107 }
108 sleep(POLL_INTERVAL).await;
109 }
110 }
111
112 async fn extend(&self, guard: &LockGuard, ttl: Duration) -> PersistenceResult<bool> {
113 let mut conn = (*self.manager).clone();
114 let ttl_ms = u64::try_from(ttl.as_millis()).unwrap_or(u64::MAX);
115 let result: i32 = self
116 .extend_script
117 .key(guard.key().redis_key())
118 .arg(guard.token())
119 .arg(ttl_ms)
120 .invoke_async(&mut conn)
121 .await
122 .map_err(backend_err)?;
123 Ok(result == 1)
124 }
125
126 async fn release(&self, mut guard: LockGuard) -> PersistenceResult<()> {
127 let mut conn = (*self.manager).clone();
128 let _: i32 = self
129 .release_script
130 .key(guard.key().redis_key())
131 .arg(guard.token())
132 .invoke_async(&mut conn)
133 .await
134 .map_err(backend_err)?;
135 guard.mark_released();
136 Ok(())
137 }
138}
139
140fn backend_err(e: redis::RedisError) -> PersistenceError {
141 PersistenceError::Backend(e.to_string())
142}