distributed_lock_postgres/
rw_lock.rs1use std::time::Duration;
4
5use distributed_lock_core::error::{LockError, LockResult};
6use distributed_lock_core::timeout::TimeoutValue;
7use distributed_lock_core::traits::{DistributedReaderWriterLock, LockHandle};
8use tokio::sync::watch;
9
10use crate::handle::PostgresConnectionInner;
11use crate::key::PostgresAdvisoryLockKey;
12use sqlx::{PgPool, Postgres, Row, Transaction};
13
14pub struct PostgresDistributedReaderWriterLock {
16 key: PostgresAdvisoryLockKey,
18 name: String,
20 pool: PgPool,
22 use_transaction: bool,
24 keepalive_cadence: Option<Duration>,
26}
27
28impl PostgresDistributedReaderWriterLock {
29 pub(crate) fn new(
30 name: String,
31 key: PostgresAdvisoryLockKey,
32 pool: PgPool,
33 use_transaction: bool,
34 keepalive_cadence: Option<Duration>,
35 ) -> Self {
36 Self {
37 key,
38 name,
39 pool,
40 use_transaction,
41 keepalive_cadence,
42 }
43 }
44
45 async fn try_acquire_read_internal(&self) -> LockResult<Option<PostgresReadLockHandle>> {
47 let mut connection = self.pool.acquire().await.map_err(|e| {
48 LockError::Connection(Box::new(std::io::Error::other(format!(
49 "failed to get connection from pool: {e}"
50 ))))
51 })?;
52
53 let sql = format!(
54 "SELECT pg_try_advisory_lock_shared({})",
55 self.key.to_sql_args()
56 );
57
58 let row = sqlx::query(&sql)
59 .fetch_one(&mut *connection)
60 .await
61 .map_err(|e| {
62 LockError::Backend(Box::new(std::io::Error::other(format!(
63 "failed to acquire read lock: {e}"
64 ))))
65 })?;
66
67 let acquired: bool = row.get(0);
68 if !acquired {
69 return Ok(None);
70 }
71
72 let (sender, receiver) = watch::channel(false);
76 Ok(Some(PostgresReadLockHandle::new(
77 PostgresConnectionInner::Connection(Box::new(connection)),
78 self.key,
79 sender,
80 receiver,
81 self.keepalive_cadence,
82 )))
83 }
84
85 async fn try_acquire_write_internal(&self) -> LockResult<Option<PostgresWriteLockHandle>> {
87 if self.use_transaction {
88 let mut transaction = self.pool.begin().await.map_err(|e| {
90 LockError::Connection(Box::new(std::io::Error::other(format!(
91 "failed to start transaction: {e}"
92 ))))
93 })?;
94
95 let sql = format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args());
96
97 let row = sqlx::query(&sql)
98 .fetch_one(&mut *transaction)
99 .await
100 .map_err(|e| {
101 LockError::Backend(Box::new(std::io::Error::other(format!(
102 "failed to acquire write lock: {e}"
103 ))))
104 })?;
105
106 let acquired: bool = row.get(0);
107 if !acquired {
108 return Ok(None);
109 }
110
111 let transaction_ptr = unsafe {
114 std::mem::transmute::<Transaction<'_, Postgres>, Transaction<'static, Postgres>>(
115 transaction,
116 )
117 };
118 let transaction_ptr = Box::into_raw(Box::new(transaction_ptr));
119
120 let (sender, receiver) = watch::channel(false);
121 Ok(Some(PostgresWriteLockHandle::new(
122 PostgresConnectionInner::Transaction(transaction_ptr),
123 self.key,
124 sender,
125 receiver,
126 self.keepalive_cadence,
127 )))
128 } else {
129 let mut connection = self.pool.acquire().await.map_err(|e| {
131 LockError::Connection(Box::new(std::io::Error::other(format!(
132 "failed to get connection from pool: {e}"
133 ))))
134 })?;
135
136 let sql = format!("SELECT pg_try_advisory_lock({})", self.key.to_sql_args());
137
138 let row = sqlx::query(&sql)
139 .fetch_one(&mut *connection)
140 .await
141 .map_err(|e| {
142 LockError::Backend(Box::new(std::io::Error::other(format!(
143 "failed to acquire write lock: {e}"
144 ))))
145 })?;
146
147 let acquired: bool = row.get(0);
148 if !acquired {
149 return Ok(None);
150 }
151
152 let (sender, receiver) = watch::channel(false);
156 Ok(Some(PostgresWriteLockHandle::new(
157 PostgresConnectionInner::Connection(Box::new(connection)),
158 self.key,
159 sender,
160 receiver,
161 self.keepalive_cadence,
162 )))
163 }
164 }
165}
166
167impl DistributedReaderWriterLock for PostgresDistributedReaderWriterLock {
168 type ReadHandle = PostgresReadLockHandle;
169 type WriteHandle = PostgresWriteLockHandle;
170
171 fn name(&self) -> &str {
172 &self.name
173 }
174
175 async fn acquire_read(&self, timeout: Option<Duration>) -> LockResult<Self::ReadHandle> {
176 let timeout_value = TimeoutValue::from(timeout);
177 let start = std::time::Instant::now();
178
179 let mut sleep_duration = Duration::from_millis(50);
181 const MAX_SLEEP: Duration = Duration::from_secs(1);
182
183 loop {
184 match self.try_acquire_read_internal().await {
185 Ok(Some(handle)) => return Ok(handle),
186 Ok(None) => {
187 if !timeout_value.is_infinite()
189 && start.elapsed() >= timeout_value.as_duration().unwrap()
190 {
191 return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
192 }
193
194 tokio::time::sleep(sleep_duration).await;
196 sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
197 }
198 Err(e) => return Err(e),
199 }
200 }
201 }
202
203 async fn try_acquire_read(&self) -> LockResult<Option<Self::ReadHandle>> {
204 self.try_acquire_read_internal().await
205 }
206
207 async fn acquire_write(&self, timeout: Option<Duration>) -> LockResult<Self::WriteHandle> {
208 let timeout_value = TimeoutValue::from(timeout);
209 let start = std::time::Instant::now();
210
211 let mut sleep_duration = Duration::from_millis(50);
213 const MAX_SLEEP: Duration = Duration::from_secs(1);
214
215 loop {
216 match self.try_acquire_write_internal().await {
217 Ok(Some(handle)) => return Ok(handle),
218 Ok(None) => {
219 if !timeout_value.is_infinite()
221 && start.elapsed() >= timeout_value.as_duration().unwrap()
222 {
223 return Err(LockError::Timeout(timeout_value.as_duration().unwrap()));
224 }
225
226 tokio::time::sleep(sleep_duration).await;
228 sleep_duration = (sleep_duration * 2).min(MAX_SLEEP);
229 }
230 Err(e) => return Err(e),
231 }
232 }
233 }
234
235 async fn try_acquire_write(&self) -> LockResult<Option<Self::WriteHandle>> {
236 self.try_acquire_write_internal().await
237 }
238}
239
240pub struct PostgresReadLockHandle {
242 _connection: Option<PostgresConnectionInner>,
244 key: PostgresAdvisoryLockKey,
246 lost_receiver: watch::Receiver<bool>,
248}
249
250impl PostgresReadLockHandle {
251 pub(crate) fn new(
252 connection: PostgresConnectionInner,
253 key: PostgresAdvisoryLockKey,
254 _lost_sender: watch::Sender<bool>,
255 lost_receiver: watch::Receiver<bool>,
256 _keepalive_cadence: Option<Duration>,
257 ) -> Self {
258 Self {
259 _connection: Some(connection),
260 key,
261 lost_receiver,
262 }
263 }
264}
265
266impl LockHandle for PostgresReadLockHandle {
267 fn lost_token(&self) -> &watch::Receiver<bool> {
268 &self.lost_receiver
269 }
270
271 async fn release(mut self) -> LockResult<()> {
272 if let Some(connection) = self._connection.take() {
274 match connection {
275 PostgresConnectionInner::Connection(mut conn) => {
276 let sql = format!(
277 "SELECT pg_advisory_unlock_shared({})",
278 self.key.to_sql_args()
279 );
280 let _ = sqlx::query(&sql).execute(&mut **conn).await;
281 }
282 PostgresConnectionInner::Transaction(transaction_ptr) => {
283 let transaction = unsafe { Box::from_raw(transaction_ptr) };
285 if let Err(e) = transaction.rollback().await {
286 tracing::warn!("Failed to rollback transaction: {}", e);
287 }
288 }
290 }
291 }
292 Ok(())
293 }
294}
295
296pub struct PostgresWriteLockHandle {
298 _connection: Option<PostgresConnectionInner>,
300 key: PostgresAdvisoryLockKey,
302 lost_receiver: watch::Receiver<bool>,
304}
305
306impl PostgresWriteLockHandle {
307 pub(crate) fn new(
308 connection: PostgresConnectionInner,
309 key: PostgresAdvisoryLockKey,
310 _lost_sender: watch::Sender<bool>,
311 lost_receiver: watch::Receiver<bool>,
312 _keepalive_cadence: Option<Duration>,
313 ) -> Self {
314 Self {
315 _connection: Some(connection),
316 key,
317 lost_receiver,
318 }
319 }
320}
321
322impl LockHandle for PostgresWriteLockHandle {
323 fn lost_token(&self) -> &watch::Receiver<bool> {
324 &self.lost_receiver
325 }
326
327 async fn release(mut self) -> LockResult<()> {
328 if let Some(connection) = self._connection.take() {
330 match connection {
331 PostgresConnectionInner::Connection(mut conn) => {
332 let sql = format!("SELECT pg_advisory_unlock({})", self.key.to_sql_args());
333 let _ = sqlx::query(&sql).execute(&mut **conn).await;
334 }
335 PostgresConnectionInner::Transaction(transaction_ptr) => {
336 let transaction = unsafe { Box::from_raw(transaction_ptr) };
338 if let Err(e) = transaction.rollback().await {
339 tracing::warn!("Failed to rollback transaction: {}", e);
340 }
341 }
343 }
344 }
345 Ok(())
346 }
347}