distributed_lock_redis/
semaphore.rs

1//! Redis distributed semaphore implementation.
2
3use std::time::{Duration, SystemTime, UNIX_EPOCH};
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::timeout::TimeoutValue;
7use distributed_lock_core::traits::{DistributedSemaphore, LockHandle};
8use fred::prelude::*;
9use rand::Rng;
10use tokio::sync::watch;
11
12/// A Redis-based distributed semaphore.
13///
14/// Uses Redis sorted sets to track semaphore tickets. Each ticket has an expiry time,
15/// and expired tickets are automatically purged before acquisition attempts.
16pub struct RedisDistributedSemaphore {
17    /// Redis key for the semaphore.
18    key: String,
19    /// Original semaphore name.
20    name: String,
21    /// Maximum concurrent holders.
22    max_count: u32,
23    /// Redis client.
24    client: RedisClient,
25    /// Lock expiry time.
26    expiry: Duration,
27    /// Extension cadence for held locks.
28    extension_cadence: Duration,
29}
30
31impl RedisDistributedSemaphore {
32    pub(crate) fn new(
33        name: String,
34        max_count: u32,
35        client: RedisClient,
36        expiry: Duration,
37        extension_cadence: Duration,
38    ) -> Self {
39        // Prefix the key to avoid collisions
40        let key = format!("distributed-lock:semaphore:{}", name);
41        Self {
42            key,
43            name,
44            max_count,
45            client,
46            expiry,
47            extension_cadence,
48        }
49    }
50
51    /// Generates a unique lock ID for this acquisition.
52    fn generate_lock_id() -> String {
53        let mut rng = rand::thread_rng();
54        format!("{:016x}", rng.gen::<u64>())
55    }
56
57    /// Gets current time in milliseconds since Unix epoch.
58    fn now_millis() -> u64 {
59        SystemTime::now()
60            .duration_since(UNIX_EPOCH)
61            .unwrap()
62            .as_millis() as u64
63    }
64
65    /// Attempts to acquire a semaphore ticket without waiting.
66    async fn try_acquire_internal(&self) -> LockResult<Option<RedisSemaphoreHandle>> {
67        let lock_id = Self::generate_lock_id();
68        let now_millis = Self::now_millis();
69        let expiry_millis = self.expiry.as_millis() as u64;
70        let expiry_time = now_millis + expiry_millis;
71
72        // Use Redis commands (non-atomic but simpler for now)
73        // TODO: Use Lua script for atomicity when fred API is clarified
74
75        // Remove expired entries
76        let _: u32 = self
77            .client
78            .zremrangebyscore(&self.key, 0.0, now_millis as f64)
79            .await
80            .map_err(|e| {
81                LockError::Backend(Box::new(std::io::Error::other(format!(
82                    "Redis error: {}",
83                    e
84                ))))
85            })?;
86
87        // Check current count
88        let count: u32 = self.client.zcard(&self.key).await.map_err(|e| {
89            LockError::Backend(Box::new(std::io::Error::other(format!(
90                "Redis error: {}",
91                e
92            ))))
93        })?;
94
95        if count >= self.max_count {
96            return Ok(None);
97        }
98
99        // Add our ticket
100        let _: () = self
101            .client
102            .zadd(
103                &self.key,
104                None,
105                None,
106                false,
107                false,
108                (expiry_time as f64, lock_id.clone()),
109            )
110            .await
111            .map_err(|e| {
112                LockError::Backend(Box::new(std::io::Error::other(format!(
113                    "Redis error: {}",
114                    e
115                ))))
116            })?;
117
118        // Set TTL on the key
119        let set_expiry = expiry_millis * 2;
120        let _: bool = self
121            .client
122            .pexpire(&self.key, set_expiry as i64, None)
123            .await
124            .map_err(|e| {
125                LockError::Backend(Box::new(std::io::Error::other(format!(
126                    "Redis error: {}",
127                    e
128                ))))
129            })?;
130
131        // Successfully acquired
132        let (sender, receiver) = watch::channel(false);
133        Ok(Some(RedisSemaphoreHandle::new(
134            self.key.clone(),
135            lock_id,
136            self.client.clone(),
137            self.expiry,
138            self.extension_cadence,
139            sender,
140            receiver,
141        )))
142    }
143}
144
145impl DistributedSemaphore for RedisDistributedSemaphore {
146    type Handle = RedisSemaphoreHandle;
147
148    fn name(&self) -> &str {
149        &self.name
150    }
151
152    fn max_count(&self) -> u32 {
153        self.max_count
154    }
155
156    async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
157        let timeout_value = TimeoutValue::from(timeout);
158        let start = std::time::Instant::now();
159
160        // Busy-wait with exponential backoff
161        let mut sleep_duration = Duration::from_millis(10);
162        const MAX_SLEEP: Duration = Duration::from_millis(200);
163
164        loop {
165            match self.try_acquire_internal().await {
166                Ok(Some(handle)) => return Ok(handle),
167                Ok(None) => {
168                    // Check timeout
169                    if !timeout_value.is_infinite()
170                        && start.elapsed() >= timeout_value.as_duration().unwrap()
171                    {
172                        return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
173                    }
174
175                    // Sleep before retry
176                    tokio::time::sleep(sleep_duration).await;
177                    sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
178                }
179                Err(e) => return Err(e),
180            }
181        }
182    }
183
184    async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
185        self.try_acquire_internal().await
186    }
187}
188
189/// Handle for a held semaphore ticket.
190pub struct RedisSemaphoreHandle {
191    /// Redis key for the semaphore.
192    key: String,
193    /// Unique lock ID for this ticket.
194    lock_id: String,
195    /// Redis client.
196    client: RedisClient,
197    /// Lock expiry time.
198    #[allow(dead_code)]
199    expiry: Duration,
200    /// Extension cadence.
201    #[allow(dead_code)]
202    extension_cadence: Duration,
203    /// Watch channel for lock lost detection.
204    lost_receiver: watch::Receiver<bool>,
205    /// Background task handle for lock extension.
206    _extension_task: tokio::task::JoinHandle<()>,
207}
208
209impl RedisSemaphoreHandle {
210    pub(crate) fn new(
211        key: String,
212        lock_id: String,
213        client: RedisClient,
214        expiry: Duration,
215        extension_cadence: Duration,
216        lost_sender: watch::Sender<bool>,
217        lost_receiver: watch::Receiver<bool>,
218    ) -> Self {
219        let extension_key = key.clone();
220        let extension_lock_id = lock_id.clone();
221        let extension_client = client.clone();
222        let extension_expiry = expiry;
223        let extension_lost_sender = lost_sender.clone();
224
225        // Spawn background task to extend the lock
226        let extension_task = tokio::spawn(async move {
227            let mut interval = tokio::time::interval(extension_cadence);
228            loop {
229                interval.tick().await;
230
231                // Check if we should stop (sender closed)
232                if extension_lost_sender.is_closed() {
233                    break;
234                }
235
236                // Extend the lock
237                let now_millis = RedisDistributedSemaphore::now_millis();
238                let expiry_millis = extension_expiry.as_millis() as u64;
239                let expiry_time = now_millis + expiry_millis;
240
241                // Remove expired entries
242                let _: u32 = match extension_client
243                    .zremrangebyscore(&extension_key, 0.0, now_millis as f64)
244                    .await
245                {
246                    Ok(count) => count,
247                    Err(_) => {
248                        // Connection error - signal lock lost
249                        let _ = extension_lost_sender.send(true);
250                        break;
251                    }
252                };
253
254                // Update our ticket expiry
255                // Note: We use zadd with the same score to update expiry time
256                let result: u32 = match extension_client
257                    .zadd(
258                        &extension_key,
259                        None,
260                        None,
261                        false,
262                        false,
263                        (expiry_time as f64, extension_lock_id.clone()),
264                    )
265                    .await
266                {
267                    Ok(count) => count,
268                    Err(_) => {
269                        // Connection error - signal lock lost
270                        let _ = extension_lost_sender.send(true);
271                        break;
272                    }
273                };
274
275                // Renew set TTL if needed
276                let set_expiry = expiry_millis * 2;
277                let _: bool = match extension_client
278                    .pexpire(&extension_key, set_expiry as i64, None)
279                    .await
280                {
281                    Ok(result) => result,
282                    Err(_) => {
283                        // Connection error - signal lock lost
284                        let _ = extension_lost_sender.send(true);
285                        break;
286                    }
287                };
288
289                // If update failed (result == 0), lock was removed - signal lost
290                // Note: zadd returns the number of elements added, so 0 means it wasn't found
291                if result == 0 {
292                    let _ = extension_lost_sender.send(true);
293                    break;
294                }
295            }
296        });
297
298        Self {
299            key,
300            lock_id,
301            client,
302            expiry,
303            extension_cadence,
304            lost_receiver,
305            _extension_task: extension_task,
306        }
307    }
308}
309
310impl LockHandle for RedisSemaphoreHandle {
311    fn lost_token(&self) -> &watch::Receiver<bool> {
312        &self.lost_receiver
313    }
314
315    async fn release(self) -> LockResult<()> {
316        // Abort the extension task
317        self._extension_task.abort();
318
319        // Remove our ticket from the sorted set
320        let _: () = self
321            .client
322            .zrem(&self.key, &self.lock_id)
323            .await
324            .map_err(|e| {
325                LockError::Backend(Box::new(std::io::Error::other(format!(
326                    "failed to release semaphore ticket: {}",
327                    e
328                ))))
329            })?;
330
331        Ok(())
332    }
333}
334
335impl Drop for RedisSemaphoreHandle {
336    fn drop(&mut self) {
337        // Abort extension task
338        self._extension_task.abort();
339        // Note: We cannot async release in Drop, so the ticket will expire naturally
340        // For proper cleanup, users should call release() explicitly
341    }
342}