1use rand::Rng;
2use std::marker::PhantomData;
3use std::mem::take;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
6
7use crate::{DerefLt, Empty, Guard};
8
9use super::{GuardLt, MutexProvider, Result};
10use async_trait::async_trait;
11use bb8_redis::redis::Script;
12use bb8_redis::{bb8::Pool, RedisConnectionManager};
13use rand::thread_rng;
14use redis::AsyncCommands;
15pub use redis::{FromRedisValue, RedisResult, RedisWrite, ToRedisArgs, Value};
16use tokio::sync::oneshot::Sender;
17use tokio::sync::{RwLock, RwLockReadGuard};
18use tracing::{error, trace, warn};
19
20const LOCK_LEASE_TIMEOUT_MILLIS: u64 = 10_000;
22const LOCK_REFRESH_INTERVAL_MILLIS: u64 = 1_000;
24const LOCK_POLL_INTERVAL_MILLIS: u64 = 100;
26const RENEWAL_PANIC_BUFFER_MILLIS: u64 = 1_000;
29
30#[derive(Debug, Clone)]
31pub struct RedisMutexProvider {
32 pool: Pool<RedisConnectionManager>,
33 provider_id: String,
34}
35
36impl RedisMutexProvider {
37 pub fn new(provider_id: String, pool: Pool<RedisConnectionManager>) -> RedisMutexProvider {
38 RedisMutexProvider { pool, provider_id }
39 }
40}
41
42#[derive(Clone, Debug)]
43pub struct RedisMutex {
44 pool: Pool<RedisConnectionManager>,
45 key: String,
46 mutex_id: u64,
47}
48
49const ACQUIRE_LOCK_SCRIPT: &str = "\
53 local got_lock = redis.call('SET', KEYS[1], ARGV[1], 'NX', 'PXAT', ARGV[2])
54 if got_lock then
55 return 1
56 end
57 return 0
58";
59
60const RENEW_LOCK_SCRIPT: &str = "\
64 if redis.call('GET', KEYS[1]) == ARGV[1] then
65 redis.call('PEXPIREAT', KEYS[1], ARGV[2])
66 return 1
67 end
68 return 0
69";
70
71const DROP_LOCK_SCRIPT: &str = "\
74 if redis.call('GET', KEYS[1]) == ARGV[1] then
75 redis.call('DEL', KEYS[1])
76 return 1
77 end
78 return 0
79";
80
81impl RedisMutex {
82 async fn try_acquire_lock(&self) -> Result<Option<Duration>> {
86 let exp = SystemTime::now().duration_since(UNIX_EPOCH).unwrap()
87 + Duration::from_millis(LOCK_LEASE_TIMEOUT_MILLIS);
88 Ok(
89 if Script::new(ACQUIRE_LOCK_SCRIPT)
90 .key(self.key.as_str())
91 .arg(self.mutex_id)
92 .arg(exp.as_millis() as i64)
93 .invoke_async::<_, i32>(&mut *self.pool.get().await?)
94 .await?
95 == 1
96 {
97 Some(exp)
98 } else {
99 None
100 },
101 )
102 }
103
104 async fn try_renew_lock(&self) -> Result<Option<Duration>> {
109 let new_exp = SystemTime::now().duration_since(UNIX_EPOCH).unwrap()
110 + Duration::from_millis(LOCK_LEASE_TIMEOUT_MILLIS);
111 Ok(
112 if Script::new(RENEW_LOCK_SCRIPT)
113 .key(self.key.as_str())
114 .arg(self.mutex_id)
115 .arg(new_exp.as_millis() as i64)
116 .invoke_async::<_, i32>(&mut *self.pool.get().await?)
117 .await?
118 == 1
119 {
120 Some(new_exp)
121 } else {
122 None
123 },
124 )
125 }
126
127 async fn drop_lock(&self) -> Result<bool> {
131 Ok(Script::new(DROP_LOCK_SCRIPT)
132 .key(self.key.as_str())
133 .arg(self.mutex_id)
134 .invoke_async::<_, i32>(&mut *self.pool.get().await?)
135 .await?
136 == 1)
137 }
138}
139
140#[async_trait]
141impl<T> super::Mutex<T> for RedisMutex
142where
143 T: Send + FromRedisValue + ToRedisArgs + Sync + 'static,
144{
145 type Guard = RedisGuardCtor<T>;
146 async fn lock(&self) -> Result<RedisGuard<'_, T>> {
147 let mut interval =
148 tokio::time::interval(core::time::Duration::from_millis(LOCK_POLL_INTERVAL_MILLIS));
149 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
150 let expires_at;
151 loop {
152 tokio::select! {
153 _ = interval.tick() => {
154 if let Some(exp) = self.try_acquire_lock().await? {
155 expires_at = exp;
156 break;
157 }
158 }
159 }
160 }
161 Ok(RedisGuard::new(&self, expires_at))
162 }
163}
164
165pub struct RedisGuardCtor<T>(PhantomData<T>);
166
167impl<'a, T> GuardLt<'a, T> for RedisGuardCtor<T>
168where
169 T: FromRedisValue + ToRedisArgs + Send + Sync + 'static,
170{
171 type Guard = RedisGuard<'a, T>;
172}
173
174pub struct RedisGuard<'a, T> {
175 mutex: &'a RedisMutex,
176 drop_tx: Option<Sender<()>>,
177 loaded: AtomicBool,
178 data: RwLock<Option<T>>,
179 _pd: PhantomData<T>,
180}
181
182impl<'a, T> RedisGuard<'a, T> {
183 fn new(mutex: &'a RedisMutex, exp_at: Duration) -> RedisGuard<'a, T> {
184 trace!(key = %mutex.key, mutex_id = %mutex.mutex_id, expires_at = ?exp_at, "acquired lock");
185 let (drop_tx, mut drop_rx) = tokio::sync::oneshot::channel();
186 let mutex_clone = mutex.clone();
187 let _ = tokio::spawn(async move {
188 let mutex = mutex_clone;
189 let mut renewal_interval = tokio::time::interval(core::time::Duration::from_millis(
190 LOCK_REFRESH_INTERVAL_MILLIS,
191 ));
192 renewal_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
193
194 let panic_timeout = tokio::time::sleep(
195 exp_at
196 - SystemTime::now().duration_since(UNIX_EPOCH).unwrap()
197 - Duration::from_millis(RENEWAL_PANIC_BUFFER_MILLIS),
198 );
199 tokio::pin!(panic_timeout);
200 loop {
201 tokio::select! {
202 _ = &mut drop_rx => {
203 break;
204 }
205 _ = renewal_interval.tick() => {
206 match mutex.try_renew_lock().await {
207 Ok(Some(new_exp)) => {
208 trace!(key = %mutex.key, mutex_id = %mutex.mutex_id, expires_at = ?new_exp, "renewed lock lease");
209 panic_timeout.as_mut().reset(tokio::time::Instant::from_std(Instant::now() + new_exp
210 - SystemTime::now().duration_since(UNIX_EPOCH).unwrap()
211 - Duration::from_millis(RENEWAL_PANIC_BUFFER_MILLIS)));
212 },
213 Ok(None) => {
214 panic!("failed to renew mutex because it had a different owner: {}", mutex.key);
215 },
216 Err(e) => {
217 error!(key = %mutex.key, mutex_id = %mutex.mutex_id, "failed to renew lease on lock, scheduling retry: {}", e);
218 continue;
219 },
220 }
221 }
222 _ = &mut panic_timeout => {
223 panic!("failed to renew mutex before lease expiration: {}", mutex.key);
224 }
225 }
226 }
227 match mutex.drop_lock().await {
228 Ok(false) => {
229 warn!(key = %mutex.key, mutex_id = %mutex.mutex_id, "lock already had different owner while attempting to drop");
230 }
231 Err(e) => {
232 error!(key = %mutex.key, mutex_id = %mutex.mutex_id, "failed to drop lock: {}", e);
233 }
234 _ => {
235 trace!(key = %mutex.key, mutex_id = %mutex.mutex_id, "successfully dropped lock");
236 }
237 }
238 });
239 RedisGuard {
240 mutex,
241 loaded: AtomicBool::new(false),
242 drop_tx: Some(drop_tx),
243 data: RwLock::new(None),
244 _pd: Default::default(),
245 }
246 }
247}
248
249impl<'a, T> Drop for RedisGuard<'a, T> {
250 fn drop(&mut self) {
251 if let Some(tx) = take(&mut self.drop_tx) {
252 let _ = tx.send(());
253 trace!(key = %self.mutex.key, mutex_id = %self.mutex.mutex_id, "guard dropped");
254 }
255 }
256}
257
258fn format_data_key(key: &str) -> String {
259 format!("{}_data", key)
260}
261
262pub struct RedisDerefCtor<T>(PhantomData<T>);
263
264impl<'a, T> DerefLt<'a, T> for RedisDerefCtor<T>
265where
266 T: Send + Sync + 'static,
267{
268 type Deref = RwLockReadGuard<'a, Option<T>>;
269}
270
271#[async_trait]
272impl<'a, T> Guard<T> for RedisGuard<'a, T>
273where
274 T: FromRedisValue + ToRedisArgs + Send + Sync + 'static,
275{
276 type D = RedisDerefCtor<T>;
277 async fn store(&mut self, data: T) -> Result<()> {
278 let mut con = self.mutex.pool.get().await?;
279 con.set(format_data_key(&self.mutex.key), &data).await?;
280 let mut guard = self.data.write().await;
281 *guard = Some(data);
282 self.loaded.store(true, Ordering::Relaxed);
283 Ok(())
284 }
285 async fn load<'s>(&'s self) -> Result<RwLockReadGuard<'s, Option<T>>> {
286 if !self.loaded.load(std::sync::atomic::Ordering::Relaxed) {
287 let mut con = self.mutex.pool.get().await?;
288 let val: Option<T> = con.get(format_data_key(&self.mutex.key)).await?;
289 let mut guard = self.data.write().await;
290 *guard = val;
291 self.loaded.store(true, Ordering::Relaxed);
292 }
293 return Ok(self.data.read().await);
294 }
295 async fn clear(&mut self) -> Result<()> {
296 let mut con = self.mutex.pool.get().await?;
297 con.del(format_data_key(&self.mutex.key)).await?;
298 let mut guard = self.data.write().await;
299 *guard = None;
300 self.loaded.store(true, Ordering::Relaxed);
301 Ok(())
302 }
303}
304
305#[async_trait]
306impl<T, K> MutexProvider<T, K> for RedisMutexProvider
307where
308 T: FromRedisValue + ToRedisArgs + Send + Sync + 'static,
309 K: AsRef<str> + Send,
310{
311 type Mutex = RedisMutex;
312 async fn get(&self, key: K) -> Result<Self::Mutex>
313 where
314 K: 'async_trait,
315 {
316 let key = format!("amutex_{}_{}", self.provider_id, key.as_ref());
317 let mutex_id = thread_rng().gen::<u64>();
318 Ok(RedisMutex {
319 pool: self.pool.clone(),
320 key,
321 mutex_id,
322 })
323 }
324}
325
326impl ToRedisArgs for Empty {
327 fn write_redis_args<W>(&self, _out: &mut W)
328 where
329 W: ?Sized + RedisWrite,
330 {
331 }
332}
333
334impl FromRedisValue for Empty {
335 fn from_redis_value(_v: &Value) -> RedisResult<Self> {
336 return Ok(Empty);
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use bb8_redis::{bb8::Pool, RedisConnectionManager};
343 use testcontainers::{clients::Cli, images::generic::GenericImage};
344
345 use crate::spec::{check_empty, check_val};
346
347 use super::RedisMutexProvider;
348
349 #[tokio::test]
350 async fn test() {
351 let cli = Cli::default();
352 let port = 6379;
353 let container = cli.run(GenericImage::new("redis", "7.0").with_exposed_port(port));
354 let host_port = container.get_host_port_ipv4(port);
355 let uri = format!("redis://localhost:{host_port}");
356 let redis_connection_manager = RedisConnectionManager::new(uri.as_str()).unwrap();
357 let pool = Pool::builder()
358 .build(redis_connection_manager)
359 .await
360 .unwrap();
361 check_empty(RedisMutexProvider::new("testing".to_string(), pool.clone())).await;
362 check_val(RedisMutexProvider::new("testing_vals".to_string(), pool)).await;
363 }
364}