1use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::traits::DistributedLock;
7use fred::prelude::*;
8use tracing::{instrument, Span};
9
10use crate::redlock::{acquire::acquire_redlock, helper::RedLockHelper, timeouts::RedLockTimeouts};
11
12#[derive(Debug, Clone)]
14pub struct RedisLockState {
15 pub key: String,
17 pub lock_id: String,
19 pub timeouts: RedLockTimeouts,
21}
22
23impl RedisLockState {
24 pub fn new(key: String, timeouts: RedLockTimeouts) -> Self {
26 Self {
27 key,
28 lock_id: RedLockHelper::create_lock_id(),
29 timeouts,
30 }
31 }
32
33 pub async fn try_acquire(&self, client: &RedisClient) -> LockResult<bool> {
35 let expiry_millis = self.timeouts.expiry.as_millis() as i64;
36
37 let result: Option<String> = client
40 .set(
41 &self.key,
42 &self.lock_id,
43 Some(Expiration::PX(expiry_millis)),
44 Some(SetOptions::NX),
45 false,
46 )
47 .await
48 .map_err(|e| {
49 LockError::Backend(Box::new(std::io::Error::other(format!(
50 "Redis SET NX failed: {}",
51 e
52 ))))
53 })?;
54
55 Ok(result.is_some())
57 }
58
59 pub async fn try_extend(&self, client: &RedisClient) -> LockResult<bool> {
64 let expiry_millis = self.timeouts.expiry.as_millis() as i64;
65
66 let current_value: Option<String> = client.get(&self.key).await.map_err(|e| {
68 LockError::Backend(Box::new(std::io::Error::other(format!(
69 "Redis GET failed: {}",
70 e
71 ))))
72 })?;
73
74 match current_value {
75 Some(value) if value == self.lock_id => {
76 let _: bool = client
78 .pexpire(&self.key, expiry_millis, None)
79 .await
80 .map_err(|e| {
81 LockError::Backend(Box::new(std::io::Error::other(format!(
82 "Redis PEXPIRE failed: {}",
83 e
84 ))))
85 })?;
86 Ok(true)
87 }
88 _ => Ok(false), }
90 }
91
92 pub async fn try_release(&self, client: &RedisClient) -> LockResult<()> {
97 let current_value: Option<String> = client.get(&self.key).await.map_err(|e| {
99 LockError::Backend(Box::new(std::io::Error::other(format!(
100 "Redis GET failed: {}",
101 e
102 ))))
103 })?;
104
105 match current_value {
106 Some(value) if value == self.lock_id => {
107 let _: i64 = client.del(&self.key).await.map_err(|e| {
109 LockError::Backend(Box::new(std::io::Error::other(format!(
110 "Redis DEL failed: {}",
111 e
112 ))))
113 })?;
114 Ok(())
115 }
116 _ => {
117 Ok(())
119 }
120 }
121 }
122}
123
124pub struct RedisDistributedLock {
128 state: RedisLockState,
130 clients: Vec<RedisClient>,
132 extension_cadence: Duration,
134}
135
136impl RedisDistributedLock {
137 pub(crate) fn new(
139 name: String,
140 clients: Vec<RedisClient>,
141 expiry: Duration,
142 min_validity: Duration,
143 extension_cadence: Duration,
144 ) -> Self {
145 let key = format!("distributed-lock:{}", name);
146 let timeouts = RedLockTimeouts::new(expiry, min_validity);
147
148 Self {
149 state: RedisLockState::new(key, timeouts),
150 clients,
151 extension_cadence,
152 }
153 }
154
155 pub fn name(&self) -> &str {
157 self.state
159 .key
160 .strip_prefix("distributed-lock:")
161 .unwrap_or(&self.state.key)
162 }
163}
164
165impl DistributedLock for RedisDistributedLock {
166 type Handle = crate::handle::RedisLockHandle;
167
168 fn name(&self) -> &str {
169 self.name()
170 }
171
172 #[instrument(skip(self), fields(lock.name = %self.name(), lock.key = %self.state.key, timeout = ?timeout, backend = "redis", servers = self.clients.len()))]
173 async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
174 use tokio::sync::watch;
175
176 let start = std::time::Instant::now();
177 Span::current().record("operation", "acquire");
178
179 let (cancel_sender, cancel_receiver) = watch::channel(false);
181
182 if let Some(timeout_duration) = timeout {
184 let cancel_sender_clone = cancel_sender.clone();
185 tokio::spawn(async move {
186 tokio::time::sleep(timeout_duration).await;
187 let _ = cancel_sender_clone.send(true);
188 });
189 }
190
191 let state = self.state.clone();
193 let clients = self.clients.clone();
194 let timeouts = self.state.timeouts.clone();
195 let acquire_result = acquire_redlock(
196 move |client| {
197 let state = state.clone();
198 let client = client.clone();
199 async move { state.try_acquire(&client).await }
200 },
201 &clients,
202 &timeouts,
203 &cancel_receiver,
204 )
205 .await?;
206
207 let acquire_result = match acquire_result {
208 Some(result) if result.is_successful(clients.len()) => {
209 let elapsed = start.elapsed();
210 Span::current().record("acquired", true);
211 Span::current().record("elapsed_ms", elapsed.as_millis() as u64);
212 Span::current().record(
213 "servers_acquired",
214 result.acquire_results.iter().filter(|&&b| b).count(),
215 );
216 result
217 }
218 _ => {
219 Span::current().record("acquired", false);
220 Span::current().record("error", "timeout");
221 return Err(LockError::Timeout(
222 timeout.unwrap_or(Duration::from_secs(0)),
223 ));
224 }
225 };
226
227 Ok(crate::handle::RedisLockHandle::new(
229 self.state.clone(),
230 acquire_result.acquire_results,
231 clients,
232 self.extension_cadence,
233 self.state.timeouts.expiry,
234 ))
235 }
236
237 #[instrument(skip(self), fields(lock.name = %self.name(), lock.key = %self.state.key, backend = "redis", servers = self.clients.len()))]
238 async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
239 use tokio::sync::watch;
240
241 Span::current().record("operation", "try_acquire");
242
243 let (_cancel_sender, cancel_receiver) = watch::channel(false);
245
246 let state = self.state.clone();
248 let clients = self.clients.clone();
249 let timeouts = self.state.timeouts.clone();
250 let acquire_result = acquire_redlock(
251 move |client| {
252 let state = state.clone();
253 let client = client.clone();
254 async move { state.try_acquire(&client).await }
255 },
256 &clients,
257 &timeouts,
258 &cancel_receiver,
259 )
260 .await?;
261
262 let result = match acquire_result {
263 Some(result) if result.is_successful(clients.len()) => {
264 Span::current().record("acquired", true);
265 Span::current().record(
266 "servers_acquired",
267 result.acquire_results.iter().filter(|&&b| b).count(),
268 );
269 Ok(Some(crate::handle::RedisLockHandle::new(
270 self.state.clone(),
271 result.acquire_results,
272 clients,
273 self.extension_cadence,
274 self.state.timeouts.expiry,
275 )))
276 }
277 _ => {
278 Span::current().record("acquired", false);
279 Span::current().record("reason", "lock_held");
280 Ok(None)
281 }
282 };
283 result
284 }
285}