distributed_lock_postgres/
rw_lock.rs

1//! PostgreSQL reader-writer lock implementation.
2
3use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::traits::{DistributedReaderWriterLock, LockHandle};
7use tokio::sync::watch;
8use tracing::{Span, instrument};
9
10use crate::key::PostgresAdvisoryLockKey;
11use sqlx::pool::PoolConnection;
12use sqlx::{Executor, PgPool, Postgres, Row};
13
14/// A PostgreSQL-based distributed reader-writer lock.
15pub struct PostgresDistributedReaderWriterLock {
16    /// The lock key.
17    key: PostgresAdvisoryLockKey,
18    /// Original lock name.
19    name: String,
20    /// Connection pool.
21    pool: PgPool,
22    /// Whether to use transaction-scoped locks.
23    use_transaction: bool,
24    /// Keepalive cadence for long-held locks.
25    keepalive_cadence: Option<Duration>,
26}
27
28impl PostgresDistributedReaderWriterLock {
29    pub(crate) fn new(
30        name: String,
31        key: PostgresAdvisoryLockKey,
32        pool: PgPool,
33        use_transaction: bool,
34        keepalive_cadence: Option<Duration>,
35    ) -> Self {
36        Self {
37            key,
38            name,
39            pool,
40            use_transaction,
41            keepalive_cadence,
42        }
43    }
44
45    async fn acquire_internal<H, F>(
46        &self,
47        timeout: Option<Duration>,
48        lock_func_shared: bool,
49        constructor: F,
50    ) -> LockResult<Option<H>>
51    where
52        F: FnOnce(
53            PoolConnection<Postgres>,
54            bool,
55            PostgresAdvisoryLockKey,
56            watch::Sender<bool>,
57            watch::Receiver<bool>,
58        ) -> H,
59    {
60        let mut conn = self.pool.acquire().await.map_err(|e| {
61            LockError::Connection(Box::new(std::io::Error::other(format!(
62                "failed to get connection from pool: {e}"
63            ))))
64        })?;
65
66        // Always start transaction to scope SET LOCAL
67        conn.execute("BEGIN").await.map_err(|e| {
68            LockError::Connection(Box::new(std::io::Error::other(format!(
69                "failed to start transaction: {e}"
70            ))))
71        })?;
72
73        let use_transaction_lock = self.use_transaction;
74        let savepoint_name = "medallion_rwlock_acquire";
75
76        let sql = format!("SAVEPOINT {}", savepoint_name);
77        conn.execute(sql.as_str()).await.map_err(|e| {
78            LockError::Backend(Box::new(std::io::Error::other(format!(
79                "failed to create savepoint: {e}"
80            ))))
81        })?;
82
83        let timeout_ms = timeout.map(|d| d.as_millis() as i64).unwrap_or(0);
84        let set_timeout_sql = format!("SET LOCAL lock_timeout = {}", timeout_ms);
85        if let Err(e) = conn.execute(set_timeout_sql.as_str()).await {
86            let _ = conn
87                .execute(format!("ROLLBACK TO SAVEPOINT {}", savepoint_name).as_str())
88                .await;
89
90            if !use_transaction_lock {
91                let _ = conn.execute("ROLLBACK").await;
92            }
93            return Err(LockError::Backend(Box::new(std::io::Error::other(
94                format!("failed to set lock_timeout: {e}"),
95            ))));
96        }
97
98        let lock_func = match (use_transaction_lock, lock_func_shared) {
99            (true, true) => "pg_advisory_xact_lock_shared",
100            (true, false) => "pg_advisory_xact_lock",
101            (false, true) => "pg_advisory_lock_shared",
102            (false, false) => "pg_advisory_lock",
103        };
104
105        let sql = format!("SELECT {}({})", lock_func, self.key.to_sql_args());
106
107        match conn.fetch_one(sql.as_str()).await {
108            Ok(_) => {
109                if !use_transaction_lock {
110                    // Commit to persist session lock but close transaction
111                    if let Err(e) = conn.execute("COMMIT").await {
112                        return Err(LockError::Backend(Box::new(std::io::Error::other(
113                            format!("failed to commit transaction after locking: {e}"),
114                        ))));
115                    }
116                }
117
118                let (sender, receiver) = watch::channel(false);
119                Ok(Some(constructor(
120                    conn,
121                    use_transaction_lock,
122                    self.key,
123                    sender,
124                    receiver,
125                )))
126            }
127            Err(e) => {
128                let db_err = e.as_database_error();
129                let code = db_err.and_then(|db_err| db_err.code()).unwrap_or_default();
130
131                let _ = conn
132                    .execute(format!("ROLLBACK TO SAVEPOINT {}", savepoint_name).as_str())
133                    .await;
134                if !use_transaction_lock {
135                    let _ = conn.execute("ROLLBACK").await;
136                }
137
138                if code == "55P03" {
139                    return Ok(None);
140                }
141                if code == "40P01" {
142                    return Err(LockError::Deadlock(
143                        "deadlock detected by postgres".to_string(),
144                    ));
145                }
146
147                Err(LockError::Backend(Box::new(std::io::Error::other(
148                    format!("failed to acquire lock: {e}"),
149                ))))
150            }
151        }
152    }
153
154    async fn try_acquire_internal_immediate<H, F>(
155        &self,
156        lock_func_shared: bool,
157        constructor: F,
158    ) -> LockResult<Option<H>>
159    where
160        F: FnOnce(
161            PoolConnection<Postgres>,
162            bool,
163            PostgresAdvisoryLockKey,
164            watch::Sender<bool>,
165            watch::Receiver<bool>,
166        ) -> H,
167    {
168        let mut conn = self.pool.acquire().await.map_err(|e| {
169            LockError::Connection(Box::new(std::io::Error::other(format!(
170                "failed to get connection from pool: {e}"
171            ))))
172        })?;
173
174        let use_transaction = self.use_transaction;
175        if use_transaction {
176            conn.execute("BEGIN").await.map_err(|e| {
177                LockError::Connection(Box::new(std::io::Error::other(format!(
178                    "failed to start transaction: {e}"
179                ))))
180            })?;
181        }
182
183        let lock_func = match (use_transaction, lock_func_shared) {
184            (true, true) => "pg_try_advisory_xact_lock_shared",
185            (true, false) => "pg_try_advisory_xact_lock",
186            (false, true) => "pg_try_advisory_lock_shared",
187            (false, false) => "pg_try_advisory_lock",
188        };
189
190        let sql = format!("SELECT {}({})", lock_func, self.key.to_sql_args());
191        let row = conn.fetch_one(sql.as_str()).await.map_err(|e| {
192            LockError::Backend(Box::new(std::io::Error::other(format!(
193                "failed to try_acquire lock: {e}"
194            ))))
195        })?;
196
197        let acquired: bool = row.get(0);
198        if !acquired {
199            if use_transaction {
200                let _ = conn.execute("ROLLBACK").await;
201            }
202            return Ok(None);
203        }
204
205        let (sender, receiver) = watch::channel(false);
206        Ok(Some(constructor(
207            conn,
208            use_transaction,
209            self.key,
210            sender,
211            receiver,
212        )))
213    }
214}
215
216impl DistributedReaderWriterLock for PostgresDistributedReaderWriterLock {
217    type ReadHandle = PostgresReadLockHandle;
218    type WriteHandle = PostgresWriteLockHandle;
219
220    fn name(&self) -> &str {
221        &self.name
222    }
223
224    #[instrument(skip(self), fields(lock.name = %self.name, timeout = ?timeout, backend = "postgres", use_transaction = self.use_transaction))]
225    async fn acquire_read(&self, timeout: Option<Duration>) -> LockResult<Self::ReadHandle> {
226        Span::current().record("operation", "acquire_read");
227        match self
228            .acquire_internal(timeout, true, |c, t, k, s, r| {
229                PostgresReadLockHandle::new(c, t, k, s, r, self.keepalive_cadence)
230            })
231            .await
232        {
233            Ok(Some(handle)) => {
234                Span::current().record("acquired", true);
235                Ok(handle)
236            }
237            Ok(None) => {
238                Span::current().record("acquired", false);
239                Span::current().record("error", "timeout");
240                Err(LockError::Timeout(timeout.unwrap_or(Duration::MAX)))
241            }
242            Err(e) => Err(e),
243        }
244    }
245
246    #[instrument(skip(self), fields(lock.name = %self.name, backend = "postgres", use_transaction = self.use_transaction))]
247    async fn try_acquire_read(&self) -> LockResult<Option<Self::ReadHandle>> {
248        Span::current().record("operation", "try_acquire_read");
249        match self
250            .try_acquire_internal_immediate(true, |c, t, k, s, r| {
251                PostgresReadLockHandle::new(c, t, k, s, r, self.keepalive_cadence)
252            })
253            .await
254        {
255            Ok(Some(handle)) => {
256                Span::current().record("acquired", true);
257                Ok(Some(handle))
258            }
259            Ok(None) => {
260                Span::current().record("acquired", false);
261                Ok(None)
262            }
263            Err(e) => Err(e),
264        }
265    }
266
267    #[instrument(skip(self), fields(lock.name = %self.name, timeout = ?timeout, backend = "postgres", use_transaction = self.use_transaction))]
268    async fn acquire_write(&self, timeout: Option<Duration>) -> LockResult<Self::WriteHandle> {
269        Span::current().record("operation", "acquire_write");
270        match self
271            .acquire_internal(timeout, false, |c, t, k, s, r| {
272                PostgresWriteLockHandle::new(c, t, k, s, r, self.keepalive_cadence)
273            })
274            .await
275        {
276            Ok(Some(handle)) => {
277                Span::current().record("acquired", true);
278                Ok(handle)
279            }
280            Ok(None) => {
281                Span::current().record("acquired", false);
282                Span::current().record("error", "timeout");
283                Err(LockError::Timeout(timeout.unwrap_or(Duration::MAX)))
284            }
285            Err(e) => Err(e),
286        }
287    }
288
289    #[instrument(skip(self), fields(lock.name = %self.name, backend = "postgres", use_transaction = self.use_transaction))]
290    async fn try_acquire_write(&self) -> LockResult<Option<Self::WriteHandle>> {
291        Span::current().record("operation", "try_acquire_write");
292        match self
293            .try_acquire_internal_immediate(false, |c, t, k, s, r| {
294                PostgresWriteLockHandle::new(c, t, k, s, r, self.keepalive_cadence)
295            })
296            .await
297        {
298            Ok(Some(handle)) => {
299                Span::current().record("acquired", true);
300                Ok(Some(handle))
301            }
302            Ok(None) => {
303                Span::current().record("acquired", false);
304                Ok(None)
305            }
306            Err(e) => Err(e),
307        }
308    }
309}
310
311/// Handle for a held PostgreSQL read lock.
312pub struct PostgresReadLockHandle {
313    conn: Option<PoolConnection<Postgres>>,
314    is_transaction: bool,
315    key: PostgresAdvisoryLockKey,
316    lost_receiver: watch::Receiver<bool>,
317    _monitor_task: tokio::task::JoinHandle<()>,
318}
319
320impl PostgresReadLockHandle {
321    pub(crate) fn new(
322        conn: PoolConnection<Postgres>,
323        is_transaction: bool,
324        key: PostgresAdvisoryLockKey,
325        _lost_sender: watch::Sender<bool>,
326        lost_receiver: watch::Receiver<bool>,
327        _keepalive_cadence: Option<Duration>,
328    ) -> Self {
329        let monitor_task = tokio::spawn(async move {});
330        Self {
331            conn: Some(conn),
332            is_transaction,
333            key,
334            lost_receiver,
335            _monitor_task: monitor_task,
336        }
337    }
338}
339
340impl LockHandle for PostgresReadLockHandle {
341    fn lost_token(&self) -> &watch::Receiver<bool> {
342        &self.lost_receiver
343    }
344
345    async fn release(mut self) -> LockResult<()> {
346        if let Some(mut conn) = self.conn.take() {
347            if self.is_transaction {
348                match conn.execute("ROLLBACK").await {
349                    Ok(_) => tracing::debug!("Transaction rolled back successfully"),
350                    Err(e) => tracing::warn!("Failed to rollback transaction: {}", e),
351                }
352            } else {
353                let sql = format!(
354                    "SELECT pg_advisory_unlock_shared({})",
355                    self.key.to_sql_args()
356                );
357                if let Err(e) = conn.execute(sql.as_str()).await {
358                    tracing::warn!("Failed to release read lock explicitly: {}", e);
359                }
360            }
361        }
362        Ok(())
363    }
364}
365
366impl Drop for PostgresReadLockHandle {
367    fn drop(&mut self) {
368        self._monitor_task.abort();
369    }
370}
371
372/// Handle for a held PostgreSQL write lock.
373pub struct PostgresWriteLockHandle {
374    conn: Option<PoolConnection<Postgres>>,
375    is_transaction: bool,
376    key: PostgresAdvisoryLockKey,
377    lost_receiver: watch::Receiver<bool>,
378    _monitor_task: tokio::task::JoinHandle<()>,
379}
380
381impl PostgresWriteLockHandle {
382    pub(crate) fn new(
383        conn: PoolConnection<Postgres>,
384        is_transaction: bool,
385        key: PostgresAdvisoryLockKey,
386        _lost_sender: watch::Sender<bool>,
387        lost_receiver: watch::Receiver<bool>,
388        _keepalive_cadence: Option<Duration>,
389    ) -> Self {
390        let monitor_task = tokio::spawn(async move {});
391        Self {
392            conn: Some(conn),
393            is_transaction,
394            key,
395            lost_receiver,
396            _monitor_task: monitor_task,
397        }
398    }
399}
400
401impl LockHandle for PostgresWriteLockHandle {
402    fn lost_token(&self) -> &watch::Receiver<bool> {
403        &self.lost_receiver
404    }
405
406    async fn release(mut self) -> LockResult<()> {
407        if let Some(mut conn) = self.conn.take() {
408            if self.is_transaction {
409                match conn.execute("ROLLBACK").await {
410                    Ok(_) => tracing::debug!("Transaction rolled back successfully"),
411                    Err(e) => tracing::warn!("Failed to rollback transaction: {}", e),
412                }
413            } else {
414                let sql = format!("SELECT pg_advisory_unlock({})", self.key.to_sql_args());
415                if let Err(e) = conn.execute(sql.as_str()).await {
416                    tracing::warn!("Failed to release write lock explicitly: {}", e);
417                }
418            }
419        }
420        Ok(())
421    }
422}
423
424impl Drop for PostgresWriteLockHandle {
425    fn drop(&mut self) {
426        self._monitor_task.abort();
427    }
428}