distributed_lock_postgres/
lock.rs

1//! PostgreSQL distributed lock implementation.
2
3use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::traits::DistributedLock;
7use tokio::sync::watch;
8use tracing::{Span, instrument};
9
10use crate::handle::PostgresLockHandle;
11use crate::key::PostgresAdvisoryLockKey;
12use sqlx::{Executor, PgPool, Row};
13
14/// A PostgreSQL-based distributed lock.
15pub struct PostgresDistributedLock {
16    /// The lock key.
17    key: PostgresAdvisoryLockKey,
18    /// Original lock name.
19    name: String,
20    /// Connection pool.
21    pool: PgPool,
22    /// Whether to use transaction-scoped locks.
23    use_transaction: bool,
24    /// Keepalive cadence for long-held locks.
25    keepalive_cadence: Option<Duration>,
26}
27
28impl PostgresDistributedLock {
29    pub(crate) fn new(
30        name: String,
31        key: PostgresAdvisoryLockKey,
32        pool: PgPool,
33        use_transaction: bool,
34        keepalive_cadence: Option<Duration>,
35    ) -> Self {
36        Self {
37            key,
38            name,
39            pool,
40            use_transaction,
41            keepalive_cadence,
42        }
43    }
44
45    /// Attempts to acquire the lock.
46    async fn acquire_internal(
47        &self,
48        timeout: Option<Duration>,
49    ) -> LockResult<Option<PostgresLockHandle>> {
50        let mut conn = self.pool.acquire().await.map_err(|e| {
51            LockError::Connection(Box::new(std::io::Error::other(format!(
52                "failed to get connection from pool: {e}"
53            ))))
54        })?;
55
56        // Always start a transaction to ensure SET LOCAL lock_timeout applies to the lock command
57        conn.execute("BEGIN").await.map_err(|e| {
58            LockError::Connection(Box::new(std::io::Error::other(format!(
59                "failed to start transaction: {e}"
60            ))))
61        })?;
62
63        let use_transaction_lock = self.use_transaction;
64        let savepoint_name = "medallion_lock_acquire";
65
66        // We always use a savepoint to robustly handle errors without aborting the main transaction
67        let sql = format!("SAVEPOINT {}", savepoint_name);
68        conn.execute(sql.as_str()).await.map_err(|e| {
69            LockError::Backend(Box::new(std::io::Error::other(format!(
70                "failed to create savepoint: {e}"
71            ))))
72        })?;
73
74        // Set timeout
75        let timeout_ms = timeout.map(|d| d.as_millis() as i64).unwrap_or(0);
76        let set_timeout_sql = format!("SET LOCAL lock_timeout = {}", timeout_ms);
77        if let Err(e) = conn.execute(set_timeout_sql.as_str()).await {
78            // Rollback savepoint if setting timeout fails
79            let _ = conn
80                .execute(format!("ROLLBACK TO SAVEPOINT {}", savepoint_name).as_str())
81                .await;
82
83            // If we aren't using transaction locks, we should rollback the whole thing
84            if !use_transaction_lock {
85                let _ = conn.execute("ROLLBACK").await;
86            }
87
88            return Err(LockError::Backend(Box::new(std::io::Error::other(
89                format!("failed to set lock_timeout: {e}"),
90            ))));
91        }
92
93        let lock_func = if use_transaction_lock {
94            "pg_advisory_xact_lock"
95        } else {
96            "pg_advisory_lock"
97        };
98
99        let sql = format!("SELECT {}({})", lock_func, self.key.to_sql_args());
100
101        match conn.fetch_one(sql.as_str()).await {
102            Ok(_) => {
103                if !use_transaction_lock {
104                    // For session locks, we must COMMIT the transaction to release the "SET LOCAL" params
105                    // and allow the connection to be used normally, BUT the lock persists (session scope).
106                    if let Err(e) = conn.execute("COMMIT").await {
107                        // If commit fails, we might have lost the lock or connection is bad
108                        return Err(LockError::Backend(Box::new(std::io::Error::other(
109                            format!("failed to commit transaction after locking: {e}"),
110                        ))));
111                    }
112                }
113
114                let (sender, receiver) = watch::channel(false);
115                Ok(Some(PostgresLockHandle::new(
116                    conn,
117                    use_transaction_lock,
118                    self.key,
119                    sender,
120                    receiver,
121                    self.keepalive_cadence,
122                )))
123            }
124            Err(e) => {
125                let db_err = e.as_database_error();
126                let code = db_err.and_then(|db_err| db_err.code()).unwrap_or_default();
127
128                // Rollback to savepoint to clear error state
129                let _ = conn
130                    .execute(format!("ROLLBACK TO SAVEPOINT {}", savepoint_name).as_str())
131                    .await;
132
133                // For session locks, since we failed, rollback the whole transaction setup
134                if !use_transaction_lock {
135                    let _ = conn.execute("ROLLBACK").await;
136                }
137
138                if code == "55P03" {
139                    return Ok(None); // Timeout -> None
140                }
141                if code == "40P01" {
142                    return Err(LockError::Deadlock(
143                        "deadlock detected by postgres".to_string(),
144                    ));
145                }
146
147                Err(LockError::Backend(Box::new(std::io::Error::other(
148                    format!("failed to acquire lock: {e}"),
149                ))))
150            }
151        }
152    }
153
154    /// Attempts to acquire the lock immediately (try_lock).
155    async fn try_acquire_internal_immediate(&self) -> LockResult<Option<PostgresLockHandle>> {
156        let mut conn = self.pool.acquire().await.map_err(|e| {
157            LockError::Connection(Box::new(std::io::Error::other(format!(
158                "failed to get connection from pool: {e}"
159            ))))
160        })?;
161
162        let use_transaction = self.use_transaction;
163        if use_transaction {
164            conn.execute("BEGIN").await.map_err(|e| {
165                LockError::Connection(Box::new(std::io::Error::other(format!(
166                    "failed to start transaction: {e}"
167                ))))
168            })?;
169        }
170
171        let lock_func = if use_transaction {
172            "pg_try_advisory_xact_lock"
173        } else {
174            "pg_try_advisory_lock"
175        };
176
177        let sql = format!("SELECT {}({})", lock_func, self.key.to_sql_args());
178        let row = conn.fetch_one(sql.as_str()).await.map_err(|e| {
179            LockError::Backend(Box::new(std::io::Error::other(format!(
180                "failed to try_acquire lock: {e}"
181            ))))
182        })?;
183
184        let acquired: bool = row.get(0);
185        if !acquired {
186            // cleanup logic if needed?
187            if use_transaction {
188                let _ = conn.execute("ROLLBACK").await;
189            }
190            return Ok(None);
191        }
192
193        let (sender, receiver) = watch::channel(false);
194        Ok(Some(PostgresLockHandle::new(
195            conn,
196            use_transaction,
197            self.key,
198            sender,
199            receiver,
200            self.keepalive_cadence,
201        )))
202    }
203}
204
205impl DistributedLock for PostgresDistributedLock {
206    type Handle = PostgresLockHandle;
207
208    fn name(&self) -> &str {
209        &self.name
210    }
211
212    #[instrument(skip(self), fields(lock.name = %self.name, timeout = ?timeout, backend = "postgres", use_transaction = self.use_transaction))]
213    async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
214        Span::current().record("operation", "acquire");
215
216        // Use the blocking implementation
217        match self.acquire_internal(timeout).await {
218            Ok(Some(handle)) => {
219                Span::current().record("acquired", true);
220                Ok(handle)
221            }
222            Ok(None) => {
223                Span::current().record("acquired", false);
224                Span::current().record("error", "timeout");
225                Err(LockError::Timeout(timeout.unwrap_or(Duration::MAX)))
226            }
227            Err(e) => Err(e),
228        }
229    }
230
231    #[instrument(skip(self), fields(lock.name = %self.name, backend = "postgres", use_transaction = self.use_transaction))]
232    async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
233        Span::current().record("operation", "try_acquire");
234        match self.try_acquire_internal_immediate().await {
235            Ok(Some(handle)) => {
236                Span::current().record("acquired", true);
237                Ok(Some(handle))
238            }
239            Ok(None) => {
240                Span::current().record("acquired", false);
241                Ok(None)
242            }
243            Err(e) => Err(e),
244        }
245    }
246}