1use crate::redis::types::Generic;
2use serde::de::DeserializeOwned;
3use serde::Serialize;
4use std::ops::{Deref, DerefMut};
5use thiserror::Error;
6
7#[derive(Error, Debug)]
8pub enum LockError {
9 #[error("Locking failed")]
10 LockFailed,
11 #[error("Unlocking failed")]
12 UnlockFailed,
13 #[error("No connection to Redis available")]
14 NoConnection,
15 #[error("Lock expired with id #{0}")]
16 LockExpired(usize),
17 #[error("Error by Redis")]
18 Redis(#[from] redis::RedisError),
19}
20
21#[derive(Debug, PartialEq)]
22enum LockNum {
23 Success,
24 Fail,
25}
26
27impl From<i8> for LockNum {
28 fn from(value: i8) -> Self {
29 match value {
30 0 => Self::Fail,
31 1 => Self::Success,
32 _ => panic!("Unexpected value"),
33 }
34 }
35}
36
37const LOCK_SCRIPT: &str = r#"
44local val = redis.call("get", ARGV[1] .. ":lock")
45if val == false or val == ARGV[3] then
46 redis.call("setex", ARGV[1] .. ":lock", ARGV[2], ARGV[3])
47 return 1
48end
49return 0"#;
50
51const DROP_SCRIPT: &str = r#"
58local current_lock = redis.call("get", ARGV[1] .. ":lock")
59if current_lock == ARGV[2] then
60 redis.call("del", ARGV[1] .. ":lock")
61 return 1
62end
63return 0"#;
64
65const UUID_SCRIPT: &str = r#"
72redis.call("incr", ARGV[1] .. ":uuids")
73local val = redis.call("get", ARGV[1] .. ":uuids")
74return val"#;
75
76const STORE_SCRIPT: &str = r#"
84local current_lock = redis.call("get", ARGV[1] .. ":lock")
85if current_lock == ARGV[2] then
86 redis.call("set", ARGV[1], ARGV[3])
87 return 1
88end
89return 0"#;
90
91const LOAD_SCRIPT: &str = r#"
98local current_lock = redis.call("get", ARGV[1] .. ":lock")
99if current_lock == ARGV[2] then
100 local val = redis.call("get", ARGV[1])
101 return val
102end
103return nil"#;
104
105pub struct Mutex<T> {
114 conn: Option<redis::Connection>,
115 data: Generic<T>,
116 uuid: usize,
117}
118
119impl<T> Mutex<T>
120where
121 T: Serialize + DeserializeOwned,
122{
123 pub fn new(data: Generic<T>) -> Self {
124 let mut conn = data
125 .client
126 .get_connection()
127 .expect("Failed to get connection to Redis");
128
129 let uuid = redis::Script::new(UUID_SCRIPT)
130 .arg(&data.key)
131 .invoke::<usize>(&mut conn)
132 .expect("Failed to get uuid");
133
134 Self {
135 data,
136 conn: Some(conn),
137 uuid,
138 }
139 }
140
141 pub fn lock(&mut self) -> Result<Guard<T>, LockError> {
206 let mut conn = match self.conn.take() {
207 Some(conn) => conn,
208 None => self
209 .client
210 .get_connection()
211 .map_err(|_| LockError::LockFailed)?,
212 };
213
214 let lock_cmd = redis::Script::new(LOCK_SCRIPT);
215
216 while LockNum::from(
217 lock_cmd
218 .arg(&self.data.key)
219 .arg(1)
220 .arg(&self.uuid.to_string())
221 .invoke::<i8>(&mut conn)
222 .expect("Failed to lock. You should not see this!"),
223 ) == LockNum::Fail
224 {
225 std::hint::spin_loop();
226 }
227
228 self.conn = Some(conn);
230 let lock = Guard::new(self)?;
231
232 Ok(lock)
233 }
234}
235
236impl<T> DerefMut for Mutex<T> {
237 fn deref_mut(&mut self) -> &mut Self::Target {
238 &mut self.data
239 }
240}
241
242impl<T> Deref for Mutex<T> {
243 type Target = Generic<T>;
244
245 fn deref(&self) -> &Self::Target {
246 &self.data
247 }
248}
249
250pub struct Guard<'a, T> {
253 lock: &'a mut Mutex<T>,
254 expanded: bool,
255}
256
257impl<'a, T> Guard<'a, T>
258where
259 T: Serialize + DeserializeOwned,
260{
261 fn new(lock: &'a mut Mutex<T>) -> Result<Self, LockError> {
262 Ok(Self {
263 lock,
264 expanded: false,
265 })
266 }
267
268 pub fn expand(&mut self) {
274 if self.expanded {
275 return;
276 }
277
278 let conn = self.lock.conn.as_mut().expect("Connection should be there");
279 let expand = redis::Cmd::expire(format!("{}:lock", &self.lock.data.key), 2);
280 expand.execute(conn);
281 self.expanded = true;
282 }
283
284 pub fn store(&mut self, value: T) -> Result<(), LockError>
288 where
289 T: Serialize,
290 {
291 let conn = self.lock.conn.as_mut().ok_or(LockError::NoConnection)?;
292 let script = redis::Script::new(STORE_SCRIPT);
293 let result: i8 = script
294 .arg(&self.lock.data.key)
295 .arg(self.lock.uuid)
296 .arg(serde_json::to_string(&value).expect("Failed to serialize value"))
297 .invoke(conn)
298 .expect("Failed to store value. You should not see this!");
299 if result == 0 {
300 return Err(LockError::LockExpired(self.lock.uuid));
301 }
302 self.lock.data.cache = Some(value);
303 Ok(())
304 }
305
306 pub fn acquire(&mut self) -> &T {
310 self.lock.data.cache = self.try_get();
311 self.lock.data.cache.as_ref().unwrap()
312 }
313
314 fn try_get(&mut self) -> Option<T> {
315 let conn = self
316 .lock
317 .conn
318 .as_mut()
319 .ok_or(LockError::NoConnection)
320 .expect("Connection should be there");
321 let script = redis::Script::new(LOAD_SCRIPT);
322 let result: Option<String> = script
323 .arg(&self.lock.data.key)
324 .arg(self.lock.uuid)
325 .invoke(conn)
326 .expect("Failed to load value. You should not see this!");
327 let result = result?;
328
329 if result == "nil" {
330 return None;
331 }
332 Some(serde_json::from_str(&result).expect("Failed to deserialize value"))
333 }
334}
335
336impl<T> Deref for Guard<'_, T>
337where
338 T: DeserializeOwned + Serialize,
339{
340 type Target = Generic<T>;
341
342 fn deref(&self) -> &Self::Target {
343 &self.lock.data
344 }
345}
346
347impl<T> DerefMut for Guard<'_, T>
348where
349 T: DeserializeOwned + Serialize,
350{
351 fn deref_mut(&mut self) -> &mut Self::Target {
352 &mut self.lock.data
353 }
354}
355
356impl<T> Drop for Guard<'_, T> {
357 fn drop(&mut self) {
358 let conn = self.lock.conn.as_mut().expect("Connection should be there");
359 let script = redis::Script::new(DROP_SCRIPT);
360 script
361 .arg(&self.lock.data.key)
362 .arg(self.lock.uuid)
363 .invoke::<()>(conn)
364 .expect("Failed to drop lock. You should not see this!");
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::Mutex;
371 use crate::redis::types::Di32;
372 use std::thread;
373 #[test]
374 fn test_create_lock() {
375 let client = redis::Client::open("redis://localhost:6379").unwrap();
376 let client2 = client.clone();
377
378 thread::scope(|s| {
379 let t1 = s.spawn(move || {
380 let i32_2 = Di32::new("test_add_locking", client2.clone());
381 let mut lock2: Mutex<i32> = Mutex::new(i32_2);
382 let mut guard = lock2.lock().unwrap();
383 guard.store(2).expect("TODO: panic message");
384 assert_eq!(*guard, 2);
385 });
386 {
387 let i32 = Di32::new("test_add_locking", client.clone());
388 let mut lock: Mutex<i32> = Mutex::new(i32);
389 let mut guard = lock.lock().unwrap();
390 guard.store(1).expect("TODO: panic message");
391 assert_eq!(*guard, 1);
392 }
393 t1.join().expect("Failed to join thread1");
394 });
395 }
396}