1use displaydoc::Display;
2use futures::future::FutureExt;
3use futures::stream::{FuturesUnordered, StreamExt};
4use std::future::Future;
5use std::panic::AssertUnwindSafe;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use thiserror::Error;
9use tokio::sync::Mutex;
10use tracing::trace;
11
12const LOCK_KEY_PREFIX: &str = "redlock:";
14pub const DEFAULT_TTL: Duration = Duration::from_millis(10);
16pub const DEFAULT_RETRY_DELAY: Duration = Duration::from_millis(100);
18pub const DEFAULT_DURATION: Duration = Duration::from_secs(10);
20const CLOCK_DRIFT_FACTOR: f64 = 0.01;
22
23#[expect(dead_code, reason = "see todo")]
26struct Lock {
27 resource: String,
29 value: String,
31 validity_time: Duration,
33}
34
35#[derive(Debug)]
37pub struct LockAcrossOptions {
38 pub ttl: Duration,
40 pub retry: Duration,
42 pub duration: Duration,
44}
45impl Default for LockAcrossOptions {
46 #[inline]
47 fn default() -> Self {
48 Self {
49 ttl: DEFAULT_TTL,
50 retry: DEFAULT_RETRY_DELAY,
51 duration: DEFAULT_DURATION,
52 }
53 }
54}
55
56#[inline]
104pub async fn lock_across<C, F>(
105 connections: &[Arc<Mutex<C>>],
106 resource: &str,
107 f: F,
108 options: LockAcrossOptions,
109) -> Result<F::Output, LockAcrossError>
110where
111 C: redis::aio::ConnectionLike + Send + 'static,
112 F: Future + 'static,
113 F::Output: 'static,
114{
115 trace!("acquiring lock");
116 let lock = acquire_lock(connections, resource, options)
117 .await
118 .map_err(LockAcrossError::Acquire)?;
119 trace!("acquired lock");
120
121 trace!("executing function");
123 let output = AssertUnwindSafe(f)
124 .catch_unwind()
125 .await
126 .map_err(LockAcrossError::Panic);
127 trace!("executed function");
128
129 trace!("releasing lock");
131 release_lock(connections, &lock)
132 .await
133 .map_err(LockAcrossError::Release)?;
134 trace!("released lock");
135 output
137}
138
139#[derive(Debug, Error, Display)]
140pub enum LockAcrossError {
141 Acquire(AcquireLockError),
143 Panic(Box<dyn std::any::Any + std::marker::Send>),
145 Release(ReleaseLockError),
147}
148
149#[derive(Debug, Error, Display)]
150pub enum AcquireLockError {
151 Failed(String),
153}
154
155#[expect(
157 clippy::as_conversions,
158 clippy::cast_possible_truncation,
159 clippy::cast_precision_loss,
160 clippy::float_arithmetic,
161 clippy::cast_sign_loss,
162 clippy::arithmetic_side_effects,
163 clippy::integer_division_remainder_used,
164 clippy::integer_division,
165 reason = "I can't be bothered to fix these right now."
166)]
167async fn acquire_lock<C: redis::aio::ConnectionLike>(
168 connections: &[Arc<Mutex<C>>],
169 resource: &str,
170 options: LockAcrossOptions,
171) -> Result<Lock, AcquireLockError> {
172 let value = uuid::Uuid::new_v4().to_string();
173 let resource_key = format!("{LOCK_KEY_PREFIX}{resource}");
174 let quorum = (connections.len() / 2) + 1;
175 let ttl_millis = options.ttl.as_millis() as u64;
176
177 let outer_start = Instant::now();
178 let mut attempts = 0u64;
179
180 while outer_start.elapsed() < options.duration {
181 attempts += 1;
182 trace!("Attempting to acquire lock (attempt {attempts})");
183
184 let mut futures = FuturesUnordered::new();
185 for conn in connections {
186 let cconn = Arc::clone(conn);
187 let cresource_key = resource_key.clone();
188 let cvalue = value.clone();
189 futures.push(async move {
190 let mut guard = cconn.lock().await;
191 try_acquire_lock(&mut *guard, &cresource_key, &cvalue, ttl_millis).await
192 });
193 }
194
195 let start = Instant::now();
196 let mut successful_locks = 0;
197 let mut failed_locks = 0;
198
199 while let Some(result) = futures.next().await {
200 if let Ok(true) = result {
201 successful_locks += 1;
202 if successful_locks >= quorum {
204 break;
205 }
206 } else {
207 failed_locks += 1;
208 if failed_locks > connections.len() - quorum {
210 break;
211 }
212 }
213 }
214
215 let drift = (ttl_millis as f64 * CLOCK_DRIFT_FACTOR + 2.0f64) as u64;
216 let elapsed = start.elapsed().as_millis() as u64;
217 let validity_time = ttl_millis.saturating_sub(elapsed).saturating_sub(drift);
218
219 if successful_locks >= quorum && validity_time > 0 {
220 trace!("Lock acquired successfully");
221 return Ok(Lock {
222 resource: resource_key,
223 value,
224 validity_time: Duration::from_millis(validity_time),
225 });
226 }
227
228 trace!("Failed to acquire lock, waiting before next attempt");
229 tokio::time::sleep(options.retry).await;
230 }
231
232 Err(AcquireLockError::Failed(format!(
233 "Failed to acquire lock after {:?} and {attempts} attempts",
234 options.duration
235 )))
236}
237
238async fn try_acquire_lock<C: redis::aio::ConnectionLike>(
240 conn: &mut C,
241 resource: &str,
242 value: &str,
243 ttl: u64,
244) -> Result<bool, redis::RedisError> {
245 let result: Option<String> = redis::cmd("SET")
246 .arg(resource)
247 .arg(value)
248 .arg("NX")
249 .arg("PX")
250 .arg(ttl)
251 .query_async(conn)
252 .await?;
253
254 Ok(result.is_some())
255}
256
257#[derive(Debug, Error, Display)]
258pub enum ReleaseLockError {
259 Join(tokio::task::JoinError),
261 Release(redis::RedisError),
263}
264
265async fn release_lock<C: redis::aio::ConnectionLike + Send + 'static>(
267 connections: &[Arc<Mutex<C>>],
268 lock: &Lock,
269) -> Result<(), ReleaseLockError> {
270 let futures = connections.iter().map(|conn| {
271 let cconn = Arc::clone(conn);
272 let cresource = lock.resource.clone();
273 let cvalue = lock.value.clone();
274 tokio::spawn(async move {
275 let mut guard = cconn.lock().await;
276 release_lock_on_connection(&mut *guard, &cresource, &cvalue).await
277 })
278 });
279
280 let results = futures::future::join_all(futures).await;
281 for result in results {
282 result
283 .map_err(ReleaseLockError::Join)?
284 .map_err(ReleaseLockError::Release)?;
285 }
286 Ok(())
287}
288
289async fn release_lock_on_connection<C: redis::aio::ConnectionLike>(
291 conn: &mut C,
292 resource: &str,
293 value: &str,
294) -> Result<(), redis::RedisError> {
295 let script = r#"
296 if redis.call("get", KEYS[1]) == ARGV[1] then
297 return redis.call("del", KEYS[1])
298 else
299 return 0
300 end
301 "#;
302
303 let _: () = redis::cmd("EVAL")
304 .arg(script)
305 .arg(1i32)
306 .arg(resource)
307 .arg(value)
308 .query_async(conn)
309 .await?;
310
311 Ok(())
312}