1use std::fs::File;
2use std::io::{self, Read};
3use std::time::{Duration, Instant};
4
5use futures::future::join_all;
6use rand::{thread_rng, Rng};
7use redis::Value::Okay;
8use redis::{Client, IntoConnectionInfo, RedisResult, Value};
9
10const DEFAULT_RETRY_COUNT: u32 = 3;
11const DEFAULT_RETRY_DELAY: u32 = 200;
12const CLOCK_DRIFT_FACTOR: f32 = 0.01;
13const UNLOCK_SCRIPT: &str = r"if redis.call('get',KEYS[1]) == ARGV[1] then
14 return redis.call('del',KEYS[1])
15 else
16 return 0
17 end";
18
19#[derive(Debug)]
20pub enum RedLockError {
21 Io(io::Error),
22 Redis(redis::RedisError),
23 Unavailable,
24}
25
26#[derive(Debug, Clone)]
31pub struct RedLock {
32 pub servers: Vec<Client>,
34 quorum: u32,
35 retry_count: u32,
36 retry_delay: u32,
37}
38
39pub struct Lock<'a> {
40 pub resource: Vec<u8>,
42 pub val: Vec<u8>,
44 pub validity_time: usize,
47 pub lock_manager: &'a RedLock,
49}
50
51pub struct RedLockGuard<'a> {
52 pub lock: Lock<'a>,
53}
54
55impl Drop for RedLockGuard<'_> {
56 fn drop(&mut self) {
57 futures::executor::block_on(self.lock.lock_manager.unlock(&self.lock));
58 }
59}
60
61impl RedLock {
62 pub fn new<T: AsRef<str> + IntoConnectionInfo>(uris: Vec<T>) -> RedLock {
67 let quorum = (uris.len() as u32) / 2 + 1;
68
69 let servers: Vec<Client> = uris
70 .into_iter()
71 .map(|uri| Client::open(uri).unwrap())
72 .collect();
73
74 RedLock {
75 servers,
76 quorum,
77 retry_count: DEFAULT_RETRY_COUNT,
78 retry_delay: DEFAULT_RETRY_DELAY,
79 }
80 }
81
82 pub fn get_unique_lock_id(&self) -> io::Result<Vec<u8>> {
84 let file = File::open("/dev/urandom")?;
85 let mut buf = Vec::with_capacity(20);
86 match file.take(20).read_to_end(&mut buf) {
87 Ok(20) => Ok(buf),
88 Ok(_) => Err(io::Error::new(
89 io::ErrorKind::Other,
90 "Can't read enough random bytes",
91 )),
92 Err(e) => Err(e),
93 }
94 }
95
96 pub fn set_retry(&mut self, count: u32, delay: u32) {
101 self.retry_count = count;
102 self.retry_delay = delay;
103 }
104
105 async fn lock_instance(
106 &self,
107 client: &redis::Client,
108 resource: &[u8],
109 val: &[u8],
110 ttl: usize,
111 ) -> bool {
112 let mut con = match client.get_async_connection().await {
113 Err(_) => return false,
114 Ok(val) => val,
115 };
116 let result: RedisResult<Value> = redis::cmd("SET")
117 .arg(resource)
118 .arg(val)
119 .arg("nx")
120 .arg("px")
121 .arg(ttl)
122 .query_async(&mut con)
123 .await;
124
125 match result {
126 Ok(Okay) => true,
127 Ok(_) | Err(_) => false,
128 }
129 }
130
131 async fn unlock_instance(&self, client: &redis::Client, resource: &[u8], val: &[u8]) -> bool {
132 let mut con = match client.get_async_connection().await {
133 Err(_) => return false,
134 Ok(val) => val,
135 };
136 let script = redis::Script::new(UNLOCK_SCRIPT);
137 let result: RedisResult<i32> = script.key(resource).arg(val).invoke_async(&mut con).await;
138 match result {
139 Ok(val) => val == 1,
140 Err(_) => false,
141 }
142 }
143
144 pub async fn unlock(&self, lock: &Lock<'_>) {
149 join_all(
150 self.servers
151 .iter()
152 .map(|client| self.unlock_instance(client, &lock.resource, &lock.val)),
153 )
154 .await;
155 }
156
157 pub async fn lock(&self, resource: &[u8], ttl: usize) -> Result<Lock<'_>, RedLockError> {
165 let val = self.get_unique_lock_id().unwrap();
166
167 for _ in 0..self.retry_count {
168 let start_time = Instant::now();
169 let n = join_all(
170 self.servers
171 .iter()
172 .map(|client| self.lock_instance(client, resource, &val, ttl)),
173 )
174 .await
175 .into_iter()
176 .fold(0, |count, locked| if locked { count + 1 } else { count });
177
178 let drift = (ttl as f32 * CLOCK_DRIFT_FACTOR) as usize + 2;
179 let elapsed = start_time.elapsed();
180 let validity_time = ttl
181 - drift
182 - elapsed.as_secs() as usize * 1000
183 - elapsed.subsec_nanos() as usize / 1_000_000;
184
185 if n >= self.quorum && validity_time > 0 {
186 return Ok(Lock {
187 lock_manager: self,
188 resource: resource.to_vec(),
189 val,
190 validity_time,
191 });
192 } else {
193 join_all(
194 self.servers
195 .iter()
196 .map(|client| self.unlock_instance(client, resource, &val)),
197 )
198 .await;
199 }
200
201 let n = thread_rng().gen_range(0..self.retry_delay);
202 tokio::time::sleep(Duration::from_millis(u64::from(n))).await
203 }
204
205 Err(RedLockError::Unavailable)
206 }
207
208 pub async fn acquire(&self, resource: &[u8], ttl: usize) -> RedLockGuard<'_> {
209 loop {
210 if let Ok(lock) = self.lock(resource, ttl).await {
211 return RedLockGuard { lock };
212 }
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use anyhow::Result;
220 use once_cell::sync::Lazy;
221 use testcontainers::clients::Cli;
222 use testcontainers::images::redis::Redis;
223 use testcontainers::{Container, Docker};
224
225 use super::*;
226
227 static DOCKER: Lazy<Cli> = Lazy::new(Cli::default);
228 static CONTAINERS: Lazy<Vec<Container<Cli, Redis>>> = Lazy::new(|| {
229 (0..3)
230 .map(|_| DOCKER.run(Redis::default().with_tag("6-alpine")))
231 .collect()
232 });
233 static ADDRESSES: Lazy<Vec<String>> = Lazy::new(|| match std::env::var("ADDRESSES") {
234 Ok(addresses) => addresses.split(',').map(String::from).collect(),
235 Err(_) => CONTAINERS
236 .iter()
237 .map(|c| format!("redis://localhost:{}", c.get_host_port(6379).unwrap()))
238 .collect(),
239 });
240
241 #[test]
242 fn test_redlock_get_unique_id() -> Result<()> {
243 let rl = RedLock::new(Vec::<String>::new());
244 assert_eq!(rl.get_unique_lock_id()?.len(), 20);
245 Ok(())
246 }
247
248 #[test]
249 fn test_redlock_get_unique_id_uniqueness() -> Result<()> {
250 let rl = RedLock::new(Vec::<String>::new());
251
252 let id1 = rl.get_unique_lock_id()?;
253 let id2 = rl.get_unique_lock_id()?;
254
255 assert_eq!(20, id1.len());
256 assert_eq!(20, id2.len());
257 assert_ne!(id1, id2);
258 Ok(())
259 }
260
261 #[test]
262 fn test_redlock_valid_instance() {
263 println!("{}", ADDRESSES.join(","));
264 let rl = RedLock::new(ADDRESSES.clone());
265 assert_eq!(3, rl.servers.len());
266 assert_eq!(2, rl.quorum);
267 }
268
269 #[tokio::test]
270 async fn test_redlock_direct_unlock_fails() -> Result<()> {
271 println!("{}", ADDRESSES.join(","));
272 let rl = RedLock::new(ADDRESSES.clone());
273 let key = rl.get_unique_lock_id()?;
274
275 let val = rl.get_unique_lock_id()?;
276 assert!(!rl.unlock_instance(&rl.servers[0], &key, &val).await);
277 Ok(())
278 }
279
280 #[tokio::test]
281 async fn test_redlock_direct_unlock_succeeds() -> Result<()> {
282 println!("{}", ADDRESSES.join(","));
283 let rl = RedLock::new(ADDRESSES.clone());
284 let key = rl.get_unique_lock_id()?;
285
286 let val = rl.get_unique_lock_id()?;
287 let mut con = rl.servers[0].get_connection()?;
288 redis::cmd("SET").arg(&*key).arg(&*val).execute(&mut con);
289
290 assert!(rl.unlock_instance(&rl.servers[0], &key, &val).await);
291 Ok(())
292 }
293
294 #[tokio::test]
295 async fn test_redlock_direct_lock_succeeds() -> Result<()> {
296 println!("{}", ADDRESSES.join(","));
297 let rl = RedLock::new(ADDRESSES.clone());
298 let key = rl.get_unique_lock_id()?;
299
300 let val = rl.get_unique_lock_id()?;
301 let mut con = rl.servers[0].get_connection()?;
302
303 redis::cmd("DEL").arg(&*key).execute(&mut con);
304 assert!(rl.lock_instance(&rl.servers[0], &*key, &*val, 1000).await);
305 Ok(())
306 }
307
308 #[tokio::test]
309 async fn test_redlock_unlock() -> Result<()> {
310 println!("{}", ADDRESSES.join(","));
311 let rl = RedLock::new(ADDRESSES.clone());
312 let key = rl.get_unique_lock_id()?;
313
314 let val = rl.get_unique_lock_id()?;
315 let mut con = rl.servers[0].get_connection()?;
316 let _: () = redis::cmd("SET")
317 .arg(&*key)
318 .arg(&*val)
319 .query(&mut con)
320 .unwrap();
321
322 let lock = Lock {
323 lock_manager: &rl,
324 resource: key,
325 val,
326 validity_time: 0,
327 };
328 rl.unlock(&lock).await;
329 Ok(())
330 }
331
332 #[tokio::test]
333 async fn test_redlock_lock() -> Result<()> {
334 println!("{}", ADDRESSES.join(","));
335 let rl = RedLock::new(ADDRESSES.clone());
336
337 let key = rl.get_unique_lock_id()?;
338 match rl.lock(&key, 1000).await {
339 Ok(lock) => {
340 assert_eq!(key, lock.resource);
341 assert_eq!(20, lock.val.len());
342 assert!(lock.validity_time > 900);
343 assert!(
344 lock.validity_time > 900,
345 "validity time: {}",
346 lock.validity_time
347 );
348 }
349 Err(_) => panic!("Lock failed"),
350 }
351 Ok(())
352 }
353
354 #[tokio::test]
355 async fn test_redlock_lock_unlock() -> Result<()> {
356 println!("{}", ADDRESSES.join(","));
357 let rl = RedLock::new(ADDRESSES.clone());
358 let rl2 = RedLock::new(ADDRESSES.clone());
359
360 let key = rl.get_unique_lock_id()?;
361
362 let lock = rl.lock(&key, 1000).await.unwrap();
363 assert!(
364 lock.validity_time > 900,
365 "validity time: {}",
366 lock.validity_time
367 );
368
369 if let Ok(_l) = rl2.lock(&key, 1000).await {
370 panic!("Lock acquired, even though it should be locked")
371 }
372
373 rl.unlock(&lock).await;
374
375 match rl2.lock(&key, 1000).await {
376 Ok(l) => assert!(l.validity_time > 900),
377 Err(_) => panic!("Lock couldn't be acquired"),
378 }
379 Ok(())
380 }
381
382 #[tokio::test]
383 async fn test_redlock_lock_unlock_raii() -> Result<()> {
384 println!("{}", ADDRESSES.join(","));
385 let rl = RedLock::new(ADDRESSES.clone());
386 let rl2 = RedLock::new(ADDRESSES.clone());
387
388 let key = rl.get_unique_lock_id()?;
389 async {
390 let lock_guard = rl.acquire(&key, 1000).await;
391 let lock = &lock_guard.lock;
392 assert!(
393 lock.validity_time > 900,
394 "validity time: {}",
395 lock.validity_time
396 );
397
398 if let Ok(_l) = rl2.lock(&key, 1000).await {
399 panic!("Lock acquired, even though it should be locked")
400 }
401 }
402 .await;
403
404 match rl2.lock(&key, 1000).await {
405 Ok(l) => assert!(l.validity_time > 900),
406 Err(_) => panic!("Lock couldn't be acquired"),
407 }
408 Ok(())
409 }
410}