redis_lock/
single.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tokio::sync::MutexGuard;
use tracing::{trace, warn};

/// The prefix for the lock key.
const LOCK_KEY_PREFIX: &str = "redlock:";
/// The default time-to-live for the lock.
pub const DEFAULT_TTL: Duration = Duration::from_millis(10);
/// The default delay between retries when attempting to acquire the lock.
pub const DEFAULT_RETRY_DELAY: Duration = Duration::from_millis(100);
/// The default duration to attempt to acquire the lock.
pub const DEFAULT_DURATION: Duration = Duration::from_secs(10);
/// The clock drift factor.
const CLOCK_DRIFT_FACTOR: f64 = 0.01;

/// Lock metadata.
// TODO Remove this allow.
#[expect(dead_code, reason = "see todo")]
struct Lock {
    /// The name of the resource to lock.
    resource: String,
    /// The unique value of the lock.
    value: String,
    /// The time the lock is valid for.
    validity_time: Duration,
}

/// Options to configure [`lock_across`].
#[derive(Debug)]
pub struct LockAcrossOptions {
    /// The time-to-live for the lock.
    pub ttl: Duration,
    /// The delay between retries when attempting to acquire the lock.
    pub retry: Duration,
    /// The maximum duration to attempt to acquire the lock.
    pub duration: Duration,
}
impl Default for LockAcrossOptions {
    #[inline]
    fn default() -> Self {
        Self {
            ttl: DEFAULT_TTL,
            retry: DEFAULT_RETRY_DELAY,
            duration: DEFAULT_DURATION,
        }
    }
}

/// Executes a function while locking on a single resource using the
/// [RedLock algorithm](https://redis.io/docs/latest/develop/use/patterns/distributed-locks/).
///
/// This is much more efficient than [`crate::MultiResourceLock`] when you only need to lock a single
/// resource. Ideally you should architect your application so you never need [`crate::MultiResourceLock`].
///
/// - `connections` is used to acquire mutable references on connections to acquire the lock and
///   then used to acquire mutable references on connections to release the lock.
/// - `resource` is the name of the resource to lock.
/// - `options` the options to configure acquisition.
///
/// ```no_run
/// # use tokio::{task, sync::Mutex};
/// # use std::sync::Arc;
/// # use redis_lock::{lock_across, LockAcrossOptions};
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # tokio::runtime::Runtime::new()?.block_on(async {
/// # let client: redis::Client = todo!();
/// // Get connection.
/// let connection = Arc::new(Mutex::new(client.get_multiplexed_async_connection().await?));
/// // Set state.
/// let x: usize = 0;
/// let ptr = &mut x as *mut usize as usize;
/// // Execute racy functions with lock.
/// const N: usize = 100_000;
/// let futures = (0..N).map(|_|{
///     let cconnection = connection.clone();
///     task::spawn(async move {
///         lock_across(
///             &[cconnection],
///             "resource",
///             async move {
///                 unsafe { *(ptr as *mut usize) += 1 };
///             },
///             LockAcrossOptions::default(),
///         );
///     })
/// }).collect::<Vec<_>>();
/// for future in futures {
///     future.await?;
/// }
/// // Assert state.
/// assert_eq!(x, N);
/// # Ok(())
/// # })
/// # }
/// ```
#[inline]
pub async fn lock_across<C, F>(
    connections: &[Arc<Mutex<C>>],
    resource: &str,
    f: F,
    options: LockAcrossOptions,
) -> Result<F::Output, redis::RedisError>
where
    C: redis::aio::ConnectionLike,
    F: Future + 'static,
    F::Output: 'static,
{
    trace!("acquiring lock");
    let lock_opt = acquire_lock(connections, resource, options).await?;

    if let Some(lock) = lock_opt {
        trace!("acquired lock");

        // Execute the provided function
        let output = f.await;
        trace!("executed function");

        // Get the connections back for releasing the lock
        let mut release_connections = Vec::new();
        for conn_future in connections {
            trace!("getting connection");
            release_connections.push(conn_future.lock().await);
            trace!("acquired connection");
        }

        // Release the lock
        trace!("releasing lock");
        release_lock(&mut release_connections, &lock).await?;
        trace!("released lock");
        Ok(output)
    } else {
        Err(redis::RedisError::from((
            redis::ErrorKind::IoError,
            "Failed to acquire lock",
        )))
    }
}

