db_library/database/
postgres.rs

1//! PostgreSQL Table Listener
2//!
3//! This module provides `PostgresTableListener`, which listens for changes in specified database tables.
4//! It uses PostgreSQL triggers and `LISTEN/NOTIFY` to capture and relay events asynchronously.
5//!
6//! # Features
7//! - Monitors specified tables for `INSERT`, `UPDATE`, and `DELETE` operations.
8//! - Captures changes at the column level.
9//! - Uses PostgreSQL triggers and functions for efficient event detection.
10//! - Sends notifications through async channels.
11//! - Provides a structured API to start and stop listeners.
12//!
13//! # Dependencies
14//! - `sqlx` for database interaction
15//! - `tokio` for async execution
16//! - `serde` and `serde_json` for JSON serialization
17//! - `tracing` for logging
18
19use async_trait::async_trait;
20use once_cell::sync::Lazy;
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23use sqlx::{
24    postgres::{PgListener, PgNotification, PgPoolOptions},
25    Executor, PgPool,
26};
27use std::{collections::HashMap, sync::Arc, time::Duration};
28use tokio::{
29    sync::{
30        mpsc::{self, Receiver, Sender},
31        Mutex, RwLock,
32    },
33    task::JoinHandle,
34};
35use tracing::{error, info};
36
37use crate::config::DBListenerError;
38
39use super::{DBListenerTrait, EventType};
40
41static PG_POOL_REGISTRY: Lazy<RwLock<HashMap<String, Arc<PgPool>>>> =
42    Lazy::new(|| RwLock::new(HashMap::new()));
43
44async fn get_or_create_pool(db_url: &str) -> Result<Arc<PgPool>, DBListenerError> {
45    // acquiring read lock first to check the existence of pool.
46    {
47        let pools = PG_POOL_REGISTRY.read().await;
48
49        if let Some(pool) = pools.get(db_url) {
50            return Ok(Arc::clone(pool));
51        }
52    }
53
54    let mut pools = PG_POOL_REGISTRY.write().await;
55
56    if let Some(pool) = pools.get(db_url) {
57        Ok(Arc::clone(pool))
58    } else {
59        let new_pool = PgPoolOptions::new()
60            .max_connections(10)
61            .acquire_timeout(Duration::from_secs(2))
62            .connect(db_url)
63            .await;
64
65        if let Err(e) = new_pool {
66            error!("Failed to connect to the database: {:?}", e);
67            return Err(DBListenerError::CreationError(format!(
68                "Failed to connect to the database url : {:#?}",
69                e
70            )));
71        }
72        let new_pool = Arc::new(new_pool.unwrap());
73        pools.insert(db_url.to_string(), Arc::clone(&new_pool));
74        Ok(new_pool)
75    }
76}
77
78#[derive(Debug, Clone)]
79pub struct PostgresTableListener {
80    pub pool: Arc<PgPool>,
81    pub table_name: String,
82    pub columns: Vec<String>,
83    pub sender: Sender<Value>,
84    pub receiver: Arc<Mutex<tokio::sync::mpsc::Receiver<Value>>>,
85    pub table_identifier: String,
86    pub pg_trigger_name: String,
87    pub pg_function_name: String,
88    pub pg_column_updates_name: String,
89    pub events: Vec<EventType>,
90}
91
92impl PostgresTableListener {
93    pub async fn new(
94        url: &str,
95        table_name: &str,
96        columns: Vec<String>,
97        table_identifier: &str,
98        events: Vec<EventType>,
99    ) -> Result<Self, DBListenerError> {
100        let pool = get_or_create_pool(url).await?;
101
102        let uniquekey_uuid = uuid::Uuid::new_v4().to_string();
103        let uniquekey = uniquekey_uuid.replace("-", "_");
104
105        let pg_trigger_name = format!("{}_{}_trigger", table_name, &uniquekey);
106        let pg_function_name = format!("{}_{}_function", table_name, &uniquekey);
107        let pg_column_updates_name = format!("{}_{}_column_updates", table_name, &uniquekey);
108
109        let (sender, receiver) = mpsc::channel::<Value>(100); // Channel with buffer size 100
110
111        let postgres_table_listener = Self {
112            pool,
113            table_name: table_name.to_string(),
114            columns: columns.into_iter().map(|c| c.to_string()).collect(),
115            table_identifier: table_identifier.to_string(),
116            sender,
117            receiver: Arc::new(Mutex::new(receiver)),
118            pg_trigger_name,
119            pg_function_name,
120            pg_column_updates_name,
121            events,
122        };
123
124        postgres_table_listener.verify_members().await?;
125
126        Ok(postgres_table_listener)
127    }
128
129    async fn verify_members(&self) -> Result<(), DBListenerError> {
130        info!("--> Verifying table and columns");
131        // Verify that the table exists
132        let table_exists = sqlx::query_as::<_, (bool,)>(
133            r#"
134            SELECT EXISTS (
135                SELECT 1
136                FROM information_schema.tables
137                WHERE table_name = $1
138            );
139            "#,
140        )
141        .bind(&self.table_name)
142        .fetch_one(&*self.pool)
143        .await
144        .map_err(|e| {
145            DBListenerError::ListenerVerifyError(format!(
146                "Failed to verify table existence : {:?}",
147                e
148            ))
149        })?;
150
151        info!("table exists : {:#?}", table_exists);
152
153        if !table_exists.0 {
154            return Err(DBListenerError::ListenerVerifyError(format!(
155                "Table '{}' does not exist",
156                &self.table_name
157            )));
158        }
159
160        // verify all columns exist
161        for column in &self.columns {
162            let column_exists = sqlx::query_as::<_, (bool,)>(
163                r#"
164                SELECT EXISTS (
165                    SELECT 1
166                    FROM information_schema.columns
167                    WHERE table_name = $1
168                    AND column_name = $2
169                );
170                "#,
171            )
172            .bind(&self.table_name)
173            .bind(column)
174            .fetch_one(&*self.pool)
175            .await
176            .map_err(|e| {
177                DBListenerError::ListenerVerifyError(format!(
178                    "Failed to verify column existence '{}': {:?}",
179                    column, e
180                ))
181            })?;
182
183            if !column_exists.0 {
184                return Err(DBListenerError::ListenerVerifyError(format!(
185                    "Column '{}' does not exist",
186                    column
187                )));
188            }
189        }
190
191        // Verify that the table identifier exists
192        let table_identifier_exists = sqlx::query_as::<_, (bool,)>(
193            r#"
194            SELECT EXISTS (
195                SELECT 1
196                FROM information_schema.columns
197                WHERE table_name = $1
198                AND column_name = $2
199            );
200            "#,
201        )
202        .bind(&self.table_name)
203        .bind(&self.table_identifier)
204        .fetch_one(&*self.pool)
205        .await
206        .map_err(|e| {
207            DBListenerError::ListenerVerifyError(format!(
208                "Failed to verify table identifier '{}': {:?}",
209                &self.table_identifier, e
210            ))
211        })?;
212
213        if !table_identifier_exists.0 {
214            return Err(DBListenerError::ListenerVerifyError(format!(
215                "Table identifier '{}' does not exist",
216                &self.table_identifier
217            )));
218        }
219
220        info!("✅ Table and columns verified successfully");
221        Ok(())
222    }
223
224    async fn create_trigger_and_function(&self) -> Result<(), DBListenerError> {
225        let table_identifier = &self.table_identifier;
226
227        // Create the function with clean INSERT handling
228        let create_function = format!(
229            r#"
230                CREATE OR REPLACE FUNCTION {function_name}()
231                RETURNS TRIGGER AS $$
232                BEGIN
233                    IF TG_OP = 'INSERT' THEN 
234                        {insert_blocks}
235                    ELSIF TG_OP = 'UPDATE' THEN 
236                        {update_blocks}
237                    ELSIF TG_OP = 'DELETE' THEN
238                        {delete_blocks}
239                    END IF;
240                    RETURN NEW;
241                END;
242                $$ LANGUAGE plpgsql;
243            "#,
244            function_name = self.pg_function_name,
245            insert_blocks = self
246                .columns
247                .iter()
248                .map(|col| {
249                    format!(
250                        r#"
251                        IF NEW.{col} IS NOT NULL THEN
252                            PERFORM pg_notify(
253                                '{column_updates}',
254                                json_build_object(
255                                    'operation', TG_OP,
256                                    'table', TG_TABLE_NAME,
257                                    'column', '{col}',
258                                    'id', NEW.{table_identifier},
259                                    'new_value', NEW.{col},
260                                    'timestamp', NOW(),
261                                    'new_row_data', row_to_json(NEW)
262                                )::text
263                            );
264                        END IF;
265                    "#,
266                        column_updates = self.pg_column_updates_name
267                    )
268                })
269                .collect::<Vec<_>>()
270                .join("\n"),
271            update_blocks = self
272                .columns
273                .iter()
274                .map(|col| {
275                    format!(
276                        r#"
277                        IF NEW.{col} IS DISTINCT FROM OLD.{col} THEN
278                            PERFORM pg_notify(
279                                '{column_updates}',
280                                json_build_object(
281                                    'operation', TG_OP,
282                                    'table', TG_TABLE_NAME,
283                                    'column', '{col}',
284                                    'id', NEW.{table_identifier},
285                                    'old_value', OLD.{col},
286                                    'new_value', NEW.{col},
287                                    'timestamp', NOW(),
288                                    'old_row_data', row_to_json(OLD),
289                                    'new_row_data', row_to_json(NEW)
290                                )::text
291                            );
292                        END IF;
293                    "#,
294                        column_updates = self.pg_column_updates_name,
295                    )
296                })
297                .collect::<Vec<_>>()
298                .join("\n"),
299            delete_blocks = self
300                .columns
301                .iter()
302                .map(|col| {
303                    format!(
304                        r#"
305                        IF OLD.{col} IS NOT NULL THEN
306                            PERFORM pg_notify(
307                                '{column_updates}',
308                                json_build_object(
309                                    'operation', TG_OP,
310                                    'table', TG_TABLE_NAME,
311                                    'column', '{col}',
312                                    'id', OLD.{table_identifier},
313                                    'old_value', OLD.{col},
314                                    'timestamp', NOW(),
315                                    'old_row_data', row_to_json(OLD)
316                                )::text
317                            );
318                        END IF;
319                    "#,
320                        column_updates = self.pg_column_updates_name,
321                    )
322                })
323                .collect::<Vec<_>>()
324                .join("\n"),
325        );
326
327        // Execute function creation
328        self.pool
329            .execute(create_function.as_str())
330            .await
331            .map_err(|e| {
332                DBListenerError::CreationError(format!(
333                    "Failed to execute function creation : {:#?}",
334                    e
335                ))
336            })?;
337
338        let events_list = self.get_events_list();
339
340        let create_trigger = format!(
341            r#"
342                DO $$ 
343                BEGIN 
344                    IF NOT EXISTS (
345                        SELECT 1 
346                        FROM pg_trigger 
347                        WHERE tgname = '{trigger_name}'
348                    ) THEN 
349                        CREATE TRIGGER {trigger_name}
350                        AFTER {events_list} ON {table_name} 
351                        FOR EACH ROW
352                        EXECUTE FUNCTION {function_name}();
353                    END IF;
354                END $$;
355            "#,
356            trigger_name = self.pg_trigger_name,
357            events_list = events_list,
358            table_name = self.table_name,
359            function_name = self.pg_function_name
360        );
361
362        // Execute trigger creation
363        self.pool
364            .execute(create_trigger.as_str())
365            .await
366            .map_err(|e| {
367                DBListenerError::CreationError(format!(
368                    "Failed to execute trigger creation : {:#?}",
369                    e
370                ))
371            })?;
372
373        Ok(())
374    }
375
376    fn get_events_list(&self) -> String {
377        self.events
378            .iter()
379            .map(|event| match event {
380                EventType::INSERT => "INSERT".to_string(),
381                EventType::UPDATE => format!("UPDATE OF {}", self.columns.join(", ")),
382                EventType::DELETE => "DELETE".to_string(),
383            })
384            .collect::<Vec<_>>()
385            .join(" OR ")
386    }
387
388    async fn drop_trigger_and_function(&self) -> Result<(), DBListenerError> {
389        let trigger_name = &self.pg_trigger_name;
390        let function_name = &self.pg_function_name;
391        let table_name = &self.table_name;
392
393        // Drop the trigger if it exists
394        let drop_trigger = format!(
395            r#"
396            DO $$
397            BEGIN
398                IF EXISTS (
399                    SELECT 1
400                    FROM pg_trigger
401                    WHERE tgname = '{trigger_name}'
402                ) THEN
403                    DROP TRIGGER {trigger_name} ON {table_name};
404                END IF;
405            END $$;
406            "#,
407            trigger_name = trigger_name,
408            table_name = table_name
409        );
410
411        self.pool
412            .execute(drop_trigger.as_str())
413            .await
414            .map_err(|e| {
415                DBListenerError::DeletionError(format!(
416                    "Failed to execute trigger deletion : {:#?}",
417                    e
418                ))
419            })?;
420
421        // Drop the function if it exists
422        let drop_function = format!(
423            r#"
424            DO $$
425            BEGIN
426                IF EXISTS (
427                    SELECT 1
428                    FROM pg_proc
429                    WHERE proname = '{function_name}'
430                ) THEN
431                    DROP FUNCTION {function_name}();
432                END IF;
433            END $$;
434            "#,
435            function_name = function_name
436        );
437
438        self.pool
439            .execute(drop_function.as_str())
440            .await
441            .map_err(|e| {
442                DBListenerError::DeletionError(format!(
443                    "Failed to execute trigger deletion : {:#?}",
444                    e
445                ))
446            })?;
447
448        info!("Trigger and function removed for table: {}", table_name);
449        Ok(())
450    }
451
452    async fn initialize_listener(&self) -> Result<Arc<Mutex<PgListener>>, DBListenerError> {
453        // Create trigger and function if they don't exist
454        self.create_trigger_and_function()
455            .await
456            .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
457
458        // Initialize and configure the listener
459        let listener = PgListener::connect_with(&self.pool)
460            .await
461            .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
462
463        let listener = Arc::new(Mutex::new(listener));
464
465        // Start listening for notifications
466        {
467            let mut locked_listener = listener.lock().await;
468            locked_listener
469                .listen(&self.pg_column_updates_name)
470                .await
471                .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
472        }
473
474        info!(
475            "Listening for column update notifications on {}",
476            self.table_name
477        );
478
479        Ok(listener)
480    }
481
482    fn spawn_listener_task(&self, listener: Arc<Mutex<PgListener>>) -> JoinHandle<()> {
483        let sender_clone = self.sender.clone();
484        let table_name = self.table_name.clone();
485
486        tokio::spawn(async move {
487            info!("Listener spawned and waiting for notifications");
488
489            loop {
490                let mut locked_listener = listener.lock().await;
491                match locked_listener.recv().await {
492                    Ok(notification) => {
493                        if let Some(pg_notify) = process_notification(&notification, &table_name) {
494                            // Convert PgNotify to Value and send
495                            if let Ok(json_data) = serde_json::to_value(pg_notify) {
496                                if let Err(e) = sender_clone.send(json_data).await {
497                                    error!("Failed to send payload to channel: {:?}", e);
498                                }
499                            } else {
500                                error!("Failed to serialize PgNotify to JSON");
501                            }
502                        }
503                    }
504                    Err(e) => {
505                        error!("Listener encountered an error: {:?}", e);
506                        break;
507                    }
508                }
509            }
510        })
511    }
512}
513
514#[async_trait]
515impl DBListenerTrait for PostgresTableListener {
516    async fn start(
517        &self,
518    ) -> Result<(Arc<Mutex<Receiver<Value>>>, JoinHandle<()>), DBListenerError> {
519        let listener = self.initialize_listener().await?;
520
521        let handle = self.spawn_listener_task(listener);
522
523        Ok((Arc::clone(&self.receiver), handle))
524    }
525
526    async fn stop(&self) -> Result<(), DBListenerError> {
527        self.drop_trigger_and_function()
528            .await
529            .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
530        Ok(())
531    }
532}
533
534fn process_notification(notification: &PgNotification, table_name: &str) -> Option<PgNotify> {
535    match serde_json::from_str::<Value>(&notification.payload()) {
536        Ok(payload) => Some(PgNotify {
537            operation: payload
538                .get("operation")
539                .and_then(|v| v.as_str().map(String::from))
540                .unwrap_or_default(),
541            table: table_name.to_string(),
542            column: payload
543                .get("column")
544                .and_then(|v| v.as_str().map(String::from))
545                .unwrap_or_default(),
546            id: payload.get("id").map(|v| v.to_string()).unwrap_or_default(),
547            new_row_data: payload
548                .get("new_row_data")
549                .cloned()
550                .unwrap_or_else(|| Value::Null),
551            old_row_data: payload
552                .get("old_row_data")
553                .cloned()
554                .unwrap_or_else(|| Value::Null),
555            timestamp: chrono::Utc::now().to_rfc3339(),
556        }),
557        Err(e) => {
558            error!("Failed to parse notification payload: {:?}", e);
559            None
560        }
561    }
562}
563
564#[derive(Debug, Deserialize, Serialize)]
565pub struct PgNotify {
566    pub operation: String,
567    pub table: String,
568    pub id: String,
569    pub column: String,
570    pub new_row_data: Value,
571    pub old_row_data: Value,
572    pub timestamp: String,
573}
574
575#[cfg(test)]
576mod tests {
577    use std::{env, sync::Arc};
578    use tokio::time::{sleep, Duration};
579
580    use dotenv::dotenv;
581    use sqlx::Executor;
582
583    use crate::{
584        database::{
585            postgres::{get_or_create_pool, PostgresTableListener},
586            DBListenerTrait,
587        },
588        EventType,
589    };
590
591    #[tokio::test]
592    async fn create_new_listener_with_props() {
593        dotenv().ok();
594        let database_url =
595            env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
596
597        let table_name = "swaps".to_string();
598        let columns = vec![
599            "initiate_tx_hash".to_string(),
600            "redeem_tx_hash".to_string(),
601            "refund_tx_hash".to_string(),
602        ];
603
604        let table_identifier = "swap_id".to_string();
605
606        let events = vec![EventType::UPDATE, EventType::INSERT, EventType::DELETE];
607
608        let result = PostgresTableListener::new(
609            &database_url,
610            &table_name,
611            columns,
612            &table_identifier,
613            events,
614        )
615        .await;
616
617        assert!(!result.is_err(), "Listener failed to connect");
618
619        sleep(Duration::from_secs(1)).await;
620    }
621
622    #[tokio::test]
623    async fn create_new_listener_with_invalid_props() {
624        dotenv().ok();
625        let database_url =
626            env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
627
628        let table_name = "atomic_swaps".to_string();
629        let columns = vec![
630            "initiate_tx_hash".to_string(),
631            "redeem_tx_hash".to_string(),
632            "refund_tx_hash".to_string(),
633        ];
634
635        let table_identifier = "swap_id".to_string();
636
637        let events = vec![EventType::UPDATE, EventType::INSERT, EventType::DELETE];
638
639        let result = PostgresTableListener::new(
640            &database_url,
641            &table_name,
642            columns,
643            &table_identifier,
644            events,
645        )
646        .await;
647
648        assert!(result.is_err(), "Listener failed to connect");
649
650        sleep(Duration::from_secs(1)).await;
651    }
652
653    #[tokio::test]
654    async fn get_same_pool_for_same_url() {
655        dotenv().ok();
656        let database_url =
657            env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
658
659        let pool1 = get_or_create_pool(&database_url).await.unwrap();
660
661        let pool2 = get_or_create_pool(&database_url).await.unwrap();
662
663        assert!(
664            Arc::ptr_eq(&pool1, &pool2),
665            "Expected the same pool instance, but got different ones"
666        );
667
668        sleep(Duration::from_secs(1)).await;
669    }
670
671    #[tokio::test]
672    // #[ignore = "this should be tested alone.. as it requires client connection for the same db url"]
673    async fn postgres_table_listener() {
674        sleep(Duration::from_secs(1)).await;
675
676        dotenv().ok();
677
678        // Get the database URL from the environment variables
679        let database_url =
680            env::var("POSTGRES_DATABASE_URL").expect("POSTGRES_DATABASE_URL must be set");
681
682        let table_name = "swaps".to_string();
683        let columns = vec![
684            "initiate_tx_hash".to_string(),
685            "redeem_tx_hash".to_string(),
686            "refund_tx_hash".to_string(),
687        ];
688
689        let table_identifier = "swap_id".to_string();
690
691        let events = vec![EventType::UPDATE, EventType::INSERT, EventType::DELETE];
692
693        let postgres_table_listener = PostgresTableListener::new(
694            &database_url,
695            &table_name,
696            columns.clone(),
697            &table_identifier,
698            events.clone(),
699        )
700        .await;
701
702        assert!(
703            postgres_table_listener.is_ok(),
704            "Failed to intialize postgres table listener"
705        );
706
707        let postgres_table_listener = postgres_table_listener.unwrap();
708
709        let (rx, handle) = postgres_table_listener.start().await.unwrap();
710
711        let notification_task = tokio::spawn(async move {
712            let mut received_events = Vec::new();
713            while let Some(payload) = rx.lock().await.recv().await {
714                println!("Notification received: {:#?}", payload);
715                received_events.push(payload);
716                if received_events.len() >= 3 {
717                    break; // Stop after receiving all expected events
718                }
719            }
720            received_events
721        });
722
723        let pool = sqlx::PgPool::connect(&database_url)
724            .await
725            .expect("Failed to connect to DB");
726
727        async fn execute_query(pool: &sqlx::PgPool, query: &str) {
728            pool.execute(query).await.expect("Query execution failed");
729        }
730
731        // Execute queries
732        execute_query(
733        &pool,
734        &format!(
735            "INSERT INTO {} (id, initiate_tx_hash, redeem_tx_hash, refund_tx_hash) VALUES (1, 'tx1', 'tx2', 'tx3')",
736            table_name
737            ),
738        )
739        .await;
740
741        sleep(Duration::from_millis(100)).await;
742
743        execute_query(
744            &pool,
745            &format!(
746                "UPDATE {} SET redeem_tx_hash = 'updated_tx2' WHERE id = 1",
747                table_name
748            ),
749        )
750        .await;
751
752        sleep(Duration::from_millis(100)).await;
753
754        execute_query(&pool, &format!("DELETE FROM {} WHERE id = 1", table_name)).await;
755
756        sleep(Duration::from_secs(2)).await; // Allow time for notifications to be received
757
758        let received_events = notification_task.await.unwrap();
759
760        assert_eq!(
761            received_events.len(),
762            3,
763            "Expected 3 events but received {}",
764            received_events.len()
765        );
766
767        postgres_table_listener.stop().await.unwrap();
768        handle.abort();
769    }
770}