use crate::error::Result;
use crate::postgres::PgConnection;
use crate::Either;
use hkdf::Hkdf;
use once_cell::sync::OnceCell;
use sha2::Sha256;
use std::ops::{Deref, DerefMut};
#[derive(Debug, Clone)]
pub struct PgAdvisoryLock {
key: PgAdvisoryLockKey,
release_query: OnceCell<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum PgAdvisoryLockKey {
BigInt(i64),
IntPair(i32, i32),
}
pub struct PgAdvisoryLockGuard<'lock, C: AsMut<PgConnection>> {
lock: &'lock PgAdvisoryLock,
conn: Option<C>,
}
impl PgAdvisoryLock {
pub fn new(key_string: impl AsRef<str>) -> Self {
let input_key_material = key_string.as_ref();
let hkdf = Hkdf::<Sha256>::new(None, input_key_material.as_bytes());
let mut output_key_material = [0u8; 8];
hkdf.expand(
b"SQLx (Rust) Postgres advisory lock",
&mut output_key_material,
)
.expect("BUG: `output_key_material` should be of acceptable length");
let key = PgAdvisoryLockKey::BigInt(i64::from_le_bytes(output_key_material));
log::trace!(
"generated {:?} from key string {:?}",
key,
input_key_material
);
Self::with_key(key)
}
pub fn with_key(key: PgAdvisoryLockKey) -> Self {
Self {
key,
release_query: OnceCell::new(),
}
}
pub fn key(&self) -> &PgAdvisoryLockKey {
&self.key
}
pub async fn acquire<C: AsMut<PgConnection>>(
&self,
mut conn: C,
) -> Result<PgAdvisoryLockGuard<'_, C>> {
match &self.key {
PgAdvisoryLockKey::BigInt(key) => {
crate::query::query("SELECT pg_advisory_lock($1)")
.bind(key)
.execute(conn.as_mut())
.await?;
}
PgAdvisoryLockKey::IntPair(key1, key2) => {
crate::query::query("SELECT pg_advisory_lock($1, $2)")
.bind(key1)
.bind(key2)
.execute(conn.as_mut())
.await?;
}
}
Ok(PgAdvisoryLockGuard::new(self, conn))
}
pub async fn try_acquire<C: AsMut<PgConnection>>(
&self,
mut conn: C,
) -> Result<Either<PgAdvisoryLockGuard<'_, C>, C>> {
let locked: bool = match &self.key {
PgAdvisoryLockKey::BigInt(key) => {
crate::query_scalar::query_scalar("SELECT pg_try_advisory_lock($1)")
.bind(key)
.fetch_one(conn.as_mut())
.await?
}
PgAdvisoryLockKey::IntPair(key1, key2) => {
crate::query_scalar::query_scalar("SELECT pg_try_advisory_lock($1, $2)")
.bind(key1)
.bind(key2)
.fetch_one(conn.as_mut())
.await?
}
};
if locked {
Ok(Either::Left(PgAdvisoryLockGuard::new(self, conn)))
} else {
Ok(Either::Right(conn))
}
}
pub async fn force_release<C: AsMut<PgConnection>>(&self, mut conn: C) -> Result<(C, bool)> {
let released: bool = match &self.key {
PgAdvisoryLockKey::BigInt(key) => {
crate::query_scalar::query_scalar("SELECT pg_advisory_unlock($1)")
.bind(key)
.fetch_one(conn.as_mut())
.await?
}
PgAdvisoryLockKey::IntPair(key1, key2) => {
crate::query_scalar::query_scalar("SELECT pg_advisory_unlock($1, $2)")
.bind(key1)
.bind(key2)
.fetch_one(conn.as_mut())
.await?
}
};
Ok((conn, released))
}
fn get_release_query(&self) -> &str {
self.release_query.get_or_init(|| match &self.key {
PgAdvisoryLockKey::BigInt(key) => format!("SELECT pg_advisory_unlock({})", key),
PgAdvisoryLockKey::IntPair(key1, key2) => {
format!("SELECT pg_advisory_unlock({}, {})", key1, key2)
}
})
}
}
impl PgAdvisoryLockKey {
pub fn as_bigint(&self) -> Option<i64> {
if let Self::BigInt(bigint) = self {
Some(*bigint)
} else {
None
}
}
}
const NONE_ERR: &str = "BUG: PgAdvisoryLockGuard.conn taken";
impl<'lock, C: AsMut<PgConnection>> PgAdvisoryLockGuard<'lock, C> {
fn new(lock: &'lock PgAdvisoryLock, conn: C) -> Self {
PgAdvisoryLockGuard {
lock,
conn: Some(conn),
}
}
pub async fn release_now(mut self) -> Result<C> {
let (conn, released) = self
.lock
.force_release(self.conn.take().expect(NONE_ERR))
.await?;
if !released {
log::warn!(
"PgAdvisoryLockGuard: advisory lock {:?} was not held by the contained connection",
self.lock.key
);
}
Ok(conn)
}
pub fn leak(mut self) -> C {
self.conn.take().expect(NONE_ERR)
}
}
impl<'lock, C: AsMut<PgConnection> + AsRef<PgConnection>> Deref for PgAdvisoryLockGuard<'lock, C> {
type Target = PgConnection;
fn deref(&self) -> &Self::Target {
self.conn.as_ref().expect(NONE_ERR).as_ref()
}
}
impl<'lock, C: AsMut<PgConnection> + AsRef<PgConnection>> DerefMut
for PgAdvisoryLockGuard<'lock, C>
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn.as_mut().expect(NONE_ERR).as_mut()
}
}
impl<'lock, C: AsMut<PgConnection> + AsRef<PgConnection>> AsRef<PgConnection>
for PgAdvisoryLockGuard<'lock, C>
{
fn as_ref(&self) -> &PgConnection {
self.conn.as_ref().expect(NONE_ERR).as_ref()
}
}
impl<'lock, C: AsMut<PgConnection>> AsMut<PgConnection> for PgAdvisoryLockGuard<'lock, C> {
fn as_mut(&mut self) -> &mut PgConnection {
self.conn.as_mut().expect(NONE_ERR).as_mut()
}
}
impl<'lock, C: AsMut<PgConnection>> Drop for PgAdvisoryLockGuard<'lock, C> {
fn drop(&mut self) {
if let Some(mut conn) = self.conn.take() {
conn.as_mut()
.queue_simple_query(self.lock.get_release_query());
}
}
}