1use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::traits::DistributedReaderWriterLock;
7
8use sqlx::Row;
9
10use crate::name::encode_lock_name;
11
12#[derive(Debug, Clone)]
14struct LockState {
15 reader_count: i32,
17 writer_held: i32,
19 #[allow(dead_code)]
21 version: i32,
22}
23
24pub struct MySqlDistributedReaderWriterLock {
42 name: String,
44 encoded_name: String,
46 pool: sqlx::MySqlPool,
48 keepalive_cadence: Option<Duration>,
50}
51
52impl MySqlDistributedReaderWriterLock {
53 pub(crate) fn new(
55 name: String,
56 pool: sqlx::MySqlPool,
57 keepalive_cadence: Option<Duration>,
58 ) -> Self {
59 let encoded_name = encode_lock_name(&name);
60 Self {
61 name,
62 encoded_name,
63 pool,
64 keepalive_cadence,
65 }
66 }
67
68 async fn ensure_table_exists(&self) -> LockResult<()> {
70 sqlx::query(
71 r#"
72 CREATE TABLE IF NOT EXISTS distributed_locks (
73 lock_name VARCHAR(255) PRIMARY KEY,
74 reader_count INT NOT NULL DEFAULT 0,
75 writer_held TINYINT(1) NOT NULL DEFAULT 0,
76 version INT NOT NULL DEFAULT 0,
77 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
78 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
79 )
80 "#,
81 )
82 .execute(&self.pool)
83 .await
84 .map_err(|e| LockError::Connection(Box::new(e)))?;
85 Ok(())
86 }
87
88 async fn get_lock_state(&self) -> LockResult<LockState> {
90 self.ensure_table_exists().await?;
91
92 let result = sqlx::query(
93 "SELECT reader_count, writer_held, version FROM distributed_locks WHERE lock_name = ?",
94 )
95 .bind(&self.encoded_name)
96 .fetch_optional(&self.pool)
97 .await
98 .map_err(|e| LockError::Connection(Box::new(e)))?;
99
100 match result {
101 Some(row) => Ok(LockState {
102 reader_count: row
103 .try_get(0)
104 .map_err(|e| LockError::Connection(Box::new(e)))?,
105 writer_held: row
106 .try_get(1)
107 .map_err(|e| LockError::Connection(Box::new(e)))?,
108 version: row
109 .try_get(2)
110 .map_err(|e| LockError::Connection(Box::new(e)))?,
111 }),
112 None => Ok(LockState {
113 reader_count: 0,
114 writer_held: 0,
115 version: 0,
116 }),
117 }
118 }
119
120 async fn try_acquire_read_internal(&self) -> LockResult<Option<MySqlReadLockHandle>> {
122 self.ensure_table_exists().await?;
123
124 let mut transaction = self
125 .pool
126 .begin()
127 .await
128 .map_err(|e| LockError::Connection(Box::new(e)))?;
129
130 let state = self.get_lock_state().await?;
132
133 if state.writer_held != 0 {
134 transaction
136 .rollback()
137 .await
138 .map_err(|e| LockError::Connection(Box::new(e)))?;
139 return Ok(None);
140 }
141
142 let result = sqlx::query(
144 r#"
145 INSERT INTO distributed_locks (lock_name, reader_count, writer_held, version)
146 VALUES (?, 1, 0, 1)
147 ON DUPLICATE KEY UPDATE
148 reader_count = reader_count + 1,
149 version = version + 1
150 "#,
151 )
152 .bind(&self.encoded_name)
153 .execute(&mut *transaction)
154 .await;
155
156 match result {
157 Ok(_) => {
158 transaction
159 .commit()
160 .await
161 .map_err(|e| LockError::Connection(Box::new(e)))?;
162
163 Ok(Some(MySqlReadLockHandle::new(
164 self.encoded_name.clone(),
165 self.pool.clone(),
166 self.keepalive_cadence,
167 )))
168 }
169 Err(e) => {
170 let error = LockError::Connection(Box::new(e));
171 transaction
172 .rollback()
173 .await
174 .map_err(|rollback_e| LockError::Connection(Box::new(rollback_e)))?;
175 Err(error)
176 }
177 }
178 }
179
180 async fn try_acquire_write_internal(&self) -> LockResult<Option<MySqlWriteLockHandle>> {
182 self.ensure_table_exists().await?;
183
184 let mut transaction = self
185 .pool
186 .begin()
187 .await
188 .map_err(|e| LockError::Connection(Box::new(e)))?;
189
190 let state = self.get_lock_state().await?;
192
193 if state.reader_count > 0 || state.writer_held != 0 {
194 transaction
196 .rollback()
197 .await
198 .map_err(|e| LockError::Connection(Box::new(e)))?;
199 return Ok(None);
200 }
201
202 let check_result = sqlx::query(
204 "SELECT reader_count, writer_held FROM distributed_locks WHERE lock_name = ?",
205 )
206 .bind(&self.encoded_name)
207 .fetch_optional(&mut *transaction)
208 .await
209 .map_err(|e| LockError::Connection(Box::new(e)))?;
210
211 let can_acquire = match check_result {
212 Some(row) => {
213 let reader_count: i32 = row
214 .try_get(0)
215 .map_err(|e| LockError::Connection(Box::new(e)))?;
216 let writer_held: i32 = row
217 .try_get(1)
218 .map_err(|e| LockError::Connection(Box::new(e)))?;
219 reader_count == 0 && writer_held == 0
220 }
221 None => true, };
223
224 if !can_acquire {
225 transaction
226 .rollback()
227 .await
228 .map_err(|e| LockError::Connection(Box::new(e)))?;
229 return Ok(None);
230 }
231
232 let result = sqlx::query(
234 r#"
235 INSERT INTO distributed_locks (lock_name, reader_count, writer_held, version)
236 VALUES (?, 0, 1, 1)
237 ON DUPLICATE KEY UPDATE
238 writer_held = VALUES(writer_held),
239 version = version + 1
240 "#,
241 )
242 .bind(&self.encoded_name)
243 .execute(&mut *transaction)
244 .await;
245
246 match result {
247 Ok(result) => {
248 if result.rows_affected() > 0 {
249 transaction
250 .commit()
251 .await
252 .map_err(|e| LockError::Connection(Box::new(e)))?;
253
254 Ok(Some(MySqlWriteLockHandle::new(
255 self.encoded_name.clone(),
256 self.pool.clone(),
257 self.keepalive_cadence,
258 )))
259 } else {
260 transaction
262 .rollback()
263 .await
264 .map_err(|e| LockError::Connection(Box::new(e)))?;
265 Ok(None)
266 }
267 }
268 Err(e) => {
269 let error = LockError::Connection(Box::new(e));
270 transaction
271 .rollback()
272 .await
273 .map_err(|rollback_e| LockError::Connection(Box::new(rollback_e)))?;
274 Err(error)
275 }
276 }
277 }
278}
279
280impl DistributedReaderWriterLock for MySqlDistributedReaderWriterLock {
281 type ReadHandle = MySqlReadLockHandle;
282 type WriteHandle = MySqlWriteLockHandle;
283
284 fn name(&self) -> &str {
285 &self.name
286 }
287
288 async fn acquire_read(&self, timeout: Option<Duration>) -> LockResult<Self::ReadHandle> {
289 let start_time = std::time::Instant::now();
290
291 loop {
292 match self.try_acquire_read_internal().await? {
293 Some(handle) => return Ok(handle),
294 None => {
295 if let Some(timeout_duration) = timeout
297 && start_time.elapsed() >= timeout_duration
298 {
299 return Err(LockError::Timeout(timeout_duration));
300 }
301
302 tokio::time::sleep(Duration::from_millis(10)).await;
304 }
305 }
306 }
307 }
308
309 async fn try_acquire_read(&self) -> LockResult<Option<Self::ReadHandle>> {
310 self.try_acquire_read_internal().await
311 }
312
313 async fn acquire_write(&self, timeout: Option<Duration>) -> LockResult<Self::WriteHandle> {
314 let start_time = std::time::Instant::now();
315
316 loop {
317 match self.try_acquire_write_internal().await? {
318 Some(handle) => return Ok(handle),
319 None => {
320 if let Some(timeout_duration) = timeout
322 && start_time.elapsed() >= timeout_duration
323 {
324 return Err(LockError::Timeout(timeout_duration));
325 }
326
327 tokio::time::sleep(Duration::from_millis(10)).await;
329 }
330 }
331 }
332 }
333
334 async fn try_acquire_write(&self) -> LockResult<Option<Self::WriteHandle>> {
335 self.try_acquire_write_internal().await
336 }
337}
338
339pub struct MySqlReadLockHandle {
341 lock_name: String,
342 pool: sqlx::MySqlPool,
343 lost_sender: tokio::sync::watch::Sender<bool>,
344 lost_receiver: tokio::sync::watch::Receiver<bool>,
345 keepalive_handle: Option<tokio::task::JoinHandle<()>>,
346}
347
348impl MySqlReadLockHandle {
349 pub(crate) fn new(
350 lock_name: String,
351 pool: sqlx::MySqlPool,
352 keepalive_cadence: Option<Duration>,
353 ) -> Self {
354 let (lost_token_tx, lost_receiver) = tokio::sync::watch::channel(false);
355
356 let lost_token_tx_clone = lost_token_tx.clone();
357 let keepalive_handle = keepalive_cadence.map(|cadence| {
358 let pool_clone = pool.clone();
359 let mut lost_token_rx_clone = lost_token_tx_clone.subscribe();
360
361 tokio::spawn(async move {
362 loop {
363 tokio::select! {
364 _ = tokio::time::sleep(cadence) => {
365 let result = sqlx::query("SELECT 1")
367 .execute(&pool_clone)
368 .await;
369
370 if result.is_err() {
371 let _ = lost_token_tx_clone.send(true);
373 break;
374 }
375 }
376 _ = lost_token_rx_clone.changed() => {
377 break;
379 }
380 }
381 }
382 })
383 });
384
385 Self {
386 lock_name,
387 pool,
388 lost_sender: lost_token_tx,
389 lost_receiver,
390 keepalive_handle,
391 }
392 }
393}
394
395impl distributed_lock_core::traits::LockHandle for MySqlReadLockHandle {
396 fn lost_token(&self) -> &tokio::sync::watch::Receiver<bool> {
397 &self.lost_receiver
398 }
399
400 async fn release(self) -> LockResult<()> {
401 if let Some(handle) = &self.keepalive_handle {
403 handle.abort();
404 }
405
406 let _ = self.lost_sender.send(true);
408
409 let result = sqlx::query(
411 "UPDATE distributed_locks SET reader_count = GREATEST(reader_count - 1, 0), version = version + 1 WHERE lock_name = ?"
412 )
413 .bind(&self.lock_name)
414 .execute(&self.pool)
415 .await;
416
417 match result {
418 Ok(_) => Ok(()),
419 Err(e) => Err(distributed_lock_core::error::LockError::Connection(
420 Box::new(e),
421 )),
422 }
423 }
424}
425
426impl Drop for MySqlReadLockHandle {
427 fn drop(&mut self) {
428 let _ = self.lost_sender.send(true);
430
431 if let Some(handle) = self.keepalive_handle.take() {
433 handle.abort();
434 }
435 }
436}
437
438pub struct MySqlWriteLockHandle {
440 lock_name: String,
441 pool: sqlx::MySqlPool,
442 lost_sender: tokio::sync::watch::Sender<bool>,
443 lost_receiver: tokio::sync::watch::Receiver<bool>,
444 keepalive_handle: Option<tokio::task::JoinHandle<()>>,
445}
446
447impl MySqlWriteLockHandle {
448 pub(crate) fn new(
449 lock_name: String,
450 pool: sqlx::MySqlPool,
451 keepalive_cadence: Option<Duration>,
452 ) -> Self {
453 let (lost_token_tx, lost_receiver) = tokio::sync::watch::channel(false);
454
455 let lost_token_tx_clone = lost_token_tx.clone();
456 let keepalive_handle = keepalive_cadence.map(|cadence| {
457 let pool_clone = pool.clone();
458 let mut lost_token_rx_clone = lost_token_tx_clone.subscribe();
459
460 tokio::spawn(async move {
461 loop {
462 tokio::select! {
463 _ = tokio::time::sleep(cadence) => {
464 let result = sqlx::query("SELECT 1")
466 .execute(&pool_clone)
467 .await;
468
469 if result.is_err() {
470 let _ = lost_token_tx_clone.send(true);
472 break;
473 }
474 }
475 _ = lost_token_rx_clone.changed() => {
476 break;
478 }
479 }
480 }
481 })
482 });
483
484 Self {
485 lock_name,
486 pool,
487 lost_sender: lost_token_tx,
488 lost_receiver,
489 keepalive_handle,
490 }
491 }
492}
493
494impl distributed_lock_core::traits::LockHandle for MySqlWriteLockHandle {
495 fn lost_token(&self) -> &tokio::sync::watch::Receiver<bool> {
496 &self.lost_receiver
497 }
498
499 async fn release(self) -> LockResult<()> {
500 if let Some(handle) = &self.keepalive_handle {
502 handle.abort();
503 }
504
505 let _ = self.lost_sender.send(true);
507
508 let result = sqlx::query(
510 "UPDATE distributed_locks SET writer_held = 0, version = version + 1 WHERE lock_name = ?"
511 )
512 .bind(&self.lock_name)
513 .execute(&self.pool)
514 .await;
515
516 match result {
517 Ok(_) => Ok(()),
518 Err(e) => Err(distributed_lock_core::error::LockError::Connection(
519 Box::new(e),
520 )),
521 }
522 }
523}
524
525impl Drop for MySqlWriteLockHandle {
526 fn drop(&mut self) {
527 let _ = self.lost_sender.send(true);
529
530 if let Some(handle) = self.keepalive_handle.take() {
532 handle.abort();
533 }
534 }
535}