distributed_lock_postgres/
lock.rs1use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::timeout::TimeoutValue;
7use distributed_lock_core::traits::DistributedLock;
8use tokio::sync::watch;
9use tracing::{instrument, Span};
10
11use crate::handle::PostgresLockHandle;
12use crate::key::PostgresAdvisoryLockKey;
13use sqlx::{PgPool, Postgres, Row, Transaction};
14
15pub struct PostgresDistributedLock {
17 key: PostgresAdvisoryLockKey,
19 name: String,
21 pool: PgPool,
23 use_transaction: bool,
25 keepalive_cadence: Option<Duration>,
27}
28
29impl PostgresDistributedLock {
30 pub(crate) fn new(
31 name: String,
32 key: PostgresAdvisoryLockKey,
33 pool: PgPool,
34 use_transaction: bool,
35 keepalive_cadence: Option<Duration>,
36 ) -> Self {
37 Self {
38 key,
39 name,
40 pool,
41 use_transaction,
42 keepalive_cadence,
43 }
44 }
45
46 async fn try_acquire_internal(&self) -> LockResult<Option<PostgresLockHandle>> {
48 if self.use_transaction {
49 let mut transaction = self.pool.begin().await.map_err(|e| {
51 LockError::Connection(Box::new(std::io::Error::other(format!(
52 "failed to start transaction: {e}"
53 ))))
54 })?;
55
56 let sql = match self.key {
57 PostgresAdvisoryLockKey::Single(_) => {
58 format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args())
59 }
60 PostgresAdvisoryLockKey::Pair(_, _) => {
61 format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args())
62 }
63 };
64
65 let row = sqlx::query(&sql)
66 .fetch_one(&mut *transaction)
67 .await
68 .map_err(|e| {
69 LockError::Backend(Box::new(std::io::Error::other(format!(
70 "failed to acquire lock: {e}"
71 ))))
72 })?;
73
74 let acquired: bool = row.get(0);
75 if !acquired {
76 return Ok(None);
77 }
78
79 let transaction_ptr = unsafe {
82 std::mem::transmute::<Transaction<'_, Postgres>, Transaction<'static, Postgres>>(
83 transaction,
84 )
85 };
86 let transaction_ptr = Box::into_raw(Box::new(transaction_ptr));
87
88 let (sender, receiver) = watch::channel(false);
89 Ok(Some(PostgresLockHandle::new(
90 crate::handle::PostgresConnectionInner::Transaction(transaction_ptr),
91 self.key,
92 sender,
93 receiver,
94 self.keepalive_cadence,
95 )))
96 } else {
97 let mut connection = self.pool.acquire().await.map_err(|e| {
99 LockError::Connection(Box::new(std::io::Error::other(format!(
100 "failed to get connection from pool: {e}"
101 ))))
102 })?;
103
104 let sql = match self.key {
105 PostgresAdvisoryLockKey::Single(_) => {
106 format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args())
107 }
108 PostgresAdvisoryLockKey::Pair(_, _) => {
109 format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args())
110 }
111 };
112
113 let row = sqlx::query(&sql)
114 .fetch_one(&mut *connection)
115 .await
116 .map_err(|e| {
117 LockError::Backend(Box::new(std::io::Error::other(format!(
118 "failed to acquire lock: {e}"
119 ))))
120 })?;
121
122 let acquired: bool = row.get(0);
123 if !acquired {
124 return Ok(None);
125 }
126
127 let (sender, receiver) = watch::channel(false);
131 Ok(Some(PostgresLockHandle::new(
132 crate::handle::PostgresConnectionInner::Connection(Box::new(connection)),
133 self.key,
134 sender,
135 receiver,
136 self.keepalive_cadence,
137 )))
138 }
139 }
140}
141
142impl DistributedLock for PostgresDistributedLock {
143 type Handle = PostgresLockHandle;
144
145 fn name(&self) -> &str {
146 &self.name
147 }
148
149 #[instrument(skip(self), fields(lock.name = %self.name, timeout = ?timeout, backend = "postgres", use_transaction = self.use_transaction))]
150 async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
151 let timeout_value = TimeoutValue::from(timeout);
152 let start = std::time::Instant::now();
153 Span::current().record("operation", "acquire");
154
155 let mut sleep_duration = Duration::from_millis(50);
157 const MAX_SLEEP: Duration = Duration::from_secs(1);
158
159 loop {
160 match self.try_acquire_internal().await {
161 Ok(Some(handle)) => {
162 let elapsed = start.elapsed();
163 Span::current().record("acquired", true);
164 Span::current().record("elapsed_ms", elapsed.as_millis() as u64);
165 return Ok(handle);
166 }
167 Ok(None) => {
168 if !timeout_value.is_infinite()
170 && start.elapsed() >= timeout_value.as_duration().unwrap()
171 {
172 Span::current().record("acquired", false);
173 Span::current().record("error", "timeout");
174 return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
175 }
176
177 tokio::time::sleep(sleep_duration).await;
179 sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
180 }
181 Err(e) => return Err(e),
182 }
183 }
184 }
185
186 #[instrument(skip(self), fields(lock.name = %self.name, backend = "postgres", use_transaction = self.use_transaction))]
187 async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
188 Span::current().record("operation", "try_acquire");
189 let result = self.try_acquire_internal().await;
190 match &result {
191 Ok(Some(_)) => {
192 Span::current().record("acquired", true);
193 }
194 Ok(None) => {
195 Span::current().record("acquired", false);
196 Span::current().record("reason", "lock_held");
197 }
198 Err(e) => {
199 Span::current().record("acquired", false);
200 Span::current().record("error", e.to_string());
201 }
202 }
203 result
204 }
205}