distributed_lock_postgres/
rw_lock.rs

1//! PostgreSQL reader-writer lock implementation.
2
3use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::timeout::TimeoutValue;
7use distributed_lock_core::traits::{DistributedReaderWriterLock, LockHandle};
8use tokio::sync::watch;
9
10use crate::handle::PostgresConnectionInner;
11use crate::key::PostgresAdvisoryLockKey;
12use sqlx::{PgPool, Postgres, Row, Transaction};
13
14/// A PostgreSQL-based distributed reader-writer lock.
15pub struct PostgresDistributedReaderWriterLock {
16    /// The lock key.
17    key: PostgresAdvisoryLockKey,
18    /// Original lock name.
19    name: String,
20    /// Connection pool.
21    pool: PgPool,
22    /// Whether to use transaction-scoped locks.
23    use_transaction: bool,
24    /// Keepalive cadence for long-held locks.
25    keepalive_cadence: Option<Duration>,
26}
27
28impl PostgresDistributedReaderWriterLock {
29    pub(crate) fn new(
30        name: String,
31        key: PostgresAdvisoryLockKey,
32        pool: PgPool,
33        use_transaction: bool,
34        keepalive_cadence: Option<Duration>,
35    ) -> Self {
36        Self {
37            key,
38            name,
39            pool,
40            use_transaction,
41            keepalive_cadence,
42        }
43    }
44
45    /// Attempts to acquire a read lock without waiting.
46    async fn try_acquire_read_internal(&self) -> LockResult<Option<PostgresReadLockHandle>> {
47        let mut connection = self.pool.acquire().await.map_err(|e| {
48            LockError::Connection(Box::new(std::io::Error::other(format!(
49                "failed to get connection from pool: {e}"
50            ))))
51        })?;
52
53        let sql = format!(
54            "SELECT pg_try_advisory_lock_shared({})",
55            self.key.to_sql_args()
56        );
57
58        let row = sqlx::query(&sql)
59            .fetch_one(&mut *connection)
60            .await
61            .map_err(|e| {
62                LockError::Backend(Box::new(std::io::Error::other(format!(
63                    "failed to acquire read lock: {e}"
64                ))))
65            })?;
66
67        let acquired: bool = row.get(0);
68        if !acquired {
69            return Ok(None);
70        }
71
72        // Store pool connection to keep it alive
73        // PoolConnection will be returned to pool when dropped
74
75        let (sender, receiver) = watch::channel(false);
76        Ok(Some(PostgresReadLockHandle::new(
77            PostgresConnectionInner::Connection(Box::new(connection)),
78            self.key,
79            sender,
80            receiver,
81            self.keepalive_cadence,
82        )))
83    }
84
85    /// Attempts to acquire a write lock without waiting.
86    async fn try_acquire_write_internal(&self) -> LockResult<Option<PostgresWriteLockHandle>> {
87        if self.use_transaction {
88            // Transaction-scoped lock
89            let mut transaction = self.pool.begin().await.map_err(|e| {
90                LockError::Connection(Box::new(std::io::Error::other(format!(
91                    "failed to start transaction: {e}"
92                ))))
93            })?;
94
95            let sql = format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args());
96
97            let row = sqlx::query(&sql)
98                .fetch_one(&mut *transaction)
99                .await
100                .map_err(|e| {
101                    LockError::Backend(Box::new(std::io::Error::other(format!(
102                        "failed to acquire write lock: {e}"
103                    ))))
104                })?;
105
106            let acquired: bool = row.get(0);
107            if !acquired {
108                return Ok(None);
109            }
110
111            // Store transaction using raw pointer to avoid lifetime issues
112            // SAFETY: We manually manage the transaction lifetime in the handle
113            let transaction_ptr = unsafe {
114                std::mem::transmute::<Transaction<'_, Postgres>, Transaction<'static, Postgres>>(
115                    transaction,
116                )
117            };
118            let transaction_ptr = Box::into_raw(Box::new(transaction_ptr));
119
120            let (sender, receiver) = watch::channel(false);
121            Ok(Some(PostgresWriteLockHandle::new(
122                PostgresConnectionInner::Transaction(transaction_ptr),
123                self.key,
124                sender,
125                receiver,
126                self.keepalive_cadence,
127            )))
128        } else {
129            // Session-scoped lock
130            let mut connection = self.pool.acquire().await.map_err(|e| {
131                LockError::Connection(Box::new(std::io::Error::other(format!(
132                    "failed to get connection from pool: {e}"
133                ))))
134            })?;
135
136            let sql = format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args());
137
138            let row = sqlx::query(&sql)
139                .fetch_one(&mut *connection)
140                .await
141                .map_err(|e| {
142                    LockError::Backend(Box::new(std::io::Error::other(format!(
143                        "failed to acquire write lock: {e}"
144                    ))))
145                })?;
146
147            let acquired: bool = row.get(0);
148            if !acquired {
149                return Ok(None);
150            }
151
152            // Store pool connection to keep it alive
153            // PoolConnection will be returned to pool when dropped
154
155            let (sender, receiver) = watch::channel(false);
156            Ok(Some(PostgresWriteLockHandle::new(
157                PostgresConnectionInner::Connection(Box::new(connection)),
158                self.key,
159                sender,
160                receiver,
161                self.keepalive_cadence,
162            )))
163        }
164    }
165}
166
167impl DistributedReaderWriterLock for PostgresDistributedReaderWriterLock {
168    type ReadHandle = PostgresReadLockHandle;
169    type WriteHandle = PostgresWriteLockHandle;
170
171    fn name(&self) -> &str {
172        &self.name
173    }
174
175    async fn acquire_read(&self, timeout: Option<Duration>) -> LockResult<Self::ReadHandle> {
176        let timeout_value = TimeoutValue::from(timeout);
177        let start = std::time::Instant::now();
178
179        // Busy-wait with exponential backoff
180        let mut sleep_duration = Duration::from_millis(50);
181        const MAX_SLEEP: Duration = Duration::from_secs(1);
182
183        loop {
184            match self.try_acquire_read_internal().await {
185                Ok(Some(handle)) => return Ok(handle),
186                Ok(None) => {
187                    // Check timeout
188                    if !timeout_value.is_infinite()
189                        && start.elapsed() >= timeout_value.as_duration().unwrap()
190                    {
191                        return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
192                    }
193
194                    // Sleep before retry
195                    tokio::time::sleep(sleep_duration).await;
196                    sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
197                }
198                Err(e) => return Err(e),
199            }
200        }
201    }
202
203    async fn try_acquire_read(&self) -> LockResult<Option<Self::ReadHandle>> {
204        self.try_acquire_read_internal().await
205    }
206
207    async fn acquire_write(&self, timeout: Option<Duration>) -> LockResult<Self::WriteHandle> {
208        let timeout_value = TimeoutValue::from(timeout);
209        let start = std::time::Instant::now();
210
211        // Busy-wait with exponential backoff
212        let mut sleep_duration = Duration::from_millis(50);
213        const MAX_SLEEP: Duration = Duration::from_secs(1);
214
215        loop {
216            match self.try_acquire_write_internal().await {
217                Ok(Some(handle)) => return Ok(handle),
218                Ok(None) => {
219                    // Check timeout
220                    if !timeout_value.is_infinite()
221                        && start.elapsed() >= timeout_value.as_duration().unwrap()
222                    {
223                        return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
224                    }
225
226                    // Sleep before retry
227                    tokio::time::sleep(sleep_duration).await;
228                    sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
229                }
230                Err(e) => return Err(e),
231            }
232        }
233    }
234
235    async fn try_acquire_write(&self) -> LockResult<Option<Self::WriteHandle>> {
236        self.try_acquire_write_internal().await
237    }
238}
239
240/// Handle for a held PostgreSQL read lock.
241pub struct PostgresReadLockHandle {
242    /// The database connection (when dropped, the lock is released).
243    _connection: Option<PostgresConnectionInner>,
244    /// The lock key for explicit unlock.
245    key: PostgresAdvisoryLockKey,
246    /// Watch channel for lock lost detection.
247    lost_receiver: watch::Receiver<bool>,
248}
249
250impl PostgresReadLockHandle {
251    pub(crate) fn new(
252        connection: PostgresConnectionInner,
253        key: PostgresAdvisoryLockKey,
254        _lost_sender: watch::Sender<bool>,
255        lost_receiver: watch::Receiver<bool>,
256        _keepalive_cadence: Option<Duration>,
257    ) -> Self {
258        Self {
259            _connection: Some(connection),
260            key,
261            lost_receiver,
262        }
263    }
264}
265
266impl LockHandle for PostgresReadLockHandle {
267    fn lost_token(&self) -> &watch::Receiver<bool> {
268        &self.lost_receiver
269    }
270
271    async fn release(mut self) -> LockResult<()> {
272        // Explicitly release the shared lock before dropping the connection
273        if let Some(connection) = self._connection.take() {
274            match connection {
275                PostgresConnectionInner::Connection(mut conn) => {
276                    let sql = format!(
277                        "SELECT pg_advisory_unlock_shared({})",
278                        self.key.to_sql_args()
279                    );
280                    let _ = sqlx::query(&sql).execute(&mut **conn).await;
281                }
282                PostgresConnectionInner::Transaction(transaction_ptr) => {
283                    // SAFETY: We created this pointer and it's still valid
284                    let transaction = unsafe { Box::from_raw(transaction_ptr) };
285                    if let Err(e) = transaction.rollback().await {
286                        tracing::warn!("Failed to rollback transaction: {}", e);
287                    }
288                    // Transaction is consumed by rollback(), so no need to drop it
289                }
290            }
291        }
292        Ok(())
293    }
294}
295
296/// Handle for a held PostgreSQL write lock.
297pub struct PostgresWriteLockHandle {
298    /// The database connection/transaction (when dropped, the lock is released).
299    _connection: Option<PostgresConnectionInner>,
300    /// The lock key for explicit unlock.
301    key: PostgresAdvisoryLockKey,
302    /// Watch channel for lock lost detection.
303    lost_receiver: watch::Receiver<bool>,
304}
305
306impl PostgresWriteLockHandle {
307    pub(crate) fn new(
308        connection: PostgresConnectionInner,
309        key: PostgresAdvisoryLockKey,
310        _lost_sender: watch::Sender<bool>,
311        lost_receiver: watch::Receiver<bool>,
312        _keepalive_cadence: Option<Duration>,
313    ) -> Self {
314        Self {
315            _connection: Some(connection),
316            key,
317            lost_receiver,
318        }
319    }
320}
321
322impl LockHandle for PostgresWriteLockHandle {
323    fn lost_token(&self) -> &watch::Receiver<bool> {
324        &self.lost_receiver
325    }
326
327    async fn release(mut self) -> LockResult<()> {
328        // Explicitly release the exclusive lock before dropping the connection
329        if let Some(connection) = self._connection.take() {
330            match connection {
331                PostgresConnectionInner::Connection(mut conn) => {
332                    let sql = format!("SELECT pg_advisory_unlock({})", self.key.to_sql_args());
333                    let _ = sqlx::query(&sql).execute(&mut **conn).await;
334                }
335                PostgresConnectionInner::Transaction(transaction_ptr) => {
336                    // SAFETY: We created this pointer and it's still valid
337                    let transaction = unsafe { Box::from_raw(transaction_ptr) };
338                    if let Err(e) = transaction.rollback().await {
339                        tracing::warn!("Failed to rollback transaction: {}", e);
340                    }
341                    // Transaction is consumed by rollback(), so no need to drop it
342                }
343            }
344        }
345        Ok(())
346    }
347}