distributed_lock_redis/
semaphore.rs

1//! Redis distributed semaphore implementation.
2
3use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::timeout::TimeoutValue;
7use distributed_lock_core::traits::{DistributedSemaphore, LockHandle};
8use fred::prelude::*;
9use fred::types::CustomCommand;
10use rand::Rng;
11use tokio::sync::watch;
12
13/// A Redis-based distributed semaphore.
14///
15/// Uses Redis sorted sets to track semaphore tickets. Each ticket has an expiry time,
16/// and expired tickets are automatically purged before acquisition attempts.
17pub struct RedisDistributedSemaphore {
18    /// Redis key for the semaphore.
19    key: String,
20    /// Original semaphore name.
21    name: String,
22    /// Maximum concurrent holders.
23    max_count: u32,
24    /// Redis client.
25    client: RedisClient,
26    /// Lock expiry time.
27    expiry: Duration,
28    /// Extension cadence for held locks.
29    extension_cadence: Duration,
30}
31
32impl RedisDistributedSemaphore {
33    /// Lua script for semaphore acquisition.
34    /// 1. Gets Redis time
35    /// 2. Removes expired tickets (zremrangebyscore)
36    /// 3. Checks if count < max_count
37    /// 4. Adds new ticket with expiry score if allowed
38    /// 5. Extends set TTL
39    const ACQUIRE_SCRIPT: &'static str = r#"
40        redis.replicate_commands()
41        local nowResult = redis.call('time')
42        local nowMillis = (tonumber(nowResult[1]) * 1000.0) + (tonumber(nowResult[2]) / 1000.0)
43        
44        redis.call('zremrangebyscore', KEYS[1], '-inf', nowMillis)
45        
46        if redis.call('zcard', KEYS[1]) < tonumber(ARGV[1]) then
47            redis.call('zadd', KEYS[1], nowMillis + tonumber(ARGV[2]), ARGV[3])
48            
49            -- Extend key TTL (set to 3x ticket expiry to be safe)
50            local keyTtl = redis.call('pttl', KEYS[1])
51            if keyTtl < tonumber(ARGV[4]) then
52                redis.call('pexpire', KEYS[1], ARGV[4])
53            end
54            return 1
55        end
56        return 0
57    "#;
58
59    pub(crate) fn new(
60        name: String,
61        max_count: u32,
62        client: RedisClient,
63        expiry: Duration,
64        extension_cadence: Duration,
65    ) -> Self {
66        // Prefix the key to avoid collisions
67        let key = format!("distributed-lock:semaphore:{}", name);
68        Self {
69            key,
70            name,
71            max_count,
72            client,
73            expiry,
74            extension_cadence,
75        }
76    }
77
78    /// Generates a unique lock ID for this acquisition.
79    fn generate_lock_id() -> String {
80        let mut rng = rand::thread_rng();
81        format!("{:016x}", rng.r#gen::<u64>())
82    }
83
84    /// Attempts to acquire a semaphore ticket without waiting.
85    async fn try_acquire_internal(&self) -> LockResult<Option<RedisSemaphoreHandle>> {
86        let lock_id = Self::generate_lock_id();
87        let expiry_millis = self.expiry.as_millis() as u64;
88
89        // TTL for the whole set (3x expiry)
90        let set_expiry_millis = expiry_millis * 3;
91
92        let args: Vec<RedisValue> = vec![
93            Self::ACQUIRE_SCRIPT.into(),
94            1_i64.into(),                      // numkeys
95            self.key.clone().into(),           // KEYS[1]
96            (self.max_count as i64).into(),    // ARGV[1]
97            (expiry_millis as i64).into(),     // ARGV[2]
98            lock_id.clone().into(),            // ARGV[3]
99            (set_expiry_millis as i64).into(), // ARGV[4]
100        ];
101
102        let cmd = CustomCommand::new_static("EVAL", None, false);
103        let result: i64 = self.client.custom(cmd, args).await.map_err(|e| {
104            LockError::Backend(Box::new(std::io::Error::other(format!(
105                "Redis custom EVAL (acquire semaphore) failed: {}",
106                e
107            ))))
108        })?;
109
110        if result == 1 {
111            // Successfully acquired
112            let (sender, receiver) = watch::channel(false);
113            Ok(Some(RedisSemaphoreHandle::new(
114                self.key.clone(),
115                lock_id,
116                self.client.clone(),
117                self.expiry,
118                self.extension_cadence,
119                sender,
120                receiver,
121            )))
122        } else {
123            Ok(None)
124        }
125    }
126}
127
128impl DistributedSemaphore for RedisDistributedSemaphore {
129    type Handle = RedisSemaphoreHandle;
130
131    fn name(&self) -> &str {
132        &self.name
133    }
134
135    fn max_count(&self) -> u32 {
136        self.max_count
137    }
138
139    async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
140        let timeout_value = TimeoutValue::from(timeout);
141        let start = std::time::Instant::now();
142
143        // Busy-wait with exponential backoff
144        // TODO: Could optimize this further but atomicity is primary goal now
145        let mut sleep_duration = Duration::from_millis(10);
146        const MAX_SLEEP: Duration = Duration::from_millis(200);
147
148        loop {
149            match self.try_acquire_internal().await {
150                Ok(Some(handle)) => return Ok(handle),
151                Ok(None) => {
152                    // Check timeout
153                    if !timeout_value.is_infinite()
154                        && start.elapsed() >= timeout_value.as_duration().unwrap()
155                    {
156                        return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
157                    }
158
159                    // Sleep before retry
160                    tokio::time::sleep(sleep_duration).await;
161                    sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
162                }
163                Err(e) => return Err(e),
164            }
165        }
166    }
167
168    async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
169        self.try_acquire_internal().await
170    }
171}
172
173/// Handle for a held semaphore ticket.
174pub struct RedisSemaphoreHandle {
175    /// Redis key for the semaphore.
176    key: String,
177    /// Unique lock ID for this ticket.
178    lock_id: String,
179    /// Redis client.
180    client: RedisClient,
181    /// Lock expiry time.
182    #[allow(dead_code)]
183    expiry: Duration,
184    /// Extension cadence.
185    #[allow(dead_code)]
186    extension_cadence: Duration,
187    /// Watch channel for lock lost detection.
188    lost_receiver: watch::Receiver<bool>,
189    /// Background task handle for lock extension.
190    _extension_task: tokio::task::JoinHandle<()>,
191}
192
193impl RedisSemaphoreHandle {
194    /// Lua script for semaphore extension.
195    /// 1. Gets Redis time
196    /// 2. Updates score (expiry) for our lock ID only if it exists (XX)
197    /// 3. Extends set TTL
198    const EXTEND_SCRIPT: &'static str = r#"
199        redis.replicate_commands()
200        local nowResult = redis.call('time')
201        local nowMillis = (tonumber(nowResult[1]) * 1000.0) + (tonumber(nowResult[2]) / 1000.0)
202        
203        local result = redis.call('zadd', KEYS[1], 'XX', 'CH', nowMillis + tonumber(ARGV[1]), ARGV[2])
204        
205        -- Extend key TTL
206        local keyTtl = redis.call('pttl', KEYS[1])
207        if keyTtl < tonumber(ARGV[3]) then
208            redis.call('pexpire', KEYS[1], ARGV[3])
209        end
210        return result
211    "#;
212
213    /// Lua script for semaphore release.
214    const RELEASE_SCRIPT: &'static str = r#"
215        return redis.call('zrem', KEYS[1], ARGV[1])
216    "#;
217
218    pub(crate) fn new(
219        key: String,
220        lock_id: String,
221        client: RedisClient,
222        expiry: Duration,
223        extension_cadence: Duration,
224        lost_sender: watch::Sender<bool>,
225        lost_receiver: watch::Receiver<bool>,
226    ) -> Self {
227        let extension_key = key.clone();
228        let extension_lock_id = lock_id.clone();
229        let extension_client = client.clone();
230        let extension_expiry = expiry;
231        let extension_lost_sender = lost_sender.clone();
232
233        // Spawn background task to extend the lock
234        let extension_task = tokio::spawn(async move {
235            let mut interval = tokio::time::interval(extension_cadence);
236            // TTL for the whole set (3x expiry)
237            let set_expiry_millis = extension_expiry.as_millis() * 3;
238
239            loop {
240                interval.tick().await;
241
242                // Check if we should stop (sender closed)
243                if extension_lost_sender.is_closed() {
244                    break;
245                }
246
247                let expiry_millis = extension_expiry.as_millis() as u64;
248
249                let args: Vec<RedisValue> = vec![
250                    Self::EXTEND_SCRIPT.into(),
251                    1_i64.into(), // numkeys
252                    extension_key.clone().into(),
253                    (expiry_millis as i64).into(),
254                    extension_lock_id.clone().into(),
255                    (set_expiry_millis as i64).into(),
256                ];
257
258                let cmd = CustomCommand::new_static("EVAL", None, false);
259                let result_op: Result<i64, _> = extension_client.custom(cmd, args).await;
260
261                match result_op {
262                    Ok(changed_count) => {
263                        // zadd with CH returns number of changed elements.
264                        // If 0, it means the element wasn't found (expired/lost).
265                        if changed_count == 0 {
266                            let _ = extension_lost_sender.send(true);
267                            break;
268                        }
269                    }
270                    Err(_) => {
271                        // Connection error - signal lock lost
272                        let _ = extension_lost_sender.send(true);
273                        break;
274                    }
275                }
276            }
277        });
278
279        Self {
280            key,
281            lock_id,
282            client,
283            expiry,
284            extension_cadence,
285            lost_receiver,
286            _extension_task: extension_task,
287        }
288    }
289}
290
291impl LockHandle for RedisSemaphoreHandle {
292    fn lost_token(&self) -> &watch::Receiver<bool> {
293        &self.lost_receiver
294    }
295
296    async fn release(self) -> LockResult<()> {
297        // Abort the extension task
298        self._extension_task.abort();
299
300        let args: Vec<RedisValue> = vec![
301            Self::RELEASE_SCRIPT.into(),
302            1_i64.into(), // numkeys
303            self.key.clone().into(),
304            self.lock_id.clone().into(),
305        ];
306
307        let cmd = CustomCommand::new_static("EVAL", None, false);
308
309        // Remove our ticket from the sorted set
310        let _: i64 = self.client.custom(cmd, args).await.map_err(|e| {
311            LockError::Backend(Box::new(std::io::Error::other(format!(
312                "failed to release semaphore ticket: {}",
313                e
314            ))))
315        })?;
316
317        Ok(())
318    }
319}
320
321impl Drop for RedisSemaphoreHandle {
322    fn drop(&mut self) {
323        // Abort extension task
324        self._extension_task.abort();
325        // Note: We cannot async release in Drop, so the ticket will expire naturally
326        // For proper cleanup, users should call release() explicitly
327    }
328}