use async_trait::async_trait;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::{
postgres::{PgListener, PgNotification, PgPoolOptions},
Executor, PgPool,
};
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::{
sync::{
mpsc::{self, Receiver, Sender},
Mutex, RwLock,
},
task::JoinHandle,
};
use tracing::{error, info};
use crate::config::DBListenerError;
use super::{DBListenerTrait, EventType};
static PG_POOL_REGISTRY: Lazy<RwLock<HashMap<String, Arc<PgPool>>>> =
Lazy::new(|| RwLock::new(HashMap::new()));
async fn get_or_create_pool(db_url: &str) -> Result<Arc<PgPool>, DBListenerError> {
{
let pools = PG_POOL_REGISTRY.read().await;
if let Some(pool) = pools.get(db_url) {
return Ok(Arc::clone(pool));
}
}
let mut pools = PG_POOL_REGISTRY.write().await;
if let Some(pool) = pools.get(db_url) {
Ok(Arc::clone(pool))
} else {
let new_pool = PgPoolOptions::new()
.max_connections(10)
.acquire_timeout(Duration::from_secs(2))
.connect(db_url)
.await;
if let Err(e) = new_pool {
error!("Failed to connect to the database: {:?}", e);
return Err(DBListenerError::CreationError(format!(
"Failed to connect to the database url : {:#?}",
e
)));
}
let new_pool = Arc::new(new_pool.unwrap());
pools.insert(db_url.to_string(), Arc::clone(&new_pool));
Ok(new_pool)
}
}
#[derive(Debug, Clone)]
pub struct PostgresTableListener {
pub pool: Arc<PgPool>,
pub table_name: String,
pub columns: Vec<String>,
pub sender: Sender<Value>,
pub receiver: Arc<Mutex<tokio::sync::mpsc::Receiver<Value>>>,
pub table_identifier: String,
pub pg_trigger_name: String,
pub pg_function_name: String,
pub pg_column_updates_name: String,
pub events: Vec<EventType>,
}
impl PostgresTableListener {
pub async fn new(
url: &str,
table_name: &str,
columns: Vec<String>,
table_identifier: &str,
events: Vec<EventType>,
) -> Result<Self, DBListenerError> {
let pool = get_or_create_pool(url).await?;
let uniquekey_uuid = uuid::Uuid::new_v4().to_string();
let uniquekey = uniquekey_uuid.replace("-", "_");
let pg_trigger_name = format!("{}_{}_trigger", table_name, &uniquekey);
let pg_function_name = format!("{}_{}_function", table_name, &uniquekey);
let pg_column_updates_name = format!("{}_{}_column_updates", table_name, &uniquekey);
let (sender, receiver) = mpsc::channel::<Value>(100);
let postgres_table_listener = Self {
pool,
table_name: table_name.to_string(),
columns: columns.into_iter().map(|c| c.to_string()).collect(),
table_identifier: table_identifier.to_string(),
sender,
receiver: Arc::new(Mutex::new(receiver)),
pg_trigger_name,
pg_function_name,
pg_column_updates_name,
events,
};
postgres_table_listener.verify_members().await?;
Ok(postgres_table_listener)
}
async fn verify_members(&self) -> Result<(), DBListenerError> {
info!("--> Verifying table and columns");
let table_exists = sqlx::query_as::<_, (bool,)>(
r#"
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_name = $1
);
"#,
)
.bind(&self.table_name)
.fetch_one(&*self.pool)
.await
.map_err(|e| {
DBListenerError::ListenerVerifyError(format!(
"Failed to verify table existence : {:?}",
e
))
})?;
info!("table exists : {:#?}", table_exists);
if !table_exists.0 {
return Err(DBListenerError::ListenerVerifyError(format!(
"Table '{}' does not exist",
&self.table_name
)));
}
for column in &self.columns {
let column_exists = sqlx::query_as::<_, (bool,)>(
r#"
SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = $1
AND column_name = $2
);
"#,
)
.bind(&self.table_name)
.bind(column)
.fetch_one(&*self.pool)
.await
.map_err(|e| {
DBListenerError::ListenerVerifyError(format!(
"Failed to verify column existence '{}': {:?}",
column, e
))
})?;
if !column_exists.0 {
return Err(DBListenerError::ListenerVerifyError(format!(
"Column '{}' does not exist",
column
)));
}
}
let table_identifier_exists = sqlx::query_as::<_, (bool,)>(
r#"
SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = $1
AND column_name = $2
);
"#,
)
.bind(&self.table_name)
.bind(&self.table_identifier)
.fetch_one(&*self.pool)
.await
.map_err(|e| {
DBListenerError::ListenerVerifyError(format!(
"Failed to verify table identifier '{}': {:?}",
&self.table_identifier, e
))
})?;
if !table_identifier_exists.0 {
return Err(DBListenerError::ListenerVerifyError(format!(
"Table identifier '{}' does not exist",
&self.table_identifier
)));
}
info!("✅ Table and columns verified successfully");
Ok(())
}
async fn create_trigger_and_function(&self) -> Result<(), DBListenerError> {
let table_identifier = &self.table_identifier;
let create_function = format!(
r#"
CREATE OR REPLACE FUNCTION {function_name}()
RETURNS TRIGGER AS $$
BEGIN
IF TG_OP = 'INSERT' THEN
{insert_blocks}
ELSIF TG_OP = 'UPDATE' THEN
{update_blocks}
ELSIF TG_OP = 'DELETE' THEN
{delete_blocks}
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"#,
function_name = self.pg_function_name,
insert_blocks = self
.columns
.iter()
.map(|col| {
format!(
r#"
IF NEW.{col} IS NOT NULL THEN
PERFORM pg_notify(
'{column_updates}',
json_build_object(
'operation', TG_OP,
'table', TG_TABLE_NAME,
'column', '{col}',
'id', NEW.{table_identifier},
'new_value', NEW.{col},
'timestamp', NOW(),
'new_row_data', row_to_json(NEW)
)::text
);
END IF;
"#,
column_updates = self.pg_column_updates_name
)
})
.collect::<Vec<_>>()
.join("\n"),
update_blocks = self
.columns
.iter()
.map(|col| {
format!(
r#"
IF NEW.{col} IS DISTINCT FROM OLD.{col} THEN
PERFORM pg_notify(
'{column_updates}',
json_build_object(
'operation', TG_OP,
'table', TG_TABLE_NAME,
'column', '{col}',
'id', NEW.{table_identifier},
'old_value', OLD.{col},
'new_value', NEW.{col},
'timestamp', NOW(),
'old_row_data', row_to_json(OLD),
'new_row_data', row_to_json(NEW)
)::text
);
END IF;
"#,
column_updates = self.pg_column_updates_name,
)
})
.collect::<Vec<_>>()
.join("\n"),
delete_blocks = self
.columns
.iter()
.map(|col| {
format!(
r#"
IF OLD.{col} IS NOT NULL THEN
PERFORM pg_notify(
'{column_updates}',
json_build_object(
'operation', TG_OP,
'table', TG_TABLE_NAME,
'column', '{col}',
'id', OLD.{table_identifier},
'old_value', OLD.{col},
'timestamp', NOW(),
'old_row_data', row_to_json(OLD)
)::text
);
END IF;
"#,
column_updates = self.pg_column_updates_name,
)
})
.collect::<Vec<_>>()
.join("\n"),
);
self.pool
.execute(create_function.as_str())
.await
.map_err(|e| {
DBListenerError::CreationError(format!(
"Failed to execute function creation : {:#?}",
e
))
})?;
let events_list = self.get_events_list();
let create_trigger = format!(
r#"
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_trigger
WHERE tgname = '{trigger_name}'
) THEN
CREATE TRIGGER {trigger_name}
AFTER {events_list} ON {table_name}
FOR EACH ROW
EXECUTE FUNCTION {function_name}();
END IF;
END $$;
"#,
trigger_name = self.pg_trigger_name,
events_list = events_list,
table_name = self.table_name,
function_name = self.pg_function_name
);
self.pool
.execute(create_trigger.as_str())
.await
.map_err(|e| {
DBListenerError::CreationError(format!(
"Failed to execute trigger creation : {:#?}",
e
))
})?;
Ok(())
}
fn get_events_list(&self) -> String {
self.events
.iter()
.map(|event| match event {
EventType::INSERT => "INSERT".to_string(),
EventType::UPDATE => format!("UPDATE OF {}", self.columns.join(", ")),
EventType::DELETE => "DELETE".to_string(),
})
.collect::<Vec<_>>()
.join(" OR ")
}
async fn drop_trigger_and_function(&self) -> Result<(), DBListenerError> {
let trigger_name = &self.pg_trigger_name;
let function_name = &self.pg_function_name;
let table_name = &self.table_name;
let drop_trigger = format!(
r#"
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM pg_trigger
WHERE tgname = '{trigger_name}'
) THEN
DROP TRIGGER {trigger_name} ON {table_name};
END IF;
END $$;
"#,
trigger_name = trigger_name,
table_name = table_name
);
self.pool
.execute(drop_trigger.as_str())
.await
.map_err(|e| {
DBListenerError::DeletionError(format!(
"Failed to execute trigger deletion : {:#?}",
e
))
})?;
let drop_function = format!(
r#"
DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM pg_proc
WHERE proname = '{function_name}'
) THEN
DROP FUNCTION {function_name}();
END IF;
END $$;
"#,
function_name = function_name
);
self.pool
.execute(drop_function.as_str())
.await
.map_err(|e| {
DBListenerError::DeletionError(format!(
"Failed to execute trigger deletion : {:#?}",
e
))
})?;
info!("Trigger and function removed for table: {}", table_name);
Ok(())
}
async fn initialize_listener(&self) -> Result<Arc<Mutex<PgListener>>, DBListenerError> {
self.create_trigger_and_function()
.await
.map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
let listener = PgListener::connect_with(&self.pool)
.await
.map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
let listener = Arc::new(Mutex::new(listener));
{
let mut locked_listener = listener.lock().await;
locked_listener
.listen(&self.pg_column_updates_name)
.await
.map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
}
info!(
"Listening for column update notifications on {}",
self.table_name
);
Ok(listener)
}
fn spawn_listener_task(&self, listener: Arc<Mutex<PgListener>>) -> JoinHandle<()> {
let sender_clone = self.sender.clone();
let table_name = self.table_name.clone();
tokio::spawn(async move {
info!("Listener spawned and waiting for notifications");
loop {
let mut locked_listener = listener.lock().await;
match locked_listener.recv().await {
Ok(notification) => {
if let Some(pg_notify) = process_notification(¬ification, &table_name) {
if let Ok(json_data) = serde_json::to_value(pg_notify) {
if let Err(e) = sender_clone.send(json_data).await {
error!("Failed to send payload to channel: {:?}", e);
}
} else {
error!("Failed to serialize PgNotify to JSON");
}
}
}
Err(e) => {
error!("Listener encountered an error: {:?}", e);
break;
}
}
}
})
}
}
#[async_trait]
impl DBListenerTrait for PostgresTableListener {
async fn start(
&self,
) -> Result<(Arc<Mutex<Receiver<Value>>>, JoinHandle<()>), DBListenerError> {
let listener = self.initialize_listener().await?;
let handle = self.spawn_listener_task(listener);
Ok((Arc::clone(&self.receiver), handle))
}
async fn stop(&self) -> Result<(), DBListenerError> {
self.drop_trigger_and_function()
.await
.map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
Ok(())
}
}
fn process_notification(notification: &PgNotification, table_name: &str) -> Option<PgNotify> {
match serde_json::from_str::<Value>(¬ification.payload()) {
Ok(payload) => Some(PgNotify {
operation: payload
.get("operation")
.and_then(|v| v.as_str().map(String::from))
.unwrap_or_default(),
table: table_name.to_string(),
column: payload
.get("column")
.and_then(|v| v.as_str().map(String::from))
.unwrap_or_default(),
id: payload.get("id").map(|v| v.to_string()).unwrap_or_default(),
new_row_data: payload
.get("new_row_data")
.cloned()
.unwrap_or_else(|| Value::Null),
old_row_data: payload
.get("old_row_data")
.cloned()
.unwrap_or_else(|| Value::Null),
timestamp: chrono::Utc::now().to_rfc3339(),
}),
Err(e) => {
error!("Failed to parse notification payload: {:?}", e);
None
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct PgNotify {
pub operation: String,
pub table: String,
pub id: String,
pub column: String,
pub new_row_data: Value,
pub old_row_data: Value,
pub timestamp: String,
}
#[cfg(test)]
mod tests {
use std::{env, sync::Arc};
use tokio::time::{sleep, Duration};
use dotenv::dotenv;
use sqlx::Executor;
use crate::{
database::{
postgres::{get_or_create_pool, PostgresTableListener},
DBListenerTrait,
},
EventType,
};
#[tokio::test]
async fn create_new_listener_with_props() {
dotenv().ok();
let database_url =
env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
let table_name = "swaps".to_string();
let columns = vec![
"initiate_tx_hash".to_string(),
"redeem_tx_hash".to_string(),
"refund_tx_hash".to_string(),
];
let table_identifier = "swap_id".to_string();
let events = vec![EventType::UPDATE, EventType::INSERT, EventType::DELETE];
let result = PostgresTableListener::new(
&database_url,
&table_name,
columns,
&table_identifier,
events,
)
.await;
assert!(!result.is_err(), "Listener failed to connect");
sleep(Duration::from_secs(1)).await;
}
#[tokio::test]
async fn create_new_listener_with_invalid_props() {
dotenv().ok();
let database_url =
env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
let table_name = "atomic_swaps".to_string();
let columns = vec![
"initiate_tx_hash".to_string(),
"redeem_tx_hash".to_string(),
"refund_tx_hash".to_string(),
];
let table_identifier = "swap_id".to_string();
let events = vec![EventType::UPDATE, EventType::INSERT, EventType::DELETE];
let result = PostgresTableListener::new(
&database_url,
&table_name,
columns,
&table_identifier,
events,
)
.await;
assert!(result.is_err(), "Listener failed to connect");
sleep(Duration::from_secs(1)).await;
}
#[tokio::test]
async fn get_same_pool_for_same_url() {
dotenv().ok();
let database_url =
env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
let pool1 = get_or_create_pool(&database_url).await.unwrap();
let pool2 = get_or_create_pool(&database_url).await.unwrap();
assert!(
Arc::ptr_eq(&pool1, &pool2),
"Expected the same pool instance, but got different ones"
);
sleep(Duration::from_secs(1)).await;
}
#[tokio::test]
async fn postgres_table_listener() {
sleep(Duration::from_secs(1)).await;
dotenv().ok();
let database_url =
env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
let table_name = "swaps".to_string();
let columns = vec![
"initiate_tx_hash".to_string(),
"redeem_tx_hash".to_string(),
"refund_tx_hash".to_string(),
];
let table_identifier = "swap_id".to_string();
let events = vec![EventType::UPDATE, EventType::INSERT, EventType::DELETE];
let postgres_table_listener = PostgresTableListener::new(
&database_url,
&table_name,
columns.clone(),
&table_identifier,
events.clone(),
)
.await;
assert!(
postgres_table_listener.is_ok(),
"Failed to intialize postgres table listener"
);
let postgres_table_listener = postgres_table_listener.unwrap();
let (rx, handle) = postgres_table_listener.start().await.unwrap();
let notification_task = tokio::spawn(async move {
let mut received_events = Vec::new();
while let Some(payload) = rx.lock().await.recv().await {
println!("Notification received: {:#?}", payload);
received_events.push(payload);
if received_events.len() >= 3 {
break; }
}
received_events
});
let pool = sqlx::PgPool::connect(&database_url)
.await
.expect("Failed to connect to DB");
async fn execute_query(pool: &sqlx::PgPool, query: &str) {
pool.execute(query).await.expect("Query execution failed");
}
execute_query(
&pool,
&format!(
"INSERT INTO {} (id, initiate_tx_hash, redeem_tx_hash, refund_tx_hash) VALUES (1, 'tx1', 'tx2', 'tx3')",
table_name
),
)
.await;
sleep(Duration::from_millis(100)).await;
execute_query(
&pool,
&format!(
"UPDATE {} SET redeem_tx_hash = 'updated_tx2' WHERE id = 1",
table_name
),
)
.await;
sleep(Duration::from_millis(100)).await;
execute_query(&pool, &format!("DELETE FROM {} WHERE id = 1", table_name)).await;
sleep(Duration::from_secs(2)).await;
let received_events = notification_task.await.unwrap();
assert_eq!(
received_events.len(),
3,
"Expected 3 events but received {}",
received_events.len()
);
postgres_table_listener.stop().await.unwrap();
handle.abort();
}
}