1use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::traits::{DistributedReaderWriterLock, LockHandle};
7use tokio::sync::watch;
8use tracing::{Span, instrument};
9
10use crate::key::PostgresAdvisoryLockKey;
11use sqlx::pool::PoolConnection;
12use sqlx::{Executor, PgPool, Postgres, Row};
13
14pub struct PostgresDistributedReaderWriterLock {
16 key: PostgresAdvisoryLockKey,
18 name: String,
20 pool: PgPool,
22 use_transaction: bool,
24 keepalive_cadence: Option<Duration>,
26}
27
28impl PostgresDistributedReaderWriterLock {
29 pub(crate) fn new(
30 name: String,
31 key: PostgresAdvisoryLockKey,
32 pool: PgPool,
33 use_transaction: bool,
34 keepalive_cadence: Option<Duration>,
35 ) -> Self {
36 Self {
37 key,
38 name,
39 pool,
40 use_transaction,
41 keepalive_cadence,
42 }
43 }
44
45 async fn acquire_internal<H, F>(
46 &self,
47 timeout: Option<Duration>,
48 lock_func_shared: bool,
49 constructor: F,
50 ) -> LockResult<Option<H>>
51 where
52 F: FnOnce(
53 PoolConnection<Postgres>,
54 bool,
55 PostgresAdvisoryLockKey,
56 watch::Sender<bool>,
57 watch::Receiver<bool>,
58 ) -> H,
59 {
60 let mut conn = self.pool.acquire().await.map_err(|e| {
61 LockError::Connection(Box::new(std::io::Error::other(format!(
62 "failed to get connection from pool: {e}"
63 ))))
64 })?;
65
66 conn.execute("BEGIN").await.map_err(|e| {
68 LockError::Connection(Box::new(std::io::Error::other(format!(
69 "failed to start transaction: {e}"
70 ))))
71 })?;
72
73 let use_transaction_lock = self.use_transaction;
74 let savepoint_name = "medallion_rwlock_acquire";
75
76 let sql = format!("SAVEPOINT {}", savepoint_name);
77 conn.execute(sql.as_str()).await.map_err(|e| {
78 LockError::Backend(Box::new(std::io::Error::other(format!(
79 "failed to create savepoint: {e}"
80 ))))
81 })?;
82
83 let timeout_ms = timeout.map(|d| d.as_millis() as i64).unwrap_or(0);
84 let set_timeout_sql = format!("SET LOCAL lock_timeout = {}", timeout_ms);
85 if let Err(e) = conn.execute(set_timeout_sql.as_str()).await {
86 let _ = conn
87 .execute(format!("ROLLBACK TO SAVEPOINT {}", savepoint_name).as_str())
88 .await;
89
90 if !use_transaction_lock {
91 let _ = conn.execute("ROLLBACK").await;
92 }
93 return Err(LockError::Backend(Box::new(std::io::Error::other(
94 format!("failed to set lock_timeout: {e}"),
95 ))));
96 }
97
98 let lock_func = match (use_transaction_lock, lock_func_shared) {
99 (true, true) => "pg_advisory_xact_lock_shared",
100 (true, false) => "pg_advisory_xact_lock",
101 (false, true) => "pg_advisory_lock_shared",
102 (false, false) => "pg_advisory_lock",
103 };
104
105 let sql = format!("SELECT {}({})", lock_func, self.key.to_sql_args());
106
107 match conn.fetch_one(sql.as_str()).await {
108 Ok(_) => {
109 if !use_transaction_lock {
110 if let Err(e) = conn.execute("COMMIT").await {
112 return Err(LockError::Backend(Box::new(std::io::Error::other(
113 format!("failed to commit transaction after locking: {e}"),
114 ))));
115 }
116 }
117
118 let (sender, receiver) = watch::channel(false);
119 Ok(Some(constructor(
120 conn,
121 use_transaction_lock,
122 self.key,
123 sender,
124 receiver,
125 )))
126 }
127 Err(e) => {
128 let db_err = e.as_database_error();
129 let code = db_err.and_then(|db_err| db_err.code()).unwrap_or_default();
130
131 let _ = conn
132 .execute(format!("ROLLBACK TO SAVEPOINT {}", savepoint_name).as_str())
133 .await;
134 if !use_transaction_lock {
135 let _ = conn.execute("ROLLBACK").await;
136 }
137
138 if code == "55P03" {
139 return Ok(None);
140 }
141 if code == "40P01" {
142 return Err(LockError::Deadlock(
143 "deadlock detected by postgres".to_string(),
144 ));
145 }
146
147 Err(LockError::Backend(Box::new(std::io::Error::other(
148 format!("failed to acquire lock: {e}"),
149 ))))
150 }
151 }
152 }
153
154 async fn try_acquire_internal_immediate<H, F>(
155 &self,
156 lock_func_shared: bool,
157 constructor: F,
158 ) -> LockResult<Option<H>>
159 where
160 F: FnOnce(
161 PoolConnection<Postgres>,
162 bool,
163 PostgresAdvisoryLockKey,
164 watch::Sender<bool>,
165 watch::Receiver<bool>,
166 ) -> H,
167 {
168 let mut conn = self.pool.acquire().await.map_err(|e| {
169 LockError::Connection(Box::new(std::io::Error::other(format!(
170 "failed to get connection from pool: {e}"
171 ))))
172 })?;
173
174 let use_transaction = self.use_transaction;
175 if use_transaction {
176 conn.execute("BEGIN").await.map_err(|e| {
177 LockError::Connection(Box::new(std::io::Error::other(format!(
178 "failed to start transaction: {e}"
179 ))))
180 })?;
181 }
182
183 let lock_func = match (use_transaction, lock_func_shared) {
184 (true, true) => "pg_try_advisory_xact_lock_shared",
185 (true, false) => "pg_try_advisory_xact_lock",
186 (false, true) => "pg_try_advisory_lock_shared",
187 (false, false) => "pg_try_advisory_lock",
188 };
189
190 let sql = format!("SELECT {}({})", lock_func, self.key.to_sql_args());
191 let row = conn.fetch_one(sql.as_str()).await.map_err(|e| {
192 LockError::Backend(Box::new(std::io::Error::other(format!(
193 "failed to try_acquire lock: {e}"
194 ))))
195 })?;
196
197 let acquired: bool = row.get(0);
198 if !acquired {
199 if use_transaction {
200 let _ = conn.execute("ROLLBACK").await;
201 }
202 return Ok(None);
203 }
204
205 let (sender, receiver) = watch::channel(false);
206 Ok(Some(constructor(
207 conn,
208 use_transaction,
209 self.key,
210 sender,
211 receiver,
212 )))
213 }
214}
215
216impl DistributedReaderWriterLock for PostgresDistributedReaderWriterLock {
217 type ReadHandle = PostgresReadLockHandle;
218 type WriteHandle = PostgresWriteLockHandle;
219
220 fn name(&self) -> &str {
221 &self.name
222 }
223
224 #[instrument(skip(self), fields(lock.name = %self.name, timeout = ?timeout, backend = "postgres", use_transaction = self.use_transaction))]
225 async fn acquire_read(&self, timeout: Option<Duration>) -> LockResult<Self::ReadHandle> {
226 Span::current().record("operation", "acquire_read");
227 match self
228 .acquire_internal(timeout, true, |c, t, k, s, r| {
229 PostgresReadLockHandle::new(c, t, k, s, r, self.keepalive_cadence)
230 })
231 .await
232 {
233 Ok(Some(handle)) => {
234 Span::current().record("acquired", true);
235 Ok(handle)
236 }
237 Ok(None) => {
238 Span::current().record("acquired", false);
239 Span::current().record("error", "timeout");
240 Err(LockError::Timeout(timeout.unwrap_or(Duration::MAX)))
241 }
242 Err(e) => Err(e),
243 }
244 }
245
246 #[instrument(skip(self), fields(lock.name = %self.name, backend = "postgres", use_transaction = self.use_transaction))]
247 async fn try_acquire_read(&self) -> LockResult<Option<Self::ReadHandle>> {
248 Span::current().record("operation", "try_acquire_read");
249 match self
250 .try_acquire_internal_immediate(true, |c, t, k, s, r| {
251 PostgresReadLockHandle::new(c, t, k, s, r, self.keepalive_cadence)
252 })
253 .await
254 {
255 Ok(Some(handle)) => {
256 Span::current().record("acquired", true);
257 Ok(Some(handle))
258 }
259 Ok(None) => {
260 Span::current().record("acquired", false);
261 Ok(None)
262 }
263 Err(e) => Err(e),
264 }
265 }
266
267 #[instrument(skip(self), fields(lock.name = %self.name, timeout = ?timeout, backend = "postgres", use_transaction = self.use_transaction))]
268 async fn acquire_write(&self, timeout: Option<Duration>) -> LockResult<Self::WriteHandle> {
269 Span::current().record("operation", "acquire_write");
270 match self
271 .acquire_internal(timeout, false, |c, t, k, s, r| {
272 PostgresWriteLockHandle::new(c, t, k, s, r, self.keepalive_cadence)
273 })
274 .await
275 {
276 Ok(Some(handle)) => {
277 Span::current().record("acquired", true);
278 Ok(handle)
279 }
280 Ok(None) => {
281 Span::current().record("acquired", false);
282 Span::current().record("error", "timeout");
283 Err(LockError::Timeout(timeout.unwrap_or(Duration::MAX)))
284 }
285 Err(e) => Err(e),
286 }
287 }
288
289 #[instrument(skip(self), fields(lock.name = %self.name, backend = "postgres", use_transaction = self.use_transaction))]
290 async fn try_acquire_write(&self) -> LockResult<Option<Self::WriteHandle>> {
291 Span::current().record("operation", "try_acquire_write");
292 match self
293 .try_acquire_internal_immediate(false, |c, t, k, s, r| {
294 PostgresWriteLockHandle::new(c, t, k, s, r, self.keepalive_cadence)
295 })
296 .await
297 {
298 Ok(Some(handle)) => {
299 Span::current().record("acquired", true);
300 Ok(Some(handle))
301 }
302 Ok(None) => {
303 Span::current().record("acquired", false);
304 Ok(None)
305 }
306 Err(e) => Err(e),
307 }
308 }
309}
310
311pub struct PostgresReadLockHandle {
313 conn: Option<PoolConnection<Postgres>>,
314 is_transaction: bool,
315 key: PostgresAdvisoryLockKey,
316 lost_receiver: watch::Receiver<bool>,
317 _monitor_task: tokio::task::JoinHandle<()>,
318}
319
320impl PostgresReadLockHandle {
321 pub(crate) fn new(
322 conn: PoolConnection<Postgres>,
323 is_transaction: bool,
324 key: PostgresAdvisoryLockKey,
325 _lost_sender: watch::Sender<bool>,
326 lost_receiver: watch::Receiver<bool>,
327 _keepalive_cadence: Option<Duration>,
328 ) -> Self {
329 let monitor_task = tokio::spawn(async move {});
330 Self {
331 conn: Some(conn),
332 is_transaction,
333 key,
334 lost_receiver,
335 _monitor_task: monitor_task,
336 }
337 }
338}
339
340impl LockHandle for PostgresReadLockHandle {
341 fn lost_token(&self) -> &watch::Receiver<bool> {
342 &self.lost_receiver
343 }
344
345 async fn release(mut self) -> LockResult<()> {
346 if let Some(mut conn) = self.conn.take() {
347 if self.is_transaction {
348 match conn.execute("ROLLBACK").await {
349 Ok(_) => tracing::debug!("Transaction rolled back successfully"),
350 Err(e) => tracing::warn!("Failed to rollback transaction: {}", e),
351 }
352 } else {
353 let sql = format!(
354 "SELECT pg_advisory_unlock_shared({})",
355 self.key.to_sql_args()
356 );
357 if let Err(e) = conn.execute(sql.as_str()).await {
358 tracing::warn!("Failed to release read lock explicitly: {}", e);
359 }
360 }
361 }
362 Ok(())
363 }
364}
365
366impl Drop for PostgresReadLockHandle {
367 fn drop(&mut self) {
368 self._monitor_task.abort();
369 }
370}
371
372pub struct PostgresWriteLockHandle {
374 conn: Option<PoolConnection<Postgres>>,
375 is_transaction: bool,
376 key: PostgresAdvisoryLockKey,
377 lost_receiver: watch::Receiver<bool>,
378 _monitor_task: tokio::task::JoinHandle<()>,
379}
380
381impl PostgresWriteLockHandle {
382 pub(crate) fn new(
383 conn: PoolConnection<Postgres>,
384 is_transaction: bool,
385 key: PostgresAdvisoryLockKey,
386 _lost_sender: watch::Sender<bool>,
387 lost_receiver: watch::Receiver<bool>,
388 _keepalive_cadence: Option<Duration>,
389 ) -> Self {
390 let monitor_task = tokio::spawn(async move {});
391 Self {
392 conn: Some(conn),
393 is_transaction,
394 key,
395 lost_receiver,
396 _monitor_task: monitor_task,
397 }
398 }
399}
400
401impl LockHandle for PostgresWriteLockHandle {
402 fn lost_token(&self) -> &watch::Receiver<bool> {
403 &self.lost_receiver
404 }
405
406 async fn release(mut self) -> LockResult<()> {
407 if let Some(mut conn) = self.conn.take() {
408 if self.is_transaction {
409 match conn.execute("ROLLBACK").await {
410 Ok(_) => tracing::debug!("Transaction rolled back successfully"),
411 Err(e) => tracing::warn!("Failed to rollback transaction: {}", e),
412 }
413 } else {
414 let sql = format!("SELECT pg_advisory_unlock({})", self.key.to_sql_args());
415 if let Err(e) = conn.execute(sql.as_str()).await {
416 tracing::warn!("Failed to release write lock explicitly: {}", e);
417 }
418 }
419 }
420 Ok(())
421 }
422}
423
424impl Drop for PostgresWriteLockHandle {
425 fn drop(&mut self) {
426 self._monitor_task.abort();
427 }
428}