use crate::ConnectionID;
use crate::objects::{TableLock, TableObject};
use anyhow::Context;
use sqlx::{Connection, Executor, PgConnection, query, query_as};
use tracing::trace;
pub struct Locker {
conn: PgConnection,
}
impl Locker {
pub async fn new(dsn: &str) -> anyhow::Result<Self> {
let mut conn = PgConnection::connect(dsn)
.await
.context("Creating connection")?;
conn.ping().await.context("Pinging postgres")?;
query!("BEGIN;")
.execute(&mut conn)
.await
.context("Starting transaction")?;
Ok(Self { conn })
}
pub async fn lock_tables(
&mut self,
tables: impl IntoIterator<Item = &TableObject>,
) -> anyhow::Result<()> {
for table in tables {
trace!(?table, "Locking table");
let lock_query = format!("LOCK TABLE \"{}\" IN ACCESS EXCLUSIVE MODE;", table.name);
self.conn
.execute(lock_query.as_str())
.await
.with_context(|| format!("Query error while locking {table:?}"))?;
}
Ok(())
}
pub async fn list_connection_locks(
&mut self,
connection_id: ConnectionID,
) -> anyhow::Result<Vec<TableLock>> {
query_as!(
TableLock,
r#"
select relation::regclass::text as "table!", mode as "lock!"
from pg_locks l
join pg_class c ON l.relation = c.oid
join pg_namespace n ON c.relnamespace = n.oid
WHERE l.pid = $1
AND n.nspname = current_schema()
AND c.relkind IN ('r', 'p')
AND l.locktype = 'relation'
AND l.mode IS NOT NULL
AND database = (SELECT oid FROM pg_database WHERE datname = current_database());
"#,
connection_id.0
)
.fetch_all(&mut self.conn)
.await
.with_context(|| {
format!("Query error while listing connection locks for {connection_id:?}")
})
}
}
#[cfg(test)]
mod tests {
use crate::tests::start_test_postgres;
use crate::locker::Locker;
use crate::{ConnectionID, Lock, TableLock};
use sqlx::query_scalar;
use tracing_test::traced_test;
#[traced_test]
#[tokio::test]
async fn test_list_connection_locks() {
let (_container, dsn) = start_test_postgres().await;
let mut locker = Locker::new(&dsn).await.unwrap();
let connection_id = query_scalar!(r#"select pg_backend_pid() as "pid!""#)
.fetch_one(&mut locker.conn)
.await
.map(ConnectionID)
.unwrap();
let orders_table = "orders".into();
locker.lock_tables([&orders_table]).await.unwrap();
let locks = locker.list_connection_locks(connection_id).await.unwrap();
assert_eq!(
locks,
vec![TableLock {
table: orders_table,
lock: Lock::AccessExclusiveLock,
}]
)
}
}