redis_lock/
single.rs

1use displaydoc::Display;
2use futures::future::FutureExt;
3use futures::stream::{FuturesUnordered, StreamExt};
4use std::future::Future;
5use std::panic::AssertUnwindSafe;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use thiserror::Error;
9use tokio::sync::Mutex;
10use tracing::trace;
11
12/// The prefix for the lock key.
13const LOCK_KEY_PREFIX: &str = "redlock:";
14/// The default time-to-live for the lock.
15pub const DEFAULT_TTL: Duration = Duration::from_millis(10);
16/// The default delay between retries when attempting to acquire the lock.
17pub const DEFAULT_RETRY_DELAY: Duration = Duration::from_millis(100);
18/// The default duration to attempt to acquire the lock.
19pub const DEFAULT_DURATION: Duration = Duration::from_secs(10);
20/// The clock drift factor.
21const CLOCK_DRIFT_FACTOR: f64 = 0.01;
22
23/// Lock metadata.
24// TODO Remove this allow.
25#[expect(dead_code, reason = "see todo")]
26struct Lock {
27    /// The name of the resource to lock.
28    resource: String,
29    /// The unique value of the lock.
30    value: String,
31    /// The time the lock is valid for.
32    validity_time: Duration,
33}
34
35/// Options to configure [`lock_across`].
36#[derive(Debug)]
37pub struct LockAcrossOptions {
38    /// The time-to-live for the lock.
39    pub ttl: Duration,
40    /// The delay between retries when attempting to acquire the lock.
41    pub retry: Duration,
42    /// The maximum duration to attempt to acquire the lock.
43    pub duration: Duration,
44}
45impl Default for LockAcrossOptions {
46    #[inline]
47    fn default() -> Self {
48        Self {
49            ttl: DEFAULT_TTL,
50            retry: DEFAULT_RETRY_DELAY,
51            duration: DEFAULT_DURATION,
52        }
53    }
54}
55
56/// Executes a function while locking on a single resource using the
57/// [RedLock algorithm](https://redis.io/docs/latest/develop/use/patterns/distributed-locks/).
58///
59/// This is much more efficient than [`crate::MultiResourceLock`] when you only need to lock a single
60/// resource. Ideally you should architect your application so you never need [`crate::MultiResourceLock`].
61///
62/// - `connections` is used to acquire mutable references on connections to acquire the lock and
63///   then used to acquire mutable references on connections to release the lock.
64/// - `resource` is the name of the resource to lock.
65/// - `options` the options to configure acquisition.
66///
67/// ```no_run
68/// # use tokio::{task, sync::Mutex};
69/// # use std::sync::Arc;
70/// # use redis_lock::{lock_across, LockAcrossOptions};
71/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
72/// # tokio::runtime::Runtime::new()?.block_on(async {
73/// # let client: redis::Client = todo!();
74/// // Get connection.
75/// let connection = Arc::new(Mutex::new(client.get_multiplexed_async_connection().await?));
76/// // Set state.
77/// let x: usize = 0;
78/// let ptr = &mut x as *mut usize as usize;
79/// // Execute racy functions with lock.
80/// const N: usize = 100_000;
81/// let futures = (0..N).map(|_|{
82///     let cconnection = connection.clone();
83///     task::spawn(async move {
84///         lock_across(
85///             &[cconnection],
86///             "resource",
87///             async move {
88///                 unsafe { *(ptr as *mut usize) += 1 };
89///             },
90///             LockAcrossOptions::default(),
91///         );
92///     })
93/// }).collect::<Vec<_>>();
94/// for future in futures {
95///     future.await?;
96/// }
97/// // Assert state.
98/// assert_eq!(x, N);
99/// # Ok(())
100/// # })
101/// # }
102/// ```
103#[inline]
104pub async fn lock_across<C, F>(
105    connections: &[Arc<Mutex<C>>],
106    resource: &str,
107    f: F,
108    options: LockAcrossOptions,
109) -> Result<F::Output, LockAcrossError>
110where
111    C: redis::aio::ConnectionLike + Send + 'static,
112    F: Future + 'static,
113    F::Output: 'static,
114{
115    trace!("acquiring lock");
116    let lock = acquire_lock(connections, resource, options)
117        .await
118        .map_err(LockAcrossError::Acquire)?;
119    trace!("acquired lock");
120
121    // Execute the provided function
122    trace!("executing function");
123    let output = AssertUnwindSafe(f)
124        .catch_unwind()
125        .await
126        .map_err(LockAcrossError::Panic);
127    trace!("executed function");
128
129    // Release the lock
130    trace!("releasing lock");
131    release_lock(connections, &lock)
132        .await
133        .map_err(LockAcrossError::Release)?;
134    trace!("released lock");
135    // We propagate panic errors after releasing the lock.
136    output
137}
138
139#[derive(Debug, Error, Display)]
140pub enum LockAcrossError {
141    /// Failed to acquire lock: {0}
142    Acquire(AcquireLockError),
143    /// The function paniced.
144    Panic(Box<dyn std::any::Any + std::marker::Send>),
145    /// Failed to release lock: {0}
146    Release(ReleaseLockError),
147}
148
149#[derive(Debug, Error, Display)]
150pub enum AcquireLockError {
151    /// Failed to acquire lock: {0}
152    Failed(String),
153}
154
155/// Attempts to acquire a lock on multiple connections.
156#[expect(
157    clippy::as_conversions,
158    clippy::cast_possible_truncation,
159    clippy::cast_precision_loss,
160    clippy::float_arithmetic,
161    clippy::cast_sign_loss,
162    clippy::arithmetic_side_effects,
163    clippy::integer_division_remainder_used,
164    clippy::integer_division,
165    reason = "I can't be bothered to fix these right now."
166)]
167async fn acquire_lock<C: redis::aio::ConnectionLike>(
168    connections: &[Arc<Mutex<C>>],
169    resource: &str,
170    options: LockAcrossOptions,
171) -> Result<Lock, AcquireLockError> {
172    let value = uuid::Uuid::new_v4().to_string();
173    let resource_key = format!("{LOCK_KEY_PREFIX}{resource}");
174    let quorum = (connections.len() / 2) + 1;
175    let ttl_millis = options.ttl.as_millis() as u64;
176
177    let outer_start = Instant::now();
178    let mut attempts = 0u64;
179
180    while outer_start.elapsed() < options.duration {
181        attempts += 1;
182        trace!("Attempting to acquire lock (attempt {attempts})");
183
184        let mut futures = FuturesUnordered::new();
185        for conn in connections {
186            let cconn = Arc::clone(conn);
187            let cresource_key = resource_key.clone();
188            let cvalue = value.clone();
189            futures.push(async move {
190                let mut guard = cconn.lock().await;
191                try_acquire_lock(&mut *guard, &cresource_key, &cvalue, ttl_millis).await
192            });
193        }
194
195        let start = Instant::now();
196        let mut successful_locks = 0;
197        let mut failed_locks = 0;
198
199        while let Some(result) = futures.next().await {
200            if let Ok(true) = result {
201                successful_locks += 1;
202                // If quorum is achieved, break early
203                if successful_locks >= quorum {
204                    break;
205                }
206            } else {
207                failed_locks += 1;
208                // If quorum is unachievable, break early
209                if failed_locks > connections.len() - quorum {
210                    break;
211                }
212            }
213        }
214
215        let drift = (ttl_millis as f64 * CLOCK_DRIFT_FACTOR + 2.0f64) as u64;
216        let elapsed = start.elapsed().as_millis() as u64;
217        let validity_time = ttl_millis.saturating_sub(elapsed).saturating_sub(drift);
218
219        if successful_locks >= quorum && validity_time > 0 {
220            trace!("Lock acquired successfully");
221            return Ok(Lock {
222                resource: resource_key,
223                value,
224                validity_time: Duration::from_millis(validity_time),
225            });
226        }
227
228        trace!("Failed to acquire lock, waiting before next attempt");
229        tokio::time::sleep(options.retry).await;
230    }
231
232    Err(AcquireLockError::Failed(format!(
233        "Failed to acquire lock after {:?} and {attempts} attempts",
234        options.duration
235    )))
236}
237
238/// Attempts to acquire a lock on a connection.
239async fn try_acquire_lock<C: redis::aio::ConnectionLike>(
240    conn: &mut C,
241    resource: &str,
242    value: &str,
243    ttl: u64,
244) -> Result<bool, redis::RedisError> {
245    let result: Option<String> = redis::cmd("SET")
246        .arg(resource)
247        .arg(value)
248        .arg("NX")
249        .arg("PX")
250        .arg(ttl)
251        .query_async(conn)
252        .await?;
253
254    Ok(result.is_some())
255}
256
257#[derive(Debug, Error, Display)]
258pub enum ReleaseLockError {
259    /// Failed to join task: {0}
260    Join(tokio::task::JoinError),
261    /// Failed to release lock: {0}
262    Release(redis::RedisError),
263}
264
265/// Releases a lock on multiple connections.
266async fn release_lock<C: redis::aio::ConnectionLike + Send + 'static>(
267    connections: &[Arc<Mutex<C>>],
268    lock: &Lock,
269) -> Result<(), ReleaseLockError> {
270    let futures = connections.iter().map(|conn| {
271        let cconn = Arc::clone(conn);
272        let cresource = lock.resource.clone();
273        let cvalue = lock.value.clone();
274        tokio::spawn(async move {
275            let mut guard = cconn.lock().await;
276            release_lock_on_connection(&mut *guard, &cresource, &cvalue).await
277        })
278    });
279
280    let results = futures::future::join_all(futures).await;
281    for result in results {
282        result
283            .map_err(ReleaseLockError::Join)?
284            .map_err(ReleaseLockError::Release)?;
285    }
286    Ok(())
287}
288
289/// Releases a lock on a connection.
290async fn release_lock_on_connection<C: redis::aio::ConnectionLike>(
291    conn: &mut C,
292    resource: &str,
293    value: &str,
294) -> Result<(), redis::RedisError> {
295    let script = r#"
296        if redis.call("get", KEYS[1]) == ARGV[1] then
297            return redis.call("del", KEYS[1])
298        else
299            return 0
300        end
301    "#;
302
303    let _: () = redis::cmd("EVAL")
304        .arg(script)
305        .arg(1i32)
306        .arg(resource)
307        .arg(value)
308        .query_async(conn)
309        .await?;
310
311    Ok(())
312}