1use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::traits::DistributedLock;
7use fred::prelude::*;
8use fred::types::CustomCommand; use tracing::{Span, instrument};
10
11use crate::redlock::{acquire::acquire_redlock, helper::RedLockHelper, timeouts::RedLockTimeouts};
12
13#[derive(Debug, Clone)]
15pub struct RedisLockState {
16 pub key: String,
18 pub lock_id: String,
20 pub timeouts: RedLockTimeouts,
22}
23
24impl RedisLockState {
25 pub fn new(key: String, timeouts: RedLockTimeouts) -> Self {
27 Self {
28 key,
29 lock_id: RedLockHelper::create_lock_id(),
30 timeouts,
31 }
32 }
33
34 pub async fn try_acquire(&self, client: &RedisClient) -> LockResult<bool> {
36 let expiry_millis = self.timeouts.expiry.as_millis() as i64;
37
38 let result: Option<String> = client
41 .set(
42 &self.key,
43 &self.lock_id,
44 Some(Expiration::PX(expiry_millis)),
45 Some(SetOptions::NX),
46 false,
47 )
48 .await
49 .map_err(|e| {
50 LockError::Backend(Box::new(std::io::Error::other(format!(
51 "Redis SET NX failed: {}",
52 e
53 ))))
54 })?;
55
56 Ok(result.is_some())
58 }
59
60 const EXTEND_SCRIPT_LUA: &'static str = r#"
68 if redis.call('get', KEYS[1]) == ARGV[1] then
69 return redis.call('pexpire', KEYS[1], ARGV[2])
70 end
71 return 0
72 "#;
73
74 const RELEASE_SCRIPT_LUA: &'static str = r#"
76 if redis.call('get', KEYS[1]) == ARGV[1] then
77 return redis.call('del', KEYS[1])
78 end
79 return 0
80 "#;
81
82 pub async fn try_extend(&self, client: &RedisClient) -> LockResult<bool> {
86 let expiry_millis = self.timeouts.expiry.as_millis() as i64;
87
88 let args: Vec<RedisValue> = vec![
89 Self::EXTEND_SCRIPT_LUA.into(),
90 1_i64.into(), self.key.clone().into(),
92 self.lock_id.clone().into(),
93 expiry_millis.into(),
94 ];
95
96 let cmd = CustomCommand::new_static("EVAL", None, false);
98
99 let result: i64 = client.custom(cmd, args).await.map_err(|e| {
100 LockError::Backend(Box::new(std::io::Error::other(format!(
101 "Redis custom EVAL (extend) failed: {}",
102 e
103 ))))
104 })?;
105
106 Ok(result == 1)
107 }
108
109 pub async fn try_release(&self, client: &RedisClient) -> LockResult<()> {
113 let args: Vec<RedisValue> = vec![
114 Self::RELEASE_SCRIPT_LUA.into(),
115 1_i64.into(), self.key.clone().into(),
117 self.lock_id.clone().into(),
118 ];
119
120 let cmd = CustomCommand::new_static("EVAL", None, false);
121
122 let _: i64 = client.custom(cmd, args).await.map_err(|e| {
123 LockError::Backend(Box::new(std::io::Error::other(format!(
124 "Redis custom EVAL (release) failed: {}",
125 e
126 ))))
127 })?;
128
129 Ok(())
130 }
131}
132
133pub struct RedisDistributedLock {
137 state: RedisLockState,
139 clients: Vec<RedisClient>,
141 extension_cadence: Duration,
143}
144
145impl RedisDistributedLock {
146 pub(crate) fn new(
148 name: String,
149 clients: Vec<RedisClient>,
150 expiry: Duration,
151 min_validity: Duration,
152 extension_cadence: Duration,
153 ) -> Self {
154 let key = format!("distributed-lock:{}", name);
155 let timeouts = RedLockTimeouts::new(expiry, min_validity);
156
157 Self {
158 state: RedisLockState::new(key, timeouts),
159 clients,
160 extension_cadence,
161 }
162 }
163
164 pub fn name(&self) -> &str {
166 self.state
168 .key
169 .strip_prefix("distributed-lock:")
170 .unwrap_or(&self.state.key)
171 }
172}
173
174impl DistributedLock for RedisDistributedLock {
175 type Handle = crate::handle::RedisLockHandle;
176
177 fn name(&self) -> &str {
178 self.name()
179 }
180
181 #[instrument(skip(self), fields(lock.name = %self.name(), lock.key = %self.state.key, timeout = ?timeout, backend = "redis", servers = self.clients.len()))]
182 async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
183 use tokio::sync::watch;
184
185 let start = std::time::Instant::now();
186 Span::current().record("operation", "acquire");
187
188 let (cancel_sender, cancel_receiver) = watch::channel(false);
190
191 if let Some(timeout_duration) = timeout {
193 let cancel_sender_clone = cancel_sender.clone();
194 tokio::spawn(async move {
195 tokio::time::sleep(timeout_duration).await;
196 let _ = cancel_sender_clone.send(true);
197 });
198 }
199
200 let state = self.state.clone();
202 let clients = self.clients.clone();
203 let timeouts = self.state.timeouts.clone();
204 let acquire_result = acquire_redlock(
205 move |client| {
206 let state = state.clone();
207 let client = client.clone();
208 async move { state.try_acquire(&client).await }
209 },
210 &clients,
211 &timeouts,
212 &cancel_receiver,
213 )
214 .await?;
215
216 let acquire_result = match acquire_result {
217 Some(result) if result.is_successful(clients.len()) => {
218 let elapsed = start.elapsed();
219 Span::current().record("acquired", true);
220 Span::current().record("elapsed_ms", elapsed.as_millis() as u64);
221 Span::current().record(
222 "servers_acquired",
223 result.acquire_results.iter().filter(|&&b| b).count(),
224 );
225 result
226 }
227 _ => {
228 Span::current().record("acquired", false);
229 Span::current().record("error", "timeout");
230 return Err(LockError::Timeout(
231 timeout.unwrap_or(Duration::from_secs(0)),
232 ));
233 }
234 };
235
236 Ok(crate::handle::RedisLockHandle::new(
238 self.state.clone(),
239 acquire_result.acquire_results,
240 clients,
241 self.extension_cadence,
242 self.state.timeouts.expiry,
243 ))
244 }
245
246 #[instrument(skip(self), fields(lock.name = %self.name(), lock.key = %self.state.key, backend = "redis", servers = self.clients.len()))]
247 async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
248 use tokio::sync::watch;
249
250 Span::current().record("operation", "try_acquire");
251
252 let (_cancel_sender, cancel_receiver) = watch::channel(false);
254
255 let state = self.state.clone();
257 let clients = self.clients.clone();
258 let timeouts = self.state.timeouts.clone();
259 let acquire_result = acquire_redlock(
260 move |client| {
261 let state = state.clone();
262 let client = client.clone();
263 async move { state.try_acquire(&client).await }
264 },
265 &clients,
266 &timeouts,
267 &cancel_receiver,
268 )
269 .await?;
270
271 match acquire_result {
272 Some(result) if result.is_successful(clients.len()) => {
273 Span::current().record("acquired", true);
274 Span::current().record(
275 "servers_acquired",
276 result.acquire_results.iter().filter(|&&b| b).count(),
277 );
278 Ok(Some(crate::handle::RedisLockHandle::new(
279 self.state.clone(),
280 result.acquire_results,
281 clients,
282 self.extension_cadence,
283 self.state.timeouts.expiry,
284 )))
285 }
286 _ => {
287 Span::current().record("acquired", false);
288 Span::current().record("reason", "lock_held");
289 Ok(None)
290 }
291 }
292 }
293}