distributed_lock_redis/
lock.rs

1//! Redis distributed lock implementation.
2
3use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::traits::DistributedLock;
7use fred::prelude::*;
8use fred::types::CustomCommand; // Correct import
9use tracing::{Span, instrument};
10
11use crate::redlock::{acquire::acquire_redlock, helper::RedLockHelper, timeouts::RedLockTimeouts};
12
13/// Internal state for a Redis lock.
14#[derive(Debug, Clone)]
15pub struct RedisLockState {
16    /// Redis key for the lock.
17    pub key: String,
18    /// Unique lock ID for this acquisition.
19    pub lock_id: String,
20    /// Timeout configuration.
21    pub timeouts: RedLockTimeouts,
22}
23
24impl RedisLockState {
25    /// Creates a new lock state.
26    pub fn new(key: String, timeouts: RedLockTimeouts) -> Self {
27        Self {
28            key,
29            lock_id: RedLockHelper::create_lock_id(),
30            timeouts,
31        }
32    }
33
34    /// Attempts to acquire the lock on a single Redis client.
35    pub async fn try_acquire(&self, client: &RedisClient) -> LockResult<bool> {
36        let expiry_millis = self.timeouts.expiry.as_millis() as i64;
37
38        // Use SET NX PX to atomically set the key if it doesn't exist
39        // Note: Using PX (milliseconds) instead of EX (seconds)
40        let result: Option<String> = client
41            .set(
42                &self.key,
43                &self.lock_id,
44                Some(Expiration::PX(expiry_millis)),
45                Some(SetOptions::NX),
46                false,
47            )
48            .await
49            .map_err(|e| {
50                LockError::Backend(Box::new(std::io::Error::other(format!(
51                    "Redis SET NX failed: {}",
52                    e
53                ))))
54            })?;
55
56        // SET NX returns Some(value) if key was set, None if key already exists
57        Ok(result.is_some())
58    }
59
60    // ... imports ...
61    // No special imports needed for custom command if RedisClient is in scope,
62    // but we need RedisValue which is in prelude.
63
64    // ...
65
66    /// Lua script to extend the lock duration.
67    const EXTEND_SCRIPT_LUA: &'static str = r#"
68        if redis.call('get', KEYS[1]) == ARGV[1] then
69            return redis.call('pexpire', KEYS[1], ARGV[2])
70        end
71        return 0
72    "#;
73
74    /// Lua script to release the lock.
75    const RELEASE_SCRIPT_LUA: &'static str = r#"
76        if redis.call('get', KEYS[1]) == ARGV[1] then
77            return redis.call('del', KEYS[1])
78        end
79        return 0
80    "#;
81
82    /// Attempts to extend the lock on a single Redis client.
83    ///
84    /// Uses a Lua script to atomically verify ownership and extend TTL.
85    pub async fn try_extend(&self, client: &RedisClient) -> LockResult<bool> {
86        let expiry_millis = self.timeouts.expiry.as_millis() as i64;
87
88        let args: Vec<RedisValue> = vec![
89            Self::EXTEND_SCRIPT_LUA.into(),
90            1_i64.into(), // numkeys
91            self.key.clone().into(),
92            self.lock_id.clone().into(),
93            expiry_millis.into(),
94        ];
95
96        // CustomCommand::new_static is common for static strings or just new
97        let cmd = CustomCommand::new_static("EVAL", None, false);
98
99        let result: i64 = client.custom(cmd, args).await.map_err(|e| {
100            LockError::Backend(Box::new(std::io::Error::other(format!(
101                "Redis custom EVAL (extend) failed: {}",
102                e
103            ))))
104        })?;
105
106        Ok(result == 1)
107    }
108
109    /// Attempts to release the lock on a single Redis client.
110    ///
111    /// Uses a Lua script to atomically verify ownership before deleting.
112    pub async fn try_release(&self, client: &RedisClient) -> LockResult<()> {
113        let args: Vec<RedisValue> = vec![
114            Self::RELEASE_SCRIPT_LUA.into(),
115            1_i64.into(), // numkeys
116            self.key.clone().into(),
117            self.lock_id.clone().into(),
118        ];
119
120        let cmd = CustomCommand::new_static("EVAL", None, false);
121
122        let _: i64 = client.custom(cmd, args).await.map_err(|e| {
123            LockError::Backend(Box::new(std::io::Error::other(format!(
124                "Redis custom EVAL (release) failed: {}",
125                e
126            ))))
127        })?;
128
129        Ok(())
130    }
131}
132
133/// A Redis-based distributed lock.
134///
135/// Supports single-server and multi-server (RedLock) configurations.
136pub struct RedisDistributedLock {
137    /// Lock state.
138    state: RedisLockState,
139    /// Redis clients (one for single-server, multiple for RedLock).
140    clients: Vec<RedisClient>,
141    /// Extension cadence for background renewal.
142    extension_cadence: Duration,
143}
144
145impl RedisDistributedLock {
146    /// Creates a new Redis distributed lock.
147    pub(crate) fn new(
148        name: String,
149        clients: Vec<RedisClient>,
150        expiry: Duration,
151        min_validity: Duration,
152        extension_cadence: Duration,
153    ) -> Self {
154        let key = format!("distributed-lock:{}", name);
155        let timeouts = RedLockTimeouts::new(expiry, min_validity);
156
157        Self {
158            state: RedisLockState::new(key, timeouts),
159            clients,
160            extension_cadence,
161        }
162    }
163
164    /// Gets the lock name.
165    pub fn name(&self) -> &str {
166        // Extract name from key (remove "distributed-lock:" prefix)
167        self.state
168            .key
169            .strip_prefix("distributed-lock:")
170            .unwrap_or(&self.state.key)
171    }
172}
173
174impl DistributedLock for RedisDistributedLock {
175    type Handle = crate::handle::RedisLockHandle;
176
177    fn name(&self) -> &str {
178        self.name()
179    }
180
181    #[instrument(skip(self), fields(lock.name = %self.name(), lock.key = %self.state.key, timeout = ?timeout, backend = "redis", servers = self.clients.len()))]
182    async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
183        use tokio::sync::watch;
184
185        let start = std::time::Instant::now();
186        Span::current().record("operation", "acquire");
187
188        // Create cancellation token
189        let (cancel_sender, cancel_receiver) = watch::channel(false);
190
191        // If timeout is provided, spawn a task to signal cancellation after timeout
192        if let Some(timeout_duration) = timeout {
193            let cancel_sender_clone = cancel_sender.clone();
194            tokio::spawn(async move {
195                tokio::time::sleep(timeout_duration).await;
196                let _ = cancel_sender_clone.send(true);
197            });
198        }
199
200        // Acquire using RedLock algorithm
201        let state = self.state.clone();
202        let clients = self.clients.clone();
203        let timeouts = self.state.timeouts.clone();
204        let acquire_result = acquire_redlock(
205            move |client| {
206                let state = state.clone();
207                let client = client.clone();
208                async move { state.try_acquire(&client).await }
209            },
210            &clients,
211            &timeouts,
212            &cancel_receiver,
213        )
214        .await?;
215
216        let acquire_result = match acquire_result {
217            Some(result) if result.is_successful(clients.len()) => {
218                let elapsed = start.elapsed();
219                Span::current().record("acquired", true);
220                Span::current().record("elapsed_ms", elapsed.as_millis() as u64);
221                Span::current().record(
222                    "servers_acquired",
223                    result.acquire_results.iter().filter(|&&b| b).count(),
224                );
225                result
226            }
227            _ => {
228                Span::current().record("acquired", false);
229                Span::current().record("error", "timeout");
230                return Err(LockError::Timeout(
231                    timeout.unwrap_or(Duration::from_secs(0)),
232                ));
233            }
234        };
235
236        // Create handle with background extension
237        Ok(crate::handle::RedisLockHandle::new(
238            self.state.clone(),
239            acquire_result.acquire_results,
240            clients,
241            self.extension_cadence,
242            self.state.timeouts.expiry,
243        ))
244    }
245
246    #[instrument(skip(self), fields(lock.name = %self.name(), lock.key = %self.state.key, backend = "redis", servers = self.clients.len()))]
247    async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
248        use tokio::sync::watch;
249
250        Span::current().record("operation", "try_acquire");
251
252        // Create cancellation token (not used for try_acquire, but required by API)
253        let (_cancel_sender, cancel_receiver) = watch::channel(false);
254
255        // Acquire using RedLock algorithm
256        let state = self.state.clone();
257        let clients = self.clients.clone();
258        let timeouts = self.state.timeouts.clone();
259        let acquire_result = acquire_redlock(
260            move |client| {
261                let state = state.clone();
262                let client = client.clone();
263                async move { state.try_acquire(&client).await }
264            },
265            &clients,
266            &timeouts,
267            &cancel_receiver,
268        )
269        .await?;
270
271        match acquire_result {
272            Some(result) if result.is_successful(clients.len()) => {
273                Span::current().record("acquired", true);
274                Span::current().record(
275                    "servers_acquired",
276                    result.acquire_results.iter().filter(|&&b| b).count(),
277                );
278                Ok(Some(crate::handle::RedisLockHandle::new(
279                    self.state.clone(),
280                    result.acquire_results,
281                    clients,
282                    self.extension_cadence,
283                    self.state.timeouts.expiry,
284                )))
285            }
286            _ => {
287                Span::current().record("acquired", false);
288                Span::current().record("reason", "lock_held");
289                Ok(None)
290            }
291        }
292    }
293}