a_mutex/
redis.rs

1use rand::Rng;
2use std::marker::PhantomData;
3use std::mem::take;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
6
7use crate::{DerefLt, Empty, Guard};
8
9use super::{GuardLt, MutexProvider, Result};
10use async_trait::async_trait;
11use bb8_redis::redis::Script;
12use bb8_redis::{bb8::Pool, RedisConnectionManager};
13use rand::thread_rng;
14use redis::AsyncCommands;
15pub use redis::{FromRedisValue, RedisResult, RedisWrite, ToRedisArgs, Value};
16use tokio::sync::oneshot::Sender;
17use tokio::sync::{RwLock, RwLockReadGuard};
18use tracing::{error, trace, warn};
19
20/// The time in milliseconds after which a lock lease will expire if not renewed.
21const LOCK_LEASE_TIMEOUT_MILLIS: u64 = 10_000;
22/// The time interval at which lock lease renewal will be attempted
23const LOCK_REFRESH_INTERVAL_MILLIS: u64 = 1_000;
24/// The time interval between attempts to acquire the lock
25const LOCK_POLL_INTERVAL_MILLIS: u64 = 100;
26/// Buffer time before lock lease expiration which determines when the guard
27/// refresher will panic when it fails to renew the lock lease.
28const RENEWAL_PANIC_BUFFER_MILLIS: u64 = 1_000;
29
30#[derive(Debug, Clone)]
31pub struct RedisMutexProvider {
32    pool: Pool<RedisConnectionManager>,
33    provider_id: String,
34}
35
36impl RedisMutexProvider {
37    pub fn new(provider_id: String, pool: Pool<RedisConnectionManager>) -> RedisMutexProvider {
38        RedisMutexProvider { pool, provider_id }
39    }
40}
41
42#[derive(Clone, Debug)]
43pub struct RedisMutex {
44    pool: Pool<RedisConnectionManager>,
45    key: String,
46    mutex_id: u64,
47}
48
49// KEYS[1] = lock key
50// ARGV[1] = mutex id
51// ARGV[2] = lock timeout in millis
52const ACQUIRE_LOCK_SCRIPT: &str = "\
53  local got_lock = redis.call('SET', KEYS[1], ARGV[1], 'NX', 'PXAT', ARGV[2])
54  if got_lock then
55      return 1
56  end
57  return 0
58";
59
60// KEYS[1] = lock key
61// ARGV[1] = mutex id
62// ARGV[2] = new lock expiration time in unix millis
63const RENEW_LOCK_SCRIPT: &str = "\
64  if redis.call('GET', KEYS[1]) == ARGV[1] then
65      redis.call('PEXPIREAT', KEYS[1],  ARGV[2])
66      return 1
67  end
68  return 0
69";
70
71// KEYS[1] = lock key
72// ARGV[1] = mutex id
73const DROP_LOCK_SCRIPT: &str = "\
74  if redis.call('GET', KEYS[1]) == ARGV[1] then
75      redis.call('DEL', KEYS[1])
76      return 1
77  end
78  return 0
79";
80
81impl RedisMutex {
82    /// Attempts to acquire the lock. If successful, returns the lock lease
83    /// expiration time as a unix timestamp in milliseconds. If the lock
84    /// is already locked, returns None so that clients know to retry.
85    async fn try_acquire_lock(&self) -> Result<Option<Duration>> {
86        let exp = SystemTime::now().duration_since(UNIX_EPOCH).unwrap()
87            + Duration::from_millis(LOCK_LEASE_TIMEOUT_MILLIS);
88        Ok(
89            if Script::new(ACQUIRE_LOCK_SCRIPT)
90                .key(self.key.as_str())
91                .arg(self.mutex_id)
92                .arg(exp.as_millis() as i64)
93                .invoke_async::<_, i32>(&mut *self.pool.get().await?)
94                .await?
95                == 1
96            {
97                Some(exp)
98            } else {
99                None
100            },
101        )
102    }
103
104    /// Attempts to renew the the lock lease. If successful, returns the new
105    /// lock lease expiration time as a unix timestamp in milliseconds. If the lock
106    /// is no longer held by this instance, returns None (clients should generally
107    /// panic in this case)
108    async fn try_renew_lock(&self) -> Result<Option<Duration>> {
109        let new_exp = SystemTime::now().duration_since(UNIX_EPOCH).unwrap()
110            + Duration::from_millis(LOCK_LEASE_TIMEOUT_MILLIS);
111        Ok(
112            if Script::new(RENEW_LOCK_SCRIPT)
113                .key(self.key.as_str())
114                .arg(self.mutex_id)
115                .arg(new_exp.as_millis() as i64)
116                .invoke_async::<_, i32>(&mut *self.pool.get().await?)
117                .await?
118                == 1
119            {
120                Some(new_exp)
121            } else {
122                None
123            },
124        )
125    }
126
127    /// Attempts to drop the lock. Returns true if the lock was held by this
128    /// owner and was dropped. Returns false if the lock was not held by
129    /// this owner and nothing happened.
130    async fn drop_lock(&self) -> Result<bool> {
131        Ok(Script::new(DROP_LOCK_SCRIPT)
132            .key(self.key.as_str())
133            .arg(self.mutex_id)
134            .invoke_async::<_, i32>(&mut *self.pool.get().await?)
135            .await?
136            == 1)
137    }
138}
139
140#[async_trait]
141impl<T> super::Mutex<T> for RedisMutex
142where
143    T: Send + FromRedisValue + ToRedisArgs + Sync + 'static,
144{
145    type Guard = RedisGuardCtor<T>;
146    async fn lock(&self) -> Result<RedisGuard<'_, T>> {
147        let mut interval =
148            tokio::time::interval(core::time::Duration::from_millis(LOCK_POLL_INTERVAL_MILLIS));
149        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
150        let expires_at;
151        loop {
152            tokio::select! {
153                _ = interval.tick() => {
154                    if let Some(exp) = self.try_acquire_lock().await? {
155                        expires_at = exp;
156                        break;
157                    }
158                }
159            }
160        }
161        Ok(RedisGuard::new(&self, expires_at))
162    }
163}
164
165pub struct RedisGuardCtor<T>(PhantomData<T>);
166
167impl<'a, T> GuardLt<'a, T> for RedisGuardCtor<T>
168where
169    T: FromRedisValue + ToRedisArgs + Send + Sync + 'static,
170{
171    type Guard = RedisGuard<'a, T>;
172}
173
174pub struct RedisGuard<'a, T> {
175    mutex: &'a RedisMutex,
176    drop_tx: Option<Sender<()>>,
177    loaded: AtomicBool,
178    data: RwLock<Option<T>>,
179    _pd: PhantomData<T>,
180}
181
182impl<'a, T> RedisGuard<'a, T> {
183    fn new(mutex: &'a RedisMutex, exp_at: Duration) -> RedisGuard<'a, T> {
184        trace!(key = %mutex.key, mutex_id = %mutex.mutex_id, expires_at = ?exp_at, "acquired lock");
185        let (drop_tx, mut drop_rx) = tokio::sync::oneshot::channel();
186        let mutex_clone = mutex.clone();
187        let _ = tokio::spawn(async move {
188            let mutex = mutex_clone;
189            let mut renewal_interval = tokio::time::interval(core::time::Duration::from_millis(
190                LOCK_REFRESH_INTERVAL_MILLIS,
191            ));
192            renewal_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
193
194            let panic_timeout = tokio::time::sleep(
195                exp_at
196                    - SystemTime::now().duration_since(UNIX_EPOCH).unwrap()
197                    - Duration::from_millis(RENEWAL_PANIC_BUFFER_MILLIS),
198            );
199            tokio::pin!(panic_timeout);
200            loop {
201                tokio::select! {
202                    _ = &mut drop_rx => {
203                        break;
204                    }
205                    _ = renewal_interval.tick() => {
206                        match mutex.try_renew_lock().await {
207                            Ok(Some(new_exp)) => {
208                                trace!(key = %mutex.key, mutex_id = %mutex.mutex_id, expires_at = ?new_exp, "renewed lock lease");
209                                panic_timeout.as_mut().reset(tokio::time::Instant::from_std(Instant::now() + new_exp
210                                    - SystemTime::now().duration_since(UNIX_EPOCH).unwrap()
211                                    - Duration::from_millis(RENEWAL_PANIC_BUFFER_MILLIS)));
212                            },
213                            Ok(None) => {
214                              panic!("failed to renew mutex because it had a different owner: {}", mutex.key);
215                            },
216                            Err(e) => {
217                                error!(key = %mutex.key, mutex_id = %mutex.mutex_id, "failed to renew lease on lock, scheduling retry: {}", e);
218                                continue;
219                            },
220                        }
221                    }
222                    _ = &mut panic_timeout => {
223                        panic!("failed to renew mutex before lease expiration: {}", mutex.key);
224                    }
225                }
226            }
227            match mutex.drop_lock().await {
228                Ok(false) => {
229                    warn!(key = %mutex.key, mutex_id = %mutex.mutex_id, "lock already had different owner while attempting to drop");
230                }
231                Err(e) => {
232                    error!(key = %mutex.key, mutex_id = %mutex.mutex_id, "failed to drop lock: {}", e);
233                }
234                _ => {
235                    trace!(key = %mutex.key, mutex_id = %mutex.mutex_id, "successfully dropped lock");
236                }
237            }
238        });
239        RedisGuard {
240            mutex,
241            loaded: AtomicBool::new(false),
242            drop_tx: Some(drop_tx),
243            data: RwLock::new(None),
244            _pd: Default::default(),
245        }
246    }
247}
248
249impl<'a, T> Drop for RedisGuard<'a, T> {
250    fn drop(&mut self) {
251        if let Some(tx) = take(&mut self.drop_tx) {
252            let _ = tx.send(());
253            trace!(key = %self.mutex.key, mutex_id = %self.mutex.mutex_id, "guard dropped");
254        }
255    }
256}
257
258fn format_data_key(key: &str) -> String {
259    format!("{}_data", key)
260}
261
262pub struct RedisDerefCtor<T>(PhantomData<T>);
263
264impl<'a, T> DerefLt<'a, T> for RedisDerefCtor<T>
265where
266    T: Send + Sync + 'static,
267{
268    type Deref = RwLockReadGuard<'a, Option<T>>;
269}
270
271#[async_trait]
272impl<'a, T> Guard<T> for RedisGuard<'a, T>
273where
274    T: FromRedisValue + ToRedisArgs + Send + Sync + 'static,
275{
276    type D = RedisDerefCtor<T>;
277    async fn store(&mut self, data: T) -> Result<()> {
278        let mut con = self.mutex.pool.get().await?;
279        con.set(format_data_key(&self.mutex.key), &data).await?;
280        let mut guard = self.data.write().await;
281        *guard = Some(data);
282        self.loaded.store(true, Ordering::Relaxed);
283        Ok(())
284    }
285    async fn load<'s>(&'s self) -> Result<RwLockReadGuard<'s, Option<T>>> {
286        if !self.loaded.load(std::sync::atomic::Ordering::Relaxed) {
287            let mut con = self.mutex.pool.get().await?;
288            let val: Option<T> = con.get(format_data_key(&self.mutex.key)).await?;
289            let mut guard = self.data.write().await;
290            *guard = val;
291            self.loaded.store(true, Ordering::Relaxed);
292        }
293        return Ok(self.data.read().await);
294    }
295    async fn clear(&mut self) -> Result<()> {
296        let mut con = self.mutex.pool.get().await?;
297        con.del(format_data_key(&self.mutex.key)).await?;
298        let mut guard = self.data.write().await;
299        *guard = None;
300        self.loaded.store(true, Ordering::Relaxed);
301        Ok(())
302    }
303}
304
305#[async_trait]
306impl<T, K> MutexProvider<T, K> for RedisMutexProvider
307where
308    T: FromRedisValue + ToRedisArgs + Send + Sync + 'static,
309    K: AsRef<str> + Send,
310{
311    type Mutex = RedisMutex;
312    async fn get(&self, key: K) -> Result<Self::Mutex>
313    where
314        K: 'async_trait,
315    {
316        let key = format!("amutex_{}_{}", self.provider_id, key.as_ref());
317        let mutex_id = thread_rng().gen::<u64>();
318        Ok(RedisMutex {
319            pool: self.pool.clone(),
320            key,
321            mutex_id,
322        })
323    }
324}
325
326impl ToRedisArgs for Empty {
327    fn write_redis_args<W>(&self, _out: &mut W)
328    where
329        W: ?Sized + RedisWrite,
330    {
331    }
332}
333
334impl FromRedisValue for Empty {
335    fn from_redis_value(_v: &Value) -> RedisResult<Self> {
336        return Ok(Empty);
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use bb8_redis::{bb8::Pool, RedisConnectionManager};
343    use testcontainers::{clients::Cli, images::generic::GenericImage};
344
345    use crate::spec::{check_empty, check_val};
346
347    use super::RedisMutexProvider;
348
349    #[tokio::test]
350    async fn test() {
351        let cli = Cli::default();
352        let port = 6379;
353        let container = cli.run(GenericImage::new("redis", "7.0").with_exposed_port(port));
354        let host_port = container.get_host_port_ipv4(port);
355        let uri = format!("redis://localhost:{host_port}");
356        let redis_connection_manager = RedisConnectionManager::new(uri.as_str()).unwrap();
357        let pool = Pool::builder()
358            .build(redis_connection_manager)
359            .await
360            .unwrap();
361        check_empty(RedisMutexProvider::new("testing".to_string(), pool.clone())).await;
362        check_val(RedisMutexProvider::new("testing_vals".to_string(), pool)).await;
363    }
364}