Skip to main content

entelix_persistence/postgres/
lock.rs

1//! `PostgresLock` — [`DistributedLock`] over `pg_advisory_lock`.
2//!
3//! Holds a session-scoped Postgres advisory lock. Each acquisition
4//! pins a `PoolConnection` to the [`LockGuard`] (via an internal
5//! lookup map keyed by token) so the lock survives until the holder
6//! calls `release` — at which point the connection returns to the
7//! pool. TTL is advisory in the Postgres backend (the lock doesn't
8//! auto-expire); callers that need true expiry use the Redis backend.
9
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use async_trait::async_trait;
14use dashmap::DashMap;
15use sqlx::pool::PoolConnection;
16use sqlx::postgres::{PgPool, Postgres};
17use tokio::time::sleep;
18
19use crate::advisory_key::AdvisoryKey;
20use crate::error::{PersistenceError, PersistenceResult};
21use crate::lock::{DistributedLock, LockGuard};
22
23const POLL_INTERVAL: Duration = Duration::from_millis(50);
24
25/// Postgres-backed distributed lock.
26pub struct PostgresLock {
27    pool: Arc<PgPool>,
28    held: Arc<DashMap<String, PoolConnection<Postgres>>>,
29}
30
31impl PostgresLock {
32    pub(crate) fn new(pool: Arc<PgPool>) -> Self {
33        Self {
34            pool,
35            held: Arc::new(DashMap::new()),
36        }
37    }
38}
39
40#[async_trait]
41impl DistributedLock for PostgresLock {
42    async fn try_acquire(
43        &self,
44        key: &AdvisoryKey,
45        _ttl: Duration,
46    ) -> PersistenceResult<Option<LockGuard>> {
47        let mut conn = self
48            .pool
49            .acquire()
50            .await
51            .map_err(|e| PersistenceError::Backend(format!("pool acquire: {e}")))?;
52
53        let (high, low) = key.halves();
54        let acquired: (bool,) = sqlx::query_as("SELECT pg_try_advisory_lock($1, $2)")
55            .bind(high)
56            .bind(low)
57            .fetch_one(&mut *conn)
58            .await
59            .map_err(backend_err)?;
60
61        if !acquired.0 {
62            // Connection drops back to the pool implicitly.
63            return Ok(None);
64        }
65        let guard = LockGuard::new(*key);
66        self.held.insert(guard.token().to_owned(), conn);
67        Ok(Some(guard))
68    }
69
70    async fn acquire(
71        &self,
72        key: &AdvisoryKey,
73        ttl: Duration,
74        deadline: Duration,
75    ) -> PersistenceResult<LockGuard> {
76        let start = Instant::now();
77        let mut attempts: u32 = 0;
78        loop {
79            attempts = attempts.saturating_add(1);
80            if let Some(guard) = self.try_acquire(key, ttl).await? {
81                return Ok(guard);
82            }
83            if start.elapsed() >= deadline {
84                return Err(PersistenceError::LockAcquireTimeout {
85                    key: key.to_string(),
86                    attempts,
87                });
88            }
89            sleep(POLL_INTERVAL).await;
90        }
91    }
92
93    async fn extend(&self, _guard: &LockGuard, _ttl: Duration) -> PersistenceResult<bool> {
94        // Postgres advisory locks don't expire — extend is a no-op
95        // that succeeds when the lock is still tracked.
96        Ok(true)
97    }
98
99    async fn release(&self, mut guard: LockGuard) -> PersistenceResult<()> {
100        let Some((_, mut conn)) = self.held.remove(guard.token()) else {
101            // Already released or never tracked. Mark released so
102            // Drop doesn't warn.
103            guard.mark_released();
104            return Ok(());
105        };
106        let (high, low) = guard.key().halves();
107        let _: (bool,) = sqlx::query_as("SELECT pg_advisory_unlock($1, $2)")
108            .bind(high)
109            .bind(low)
110            .fetch_one(&mut *conn)
111            .await
112            .map_err(backend_err)?;
113        guard.mark_released();
114        // Connection drops back to the pool implicitly.
115        Ok(())
116    }
117}
118
119fn backend_err(e: sqlx::Error) -> PersistenceError {
120    PersistenceError::Backend(e.to_string())
121}