/// Attempts to acquire a lock on multiple connections.
#[expect(
    clippy::as_conversions,
    clippy::cast_possible_truncation,
    clippy::cast_precision_loss,
    clippy::float_arithmetic,
    clippy::cast_sign_loss,
    clippy::arithmetic_side_effects,
    clippy::integer_division_remainder_used,
    clippy::integer_division,
    reason = "I can't be bothered to fix these right now."
)]
async fn acquire_lock<'a, C: redis::aio::ConnectionLike>(
    connections: &[Arc<Mutex<C>>],
    resource: &str,
    options: LockAcrossOptions,
) -> Result<Option<Lock>, redis::RedisError> {
    let value = uuid::Uuid::new_v4().to_string();
    let resource_key = format!("{LOCK_KEY_PREFIX}{resource}");
    let quorum = (connections.len() / 2) + 1;
    let outer_start = Instant::now();
    let mut attempts = 0u64;
    let ttl_millis = options.ttl.as_millis() as u64;

    while outer_start.elapsed() < options.duration {
        attempts += 1;
        trace!("Attempting to acquire lock (attempt {attempts})");
        let mut successful_locks = Vec::new();
        let start = Instant::now();

        for conn_future in connections {
            trace!("getting connection");
            let mut conn = conn_future.lock().await;
            if let Ok(true) = try_acquire_lock(&mut *conn, &resource_key, &value, ttl_millis).await
            {
                trace!("acquired lock");
                successful_locks.push(conn);
            }
        }

        let drift = (ttl_millis as f64 * CLOCK_DRIFT_FACTOR + 2.0f64) as u64;
        let elapsed = start.elapsed().as_millis() as u64;
        let validity_time = ttl_millis.saturating_sub(elapsed).saturating_sub(drift);

        if successful_locks.len() >= quorum && validity_time > 0 {
            trace!("Lock acquired successfully");
            return Ok(Some(Lock {
                resource: resource_key,
                value,
                validity_time: Duration::from_millis(validity_time),
            }));
        }
        trace!(
            "Failed to acquire lock, releasing {} successful locks",
            successful_locks.len()
        );
        for mut conn in successful_locks {
            if let Err(e) = release_lock_on_connection(&mut *conn, &resource_key, &value).await {
                trace!("Error releasing lock: {:?}", e);
            }
        }

        trace!("Waiting before next attempt");
        tokio::time::sleep(options.retry).await;
    }

    warn!(
        "Failed to acquire lock after {:?} and {attempts} attempts",
        options.duration
    );
    Ok(None)
}

/// Attempts to acquire a lock on a connection.
async fn try_acquire_lock<C: redis::aio::ConnectionLike>(
    conn: &mut C,
    resource: &str,
    value: &str,
    ttl: u64,
) -> Result<bool, redis::RedisError> {
    let result: Option<String> = redis::cmd("SET")
        .arg(resource)
        .arg(value)
        .arg("NX")
        .arg("PX")
        .arg(ttl)
        .query_async(conn)
        .await?;

    Ok(result.is_some())
}

/// Releases a lock on multiple connections.
async fn release_lock<C: redis::aio::ConnectionLike>(
    connections: &mut [MutexGuard<'_, C>],
    lock: &Lock,
) -> Result<(), redis::RedisError> {
    for conn in connections.iter_mut() {
        let x: &mut C = &mut *conn;
        release_lock_on_connection(x, &lock.resource, &lock.value).await?;
    }
    Ok(())
}

/// Releases a lock on a connection.
async fn release_lock_on_connection<C: redis::aio::ConnectionLike>(
    conn: &mut C,
    resource: &str,
    value: &str,
) -> Result<(), redis::RedisError> {
    let script = r#"
        if redis.call("get", KEYS[1]) == ARGV[1] then
            return redis.call("del", KEYS[1])
        else
            return 0
        end
    "#;

    let _: () = redis::cmd("EVAL")
        .arg(script)
        .arg(1i32)
        .arg(resource)
        .arg(value)
        .query_async(conn)
        .await?;

    Ok(())
}