use displaydoc::Display;
use futures::future::FutureExt;
use futures::stream::{FuturesUnordered, StreamExt};
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::Mutex;
use tracing::trace;
const LOCK_KEY_PREFIX: &str = "redlock:";
pub const DEFAULT_TTL: Duration = Duration::from_millis(10);
pub const DEFAULT_RETRY_DELAY: Duration = Duration::from_millis(100);
pub const DEFAULT_DURATION: Duration = Duration::from_secs(10);
const CLOCK_DRIFT_FACTOR: f64 = 0.01;
#[expect(dead_code, reason = "see todo")]
struct Lock {
resource: String,
value: String,
validity_time: Duration,
}
#[derive(Debug)]
pub struct LockAcrossOptions {
pub ttl: Duration,
pub retry: Duration,
pub duration: Duration,
}
impl Default for LockAcrossOptions {
#[inline]
fn default() -> Self {
Self {
ttl: DEFAULT_TTL,
retry: DEFAULT_RETRY_DELAY,
duration: DEFAULT_DURATION,
}
}
}
#[inline]
pub async fn lock_across<C, F>(
connections: &[Arc<Mutex<C>>],
resource: &str,
f: F,
options: LockAcrossOptions,
) -> Result<F::Output, LockAcrossError>
where
C: redis::aio::ConnectionLike + Send + 'static,
F: Future + 'static,
F::Output: 'static,
{
trace!("acquiring lock");
let lock = acquire_lock(connections, resource, options)
.await
.map_err(LockAcrossError::Acquire)?;
trace!("acquired lock");
trace!("executing function");
let output = AssertUnwindSafe(f)
.catch_unwind()
.await
.map_err(LockAcrossError::Panic);
trace!("executed function");
trace!("releasing lock");
release_lock(connections, &lock)
.await
.map_err(LockAcrossError::Release)?;
trace!("released lock");
output
}
#[derive(Debug, Error, Display)]
pub enum LockAcrossError {
Acquire(AcquireLockError),
Panic(Box<dyn std::any::Any + std::marker::Send>),
Release(ReleaseLockError),
}
#[derive(Debug, Error, Display)]
pub enum AcquireLockError {
Failed(String),
}
#[expect(
clippy::as_conversions,
clippy::cast_possible_truncation,
clippy::cast_precision_loss,
clippy::float_arithmetic,
clippy::cast_sign_loss,
clippy::arithmetic_side_effects,
clippy::integer_division_remainder_used,
clippy::integer_division,
reason = "I can't be bothered to fix these right now."
)]
async fn acquire_lock<C: redis::aio::ConnectionLike>(
connections: &[Arc<Mutex<C>>],
resource: &str,
options: LockAcrossOptions,
) -> Result<Lock, AcquireLockError> {
let value = uuid::Uuid::new_v4().to_string();
let resource_key = format!("{LOCK_KEY_PREFIX}{resource}");
let quorum = (connections.len() / 2) + 1;
let ttl_millis = options.ttl.as_millis() as u64;
let outer_start = Instant::now();
let mut attempts = 0u64;
while outer_start.elapsed() < options.duration {
attempts += 1;
trace!("Attempting to acquire lock (attempt {attempts})");
let mut futures = FuturesUnordered::new();
for conn in connections {
let cconn = Arc::clone(conn);
let cresource_key = resource_key.clone();
let cvalue = value.clone();
futures.push(async move {
let mut guard = cconn.lock().await;
try_acquire_lock(&mut *guard, &cresource_key, &cvalue, ttl_millis).await
});
}
let start = Instant::now();
let mut successful_locks = 0;
let mut failed_locks = 0;
while let Some(result) = futures.next().await {
if let Ok(true) = result {
successful_locks += 1;
if successful_locks >= quorum {
break;
}
} else {
failed_locks += 1;
if failed_locks > connections.len() - quorum {
break;
}
}
}
let drift = (ttl_millis as f64 * CLOCK_DRIFT_FACTOR + 2.0f64) as u64;
let elapsed = start.elapsed().as_millis() as u64;
let validity_time = ttl_millis.saturating_sub(elapsed).saturating_sub(drift);
if successful_locks >= quorum && validity_time > 0 {
trace!("Lock acquired successfully");
return Ok(Lock {
resource: resource_key,
value,
validity_time: Duration::from_millis(validity_time),
});
}
trace!("Failed to acquire lock, waiting before next attempt");
tokio::time::sleep(options.retry).await;
}
Err(AcquireLockError::Failed(format!(
"Failed to acquire lock after {:?} and {attempts} attempts",
options.duration
)))
}
async fn try_acquire_lock<C: redis::aio::ConnectionLike>(
conn: &mut C,
resource: &str,
value: &str,
ttl: u64,
) -> Result<bool, redis::RedisError> {
let result: Option<String> = redis::cmd("SET")
.arg(resource)
.arg(value)
.arg("NX")
.arg("PX")
.arg(ttl)
.query_async(conn)
.await?;
Ok(result.is_some())
}
#[derive(Debug, Error, Display)]
pub enum ReleaseLockError {
Join(tokio::task::JoinError),
Release(redis::RedisError),
}
async fn release_lock<C: redis::aio::ConnectionLike + Send + 'static>(
connections: &[Arc<Mutex<C>>],
lock: &Lock,
) -> Result<(), ReleaseLockError> {
let futures = connections.iter().map(|conn| {
let cconn = Arc::clone(conn);
let cresource = lock.resource.clone();
let cvalue = lock.value.clone();
tokio::spawn(async move {
let mut guard = cconn.lock().await;
release_lock_on_connection(&mut *guard, &cresource, &cvalue).await
})
});
let results = futures::future::join_all(futures).await;
for result in results {
result
.map_err(ReleaseLockError::Join)?
.map_err(ReleaseLockError::Release)?;
}
Ok(())
}
async fn release_lock_on_connection<C: redis::aio::ConnectionLike>(
conn: &mut C,
resource: &str,
value: &str,
) -> Result<(), redis::RedisError> {
let script = r#"
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"#;
let _: () = redis::cmd("EVAL")
.arg(script)
.arg(1i32)
.arg(resource)
.arg(value)
.query_async(conn)
.await?;
Ok(())
}