dtypes/redis/
mutex.rs

1use crate::redis::types::Generic;
2use serde::de::DeserializeOwned;
3use serde::Serialize;
4use std::ops::{Deref, DerefMut};
5use thiserror::Error;
6
7#[derive(Error, Debug)]
8pub enum LockError {
9    #[error("Locking failed")]
10    LockFailed,
11    #[error("Unlocking failed")]
12    UnlockFailed,
13    #[error("No connection to Redis available")]
14    NoConnection,
15    #[error("Lock expired with id #{0}")]
16    LockExpired(usize),
17    #[error("Error by Redis")]
18    Redis(#[from] redis::RedisError),
19}
20
21#[derive(Debug, PartialEq)]
22enum LockNum {
23    Success,
24    Fail,
25}
26
27impl From<i8> for LockNum {
28    fn from(value: i8) -> Self {
29        match value {
30            0 => Self::Fail,
31            1 => Self::Success,
32            _ => panic!("Unexpected value"),
33        }
34    }
35}
36
37/// The lock script.
38/// It is used to lock a value in Redis, so that only one instance can access it at a time.
39/// Takes 3 Arguments:
40/// 1. The key of the value to lock,
41/// 2. The timeout in seconds,
42/// 3. The value to store.
43const LOCK_SCRIPT: &str = r#"
44local val = redis.call("get", ARGV[1] .. ":lock")
45if val == false or val == ARGV[3] then
46    redis.call("setex", ARGV[1] .. ":lock", ARGV[2], ARGV[3])
47    return 1
48end
49return 0"#;
50
51/// The drop script.
52/// It is used to drop a value in Redis, so that only the instance that locked it can drop it.
53///
54/// Takes 2 Arguments:
55/// 1. The key of the value to drop,
56/// 2. The value to check.
57const DROP_SCRIPT: &str = r#"
58local current_lock = redis.call("get", ARGV[1] .. ":lock")
59if current_lock == ARGV[2] then
60    redis.call("del", ARGV[1] .. ":lock")
61    return 1
62end
63return 0"#;
64
65/// The uuid script.
66/// It is used to generate a uuid for the lock.
67/// It is a very simple counter that is stored in Redis and returns all numbers only once.
68///
69/// Takes 1 Argument:
70/// 1. The key of the value to lock.
71const UUID_SCRIPT: &str = r#"
72redis.call("incr", ARGV[1] .. ":uuids")
73local val = redis.call("get", ARGV[1] .. ":uuids")
74return val"#;
75
76/// The store script.
77/// It is used to store a value in Redis with a lock.
78///
79/// Takes 3 Arguments:
80/// 1. The key of the value to store,
81/// 2. The uuid of the lock object,
82/// 3. The value to store.
83const STORE_SCRIPT: &str = r#"
84local current_lock = redis.call("get", ARGV[1] .. ":lock")
85if current_lock == ARGV[2] then
86    redis.call("set", ARGV[1], ARGV[3])
87    return 1
88end
89return 0"#;
90
91/// The load script.
92/// It is used to load a value from Redis with a lock.
93///
94/// Takes 2 Arguments:
95/// 1. The key of the value to load,
96/// 2. The uuid of the lock.
97const LOAD_SCRIPT: &str = r#"
98local current_lock = redis.call("get", ARGV[1] .. ":lock")
99if current_lock == ARGV[2] then
100    local val = redis.call("get", ARGV[1])
101    return val
102end
103return nil"#;
104
105/// The RedisMutex struct.
106///
107/// It is used to lock a value in Redis, so that only one instance can access it at a time.
108/// You have to use RedisGeneric as the data type.
109/// It is a wrapper around the data type you want to store like the Mutex in std.
110///
111/// The lock is released when the guard is dropped or it expires.
112/// The default expiration time is 1000ms. If you need more time, use the [Guard::expand()] function.
113pub struct Mutex<T> {
114    conn: Option<redis::Connection>,
115    data: Generic<T>,
116    uuid: usize,
117}
118
119impl<T> Mutex<T>
120where
121    T: Serialize + DeserializeOwned,
122{
123    pub fn new(data: Generic<T>) -> Self {
124        let mut conn = data
125            .client
126            .get_connection()
127            .expect("Failed to get connection to Redis");
128
129        let uuid = redis::Script::new(UUID_SCRIPT)
130            .arg(&data.key)
131            .invoke::<usize>(&mut conn)
132            .expect("Failed to get uuid");
133
134        Self {
135            data,
136            conn: Some(conn),
137            uuid,
138        }
139    }
140
141    /// Locks the value in Redis.
142    /// This function blocks until the lock is acquired.
143    /// It returns a guard that can be used to access the value.
144    /// The guard will unlock the value when it is dropped.
145    ///
146    /// Beware that the value is not locked in the Rust sense and can be set by other instances,
147    /// if they skip the locking process and its LOCK_SCRIPT.
148    ///
149    /// If you try to lock a value that is already locked by another instance in the same scope,
150    /// this function will block until the lock is released, which will be happen after the lock
151    /// expires (1000ms).
152    /// If you need to extend this time, you can use the [Guard::expand()] function.
153    ///
154    /// # Example
155    /// ```
156    /// use dtypes::redis::types::Di32 as i32;
157    /// use dtypes::redis::sync::Mutex;
158    /// use std::thread::scope;
159    ///
160    /// let client = redis::Client::open("redis://localhost:6379").unwrap();
161    /// let client2 = client.clone();
162    ///
163    /// scope(|s| {
164    ///    let t1 = s.spawn(move || {
165    ///         let mut i32 = i32::new("test_add_example1", client2);
166    ///         let mut lock = Mutex::new(i32);
167    ///         let mut guard = lock.lock().unwrap();
168    ///         guard.store(2).expect("TODO: panic message");
169    ///         assert_eq!(*guard, 2);
170    ///     });
171    ///     {  
172    ///         let mut i32 = i32::new("test_add_example1", client);
173    ///         let mut lock = Mutex::new(i32);
174    ///         let mut guard = lock.lock().unwrap();
175    ///         guard.store(1).expect("Failed to store value");
176    ///         assert_eq!(*guard, 1);
177    ///     }
178    ///     t1.join().expect("Failed to join thread1");
179    /// });
180    /// ```
181    ///
182    /// It does not allow any deadlocks, because the lock will automatically release after some time.
183    /// So you have to check for errors, if you want to handle them.
184    ///
185    /// Beware: Your CPU can anytime switch to another thread, so you have to check for errors!
186    /// But if you are brave enough, you can drop the result and hope for the best.
187    ///
188    /// # Example
189    /// ```
190    /// use std::thread::sleep;
191    /// use dtypes::redis::types::Di32 as i32;
192    /// use dtypes::redis::sync::Mutex;
193    ///
194    /// let client = redis::Client::open("redis://localhost:6379").unwrap();
195    /// let mut i32 = i32::new("test_add_example2", client.clone());
196    /// i32.store(1);
197    /// assert_eq!(i32.acquire(), &1);
198    /// let mut lock = Mutex::new(i32);
199    ///
200    /// let mut guard = lock.lock().unwrap();
201    /// sleep(std::time::Duration::from_millis(1500));
202    /// let res = guard.store(3);
203    /// assert!(res.is_err(), "{:?}", res);
204    /// ```
205    pub fn lock(&mut self) -> Result<Guard<T>, LockError> {
206        let mut conn = match self.conn.take() {
207            Some(conn) => conn,
208            None => self
209                .client
210                .get_connection()
211                .map_err(|_| LockError::LockFailed)?,
212        };
213
214        let lock_cmd = redis::Script::new(LOCK_SCRIPT);
215
216        while LockNum::from(
217            lock_cmd
218                .arg(&self.data.key)
219                .arg(1)
220                .arg(&self.uuid.to_string())
221                .invoke::<i8>(&mut conn)
222                .expect("Failed to lock. You should not see this!"),
223        ) == LockNum::Fail
224        {
225            std::hint::spin_loop();
226        }
227
228        // store the connection for later use
229        self.conn = Some(conn);
230        let lock = Guard::new(self)?;
231
232        Ok(lock)
233    }
234}
235
236impl<T> DerefMut for Mutex<T> {
237    fn deref_mut(&mut self) -> &mut Self::Target {
238        &mut self.data
239    }
240}
241
242impl<T> Deref for Mutex<T> {
243    type Target = Generic<T>;
244
245    fn deref(&self) -> &Self::Target {
246        &self.data
247    }
248}
249
250/// The guard struct for the Mutex.
251/// It is used to access the value and not for you to initialize it by your own.
252pub struct Guard<'a, T> {
253    lock: &'a mut Mutex<T>,
254    expanded: bool,
255}
256
257impl<'a, T> Guard<'a, T>
258where
259    T: Serialize + DeserializeOwned,
260{
261    fn new(lock: &'a mut Mutex<T>) -> Result<Self, LockError> {
262        Ok(Self {
263            lock,
264            expanded: false,
265        })
266    }
267
268    /// Expands the lock time by 2000ms from the point on its called.
269    /// This is useful if you need to access the value for a longer time.
270    ///
271    /// But use it with caution, because it can lead to deadlocks.
272    /// To avoid deadlocks, we only allow one extension per lock.
273    pub fn expand(&mut self) {
274        if self.expanded {
275            return;
276        }
277
278        let conn = self.lock.conn.as_mut().expect("Connection should be there");
279        let expand = redis::Cmd::expire(format!("{}:lock", &self.lock.data.key), 2);
280        expand.execute(conn);
281        self.expanded = true;
282    }
283
284    /// Stores the value in Redis.
285    /// This function blocks until the value is stored.
286    /// Disables the store operation of the guarded value.
287    pub fn store(&mut self, value: T) -> Result<(), LockError>
288    where
289        T: Serialize,
290    {
291        let conn = self.lock.conn.as_mut().ok_or(LockError::NoConnection)?;
292        let script = redis::Script::new(STORE_SCRIPT);
293        let result: i8 = script
294            .arg(&self.lock.data.key)
295            .arg(self.lock.uuid)
296            .arg(serde_json::to_string(&value).expect("Failed to serialize value"))
297            .invoke(conn)
298            .expect("Failed to store value. You should not see this!");
299        if result == 0 {
300            return Err(LockError::LockExpired(self.lock.uuid));
301        }
302        self.lock.data.cache = Some(value);
303        Ok(())
304    }
305
306    /// Loads the value from Redis.
307    /// This function blocks until the value is loaded.
308    /// Shadows the load operation of the guarded value.
309    pub fn acquire(&mut self) -> &T {
310        self.lock.data.cache = self.try_get();
311        self.lock.data.cache.as_ref().unwrap()
312    }
313
314    fn try_get(&mut self) -> Option<T> {
315        let conn = self
316            .lock
317            .conn
318            .as_mut()
319            .ok_or(LockError::NoConnection)
320            .expect("Connection should be there");
321        let script = redis::Script::new(LOAD_SCRIPT);
322        let result: Option<String> = script
323            .arg(&self.lock.data.key)
324            .arg(self.lock.uuid)
325            .invoke(conn)
326            .expect("Failed to load value. You should not see this!");
327        let result = result?;
328
329        if result == "nil" {
330            return None;
331        }
332        Some(serde_json::from_str(&result).expect("Failed to deserialize value"))
333    }
334}
335
336impl<T> Deref for Guard<'_, T>
337where
338    T: DeserializeOwned + Serialize,
339{
340    type Target = Generic<T>;
341
342    fn deref(&self) -> &Self::Target {
343        &self.lock.data
344    }
345}
346
347impl<T> DerefMut for Guard<'_, T>
348where
349    T: DeserializeOwned + Serialize,
350{
351    fn deref_mut(&mut self) -> &mut Self::Target {
352        &mut self.lock.data
353    }
354}
355
356impl<T> Drop for Guard<'_, T> {
357    fn drop(&mut self) {
358        let conn = self.lock.conn.as_mut().expect("Connection should be there");
359        let script = redis::Script::new(DROP_SCRIPT);
360        script
361            .arg(&self.lock.data.key)
362            .arg(self.lock.uuid)
363            .invoke::<()>(conn)
364            .expect("Failed to drop lock. You should not see this!");
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::Mutex;
371    use crate::redis::types::Di32;
372    use std::thread;
373    #[test]
374    fn test_create_lock() {
375        let client = redis::Client::open("redis://localhost:6379").unwrap();
376        let client2 = client.clone();
377
378        thread::scope(|s| {
379            let t1 = s.spawn(move || {
380                let i32_2 = Di32::new("test_add_locking", client2.clone());
381                let mut lock2: Mutex<i32> = Mutex::new(i32_2);
382                let mut guard = lock2.lock().unwrap();
383                guard.store(2).expect("TODO: panic message");
384                assert_eq!(*guard, 2);
385            });
386            {
387                let i32 = Di32::new("test_add_locking", client.clone());
388                let mut lock: Mutex<i32> = Mutex::new(i32);
389                let mut guard = lock.lock().unwrap();
390                guard.store(1).expect("TODO: panic message");
391                assert_eq!(*guard, 1);
392            }
393            t1.join().expect("Failed to join thread1");
394        });
395    }
396}