redlock_async/
redlock.rs

1use std::fs::File;
2use std::io::{self, Read};
3use std::time::{Duration, Instant};
4
5use futures::future::join_all;
6use rand::{thread_rng, Rng};
7use redis::Value::Okay;
8use redis::{Client, IntoConnectionInfo, RedisResult, Value};
9
10const DEFAULT_RETRY_COUNT: u32 = 3;
11const DEFAULT_RETRY_DELAY: u32 = 200;
12const CLOCK_DRIFT_FACTOR: f32 = 0.01;
13const UNLOCK_SCRIPT: &str = r"if redis.call('get',KEYS[1]) == ARGV[1] then
14                                return redis.call('del',KEYS[1])
15                              else
16                                return 0
17                              end";
18
19#[derive(Debug)]
20pub enum RedLockError {
21    Io(io::Error),
22    Redis(redis::RedisError),
23    Unavailable,
24}
25
26/// The lock manager.
27///
28/// Implements the necessary functionality to acquire and release locks
29/// and handles the Redis connections.
30#[derive(Debug, Clone)]
31pub struct RedLock {
32    /// List of all Redis clients
33    pub servers: Vec<Client>,
34    quorum: u32,
35    retry_count: u32,
36    retry_delay: u32,
37}
38
39pub struct Lock<'a> {
40    /// The resource to lock. Will be used as the key in Redis.
41    pub resource: Vec<u8>,
42    /// The value for this lock.
43    pub val: Vec<u8>,
44    /// Time the lock is still valid.
45    /// Should only be slightly smaller than the requested TTL.
46    pub validity_time: usize,
47    /// Used to limit the lifetime of a lock to its lock manager.
48    pub lock_manager: &'a RedLock,
49}
50
51pub struct RedLockGuard<'a> {
52    pub lock: Lock<'a>,
53}
54
55impl Drop for RedLockGuard<'_> {
56    fn drop(&mut self) {
57        futures::executor::block_on(self.lock.lock_manager.unlock(&self.lock));
58    }
59}
60
61impl RedLock {
62    /// Create a new lock manager instance, defined by the given Redis connection uris.
63    /// Quorum is defined to be N/2+1, with N being the number of given Redis instances.
64    ///
65    /// Sample URI: `"redis://127.0.0.1:6379"`
66    pub fn new<T: AsRef<str> + IntoConnectionInfo>(uris: Vec<T>) -> RedLock {
67        let quorum = (uris.len() as u32) / 2 + 1;
68
69        let servers: Vec<Client> = uris
70            .into_iter()
71            .map(|uri| Client::open(uri).unwrap())
72            .collect();
73
74        RedLock {
75            servers,
76            quorum,
77            retry_count: DEFAULT_RETRY_COUNT,
78            retry_delay: DEFAULT_RETRY_DELAY,
79        }
80    }
81
82    /// Get 20 random bytes from `/dev/urandom`.
83    pub fn get_unique_lock_id(&self) -> io::Result<Vec<u8>> {
84        let file = File::open("/dev/urandom")?;
85        let mut buf = Vec::with_capacity(20);
86        match file.take(20).read_to_end(&mut buf) {
87            Ok(20) => Ok(buf),
88            Ok(_) => Err(io::Error::new(
89                io::ErrorKind::Other,
90                "Can't read enough random bytes",
91            )),
92            Err(e) => Err(e),
93        }
94    }
95
96    /// Set retry count and retry delay.
97    ///
98    /// Retry count defaults to `3`.
99    /// Retry delay defaults to `200`.
100    pub fn set_retry(&mut self, count: u32, delay: u32) {
101        self.retry_count = count;
102        self.retry_delay = delay;
103    }
104
105    async fn lock_instance(
106        &self,
107        client: &redis::Client,
108        resource: &[u8],
109        val: &[u8],
110        ttl: usize,
111    ) -> bool {
112        let mut con = match client.get_async_connection().await {
113            Err(_) => return false,
114            Ok(val) => val,
115        };
116        let result: RedisResult<Value> = redis::cmd("SET")
117            .arg(resource)
118            .arg(val)
119            .arg("nx")
120            .arg("px")
121            .arg(ttl)
122            .query_async(&mut con)
123            .await;
124
125        match result {
126            Ok(Okay) => true,
127            Ok(_) | Err(_) => false,
128        }
129    }
130
131    async fn unlock_instance(&self, client: &redis::Client, resource: &[u8], val: &[u8]) -> bool {
132        let mut con = match client.get_async_connection().await {
133            Err(_) => return false,
134            Ok(val) => val,
135        };
136        let script = redis::Script::new(UNLOCK_SCRIPT);
137        let result: RedisResult<i32> = script.key(resource).arg(val).invoke_async(&mut con).await;
138        match result {
139            Ok(val) => val == 1,
140            Err(_) => false,
141        }
142    }
143
144    /// Unlock the given lock.
145    ///
146    /// Unlock is best effort. It will simply try to contact all instances
147    /// and remove the key.
148    pub async fn unlock(&self, lock: &Lock<'_>) {
149        join_all(
150            self.servers
151                .iter()
152                .map(|client| self.unlock_instance(client, &lock.resource, &lock.val)),
153        )
154        .await;
155    }
156
157    /// Acquire the lock for the given resource and the requested TTL.
158    ///
159    /// If it succeeds, a `Lock` instance is returned,
160    /// including the value and the validity time
161    ///
162    /// If it fails. `None` is returned.
163    /// A user should retry after a short wait time.
164    pub async fn lock(&self, resource: &[u8], ttl: usize) -> Result<Lock<'_>, RedLockError> {
165        let val = self.get_unique_lock_id().unwrap();
166
167        for _ in 0..self.retry_count {
168            let start_time = Instant::now();
169            let n = join_all(
170                self.servers
171                    .iter()
172                    .map(|client| self.lock_instance(client, resource, &val, ttl)),
173            )
174            .await
175            .into_iter()
176            .fold(0, |count, locked| if locked { count + 1 } else { count });
177
178            let drift = (ttl as f32 * CLOCK_DRIFT_FACTOR) as usize + 2;
179            let elapsed = start_time.elapsed();
180            let validity_time = ttl
181                - drift
182                - elapsed.as_secs() as usize * 1000
183                - elapsed.subsec_nanos() as usize / 1_000_000;
184
185            if n >= self.quorum && validity_time > 0 {
186                return Ok(Lock {
187                    lock_manager: self,
188                    resource: resource.to_vec(),
189                    val,
190                    validity_time,
191                });
192            } else {
193                join_all(
194                    self.servers
195                        .iter()
196                        .map(|client| self.unlock_instance(client, resource, &val)),
197                )
198                .await;
199            }
200
201            let n = thread_rng().gen_range(0..self.retry_delay);
202            tokio::time::sleep(Duration::from_millis(u64::from(n))).await
203        }
204
205        Err(RedLockError::Unavailable)
206    }
207
208    pub async fn acquire(&self, resource: &[u8], ttl: usize) -> RedLockGuard<'_> {
209        loop {
210            if let Ok(lock) = self.lock(resource, ttl).await {
211                return RedLockGuard { lock };
212            }
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use anyhow::Result;
220    use once_cell::sync::Lazy;
221    use testcontainers::clients::Cli;
222    use testcontainers::images::redis::Redis;
223    use testcontainers::{Container, Docker};
224
225    use super::*;
226
227    static DOCKER: Lazy<Cli> = Lazy::new(Cli::default);
228    static CONTAINERS: Lazy<Vec<Container<Cli, Redis>>> = Lazy::new(|| {
229        (0..3)
230            .map(|_| DOCKER.run(Redis::default().with_tag("6-alpine")))
231            .collect()
232    });
233    static ADDRESSES: Lazy<Vec<String>> = Lazy::new(|| match std::env::var("ADDRESSES") {
234        Ok(addresses) => addresses.split(',').map(String::from).collect(),
235        Err(_) => CONTAINERS
236            .iter()
237            .map(|c| format!("redis://localhost:{}", c.get_host_port(6379).unwrap()))
238            .collect(),
239    });
240
241    #[test]
242    fn test_redlock_get_unique_id() -> Result<()> {
243        let rl = RedLock::new(Vec::<String>::new());
244        assert_eq!(rl.get_unique_lock_id()?.len(), 20);
245        Ok(())
246    }
247
248    #[test]
249    fn test_redlock_get_unique_id_uniqueness() -> Result<()> {
250        let rl = RedLock::new(Vec::<String>::new());
251
252        let id1 = rl.get_unique_lock_id()?;
253        let id2 = rl.get_unique_lock_id()?;
254
255        assert_eq!(20, id1.len());
256        assert_eq!(20, id2.len());
257        assert_ne!(id1, id2);
258        Ok(())
259    }
260
261    #[test]
262    fn test_redlock_valid_instance() {
263        println!("{}", ADDRESSES.join(","));
264        let rl = RedLock::new(ADDRESSES.clone());
265        assert_eq!(3, rl.servers.len());
266        assert_eq!(2, rl.quorum);
267    }
268
269    #[tokio::test]
270    async fn test_redlock_direct_unlock_fails() -> Result<()> {
271        println!("{}", ADDRESSES.join(","));
272        let rl = RedLock::new(ADDRESSES.clone());
273        let key = rl.get_unique_lock_id()?;
274
275        let val = rl.get_unique_lock_id()?;
276        assert!(!rl.unlock_instance(&rl.servers[0], &key, &val).await);
277        Ok(())
278    }
279
280    #[tokio::test]
281    async fn test_redlock_direct_unlock_succeeds() -> Result<()> {
282        println!("{}", ADDRESSES.join(","));
283        let rl = RedLock::new(ADDRESSES.clone());
284        let key = rl.get_unique_lock_id()?;
285
286        let val = rl.get_unique_lock_id()?;
287        let mut con = rl.servers[0].get_connection()?;
288        redis::cmd("SET").arg(&*key).arg(&*val).execute(&mut con);
289
290        assert!(rl.unlock_instance(&rl.servers[0], &key, &val).await);
291        Ok(())
292    }
293
294    #[tokio::test]
295    async fn test_redlock_direct_lock_succeeds() -> Result<()> {
296        println!("{}", ADDRESSES.join(","));
297        let rl = RedLock::new(ADDRESSES.clone());
298        let key = rl.get_unique_lock_id()?;
299
300        let val = rl.get_unique_lock_id()?;
301        let mut con = rl.servers[0].get_connection()?;
302
303        redis::cmd("DEL").arg(&*key).execute(&mut con);
304        assert!(rl.lock_instance(&rl.servers[0], &*key, &*val, 1000).await);
305        Ok(())
306    }
307
308    #[tokio::test]
309    async fn test_redlock_unlock() -> Result<()> {
310        println!("{}", ADDRESSES.join(","));
311        let rl = RedLock::new(ADDRESSES.clone());
312        let key = rl.get_unique_lock_id()?;
313
314        let val = rl.get_unique_lock_id()?;
315        let mut con = rl.servers[0].get_connection()?;
316        let _: () = redis::cmd("SET")
317            .arg(&*key)
318            .arg(&*val)
319            .query(&mut con)
320            .unwrap();
321
322        let lock = Lock {
323            lock_manager: &rl,
324            resource: key,
325            val,
326            validity_time: 0,
327        };
328        rl.unlock(&lock).await;
329        Ok(())
330    }
331
332    #[tokio::test]
333    async fn test_redlock_lock() -> Result<()> {
334        println!("{}", ADDRESSES.join(","));
335        let rl = RedLock::new(ADDRESSES.clone());
336
337        let key = rl.get_unique_lock_id()?;
338        match rl.lock(&key, 1000).await {
339            Ok(lock) => {
340                assert_eq!(key, lock.resource);
341                assert_eq!(20, lock.val.len());
342                assert!(lock.validity_time > 900);
343                assert!(
344                    lock.validity_time > 900,
345                    "validity time: {}",
346                    lock.validity_time
347                );
348            }
349            Err(_) => panic!("Lock failed"),
350        }
351        Ok(())
352    }
353
354    #[tokio::test]
355    async fn test_redlock_lock_unlock() -> Result<()> {
356        println!("{}", ADDRESSES.join(","));
357        let rl = RedLock::new(ADDRESSES.clone());
358        let rl2 = RedLock::new(ADDRESSES.clone());
359
360        let key = rl.get_unique_lock_id()?;
361
362        let lock = rl.lock(&key, 1000).await.unwrap();
363        assert!(
364            lock.validity_time > 900,
365            "validity time: {}",
366            lock.validity_time
367        );
368
369        if let Ok(_l) = rl2.lock(&key, 1000).await {
370            panic!("Lock acquired, even though it should be locked")
371        }
372
373        rl.unlock(&lock).await;
374
375        match rl2.lock(&key, 1000).await {
376            Ok(l) => assert!(l.validity_time > 900),
377            Err(_) => panic!("Lock couldn't be acquired"),
378        }
379        Ok(())
380    }
381
382    #[tokio::test]
383    async fn test_redlock_lock_unlock_raii() -> Result<()> {
384        println!("{}", ADDRESSES.join(","));
385        let rl = RedLock::new(ADDRESSES.clone());
386        let rl2 = RedLock::new(ADDRESSES.clone());
387
388        let key = rl.get_unique_lock_id()?;
389        async {
390            let lock_guard = rl.acquire(&key, 1000).await;
391            let lock = &lock_guard.lock;
392            assert!(
393                lock.validity_time > 900,
394                "validity time: {}",
395                lock.validity_time
396            );
397
398            if let Ok(_l) = rl2.lock(&key, 1000).await {
399                panic!("Lock acquired, even though it should be locked")
400            }
401        }
402        .await;
403
404        match rl2.lock(&key, 1000).await {
405            Ok(l) => assert!(l.validity_time > 900),
406            Err(_) => panic!("Lock couldn't be acquired"),
407        }
408        Ok(())
409    }
410}