use crate::redis::Generic;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::ops::{Deref, DerefMut};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum LockError {
#[error("Locking failed")]
LockFailed,
#[error("Unlocking failed")]
UnlockFailed,
#[error("No connection to Redis available")]
NoConnection,
#[error("Lock expired with id #{0}")]
LockExpired(usize),
#[error("Error by Redis")]
Redis(#[from] redis::RedisError),
}
#[derive(Debug, PartialEq)]
enum LockNum {
Success,
Fail,
}
impl From<i8> for LockNum {
fn from(value: i8) -> Self {
match value {
0 => Self::Fail,
1 => Self::Success,
_ => panic!("Unexpected value"),
}
}
}
const LOCK_SCRIPT: &str = r#"
local val = redis.call("get", ARGV[1] .. ":lock")
if val == false or val == ARGV[3] then
redis.call("setex", ARGV[1] .. ":lock", ARGV[2], ARGV[3])
return 1
end
return 0"#;
const DROP_SCRIPT: &str = r#"
local current_lock = redis.call("get", ARGV[1] .. ":lock")
if current_lock == ARGV[2] then
redis.call("del", ARGV[1] .. ":lock")
return 1
end
return 0"#;
const UUID_SCRIPT: &str = r#"
redis.call("incr", ARGV[1] .. ":uuids")
local val = redis.call("get", ARGV[1] .. ":uuids")
return val"#;
const STORE_SCRIPT: &str = r#"
local current_lock = redis.call("get", ARGV[1] .. ":lock")
if current_lock == ARGV[2] then
redis.call("set", ARGV[1], ARGV[3])
return 1
end
return 0"#;
const LOAD_SCRIPT: &str = r#"
local current_lock = redis.call("get", ARGV[1] .. ":lock")
if current_lock == ARGV[2] then
local val = redis.call("get", ARGV[1])
return val
end
return nil"#;
pub struct Mutex<T> {
conn: Option<redis::Connection>,
data: Generic<T>,
uuid: usize,
}
impl<T> Mutex<T>
where
T: Serialize + DeserializeOwned,
{
pub fn new(data: Generic<T>) -> Self {
let mut conn = data
.client
.get_connection()
.expect("Failed to get connection to Redis");
let uuid = redis::Script::new(UUID_SCRIPT)
.arg(&data.key)
.invoke::<usize>(&mut conn)
.expect("Failed to get uuid");
Self {
data,
conn: Some(conn),
uuid,
}
}
pub fn lock(&mut self) -> Result<Guard<T>, LockError> {
let mut conn = match self.conn.take() {
Some(conn) => conn,
None => self
.client
.get_connection()
.map_err(|_| LockError::LockFailed)?,
};
let lock_cmd = redis::Script::new(LOCK_SCRIPT);
while LockNum::from(
lock_cmd
.arg(&self.data.key)
.arg(1)
.arg(&self.uuid.to_string())
.invoke::<i8>(&mut conn)
.expect("Failed to lock. You should not see this!"),
) == LockNum::Fail
{
std::hint::spin_loop();
}
self.conn = Some(conn);
let lock = Guard::new(self)?;
Ok(lock)
}
}
impl<T> DerefMut for Mutex<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
impl<T> Deref for Mutex<T> {
type Target = Generic<T>;
fn deref(&self) -> &Self::Target {
&self.data
}
}
pub struct Guard<'a, T> {
lock: &'a mut Mutex<T>,
expanded: bool,
}
impl<'a, T> Guard<'a, T>
where
T: Serialize + DeserializeOwned,
{
fn new(lock: &'a mut Mutex<T>) -> Result<Self, LockError> {
Ok(Self {
lock,
expanded: false,
})
}
pub fn expand(&mut self) {
if self.expanded {
return;
}
let conn = self.lock.conn.as_mut().expect("Connection should be there");
let expand = redis::Cmd::expire(format!("{}:lock", &self.lock.data.key), 2);
expand.execute(conn);
self.expanded = true;
}
pub fn store(&mut self, value: T) -> Result<(), LockError>
where
T: Serialize,
{
let conn = self.lock.conn.as_mut().ok_or(LockError::NoConnection)?;
let script = redis::Script::new(STORE_SCRIPT);
let result: i8 = script
.arg(&self.lock.data.key)
.arg(self.lock.uuid)
.arg(serde_json::to_string(&value).expect("Failed to serialize value"))
.invoke(conn)
.expect("Failed to store value. You should not see this!");
if result == 0 {
return Err(LockError::LockExpired(self.lock.uuid));
}
self.lock.data.cache = Some(value);
Ok(())
}
pub fn acquire(&mut self) -> &T {
self.lock.data.cache = self.try_get();
self.lock.data.cache.as_ref().unwrap()
}
fn try_get(&mut self) -> Option<T> {
let conn = self
.lock
.conn
.as_mut()
.ok_or(LockError::NoConnection)
.expect("Connection should be there");
let script = redis::Script::new(LOAD_SCRIPT);
let result: Option<String> = script
.arg(&self.lock.data.key)
.arg(self.lock.uuid)
.invoke(conn)
.expect("Failed to load value. You should not see this!");
let result = result?;
if result == "nil" {
return None;
}
Some(serde_json::from_str(&result).expect("Failed to deserialize value"))
}
}
impl<T> Deref for Guard<'_, T>
where
T: DeserializeOwned + Serialize,
{
type Target = Generic<T>;
fn deref(&self) -> &Self::Target {
&self.lock.data
}
}
impl<T> DerefMut for Guard<'_, T>
where
T: DeserializeOwned + Serialize,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.lock.data
}
}
impl<T> Drop for Guard<'_, T> {
fn drop(&mut self) {
let conn = self.lock.conn.as_mut().expect("Connection should be there");
let script = redis::Script::new(DROP_SCRIPT);
script
.arg(&self.lock.data.key)
.arg(self.lock.uuid)
.invoke::<()>(conn)
.expect("Failed to drop lock. You should not see this!");
}
}
#[cfg(test)]
mod tests {
use super::Mutex;
use crate::redis::Di32;
use std::thread;
#[test]
fn test_create_lock() {
let client = redis::Client::open("redis://localhost:6379").unwrap();
let client2 = client.clone();
thread::scope(|s| {
let t1 = s.spawn(move || {
let i32_2 = Di32::new("test_add_locking", client2.clone());
let mut lock2: Mutex<i32> = Mutex::new(i32_2);
let mut guard = lock2.lock().unwrap();
guard.store(2).expect("TODO: panic message");
assert_eq!(*guard, 2);
});
{
let i32 = Di32::new("test_add_locking", client.clone());
let mut lock: Mutex<i32> = Mutex::new(i32);
let mut guard = lock.lock().unwrap();
guard.store(1).expect("TODO: panic message");
assert_eq!(*guard, 1);
}
t1.join().expect("Failed to join thread1");
});
}
}