distributed_lock_postgres/
lock.rs

1//! PostgreSQL distributed lock implementation.
2
3use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::timeout::TimeoutValue;
7use distributed_lock_core::traits::DistributedLock;
8use tokio::sync::watch;
9use tracing::{instrument, Span};
10
11use crate::handle::PostgresLockHandle;
12use crate::key::PostgresAdvisoryLockKey;
13use sqlx::{PgPool, Postgres, Row, Transaction};
14
15/// A PostgreSQL-based distributed lock.
16pub struct PostgresDistributedLock {
17    /// The lock key.
18    key: PostgresAdvisoryLockKey,
19    /// Original lock name.
20    name: String,
21    /// Connection pool.
22    pool: PgPool,
23    /// Whether to use transaction-scoped locks.
24    use_transaction: bool,
25    /// Keepalive cadence for long-held locks.
26    keepalive_cadence: Option<Duration>,
27}
28
29impl PostgresDistributedLock {
30    pub(crate) fn new(
31        name: String,
32        key: PostgresAdvisoryLockKey,
33        pool: PgPool,
34        use_transaction: bool,
35        keepalive_cadence: Option<Duration>,
36    ) -> Self {
37        Self {
38            key,
39            name,
40            pool,
41            use_transaction,
42            keepalive_cadence,
43        }
44    }
45
46    /// Attempts to acquire the lock without waiting.
47    async fn try_acquire_internal(&self) -> LockResult<Option<PostgresLockHandle>> {
48        if self.use_transaction {
49            // Transaction-scoped lock
50            let mut transaction = self.pool.begin().await.map_err(|e| {
51                LockError::Connection(Box::new(std::io::Error::other(format!(
52                    "failed to start transaction: {e}"
53                ))))
54            })?;
55
56            let sql = match self.key {
57                PostgresAdvisoryLockKey::Single(_) => {
58                    format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args())
59                }
60                PostgresAdvisoryLockKey::Pair(_, _) => {
61                    format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args())
62                }
63            };
64
65            let row = sqlx::query(&sql)
66                .fetch_one(&mut *transaction)
67                .await
68                .map_err(|e| {
69                    LockError::Backend(Box::new(std::io::Error::other(format!(
70                        "failed to acquire lock: {e}"
71                    ))))
72                })?;
73
74            let acquired: bool = row.get(0);
75            if !acquired {
76                return Ok(None);
77            }
78
79            // Store transaction using raw pointer to avoid lifetime issues
80            // SAFETY: We manually manage the transaction lifetime in the handle
81            let transaction_ptr = unsafe {
82                std::mem::transmute::<Transaction<'_, Postgres>, Transaction<'static, Postgres>>(
83                    transaction,
84                )
85            };
86            let transaction_ptr = Box::into_raw(Box::new(transaction_ptr));
87
88            let (sender, receiver) = watch::channel(false);
89            Ok(Some(PostgresLockHandle::new(
90                crate::handle::PostgresConnectionInner::Transaction(transaction_ptr),
91                self.key,
92                sender,
93                receiver,
94                self.keepalive_cadence,
95            )))
96        } else {
97            // Session-scoped lock
98            let mut connection = self.pool.acquire().await.map_err(|e| {
99                LockError::Connection(Box::new(std::io::Error::other(format!(
100                    "failed to get connection from pool: {e}"
101                ))))
102            })?;
103
104            let sql = match self.key {
105                PostgresAdvisoryLockKey::Single(_) => {
106                    format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args())
107                }
108                PostgresAdvisoryLockKey::Pair(_, _) => {
109                    format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args())
110                }
111            };
112
113            let row = sqlx::query(&sql)
114                .fetch_one(&mut *connection)
115                .await
116                .map_err(|e| {
117                    LockError::Backend(Box::new(std::io::Error::other(format!(
118                        "failed to acquire lock: {e}"
119                    ))))
120                })?;
121
122            let acquired: bool = row.get(0);
123            if !acquired {
124                return Ok(None);
125            }
126
127            // Store pool connection to keep it alive
128            // PoolConnection will be returned to pool when dropped
129
130            let (sender, receiver) = watch::channel(false);
131            Ok(Some(PostgresLockHandle::new(
132                crate::handle::PostgresConnectionInner::Connection(Box::new(connection)),
133                self.key,
134                sender,
135                receiver,
136                self.keepalive_cadence,
137            )))
138        }
139    }
140}
141
142impl DistributedLock for PostgresDistributedLock {
143    type Handle = PostgresLockHandle;
144
145    fn name(&self) -> &str {
146        &self.name
147    }
148
149    #[instrument(skip(self), fields(lock.name = %self.name, timeout = ?timeout, backend = "postgres", use_transaction = self.use_transaction))]
150    async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
151        let timeout_value = TimeoutValue::from(timeout);
152        let start = std::time::Instant::now();
153        Span::current().record("operation", "acquire");
154
155        // Busy-wait with exponential backoff
156        let mut sleep_duration = Duration::from_millis(50);
157        const MAX_SLEEP: Duration = Duration::from_secs(1);
158
159        loop {
160            match self.try_acquire_internal().await {
161                Ok(Some(handle)) => {
162                    let elapsed = start.elapsed();
163                    Span::current().record("acquired", true);
164                    Span::current().record("elapsed_ms", elapsed.as_millis() as u64);
165                    return Ok(handle);
166                }
167                Ok(None) => {
168                    // Check timeout
169                    if !timeout_value.is_infinite()
170                        && start.elapsed() >= timeout_value.as_duration().unwrap()
171                    {
172                        Span::current().record("acquired", false);
173                        Span::current().record("error", "timeout");
174                        return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
175                    }
176
177                    // Sleep before retry
178                    tokio::time::sleep(sleep_duration).await;
179                    sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
180                }
181                Err(e) => return Err(e),
182            }
183        }
184    }
185
186    #[instrument(skip(self), fields(lock.name = %self.name, backend = "postgres", use_transaction = self.use_transaction))]
187    async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
188        Span::current().record("operation", "try_acquire");
189        let result = self.try_acquire_internal().await;
190        match &result {
191            Ok(Some(_)) => {
192                Span::current().record("acquired", true);
193            }
194            Ok(None) => {
195                Span::current().record("acquired", false);
196                Span::current().record("reason", "lock_held");
197            }
198            Err(e) => {
199                Span::current().record("acquired", false);
200                Span::current().record("error", e.to_string());
201            }
202        }
203        result
204    }
205}