use std::time::Duration;
use distributed_lock_core::error::{LockError, LockResult};
use distributed_lock_core::traits::DistributedLock;
use tokio::sync::watch;
use tracing::{Span, instrument};
use crate::handle::PostgresLockHandle;
use crate::key::PostgresAdvisoryLockKey;
use sqlx::{Executor, PgPool, Row};
pub struct PostgresDistributedLock {
key: PostgresAdvisoryLockKey,
name: String,
pool: PgPool,
use_transaction: bool,
keepalive_cadence: Option<Duration>,
}
impl PostgresDistributedLock {
pub(crate) fn new(
name: String,
key: PostgresAdvisoryLockKey,
pool: PgPool,
use_transaction: bool,
keepalive_cadence: Option<Duration>,
) -> Self {
Self {
key,
name,
pool,
use_transaction,
keepalive_cadence,
}
}
async fn acquire_internal(
&self,
timeout: Option<Duration>,
) -> LockResult<Option<PostgresLockHandle>> {
let mut conn = self.pool.acquire().await.map_err(|e| {
LockError::Connection(Box::new(std::io::Error::other(format!(
"failed to get connection from pool: {e}"
))))
})?;
conn.execute("BEGIN").await.map_err(|e| {
LockError::Connection(Box::new(std::io::Error::other(format!(
"failed to start transaction: {e}"
))))
})?;
let use_transaction_lock = self.use_transaction;
let savepoint_name = "medallion_lock_acquire";
let sql = format!("SAVEPOINT {}", savepoint_name);
conn.execute(sql.as_str()).await.map_err(|e| {
LockError::Backend(Box::new(std::io::Error::other(format!(
"failed to create savepoint: {e}"
))))
})?;
let timeout_ms = timeout.map(|d| d.as_millis() as i64).unwrap_or(0);
let set_timeout_sql = format!("SET LOCAL lock_timeout = {}", timeout_ms);
if let Err(e) = conn.execute(set_timeout_sql.as_str()).await {
let _ = conn
.execute(format!("ROLLBACK TO SAVEPOINT {}", savepoint_name).as_str())
.await;
if !use_transaction_lock {
let _ = conn.execute("ROLLBACK").await;
}
return Err(LockError::Backend(Box::new(std::io::Error::other(
format!("failed to set lock_timeout: {e}"),
))));
}
let lock_func = if use_transaction_lock {
"pg_advisory_xact_lock"
} else {
"pg_advisory_lock"
};
let sql = format!("SELECT {}({})", lock_func, self.key.to_sql_args());
match conn.fetch_one(sql.as_str()).await {
Ok(_) => {
if !use_transaction_lock {
if let Err(e) = conn.execute("COMMIT").await {
return Err(LockError::Backend(Box::new(std::io::Error::other(
format!("failed to commit transaction after locking: {e}"),
))));
}
}
let (sender, receiver) = watch::channel(false);
Ok(Some(PostgresLockHandle::new(
conn,
use_transaction_lock,
self.key,
sender,
receiver,
self.keepalive_cadence,
)))
}
Err(e) => {
let db_err = e.as_database_error();
let code = db_err.and_then(|db_err| db_err.code()).unwrap_or_default();
let _ = conn
.execute(format!("ROLLBACK TO SAVEPOINT {}", savepoint_name).as_str())
.await;
if !use_transaction_lock {
let _ = conn.execute("ROLLBACK").await;
}
if code == "55P03" {
return Ok(None); }
if code == "40P01" {
return Err(LockError::Deadlock(
"deadlock detected by postgres".to_string(),
));
}
Err(LockError::Backend(Box::new(std::io::Error::other(
format!("failed to acquire lock: {e}"),
))))
}
}
}
async fn try_acquire_internal_immediate(&self) -> LockResult<Option<PostgresLockHandle>> {
let mut conn = self.pool.acquire().await.map_err(|e| {
LockError::Connection(Box::new(std::io::Error::other(format!(
"failed to get connection from pool: {e}"
))))
})?;
let use_transaction = self.use_transaction;
if use_transaction {
conn.execute("BEGIN").await.map_err(|e| {
LockError::Connection(Box::new(std::io::Error::other(format!(
"failed to start transaction: {e}"
))))
})?;
}
let lock_func = if use_transaction {
"pg_try_advisory_xact_lock"
} else {
"pg_try_advisory_lock"
};
let sql = format!("SELECT {}({})", lock_func, self.key.to_sql_args());
let row = conn.fetch_one(sql.as_str()).await.map_err(|e| {
LockError::Backend(Box::new(std::io::Error::other(format!(
"failed to try_acquire lock: {e}"
))))
})?;
let acquired: bool = row.get(0);
if !acquired {
if use_transaction {
let _ = conn.execute("ROLLBACK").await;
}
return Ok(None);
}
let (sender, receiver) = watch::channel(false);
Ok(Some(PostgresLockHandle::new(
conn,
use_transaction,
self.key,
sender,
receiver,
self.keepalive_cadence,
)))
}
}
impl DistributedLock for PostgresDistributedLock {
type Handle = PostgresLockHandle;
fn name(&self) -> &str {
&self.name
}
#[instrument(skip(self), fields(lock.name = %self.name, timeout = ?timeout, backend = "postgres", use_transaction = self.use_transaction))]
async fn acquire(&self, timeout: Option<Duration>) -> LockResult<Self::Handle> {
Span::current().record("operation", "acquire");
match self.acquire_internal(timeout).await {
Ok(Some(handle)) => {
Span::current().record("acquired", true);
Ok(handle)
}
Ok(None) => {
Span::current().record("acquired", false);
Span::current().record("error", "timeout");
Err(LockError::Timeout(timeout.unwrap_or(Duration::MAX)))
}
Err(e) => Err(e),
}
}
#[instrument(skip(self), fields(lock.name = %self.name, backend = "postgres", use_transaction = self.use_transaction))]
async fn try_acquire(&self) -> LockResult<Option<Self::Handle>> {
Span::current().record("operation", "try_acquire");
match self.try_acquire_internal_immediate().await {
Ok(Some(handle)) => {
Span::current().record("acquired", true);
Ok(Some(handle))
}
Ok(None) => {
Span::current().record("acquired", false);
Ok(None)
}
Err(e) => Err(e),
}
}
}