pub mod lua_script;
use anyhow::{Result, anyhow};
use redis::Script;
use redis::aio::ConnectionLike;
use std::any::Any;
use tokio::time::{Duration, sleep};
use tokio::{select, spawn, sync::oneshot};
pub struct Locker<T: ConnectionLike> {
conn: T,
mode: Mode,
to: u64,
rty_iv: u64,
rty_to: i64,
ext_iv: u64,
}
impl<T: ConnectionLike + Send + Clone + 'static> Locker<T> {
pub fn new(conn: T) -> Self {
Self {
conn,
mode: Mode::W,
to: 3000,
rty_iv: 100,
rty_to: 1000,
ext_iv: 1000,
}
}
pub fn mode(mut self, mode: &Mode) -> Self {
self.mode = mode.clone();
self
}
pub fn to(mut self, ms: u64) -> Self {
self.to = ms;
self
}
pub fn rty_int(mut self, ms: u64) -> Self {
self.rty_iv = ms;
self
}
pub fn rty_to(mut self, ms: i64) -> Self {
self.rty_to = ms;
self
}
pub fn ext_int(mut self, ms: u64) -> Self {
self.ext_iv = ms;
self
}
pub async fn lock(mut self, key: String) -> Result<impl Future<Output = Result<()>>> {
let id = lock(
&mut self.conn,
&self.mode,
&key,
self.to,
self.rty_iv,
self.rty_to,
)
.await?;
let mut conn_c = self.conn.clone();
let mode_c = self.mode.clone();
let key_c = key.clone();
let id_c = id.clone();
let (ext_tx, mut ext_rx) = oneshot::channel();
let ext_iv = Duration::from_millis(self.ext_iv);
let ext = spawn(async move {
let mut ext_ac = self.ext_iv;
loop {
select! {
_ = &mut ext_rx => break,
_ = sleep(ext_iv) => {
if extend(&mut conn_c, &mode_c, &key_c, &id_c, self.to).await.is_ok() {
ext_ac += self.ext_iv;
if ext_ac > self.to {
panic!("Failed to extend lock")
}
}
ext_ac = self.ext_iv;
}
}
}
});
let unlock = async move {
if ext_tx.send(()).is_err() {
panic!("Failed to stop lock extension");
}
ext.await?;
unlock(&mut self.conn, &self.mode, &key, &id).await
};
Ok(unlock)
}
pub async fn lock_exec<V: Any, F: Future<Output = Result<V>>>(
self,
key: String,
f: F,
) -> Result<V> {
let unlock = self.lock(key.to_string()).await?;
let r = f.await;
let _ = unlock.await;
r
}
}
#[derive(Clone)]
pub enum Mode {
R,
W,
}
pub async fn lock<T: ConnectionLike>(
conn: &mut T,
mode: &Mode,
key: &str,
to: u64,
rty_iv: u64,
rty_to: i64,
) -> Result<String> {
let id = uuid::Uuid::new_v4().to_string();
let script = match mode {
Mode::R => Script::new(lua_script::R_LOCK),
Mode::W => Script::new(lua_script::W_LOCK),
};
select! {
r = async move {
loop {
if let 1 = script.key(key).arg(&id).arg(to).invoke_async(conn).await? {
break Ok(id);
}
sleep(Duration::from_millis(rty_iv)).await
}
} => r,
Some(v) = async move {
match rty_to {
0.. => {
sleep(Duration::from_millis(rty_to as u64)).await;
Some(Err(anyhow!("Timed out")))
}
_ => None,
}
} => v,
}
}
pub async fn extend<T: ConnectionLike>(
conn: &mut T,
mode: &Mode,
key: &str,
id: &str,
to: u64,
) -> Result<()> {
let script = match mode {
Mode::R => Script::new(lua_script::R_EXTEND),
Mode::W => Script::new(lua_script::W_EXTEND),
};
match script.key(key).arg(id).arg(to).invoke_async(conn).await? {
1 => Ok(()),
_ => Err(anyhow!("Not found")),
}
}
pub async fn unlock<T: ConnectionLike>(
conn: &mut T,
mode: &Mode,
key: &str,
id: &str,
) -> Result<()> {
let script = match mode {
Mode::R => Script::new(lua_script::R_UNLOCK),
Mode::W => Script::new(lua_script::W_UNLOCK),
};
match script.key(key).arg(id).invoke_async(conn).await? {
1 => Ok(()),
_ => Err(anyhow!("Not found")),
}
}
#[cfg(test)]
mod test {
use super::*;
use std::iter;
#[tokio::test]
async fn test_lock_exclusive() {
let url = "redis://:c6bfb872-49f6-48bc-858d-2aca0c020702@127.0.0.1:8003/0";
let cli = redis::Client::open(url).unwrap();
let cfg = redis::aio::ConnectionManagerConfig::new().set_max_delay(1000);
let con = redis::aio::ConnectionManager::new_with_config(cli, cfg)
.await
.unwrap();
let key = String::from("test:lock_key");
let r_locks: Vec<_> = iter::repeat_with(|| {
let con = con.clone();
let key = key.clone();
spawn(async move { Locker::new(con).mode(&Mode::R).lock(key).await })
})
.take(10)
.collect();
let mut r_unlocks = Vec::new();
for r_lock in r_locks {
r_unlocks.push(r_lock.await.unwrap().unwrap());
}
sleep(Duration::from_secs(5)).await;
assert!(
Locker::new(con.clone())
.mode(&Mode::W)
.lock(key.clone())
.await
.is_err()
);
for r_unlock in r_unlocks {
r_unlock.await.unwrap();
}
let w_unlock = Locker::new(con.clone())
.mode(&Mode::W)
.lock(key.clone())
.await
.unwrap();
sleep(Duration::from_secs(5)).await;
assert!(
Locker::new(con.clone())
.mode(&Mode::W)
.lock(key.clone())
.await
.is_err()
);
assert!(
Locker::new(con.clone())
.mode(&Mode::R)
.lock(key.clone())
.await
.is_err()
);
w_unlock.await.unwrap();
}
#[tokio::test]
async fn test_lock_exec() {
let url = "redis://:c6bfb872-49f6-48bc-858d-2aca0c020702@127.0.0.1:8003/0";
let cli = redis::Client::open(url).unwrap();
let cfg = redis::aio::ConnectionManagerConfig::new().set_max_delay(1000);
let con = redis::aio::ConnectionManager::new_with_config(cli, cfg)
.await
.unwrap();
let key = String::from("test:lock_key_exec");
let r = Locker::new(con)
.mode(&Mode::W)
.lock_exec(key, async {
sleep(Duration::from_secs(5)).await;
Ok(1)
})
.await
.unwrap();
assert_eq!(r, 1);
}
}