distributed_lock_mysql/
rw_lock.rs

1//! MySQL distributed reader-writer lock implementation.
2
3use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::traits::DistributedReaderWriterLock;
7
8use sqlx::Row;
9
10use crate::name::encode_lock_name;
11
12/// State of a MySQL reader-writer lock stored in the database.
13#[derive(Debug, Clone)]
14struct LockState {
15    /// Number of active readers.
16    reader_count: i32,
17    /// Whether a writer holds the lock (1 = held, 0 = not held).
18    writer_held: i32,
19    /// Version for optimistic locking.
20    #[allow(dead_code)]
21    version: i32,
22}
23
24/// A MySQL-based distributed reader-writer lock.
25///
26/// Uses a database table to track reader-writer state and transactions
27/// to ensure atomic operations. This allows multiple readers to hold
28/// the lock simultaneously while ensuring writers get exclusive access.
29///
30/// The table schema is:
31/// ```sql
32/// CREATE TABLE distributed_locks (
33///     lock_name VARCHAR(255) PRIMARY KEY,
34///     reader_count INT NOT NULL DEFAULT 0,
35///     writer_held TINYINT(1) NOT NULL DEFAULT 0,
36///     version INT NOT NULL DEFAULT 0,
37///     created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
38///     updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
39/// );
40/// ```
41pub struct MySqlDistributedReaderWriterLock {
42    /// The original lock name.
43    name: String,
44    /// The encoded lock name used as database key.
45    encoded_name: String,
46    /// MySQL connection pool.
47    pool: sqlx::MySqlPool,
48    /// Keepalive cadence for long-held locks.
49    keepalive_cadence: Option<Duration>,
50}
51
52impl MySqlDistributedReaderWriterLock {
53    /// Creates a new MySQL distributed reader-writer lock.
54    pub(crate) fn new(
55        name: String,
56        pool: sqlx::MySqlPool,
57        keepalive_cadence: Option<Duration>,
58    ) -> Self {
59        let encoded_name = encode_lock_name(&name);
60        Self {
61            name,
62            encoded_name,
63            pool,
64            keepalive_cadence,
65        }
66    }
67
68    /// Ensures the lock table exists in the database.
69    async fn ensure_table_exists(&self) -> LockResult<()> {
70        sqlx::query(
71            r#"
72            CREATE TABLE IF NOT EXISTS distributed_locks (
73                lock_name VARCHAR(255) PRIMARY KEY,
74                reader_count INT NOT NULL DEFAULT 0,
75                writer_held TINYINT(1) NOT NULL DEFAULT 0,
76                version INT NOT NULL DEFAULT 0,
77                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
78                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
79            )
80            "#,
81        )
82        .execute(&self.pool)
83        .await
84        .map_err(|e| LockError::Connection(Box::new(e)))?;
85        Ok(())
86    }
87
88    /// Gets the current lock state from the database.
89    async fn get_lock_state(&self) -> LockResult<LockState> {
90        self.ensure_table_exists().await?;
91
92        let result = sqlx::query(
93            "SELECT reader_count, writer_held, version FROM distributed_locks WHERE lock_name = ?",
94        )
95        .bind(&self.encoded_name)
96        .fetch_optional(&self.pool)
97        .await
98        .map_err(|e| LockError::Connection(Box::new(e)))?;
99
100        match result {
101            Some(row) => Ok(LockState {
102                reader_count: row
103                    .try_get(0)
104                    .map_err(|e| LockError::Connection(Box::new(e)))?,
105                writer_held: row
106                    .try_get(1)
107                    .map_err(|e| LockError::Connection(Box::new(e)))?,
108                version: row
109                    .try_get(2)
110                    .map_err(|e| LockError::Connection(Box::new(e)))?,
111            }),
112            None => Ok(LockState {
113                reader_count: 0,
114                writer_held: 0,
115                version: 0,
116            }),
117        }
118    }
119
120    /// Attempts to acquire a read lock using database transactions.
121    async fn try_acquire_read_internal(&self) -> LockResult<Option<MySqlReadLockHandle>> {
122        self.ensure_table_exists().await?;
123
124        let mut transaction = self
125            .pool
126            .begin()
127            .await
128            .map_err(|e| LockError::Connection(Box::new(e)))?;
129
130        // Check if writer holds the lock
131        let state = self.get_lock_state().await?;
132
133        if state.writer_held != 0 {
134            // Writer holds the lock, cannot acquire read lock
135            transaction
136                .rollback()
137                .await
138                .map_err(|e| LockError::Connection(Box::new(e)))?;
139            return Ok(None);
140        }
141
142        // Insert or update the lock record to increment reader count
143        let result = sqlx::query(
144            r#"
145            INSERT INTO distributed_locks (lock_name, reader_count, writer_held, version)
146            VALUES (?, 1, 0, 1)
147            ON DUPLICATE KEY UPDATE
148                reader_count = reader_count + 1,
149                version = version + 1
150            "#,
151        )
152        .bind(&self.encoded_name)
153        .execute(&mut *transaction)
154        .await;
155
156        match result {
157            Ok(_) => {
158                transaction
159                    .commit()
160                    .await
161                    .map_err(|e| LockError::Connection(Box::new(e)))?;
162
163                Ok(Some(MySqlReadLockHandle::new(
164                    self.encoded_name.clone(),
165                    self.pool.clone(),
166                    self.keepalive_cadence,
167                )))
168            }
169            Err(e) => {
170                let error = LockError::Connection(Box::new(e));
171                transaction
172                    .rollback()
173                    .await
174                    .map_err(|rollback_e| LockError::Connection(Box::new(rollback_e)))?;
175                Err(error)
176            }
177        }
178    }
179
180    /// Attempts to acquire a write lock using database transactions.
181    async fn try_acquire_write_internal(&self) -> LockResult<Option<MySqlWriteLockHandle>> {
182        self.ensure_table_exists().await?;
183
184        let mut transaction = self
185            .pool
186            .begin()
187            .await
188            .map_err(|e| LockError::Connection(Box::new(e)))?;
189
190        // Check if any readers or writers hold the lock
191        let state = self.get_lock_state().await?;
192
193        if state.reader_count > 0 || state.writer_held != 0 {
194            // Lock is held by readers or writer
195            transaction
196                .rollback()
197                .await
198                .map_err(|e| LockError::Connection(Box::new(e)))?;
199            return Ok(None);
200        }
201
202        // Acquire the write lock - first check if we can acquire it
203        let check_result = sqlx::query(
204            "SELECT reader_count, writer_held FROM distributed_locks WHERE lock_name = ?",
205        )
206        .bind(&self.encoded_name)
207        .fetch_optional(&mut *transaction)
208        .await
209        .map_err(|e| LockError::Connection(Box::new(e)))?;
210
211        let can_acquire = match check_result {
212            Some(row) => {
213                let reader_count: i32 = row
214                    .try_get(0)
215                    .map_err(|e| LockError::Connection(Box::new(e)))?;
216                let writer_held: i32 = row
217                    .try_get(1)
218                    .map_err(|e| LockError::Connection(Box::new(e)))?;
219                reader_count == 0 && writer_held == 0
220            }
221            None => true, // Lock doesn't exist yet, we can create it
222        };
223
224        if !can_acquire {
225            transaction
226                .rollback()
227                .await
228                .map_err(|e| LockError::Connection(Box::new(e)))?;
229            return Ok(None);
230        }
231
232        // Now acquire the write lock
233        let result = sqlx::query(
234            r#"
235            INSERT INTO distributed_locks (lock_name, reader_count, writer_held, version)
236            VALUES (?, 0, 1, 1)
237            ON DUPLICATE KEY UPDATE
238                writer_held = VALUES(writer_held),
239                version = version + 1
240            "#,
241        )
242        .bind(&self.encoded_name)
243        .execute(&mut *transaction)
244        .await;
245
246        match result {
247            Ok(result) => {
248                if result.rows_affected() > 0 {
249                    transaction
250                        .commit()
251                        .await
252                        .map_err(|e| LockError::Connection(Box::new(e)))?;
253
254                    Ok(Some(MySqlWriteLockHandle::new(
255                        self.encoded_name.clone(),
256                        self.pool.clone(),
257                        self.keepalive_cadence,
258                    )))
259                } else {
260                    // Could not acquire (condition not met)
261                    transaction
262                        .rollback()
263                        .await
264                        .map_err(|e| LockError::Connection(Box::new(e)))?;
265                    Ok(None)
266                }
267            }
268            Err(e) => {
269                let error = LockError::Connection(Box::new(e));
270                transaction
271                    .rollback()
272                    .await
273                    .map_err(|rollback_e| LockError::Connection(Box::new(rollback_e)))?;
274                Err(error)
275            }
276        }
277    }
278}
279
280impl DistributedReaderWriterLock for MySqlDistributedReaderWriterLock {
281    type ReadHandle = MySqlReadLockHandle;
282    type WriteHandle = MySqlWriteLockHandle;
283
284    fn name(&self) -> &str {
285        &self.name
286    }
287
288    async fn acquire_read(&self, timeout: Option<Duration>) -> LockResult<Self::ReadHandle> {
289        let start_time = std::time::Instant::now();
290
291        loop {
292            match self.try_acquire_read_internal().await? {
293                Some(handle) => return Ok(handle),
294                None => {
295                    // Check if we've exceeded the timeout
296                    if let Some(timeout_duration) = timeout
297                        && start_time.elapsed() >= timeout_duration
298                    {
299                        return Err(LockError::Timeout(timeout_duration));
300                    }
301
302                    // Wait a bit before retrying
303                    tokio::time::sleep(Duration::from_millis(10)).await;
304                }
305            }
306        }
307    }
308
309    async fn try_acquire_read(&self) -> LockResult<Option<Self::ReadHandle>> {
310        self.try_acquire_read_internal().await
311    }
312
313    async fn acquire_write(&self, timeout: Option<Duration>) -> LockResult<Self::WriteHandle> {
314        let start_time = std::time::Instant::now();
315
316        loop {
317            match self.try_acquire_write_internal().await? {
318                Some(handle) => return Ok(handle),
319                None => {
320                    // Check if we've exceeded the timeout
321                    if let Some(timeout_duration) = timeout
322                        && start_time.elapsed() >= timeout_duration
323                    {
324                        return Err(LockError::Timeout(timeout_duration));
325                    }
326
327                    // Wait a bit before retrying
328                    tokio::time::sleep(Duration::from_millis(10)).await;
329                }
330            }
331        }
332    }
333
334    async fn try_acquire_write(&self) -> LockResult<Option<Self::WriteHandle>> {
335        self.try_acquire_write_internal().await
336    }
337}
338
339/// Handle for a MySQL read lock.
340pub struct MySqlReadLockHandle {
341    lock_name: String,
342    pool: sqlx::MySqlPool,
343    lost_sender: tokio::sync::watch::Sender<bool>,
344    lost_receiver: tokio::sync::watch::Receiver<bool>,
345    keepalive_handle: Option<tokio::task::JoinHandle<()>>,
346}
347
348impl MySqlReadLockHandle {
349    pub(crate) fn new(
350        lock_name: String,
351        pool: sqlx::MySqlPool,
352        keepalive_cadence: Option<Duration>,
353    ) -> Self {
354        let (lost_token_tx, lost_receiver) = tokio::sync::watch::channel(false);
355
356        let lost_token_tx_clone = lost_token_tx.clone();
357        let keepalive_handle = keepalive_cadence.map(|cadence| {
358            let pool_clone = pool.clone();
359            let mut lost_token_rx_clone = lost_token_tx_clone.subscribe();
360
361            tokio::spawn(async move {
362                loop {
363                    tokio::select! {
364                        _ = tokio::time::sleep(cadence) => {
365                            // Run a keepalive query
366                            let result = sqlx::query("SELECT 1")
367                                .execute(&pool_clone)
368                                .await;
369
370                            if result.is_err() {
371                                // Connection failed, signal lock loss
372                                let _ = lost_token_tx_clone.send(true);
373                                break;
374                            }
375                        }
376                        _ = lost_token_rx_clone.changed() => {
377                            // Lock was released, stop keepalive
378                            break;
379                        }
380                    }
381                }
382            })
383        });
384
385        Self {
386            lock_name,
387            pool,
388            lost_sender: lost_token_tx,
389            lost_receiver,
390            keepalive_handle,
391        }
392    }
393}
394
395impl distributed_lock_core::traits::LockHandle for MySqlReadLockHandle {
396    fn lost_token(&self) -> &tokio::sync::watch::Receiver<bool> {
397        &self.lost_receiver
398    }
399
400    async fn release(self) -> LockResult<()> {
401        // Stop keepalive task
402        if let Some(handle) = &self.keepalive_handle {
403            handle.abort();
404        }
405
406        // Signal that the lock is being released
407        let _ = self.lost_sender.send(true);
408
409        // Decrement reader count in database
410        let result = sqlx::query(
411            "UPDATE distributed_locks SET reader_count = GREATEST(reader_count - 1, 0), version = version + 1 WHERE lock_name = ?"
412        )
413        .bind(&self.lock_name)
414        .execute(&self.pool)
415        .await;
416
417        match result {
418            Ok(_) => Ok(()),
419            Err(e) => Err(distributed_lock_core::error::LockError::Connection(
420                Box::new(e),
421            )),
422        }
423    }
424}
425
426impl Drop for MySqlReadLockHandle {
427    fn drop(&mut self) {
428        // Signal that the lock is lost (being dropped)
429        let _ = self.lost_sender.send(true);
430
431        // Stop keepalive task
432        if let Some(handle) = self.keepalive_handle.take() {
433            handle.abort();
434        }
435    }
436}
437
438/// Handle for a MySQL write lock.
439pub struct MySqlWriteLockHandle {
440    lock_name: String,
441    pool: sqlx::MySqlPool,
442    lost_sender: tokio::sync::watch::Sender<bool>,
443    lost_receiver: tokio::sync::watch::Receiver<bool>,
444    keepalive_handle: Option<tokio::task::JoinHandle<()>>,
445}
446
447impl MySqlWriteLockHandle {
448    pub(crate) fn new(
449        lock_name: String,
450        pool: sqlx::MySqlPool,
451        keepalive_cadence: Option<Duration>,
452    ) -> Self {
453        let (lost_token_tx, lost_receiver) = tokio::sync::watch::channel(false);
454
455        let lost_token_tx_clone = lost_token_tx.clone();
456        let keepalive_handle = keepalive_cadence.map(|cadence| {
457            let pool_clone = pool.clone();
458            let mut lost_token_rx_clone = lost_token_tx_clone.subscribe();
459
460            tokio::spawn(async move {
461                loop {
462                    tokio::select! {
463                        _ = tokio::time::sleep(cadence) => {
464                            // Run a keepalive query
465                            let result = sqlx::query("SELECT 1")
466                                .execute(&pool_clone)
467                                .await;
468
469                            if result.is_err() {
470                                // Connection failed, signal lock loss
471                                let _ = lost_token_tx_clone.send(true);
472                                break;
473                            }
474                        }
475                        _ = lost_token_rx_clone.changed() => {
476                            // Lock was released, stop keepalive
477                            break;
478                        }
479                    }
480                }
481            })
482        });
483
484        Self {
485            lock_name,
486            pool,
487            lost_sender: lost_token_tx,
488            lost_receiver,
489            keepalive_handle,
490        }
491    }
492}
493
494impl distributed_lock_core::traits::LockHandle for MySqlWriteLockHandle {
495    fn lost_token(&self) -> &tokio::sync::watch::Receiver<bool> {
496        &self.lost_receiver
497    }
498
499    async fn release(self) -> LockResult<()> {
500        // Stop keepalive task
501        if let Some(handle) = &self.keepalive_handle {
502            handle.abort();
503        }
504
505        // Signal that the lock is being released
506        let _ = self.lost_sender.send(true);
507
508        // Release the write lock in database
509        let result = sqlx::query(
510            "UPDATE distributed_locks SET writer_held = 0, version = version + 1 WHERE lock_name = ?"
511        )
512        .bind(&self.lock_name)
513        .execute(&self.pool)
514        .await;
515
516        match result {
517            Ok(_) => Ok(()),
518            Err(e) => Err(distributed_lock_core::error::LockError::Connection(
519                Box::new(e),
520            )),
521        }
522    }
523}
524
525impl Drop for MySqlWriteLockHandle {
526    fn drop(&mut self) {
527        // Signal that the lock is lost (being dropped)
528        let _ = self.lost_sender.send(true);
529
530        // Stop keepalive task
531        if let Some(handle) = self.keepalive_handle.take() {
532            handle.abort();
533        }
534    }
535}