distributed_lock_postgres/
lock.rs1use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::traits::DistributedLock;
7use tokio::sync::watch;
8use tracing::{Span, instrument};
9
10use crate::handle::PostgresLockHandle;
11use crate::key::PostgresAdvisoryLockKey;
12use sqlx::{Executor, PgPool, Row};
13
14pub struct PostgresDistributedLock {
16 key: PostgresAdvisoryLockKey,
18 name: String,
20 pool: PgPool,
22 use_transaction: bool,
24 keepalive_cadence: Option<Duration>,
26}
27
28impl PostgresDistributedLock {
29 pub(crate) fn new(
30 name: String,
31 key: PostgresAdvisoryLockKey,
32 pool: PgPool,
33 use_transaction: bool,
34 keepalive_cadence: Option<Duration>,
35 ) -> Self {
36 Self {
37 key,
38 name,
39 pool,
40 use_transaction,
41 keepalive_cadence,
42 }
43 }
44
45 async fn acquire_internal(
47 &self,
48 timeout: Option<Duration>,
49 ) -> LockResult<Option<PostgresLockHandle>> {
50 let mut conn = self.pool.acquire().await.map_err(|e| {
51 LockError::Connection(Box::new(std::io::Error::other(format!(
52 "failed to get connection from pool: {e}"
53 ))))
54 })?;
55
56 conn.execute("BEGIN").await.map_err(|e| {
58 LockError::Connection(Box::new(std::io::Error::other(format!(
59 "failed to start transaction: {e}"
60 ))))
61 })?;
62
63 let use_transaction_lock = self.use_transaction;
64 let savepoint_name = "medallion_lock_acquire";
65
66 let sql = format!("SAVEPOINT {}", savepoint_name);
68 conn.execute(sql.as_str()).await.map_err(|e| {
69 LockError::Backend(Box::new(std::io::Error::other(format!(
70 "failed to create savepoint: {e}"
71 ))))
72 })?;
73
74 let timeout_ms = timeout.map(|d| d.as_millis() as i64).unwrap_or(0);
76 let set_timeout_sql = format!("SET LOCAL lock_timeout = {}", timeout_ms);
77 if let Err(e) = conn.execute(set_timeout_sql.as_str()).await {
78 let _ = conn
80 .execute(format!("ROLLBACK TO SAVEPOINT {}", savepoint_name).as_str())
81 .await;
82
83 if !use_transaction_lock {
85 let _ = conn.execute("ROLLBACK").await;
86 }
87
88 return Err(LockError::Backend(Box::new(std::io::Error::other(
89 format!("failed to set lock_timeout: {e}"),
90 ))));
91 }
92
93 let lock_func = if use_transaction_lock {
94 "pg_advisory_xact_lock"
95 } else {
96 "pg_advisory_lock"
97 };
98
99 let sql = format!("SELECT {}({})", lock_func, self.key.to_sql_args());
100
101 match conn.fetch_one(sql.as_str()).await {
102 Ok(_) => {
103 if !use_transaction_lock {
104 if let Err(e) = conn.execute("COMMIT").await {
107 return Err(LockError::Backend(Box::new(std::io::Error::other(
109 format!("failed to commit transaction after locking: {e}"),
110 ))));
111 }
112 }
113
114 let (sender, receiver) = watch::channel(false);
115 Ok(Some(PostgresLockHandle::new(
116 conn,
117 use_transaction_lock,
118 self.key,
119 sender,
120 receiver,
121 self.keepalive_cadence,
122 )))
123 }
124 Err(e) => {
125 let db_err = e.as_database_error();
126 let code = db_err.and_then(|db_err| db_err.code()).unwrap_or_default();
127
128 let _ = conn
130 .execute(format!("ROLLBACK TO SAVEPOINT {}", savepoint_name).as_str())
131 .await;
132
133 if !use_transaction_lock {
135 let _ = conn.execute("ROLLBACK").await;
136 }
137
138 if code == "55P03" {
139 return Ok(None); }
141 if code == "40P01" {
142 return Err(LockError::Deadlock(
143 "deadlock detected by postgres".to_string(),
144 ));
145 }
146
147 Err(LockError::Backend(Box::new(std::io::Error::other(
148 format!("failed to acquire lock: {e}"),
149 ))))
150 }
151 }
152 }
153
154 async fn try_acquire_internal_immediate(&self) -> LockResult<Option<PostgresLockHandle>> {
156 let mut conn = self.pool.acquire().await.map_err(|e| {
157 LockError::Connection(Box::new(std::io::Error::other(format!(
158 "failed to get connection from pool: {e}"
159 ))))
160 })?;
161
162 let use_transaction = self.use_transaction;
163 if use_transaction {
164 conn.execute("BEGIN").await.map_err(|e| {
165 LockError::Connection(Box::new(std::io::Error::other(format!(
166 "failed to start transaction: {e}"
167 ))))
168 })?;
169 }
170
171 let lock_func = if use_transaction {
172 "pg_try_advisory_xact_lock"
173 } else {
174 "pg_try_advisory_lock"
175 };
176
177 let sql = format!("SELECT {}({})", lock_func, self.key.to_sql_args());
178 let row = conn.fetch_one(sql.as_str()).await.map_err(|e| {
179 LockError::Backend(Box::new(std::io::Error::other(format!(
180 "failed to try_acquire lock: {e}"
181 ))))
182 })?;
183
184 let acquired: bool = row.get(0);
185 if !acquired {
186 if use_transaction {
188 let _ = conn.execute("ROLLBACK").await;
189 }
190 return Ok(None);
191 }
192
193 let (sender, receiver) = watch::channel(false);
194 Ok(Some(PostgresLockHandle::new(
195 conn,
196 use_transaction,
197 self.key,
198 sender,
199 receiver,
200 self.keepalive_cadence,
201 )))
202 }
203}
204
205impl DistributedLock for PostgresDistributedLock {
206 type Handle = PostgresLockHandle;
207
208 fn name(&self) -> &str {
209 &self.name
210 }
211
212 #[instrument(skip(self), fields(lock.name = %self.name, timeout = ?timeout, backend = "postgres", use_transaction = self.use_transaction))]
213 async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
214 Span::current().record("operation", "acquire");
215
216 match self.acquire_internal(timeout).await {
218 Ok(Some(handle)) => {
219 Span::current().record("acquired", true);
220 Ok(handle)
221 }
222 Ok(None) => {
223 Span::current().record("acquired", false);
224 Span::current().record("error", "timeout");
225 Err(LockError::Timeout(timeout.unwrap_or(Duration::MAX)))
226 }
227 Err(e) => Err(e),
228 }
229 }
230
231 #[instrument(skip(self), fields(lock.name = %self.name, backend = "postgres", use_transaction = self.use_transaction))]
232 async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
233 Span::current().record("operation", "try_acquire");
234 match self.try_acquire_internal_immediate().await {
235 Ok(Some(handle)) => {
236 Span::current().record("acquired", true);
237 Ok(Some(handle))
238 }
239 Ok(None) => {
240 Span::current().record("acquired", false);
241 Ok(None)
242 }
243 Err(e) => Err(e),
244 }
245 }
246}