1use async_trait::async_trait;
20use once_cell::sync::Lazy;
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use sqlx::{
24 postgres::{PgListener, PgNotification, PgPoolOptions},
25 Executor, PgPool,
26};
27use std::{collections::HashMap, sync::Arc, time::Duration};
28use tokio::{
29 sync::{
30 mpsc::{self, Receiver, Sender},
31 Mutex, RwLock,
32 },
33 task::JoinHandle,
34};
35use tracing::{error, info};
36
37use crate::config::DBListenerError;
38
39use super::{DBListenerTrait, EventType};
40
41static PG_POOL_REGISTRY: Lazy<RwLock<HashMap<String, Arc<PgPool>>>> =
42 Lazy::new(|| RwLock::new(HashMap::new()));
43
44async fn get_or_create_pool(db_url: &str) -> Result<Arc<PgPool>, DBListenerError> {
45 {
47 let pools = PG_POOL_REGISTRY.read().await;
48
49 if let Some(pool) = pools.get(db_url) {
50 return Ok(Arc::clone(pool));
51 }
52 }
53
54 let mut pools = PG_POOL_REGISTRY.write().await;
55
56 if let Some(pool) = pools.get(db_url) {
57 Ok(Arc::clone(pool))
58 } else {
59 let new_pool = PgPoolOptions::new()
60 .max_connections(10)
61 .acquire_timeout(Duration::from_secs(2))
62 .connect(db_url)
63 .await;
64
65 if let Err(e) = new_pool {
66 error!("Failed to connect to the database: {:?}", e);
67 return Err(DBListenerError::CreationError(format!(
68 "Failed to connect to the database url : {:#?}",
69 e
70 )));
71 }
72 let new_pool = Arc::new(new_pool.unwrap());
73 pools.insert(db_url.to_string(), Arc::clone(&new_pool));
74 Ok(new_pool)
75 }
76}
77
78#[derive(Debug, Clone)]
79pub struct PostgresTableListener {
80 pub pool: Arc<PgPool>,
81 pub table_name: String,
82 pub columns: Vec<String>,
83 pub sender: Sender<Value>,
84 pub receiver: Arc<Mutex<tokio::sync::mpsc::Receiver<Value>>>,
85 pub table_identifier: String,
86 pub pg_trigger_name: String,
87 pub pg_function_name: String,
88 pub pg_column_updates_name: String,
89 pub events: Vec<EventType>,
90}
91
92impl PostgresTableListener {
93 pub async fn new(
94 url: &str,
95 table_name: &str,
96 columns: Vec<String>,
97 table_identifier: &str,
98 events: Vec<EventType>,
99 ) -> Result<Self, DBListenerError> {
100 let pool = get_or_create_pool(url).await?;
101
102 let uniquekey_uuid = uuid::Uuid::new_v4().to_string();
103 let uniquekey = uniquekey_uuid.replace("-", "_");
104
105 let pg_trigger_name = format!("{}_{}_trigger", table_name, &uniquekey);
106 let pg_function_name = format!("{}_{}_function", table_name, &uniquekey);
107 let pg_column_updates_name = format!("{}_{}_column_updates", table_name, &uniquekey);
108
109 let (sender, receiver) = mpsc::channel::<Value>(100); let postgres_table_listener = Self {
112 pool,
113 table_name: table_name.to_string(),
114 columns: columns.into_iter().map(|c| c.to_string()).collect(),
115 table_identifier: table_identifier.to_string(),
116 sender,
117 receiver: Arc::new(Mutex::new(receiver)),
118 pg_trigger_name,
119 pg_function_name,
120 pg_column_updates_name,
121 events,
122 };
123
124 postgres_table_listener.verify_members().await?;
125
126 Ok(postgres_table_listener)
127 }
128
129 async fn verify_members(&self) -> Result<(), DBListenerError> {
130 info!("--> Verifying table and columns");
131 let table_exists = sqlx::query_as::<_, (bool,)>(
133 r#"
134 SELECT EXISTS (
135 SELECT 1
136 FROM information_schema.tables
137 WHERE table_name = $1
138 );
139 "#,
140 )
141 .bind(&self.table_name)
142 .fetch_one(&*self.pool)
143 .await
144 .map_err(|e| {
145 DBListenerError::ListenerVerifyError(format!(
146 "Failed to verify table existence : {:?}",
147 e
148 ))
149 })?;
150
151 info!("table exists : {:#?}", table_exists);
152
153 if !table_exists.0 {
154 return Err(DBListenerError::ListenerVerifyError(format!(
155 "Table '{}' does not exist",
156 &self.table_name
157 )));
158 }
159
160 for column in &self.columns {
162 let column_exists = sqlx::query_as::<_, (bool,)>(
163 r#"
164 SELECT EXISTS (
165 SELECT 1
166 FROM information_schema.columns
167 WHERE table_name = $1
168 AND column_name = $2
169 );
170 "#,
171 )
172 .bind(&self.table_name)
173 .bind(column)
174 .fetch_one(&*self.pool)
175 .await
176 .map_err(|e| {
177 DBListenerError::ListenerVerifyError(format!(
178 "Failed to verify column existence '{}': {:?}",
179 column, e
180 ))
181 })?;
182
183 if !column_exists.0 {
184 return Err(DBListenerError::ListenerVerifyError(format!(
185 "Column '{}' does not exist",
186 column
187 )));
188 }
189 }
190
191 let table_identifier_exists = sqlx::query_as::<_, (bool,)>(
193 r#"
194 SELECT EXISTS (
195 SELECT 1
196 FROM information_schema.columns
197 WHERE table_name = $1
198 AND column_name = $2
199 );
200 "#,
201 )
202 .bind(&self.table_name)
203 .bind(&self.table_identifier)
204 .fetch_one(&*self.pool)
205 .await
206 .map_err(|e| {
207 DBListenerError::ListenerVerifyError(format!(
208 "Failed to verify table identifier '{}': {:?}",
209 &self.table_identifier, e
210 ))
211 })?;
212
213 if !table_identifier_exists.0 {
214 return Err(DBListenerError::ListenerVerifyError(format!(
215 "Table identifier '{}' does not exist",
216 &self.table_identifier
217 )));
218 }
219
220 info!("✅ Table and columns verified successfully");
221 Ok(())
222 }
223
224 async fn create_trigger_and_function(&self) -> Result<(), DBListenerError> {
225 let table_identifier = &self.table_identifier;
226
227 let create_function = format!(
229 r#"
230 CREATE OR REPLACE FUNCTION {function_name}()
231 RETURNS TRIGGER AS $$
232 BEGIN
233 IF TG_OP = 'INSERT' THEN
234 {insert_blocks}
235 ELSIF TG_OP = 'UPDATE' THEN
236 {update_blocks}
237 ELSIF TG_OP = 'DELETE' THEN
238 {delete_blocks}
239 END IF;
240 RETURN NEW;
241 END;
242 $$ LANGUAGE plpgsql;
243 "#,
244 function_name = self.pg_function_name,
245 insert_blocks = self
246 .columns
247 .iter()
248 .map(|col| {
249 format!(
250 r#"
251 IF NEW.{col} IS NOT NULL THEN
252 PERFORM pg_notify(
253 '{column_updates}',
254 json_build_object(
255 'operation', TG_OP,
256 'table', TG_TABLE_NAME,
257 'column', '{col}',
258 'id', NEW.{table_identifier},
259 'new_value', NEW.{col},
260 'timestamp', NOW(),
261 'new_row_data', row_to_json(NEW)
262 )::text
263 );
264 END IF;
265 "#,
266 column_updates = self.pg_column_updates_name
267 )
268 })
269 .collect::<Vec<_>>()
270 .join("\n"),
271 update_blocks = self
272 .columns
273 .iter()
274 .map(|col| {
275 format!(
276 r#"
277 IF NEW.{col} IS DISTINCT FROM OLD.{col} THEN
278 PERFORM pg_notify(
279 '{column_updates}',
280 json_build_object(
281 'operation', TG_OP,
282 'table', TG_TABLE_NAME,
283 'column', '{col}',
284 'id', NEW.{table_identifier},
285 'old_value', OLD.{col},
286 'new_value', NEW.{col},
287 'timestamp', NOW(),
288 'old_row_data', row_to_json(OLD),
289 'new_row_data', row_to_json(NEW)
290 )::text
291 );
292 END IF;
293 "#,
294 column_updates = self.pg_column_updates_name,
295 )
296 })
297 .collect::<Vec<_>>()
298 .join("\n"),
299 delete_blocks = self
300 .columns
301 .iter()
302 .map(|col| {
303 format!(
304 r#"
305 IF OLD.{col} IS NOT NULL THEN
306 PERFORM pg_notify(
307 '{column_updates}',
308 json_build_object(
309 'operation', TG_OP,
310 'table', TG_TABLE_NAME,
311 'column', '{col}',
312 'id', OLD.{table_identifier},
313 'old_value', OLD.{col},
314 'timestamp', NOW(),
315 'old_row_data', row_to_json(OLD)
316 )::text
317 );
318 END IF;
319 "#,
320 column_updates = self.pg_column_updates_name,
321 )
322 })
323 .collect::<Vec<_>>()
324 .join("\n"),
325 );
326
327 self.pool
329 .execute(create_function.as_str())
330 .await
331 .map_err(|e| {
332 DBListenerError::CreationError(format!(
333 "Failed to execute function creation : {:#?}",
334 e
335 ))
336 })?;
337
338 let events_list = self.get_events_list();
339
340 let create_trigger = format!(
341 r#"
342 DO $$
343 BEGIN
344 IF NOT EXISTS (
345 SELECT 1
346 FROM pg_trigger
347 WHERE tgname = '{trigger_name}'
348 ) THEN
349 CREATE TRIGGER {trigger_name}
350 AFTER {events_list} ON {table_name}
351 FOR EACH ROW
352 EXECUTE FUNCTION {function_name}();
353 END IF;
354 END $$;
355 "#,
356 trigger_name = self.pg_trigger_name,
357 events_list = events_list,
358 table_name = self.table_name,
359 function_name = self.pg_function_name
360 );
361
362 self.pool
364 .execute(create_trigger.as_str())
365 .await
366 .map_err(|e| {
367 DBListenerError::CreationError(format!(
368 "Failed to execute trigger creation : {:#?}",
369 e
370 ))
371 })?;
372
373 Ok(())
374 }
375
376 fn get_events_list(&self) -> String {
377 self.events
378 .iter()
379 .map(|event| match event {
380 EventType::INSERT => "INSERT".to_string(),
381 EventType::UPDATE => format!("UPDATE OF {}", self.columns.join(", ")),
382 EventType::DELETE => "DELETE".to_string(),
383 })
384 .collect::<Vec<_>>()
385 .join(" OR ")
386 }
387
388 async fn drop_trigger_and_function(&self) -> Result<(), DBListenerError> {
389 let trigger_name = &self.pg_trigger_name;
390 let function_name = &self.pg_function_name;
391 let table_name = &self.table_name;
392
393 let drop_trigger = format!(
395 r#"
396 DO $$
397 BEGIN
398 IF EXISTS (
399 SELECT 1
400 FROM pg_trigger
401 WHERE tgname = '{trigger_name}'
402 ) THEN
403 DROP TRIGGER {trigger_name} ON {table_name};
404 END IF;
405 END $$;
406 "#,
407 trigger_name = trigger_name,
408 table_name = table_name
409 );
410
411 self.pool
412 .execute(drop_trigger.as_str())
413 .await
414 .map_err(|e| {
415 DBListenerError::DeletionError(format!(
416 "Failed to execute trigger deletion : {:#?}",
417 e
418 ))
419 })?;
420
421 let drop_function = format!(
423 r#"
424 DO $$
425 BEGIN
426 IF EXISTS (
427 SELECT 1
428 FROM pg_proc
429 WHERE proname = '{function_name}'
430 ) THEN
431 DROP FUNCTION {function_name}();
432 END IF;
433 END $$;
434 "#,
435 function_name = function_name
436 );
437
438 self.pool
439 .execute(drop_function.as_str())
440 .await
441 .map_err(|e| {
442 DBListenerError::DeletionError(format!(
443 "Failed to execute trigger deletion : {:#?}",
444 e
445 ))
446 })?;
447
448 info!("Trigger and function removed for table: {}", table_name);
449 Ok(())
450 }
451
452 async fn initialize_listener(&self) -> Result<Arc<Mutex<PgListener>>, DBListenerError> {
453 self.create_trigger_and_function()
455 .await
456 .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
457
458 let listener = PgListener::connect_with(&self.pool)
460 .await
461 .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
462
463 let listener = Arc::new(Mutex::new(listener));
464
465 {
467 let mut locked_listener = listener.lock().await;
468 locked_listener
469 .listen(&self.pg_column_updates_name)
470 .await
471 .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
472 }
473
474 info!(
475 "Listening for column update notifications on {}",
476 self.table_name
477 );
478
479 Ok(listener)
480 }
481
482 fn spawn_listener_task(&self, listener: Arc<Mutex<PgListener>>) -> JoinHandle<()> {
483 let sender_clone = self.sender.clone();
484 let table_name = self.table_name.clone();
485
486 tokio::spawn(async move {
487 info!("Listener spawned and waiting for notifications");
488
489 loop {
490 let mut locked_listener = listener.lock().await;
491 match locked_listener.recv().await {
492 Ok(notification) => {
493 if let Some(pg_notify) = process_notification(¬ification, &table_name) {
494 if let Ok(json_data) = serde_json::to_value(pg_notify) {
496 if let Err(e) = sender_clone.send(json_data).await {
497 error!("Failed to send payload to channel: {:?}", e);
498 }
499 } else {
500 error!("Failed to serialize PgNotify to JSON");
501 }
502 }
503 }
504 Err(e) => {
505 error!("Listener encountered an error: {:?}", e);
506 break;
507 }
508 }
509 }
510 })
511 }
512}
513
514#[async_trait]
515impl DBListenerTrait for PostgresTableListener {
516 async fn start(
517 &self,
518 ) -> Result<(Arc<Mutex<Receiver<Value>>>, JoinHandle<()>), DBListenerError> {
519 let listener = self.initialize_listener().await?;
520
521 let handle = self.spawn_listener_task(listener);
522
523 Ok((Arc::clone(&self.receiver), handle))
524 }
525
526 async fn stop(&self) -> Result<(), DBListenerError> {
527 self.drop_trigger_and_function()
528 .await
529 .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
530 Ok(())
531 }
532}
533
534fn process_notification(notification: &PgNotification, table_name: &str) -> Option<PgNotify> {
535 match serde_json::from_str::<Value>(¬ification.payload()) {
536 Ok(payload) => Some(PgNotify {
537 operation: payload
538 .get("operation")
539 .and_then(|v| v.as_str().map(String::from))
540 .unwrap_or_default(),
541 table: table_name.to_string(),
542 column: payload
543 .get("column")
544 .and_then(|v| v.as_str().map(String::from))
545 .unwrap_or_default(),
546 id: payload.get("id").map(|v| v.to_string()).unwrap_or_default(),
547 new_row_data: payload
548 .get("new_row_data")
549 .cloned()
550 .unwrap_or_else(|| Value::Null),
551 old_row_data: payload
552 .get("old_row_data")
553 .cloned()
554 .unwrap_or_else(|| Value::Null),
555 timestamp: chrono::Utc::now().to_rfc3339(),
556 }),
557 Err(e) => {
558 error!("Failed to parse notification payload: {:?}", e);
559 None
560 }
561 }
562}
563
564#[derive(Debug, Deserialize, Serialize)]
565pub struct PgNotify {
566 pub operation: String,
567 pub table: String,
568 pub id: String,
569 pub column: String,
570 pub new_row_data: Value,
571 pub old_row_data: Value,
572 pub timestamp: String,
573}
574
575#[cfg(test)]
576mod tests {
577 use std::{env, sync::Arc};
578 use tokio::time::{sleep, Duration};
579
580 use dotenv::dotenv;
581 use sqlx::Executor;
582
583 use crate::{
584 database::{
585 postgres::{get_or_create_pool, PostgresTableListener},
586 DBListenerTrait,
587 },
588 EventType,
589 };
590
591 #[tokio::test]
592 async fn create_new_listener_with_props() {
593 dotenv().ok();
594 let database_url =
595 env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
596
597 let table_name = "swaps".to_string();
598 let columns = vec![
599 "initiate_tx_hash".to_string(),
600 "redeem_tx_hash".to_string(),
601 "refund_tx_hash".to_string(),
602 ];
603
604 let table_identifier = "swap_id".to_string();
605
606 let events = vec![EventType::UPDATE, EventType::INSERT, EventType::DELETE];
607
608 let result = PostgresTableListener::new(
609 &database_url,
610 &table_name,
611 columns,
612 &table_identifier,
613 events,
614 )
615 .await;
616
617 assert!(!result.is_err(), "Listener failed to connect");
618
619 sleep(Duration::from_secs(1)).await;
620 }
621
622 #[tokio::test]
623 async fn create_new_listener_with_invalid_props() {
624 dotenv().ok();
625 let database_url =
626 env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
627
628 let table_name = "atomic_swaps".to_string();
629 let columns = vec![
630 "initiate_tx_hash".to_string(),
631 "redeem_tx_hash".to_string(),
632 "refund_tx_hash".to_string(),
633 ];
634
635 let table_identifier = "swap_id".to_string();
636
637 let events = vec![EventType::UPDATE, EventType::INSERT, EventType::DELETE];
638
639 let result = PostgresTableListener::new(
640 &database_url,
641 &table_name,
642 columns,
643 &table_identifier,
644 events,
645 )
646 .await;
647
648 assert!(result.is_err(), "Listener failed to connect");
649
650 sleep(Duration::from_secs(1)).await;
651 }
652
653 #[tokio::test]
654 async fn get_same_pool_for_same_url() {
655 dotenv().ok();
656 let database_url =
657 env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
658
659 let pool1 = get_or_create_pool(&database_url).await.unwrap();
660
661 let pool2 = get_or_create_pool(&database_url).await.unwrap();
662
663 assert!(
664 Arc::ptr_eq(&pool1, &pool2),
665 "Expected the same pool instance, but got different ones"
666 );
667
668 sleep(Duration::from_secs(1)).await;
669 }
670
671 #[tokio::test]
672 async fn postgres_table_listener() {
674 sleep(Duration::from_secs(1)).await;
675
676 dotenv().ok();
677
678 let database_url =
680 env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
681
682 let table_name = "swaps".to_string();
683 let columns = vec![
684 "initiate_tx_hash".to_string(),
685 "redeem_tx_hash".to_string(),
686 "refund_tx_hash".to_string(),
687 ];
688
689 let table_identifier = "swap_id".to_string();
690
691 let events = vec![EventType::UPDATE, EventType::INSERT, EventType::DELETE];
692
693 let postgres_table_listener = PostgresTableListener::new(
694 &database_url,
695 &table_name,
696 columns.clone(),
697 &table_identifier,
698 events.clone(),
699 )
700 .await;
701
702 assert!(
703 postgres_table_listener.is_ok(),
704 "Failed to intialize postgres table listener"
705 );
706
707 let postgres_table_listener = postgres_table_listener.unwrap();
708
709 let (rx, handle) = postgres_table_listener.start().await.unwrap();
710
711 let notification_task = tokio::spawn(async move {
712 let mut received_events = Vec::new();
713 while let Some(payload) = rx.lock().await.recv().await {
714 println!("Notification received: {:#?}", payload);
715 received_events.push(payload);
716 if received_events.len() >= 3 {
717 break; }
719 }
720 received_events
721 });
722
723 let pool = sqlx::PgPool::connect(&database_url)
724 .await
725 .expect("Failed to connect to DB");
726
727 async fn execute_query(pool: &sqlx::PgPool, query: &str) {
728 pool.execute(query).await.expect("Query execution failed");
729 }
730
731 execute_query(
733 &pool,
734 &format!(
735 "INSERT INTO {} (id, initiate_tx_hash, redeem_tx_hash, refund_tx_hash) VALUES (1, 'tx1', 'tx2', 'tx3')",
736 table_name
737 ),
738 )
739 .await;
740
741 sleep(Duration::from_millis(100)).await;
742
743 execute_query(
744 &pool,
745 &format!(
746 "UPDATE {} SET redeem_tx_hash = 'updated_tx2' WHERE id = 1",
747 table_name
748 ),
749 )
750 .await;
751
752 sleep(Duration::from_millis(100)).await;
753
754 execute_query(&pool, &format!("DELETE FROM {} WHERE id = 1", table_name)).await;
755
756 sleep(Duration::from_secs(2)).await; let received_events = notification_task.await.unwrap();
759
760 assert_eq!(
761 received_events.len(),
762 3,
763 "Expected 3 events but received {}",
764 received_events.len()
765 );
766
767 postgres_table_listener.stop().await.unwrap();
768 handle.abort();
769 }
770}