#![cfg_attr(docsrs, feature(doc_cfg))]
use displaydoc::Display;
use redis::Client;
use std::error::Error;
use std::future::Future;
use std::time::Duration;
use thiserror::Error;
use uuid::Uuid;
pub type RedisError = redis::RedisError;
pub type RedisResult<T> = Result<T, RedisError>;
#[cfg(feature = "sync")]
#[cfg_attr(docsrs, doc(cfg(feature = "sync")))]
pub mod sync;
mod single;
pub use single::*;
pub struct MultiResourceLock {
client: Client,
}
impl std::fmt::Debug for MultiResourceLock {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MultiResourceLock")
.field("conn", &"..")
.finish()
}
}
#[inline]
pub async fn setup(client: &Client) -> Result<(), Box<dyn Error>> {
let mut con = client.get_multiplexed_async_connection().await?;
let lua_library = include_str!("functions.lua");
redis::cmd("FUNCTION")
.arg("LOAD")
.arg("REPLACE")
.arg(lua_library)
.exec_async(&mut con)
.await?;
Ok(())
}
pub const DEFAULT_EXPIRATION: Duration = Duration::from_secs(3600);
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
pub const DEFAULT_SLEEP: Duration = Duration::from_secs(1);
impl MultiResourceLock {
#[inline]
pub fn new(client: Client) -> RedisResult<Self> {
Ok(MultiResourceLock { client })
}
#[inline]
pub async fn acquire_default(&mut self, resources: &[String]) -> RedisResult<Option<String>> {
self.acquire(
resources,
DEFAULT_EXPIRATION,
DEFAULT_TIMEOUT,
DEFAULT_SLEEP,
)
.await
}
#[inline]
pub async fn acquire(
&mut self,
resources: &[String],
expiration: Duration,
timeout: Duration,
sleep: Duration,
) -> RedisResult<Option<String>> {
let now = std::time::Instant::now();
loop {
if now.elapsed() > timeout {
return Ok(None);
}
match self.try_acquire(resources, expiration).await? {
Some(res) => break Ok(Some(res)),
None => tokio::time::sleep(sleep).await,
}
}
}
#[inline]
pub async fn try_acquire_default(
&mut self,
resources: &[String],
) -> RedisResult<Option<String>> {
self.try_acquire(resources, DEFAULT_EXPIRATION).await
}
#[inline]
pub async fn try_acquire(
&mut self,
resources: &[String],
expiration: Duration,
) -> RedisResult<Option<String>> {
let mut connection = self.client.get_multiplexed_async_connection().await?;
let lock_id = Uuid::new_v4().to_string();
let mut args = vec![lock_id.clone(), expiration.as_millis().to_string()];
args.extend(resources.iter().cloned());
let result: Option<String> = redis::cmd("FCALL")
.arg("acquire_lock")
.arg(0i32)
.arg(&args)
.query_async(&mut connection)
.await?;
Ok(result)
}
#[inline]
pub async fn release(&mut self, lock_id: &str) -> RedisResult<usize> {
let mut connection = self.client.get_multiplexed_async_connection().await?;
let result: usize = redis::cmd("FCALL")
.arg("release_lock")
.arg(0i32)
.arg(lock_id)
.query_async(&mut connection)
.await?;
Ok(result)
}
#[inline]
pub async fn map<F>(
&mut self,
resources: &[String],
expiration: Duration,
timeout: Duration,
sleep: Duration,
f: F,
) -> Result<F::Output, MapError>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let lock_id = self
.acquire(resources, expiration, timeout, sleep)
.await
.map_err(MapError::Acquire)?
.ok_or(MapError::Timeout)?;
let result = f.await;
self.release(&lock_id).await.map_err(MapError::Release)?;
Ok(result)
}
#[inline]
pub async fn map_default<F>(
&mut self,
resources: &[String],
f: F,
) -> Result<F::Output, MapError>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.map(
resources,
DEFAULT_EXPIRATION,
DEFAULT_TIMEOUT,
DEFAULT_SLEEP,
f,
)
.await
}
}
#[derive(Debug, Display, Error)]
pub enum MapError {
Timeout,
Acquire(RedisError),
Release(RedisError),
}