distributed_lock_redis/
lock.rs

1//! Redis distributed lock implementation.
2
3use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::traits::DistributedLock;
7use fred::prelude::*;
8use tracing::{instrument, Span};
9
10use crate::redlock::{acquire::acquire_redlock, helper::RedLockHelper, timeouts::RedLockTimeouts};
11
12/// Internal state for a Redis lock.
13#[derive(Debug, Clone)]
14pub struct RedisLockState {
15    /// Redis key for the lock.
16    pub key: String,
17    /// Unique lock ID for this acquisition.
18    pub lock_id: String,
19    /// Timeout configuration.
20    pub timeouts: RedLockTimeouts,
21}
22
23impl RedisLockState {
24    /// Creates a new lock state.
25    pub fn new(key: String, timeouts: RedLockTimeouts) -> Self {
26        Self {
27            key,
28            lock_id: RedLockHelper::create_lock_id(),
29            timeouts,
30        }
31    }
32
33    /// Attempts to acquire the lock on a single Redis client.
34    pub async fn try_acquire(&self, client: &RedisClient) -> LockResult<bool> {
35        let expiry_millis = self.timeouts.expiry.as_millis() as i64;
36
37        // Use SET NX PX to atomically set the key if it doesn't exist
38        // Note: Using PX (milliseconds) instead of EX (seconds)
39        let result: Option<String> = client
40            .set(
41                &self.key,
42                &self.lock_id,
43                Some(Expiration::PX(expiry_millis)),
44                Some(SetOptions::NX),
45                false,
46            )
47            .await
48            .map_err(|e| {
49                LockError::Backend(Box::new(std::io::Error::other(format!(
50                    "Redis SET NX failed: {}",
51                    e
52                ))))
53            })?;
54
55        // SET NX returns Some(value) if key was set, None if key already exists
56        Ok(result.is_some())
57    }
58
59    /// Attempts to extend the lock on a single Redis client.
60    ///
61    /// Uses GET + PEXPIRE to verify ownership and extend TTL.
62    /// TODO: Use Lua script for atomicity once fred API is clarified.
63    pub async fn try_extend(&self, client: &RedisClient) -> LockResult<bool> {
64        let expiry_millis = self.timeouts.expiry.as_millis() as i64;
65
66        // First check if the key exists and value matches our lock_id
67        let current_value: Option<String> = client.get(&self.key).await.map_err(|e| {
68            LockError::Backend(Box::new(std::io::Error::other(format!(
69                "Redis GET failed: {}",
70                e
71            ))))
72        })?;
73
74        match current_value {
75            Some(value) if value == self.lock_id => {
76                // Value matches - extend TTL
77                let _: bool = client
78                    .pexpire(&self.key, expiry_millis, None)
79                    .await
80                    .map_err(|e| {
81                        LockError::Backend(Box::new(std::io::Error::other(format!(
82                            "Redis PEXPIRE failed: {}",
83                            e
84                        ))))
85                    })?;
86                Ok(true)
87            }
88            _ => Ok(false), // Key doesn't exist or value doesn't match
89        }
90    }
91
92    /// Attempts to release the lock on a single Redis client.
93    ///
94    /// Uses GET + DEL to verify ownership before deleting.
95    /// TODO: Use Lua script for atomicity once fred API is clarified.
96    pub async fn try_release(&self, client: &RedisClient) -> LockResult<()> {
97        // First check if the key exists and value matches our lock_id
98        let current_value: Option<String> = client.get(&self.key).await.map_err(|e| {
99            LockError::Backend(Box::new(std::io::Error::other(format!(
100                "Redis GET failed: {}",
101                e
102            ))))
103        })?;
104
105        match current_value {
106            Some(value) if value == self.lock_id => {
107                // Value matches - delete the key
108                let _: i64 = client.del(&self.key).await.map_err(|e| {
109                    LockError::Backend(Box::new(std::io::Error::other(format!(
110                        "Redis DEL failed: {}",
111                        e
112                    ))))
113                })?;
114                Ok(())
115            }
116            _ => {
117                // Key doesn't exist or value doesn't match - already released or not ours
118                Ok(())
119            }
120        }
121    }
122}
123
124/// A Redis-based distributed lock.
125///
126/// Supports single-server and multi-server (RedLock) configurations.
127pub struct RedisDistributedLock {
128    /// Lock state.
129    state: RedisLockState,
130    /// Redis clients (one for single-server, multiple for RedLock).
131    clients: Vec<RedisClient>,
132    /// Extension cadence for background renewal.
133    extension_cadence: Duration,
134}
135
136impl RedisDistributedLock {
137    /// Creates a new Redis distributed lock.
138    pub(crate) fn new(
139        name: String,
140        clients: Vec<RedisClient>,
141        expiry: Duration,
142        min_validity: Duration,
143        extension_cadence: Duration,
144    ) -> Self {
145        let key = format!("distributed-lock:{}", name);
146        let timeouts = RedLockTimeouts::new(expiry, min_validity);
147
148        Self {
149            state: RedisLockState::new(key, timeouts),
150            clients,
151            extension_cadence,
152        }
153    }
154
155    /// Gets the lock name.
156    pub fn name(&self) -> &str {
157        // Extract name from key (remove "distributed-lock:" prefix)
158        self.state
159            .key
160            .strip_prefix("distributed-lock:")
161            .unwrap_or(&self.state.key)
162    }
163}
164
165impl DistributedLock for RedisDistributedLock {
166    type Handle = crate::handle::RedisLockHandle;
167
168    fn name(&self) -> &str {
169        self.name()
170    }
171
172    #[instrument(skip(self), fields(lock.name = %self.name(), lock.key = %self.state.key, timeout = ?timeout, backend = "redis", servers = self.clients.len()))]
173    async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
174        use tokio::sync::watch;
175
176        let start = std::time::Instant::now();
177        Span::current().record("operation", "acquire");
178
179        // Create cancellation token
180        let (cancel_sender, cancel_receiver) = watch::channel(false);
181
182        // If timeout is provided, spawn a task to signal cancellation after timeout
183        if let Some(timeout_duration) = timeout {
184            let cancel_sender_clone = cancel_sender.clone();
185            tokio::spawn(async move {
186                tokio::time::sleep(timeout_duration).await;
187                let _ = cancel_sender_clone.send(true);
188            });
189        }
190
191        // Acquire using RedLock algorithm
192        let state = self.state.clone();
193        let clients = self.clients.clone();
194        let timeouts = self.state.timeouts.clone();
195        let acquire_result = acquire_redlock(
196            move |client| {
197                let state = state.clone();
198                let client = client.clone();
199                async move { state.try_acquire(&client).await }
200            },
201            &clients,
202            &timeouts,
203            &cancel_receiver,
204        )
205        .await?;
206
207        let acquire_result = match acquire_result {
208            Some(result) if result.is_successful(clients.len()) => {
209                let elapsed = start.elapsed();
210                Span::current().record("acquired", true);
211                Span::current().record("elapsed_ms", elapsed.as_millis() as u64);
212                Span::current().record(
213                    "servers_acquired",
214                    result.acquire_results.iter().filter(|&&b| b).count(),
215                );
216                result
217            }
218            _ => {
219                Span::current().record("acquired", false);
220                Span::current().record("error", "timeout");
221                return Err(LockError::Timeout(
222                    timeout.unwrap_or(Duration::from_secs(0)),
223                ));
224            }
225        };
226
227        // Create handle with background extension
228        Ok(crate::handle::RedisLockHandle::new(
229            self.state.clone(),
230            acquire_result.acquire_results,
231            clients,
232            self.extension_cadence,
233            self.state.timeouts.expiry,
234        ))
235    }
236
237    #[instrument(skip(self), fields(lock.name = %self.name(), lock.key = %self.state.key, backend = "redis", servers = self.clients.len()))]
238    async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
239        use tokio::sync::watch;
240
241        Span::current().record("operation", "try_acquire");
242
243        // Create cancellation token (not used for try_acquire, but required by API)
244        let (_cancel_sender, cancel_receiver) = watch::channel(false);
245
246        // Acquire using RedLock algorithm
247        let state = self.state.clone();
248        let clients = self.clients.clone();
249        let timeouts = self.state.timeouts.clone();
250        let acquire_result = acquire_redlock(
251            move |client| {
252                let state = state.clone();
253                let client = client.clone();
254                async move { state.try_acquire(&client).await }
255            },
256            &clients,
257            &timeouts,
258            &cancel_receiver,
259        )
260        .await?;
261
262        let result = match acquire_result {
263            Some(result) if result.is_successful(clients.len()) => {
264                Span::current().record("acquired", true);
265                Span::current().record(
266                    "servers_acquired",
267                    result.acquire_results.iter().filter(|&&b| b).count(),
268                );
269                Ok(Some(crate::handle::RedisLockHandle::new(
270                    self.state.clone(),
271                    result.acquire_results,
272                    clients,
273                    self.extension_cadence,
274                    self.state.timeouts.expiry,
275                )))
276            }
277            _ => {
278                Span::current().record("acquired", false);
279                Span::current().record("reason", "lock_held");
280                Ok(None)
281            }
282        };
283        result
284    }
285}