1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use sqlx::{postgres::PgListener, Executor, PgPool};
5use std::sync::Arc;
6use tokio::{
7 sync::{
8 mpsc::{self, Receiver, Sender},
9 Mutex,
10 },
11 task::JoinHandle,
12};
13use tracing::{error, info};
14
15use crate::config::DBListenerError;
16use crate::utils::stringify;
17
18use super::{DBListenerTrait, EventType};
19
20#[derive(Debug, Clone)]
21pub struct PostgresTableListener {
22 pub pool: Arc<PgPool>,
23 pub table_name: String,
24 pub columns: Vec<String>,
25 pub sender: Sender<Value>,
26 pub receiver: Arc<Mutex<tokio::sync::mpsc::Receiver<Value>>>,
27 pub table_identifier: String,
28 pub pg_trigger_name: String,
29 pub pg_function_name: String,
30 pub pg_column_updates_name: String,
31 pub events: Vec<EventType>,
32}
33
34impl PostgresTableListener {
35 pub async fn new(
36 pool: Arc<PgPool>,
37 table_name: &str,
38 columns: Vec<String>,
39 table_identifier: &str,
40 events: Vec<EventType>,
41 ) -> Result<Self, DBListenerError> {
42 let uniquekey_uuid = uuid::Uuid::new_v4().to_string();
43 let uniquekey = uniquekey_uuid.replace("-", "_");
44
45 let pg_trigger_name = format!("{}_{}_trigger", table_name, &uniquekey);
46 let pg_function_name = format!("{}_{}_function", table_name, &uniquekey);
47 let pg_column_updates_name = format!("{}_{}_column_updates", table_name, &uniquekey);
48
49 let (sender, receiver) = mpsc::channel::<Value>(100); let postgres_table_listener = Self {
52 pool,
53 table_name: table_name.to_string(),
54 columns: columns.into_iter().map(|c| c.to_string()).collect(),
55 table_identifier: table_identifier.to_string(),
56 sender,
57 receiver: Arc::new(Mutex::new(receiver)),
58 pg_trigger_name,
59 pg_function_name,
60 pg_column_updates_name,
61 events,
62 };
63
64 postgres_table_listener.verify_members().await?;
65
66 Ok(postgres_table_listener)
67 }
68
69 async fn verify_members(&self) -> Result<(), DBListenerError> {
70 info!("--> Verifying table and columns");
71 let table_exists = sqlx::query_as::<_, (bool,)>(
73 r#"
74 SELECT EXISTS (
75 SELECT 1
76 FROM information_schema.tables
77 WHERE table_name = $1
78 );
79 "#,
80 )
81 .bind(&self.table_name)
82 .fetch_one(&*self.pool)
83 .await
84 .map_err(|e| {
85 DBListenerError::ListenerVerifyError(format!(
86 "Failed to verify table existence : {:?}",
87 e
88 ))
89 })?;
90
91 info!("table exists : {:#?}", table_exists);
92
93 if !table_exists.0 {
94 return Err(DBListenerError::ListenerVerifyError(format!(
95 "Table '{}' does not exist",
96 &self.table_name
97 )));
98 }
99
100 for column in &self.columns {
102 let column_exists = sqlx::query_as::<_, (bool,)>(
103 r#"
104 SELECT EXISTS (
105 SELECT 1
106 FROM information_schema.columns
107 WHERE table_name = $1
108 AND column_name = $2
109 );
110 "#,
111 )
112 .bind(&self.table_name)
113 .bind(column)
114 .fetch_one(&*self.pool)
115 .await
116 .map_err(|e| {
117 DBListenerError::ListenerVerifyError(format!(
118 "Failed to verify column existence '{}': {:?}",
119 column, e
120 ))
121 })?;
122
123 if !column_exists.0 {
124 return Err(DBListenerError::ListenerVerifyError(format!(
125 "Column '{}' does not exist",
126 column
127 )));
128 }
129 }
130
131 let table_identifier_exists = sqlx::query_as::<_, (bool,)>(
133 r#"
134 SELECT EXISTS (
135 SELECT 1
136 FROM information_schema.columns
137 WHERE table_name = $1
138 AND column_name = $2
139 );
140 "#,
141 )
142 .bind(&self.table_name)
143 .bind(&self.table_identifier)
144 .fetch_one(&*self.pool)
145 .await
146 .map_err(|e| {
147 DBListenerError::ListenerVerifyError(format!(
148 "Failed to verify table identifier '{}': {:?}",
149 &self.table_identifier, e
150 ))
151 })?;
152
153 if !table_identifier_exists.0 {
154 return Err(DBListenerError::ListenerVerifyError(format!(
155 "Table identifier '{}' does not exist",
156 &self.table_identifier
157 )));
158 }
159
160 info!("✅ Table and columns verified successfully");
161 Ok(())
162 }
163
164 async fn create_trigger_and_function(&self) -> Result<(), DBListenerError> {
165 let table_identifier = &self.table_identifier;
166
167 let create_function = format!(
169 r#"
170 CREATE OR REPLACE FUNCTION {function_name}()
171 RETURNS TRIGGER AS $$
172 BEGIN
173 IF TG_OP = 'INSERT' THEN
174 {insert_blocks}
175 ELSIF TG_OP = 'UPDATE' THEN
176 {update_blocks}
177 ELSIF TG_OP = 'DELETE' THEN
178 {delete_blocks}
179 END IF;
180 RETURN NEW;
181 END;
182 $$ LANGUAGE plpgsql;
183 "#,
184 function_name = self.pg_function_name,
185 insert_blocks = self
186 .columns
187 .iter()
188 .map(|col| {
189 format!(
190 r#"
191 IF NEW.{col} IS NOT NULL THEN
192 PERFORM pg_notify(
193 '{column_updates}',
194 json_build_object(
195 'operation', TG_OP,
196 'table', TG_TABLE_NAME,
197 'column', '{col}',
198 'id', NEW.{table_identifier},
199 'new_value', NEW.{col},
200 'timestamp', NOW()
201 )::text
202 );
203 END IF;
204 "#,
205 column_updates = self.pg_column_updates_name
206 )
207 })
208 .collect::<Vec<_>>()
209 .join("\n"),
210 update_blocks = self
211 .columns
212 .iter()
213 .map(|col| {
214 format!(
215 r#"
216 IF NEW.{col} IS DISTINCT FROM OLD.{col} THEN
217 PERFORM pg_notify(
218 '{column_updates}',
219 json_build_object(
220 'operation', TG_OP,
221 'table', TG_TABLE_NAME,
222 'column', '{col}',
223 'id', NEW.{table_identifier},
224 'old_value', OLD.{col},
225 'new_value', NEW.{col},
226 'timestamp', NOW()
227 )::text
228 );
229 END IF;
230 "#,
231 column_updates = self.pg_column_updates_name,
232 )
233 })
234 .collect::<Vec<_>>()
235 .join("\n"),
236 delete_blocks = self
237 .columns
238 .iter()
239 .map(|col| {
240 format!(
241 r#"
242 IF OLD.{col} IS NOT NULL THEN
243 PERFORM pg_notify(
244 '{column_updates}',
245 json_build_object(
246 'operation', TG_OP,
247 'table', TG_TABLE_NAME,
248 'column', '{col}',
249 'id', OLD.{table_identifier},
250 'old_value', OLD.{col},
251 'timestamp', NOW()
252 )::text
253 );
254 END IF;
255 "#,
256 column_updates = self.pg_column_updates_name,
257 )
258 })
259 .collect::<Vec<_>>()
260 .join("\n"),
261 );
262
263 self.pool
265 .execute(create_function.as_str())
266 .await
267 .map_err(|e| {
268 DBListenerError::CreationError(format!(
269 "Failed to execute function creation : {:#?}",
270 e
271 ))
272 })?;
273
274 let events_list = EventType::get_events_list(&self.events, &self.columns.join(", "));
275
276 let create_trigger = format!(
278 r#"
279 DO $$
280 BEGIN
281 IF NOT EXISTS (
282 SELECT 1
283 FROM pg_trigger
284 WHERE tgname = '{trigger_name}'
285 ) THEN
286 CREATE TRIGGER {trigger_name}
287 AFTER {events_list} ON {table_name}
288 FOR EACH ROW
289 EXECUTE FUNCTION {function_name}();
290 END IF;
291 END $$;
292 "#,
293 trigger_name = self.pg_trigger_name,
294 table_name = self.table_name,
295 function_name = self.pg_function_name
296 );
297
298 self.pool
300 .execute(create_trigger.as_str())
301 .await
302 .map_err(|e| {
303 DBListenerError::CreationError(format!(
304 "Failed to execute trigger creation : {:#?}",
305 e
306 ))
307 })?;
308
309 Ok(())
310 }
311
312 async fn drop_trigger_and_function(&self) -> Result<(), DBListenerError> {
313 let trigger_name = &self.pg_trigger_name;
314 let function_name = &self.pg_function_name;
315 let table_name = &self.table_name;
316
317 let drop_trigger = format!(
319 r#"
320 DO $$
321 BEGIN
322 IF EXISTS (
323 SELECT 1
324 FROM pg_trigger
325 WHERE tgname = '{trigger_name}'
326 ) THEN
327 DROP TRIGGER {trigger_name} ON {table_name};
328 END IF;
329 END $$;
330 "#,
331 trigger_name = trigger_name,
332 table_name = table_name
333 );
334
335 self.pool
336 .execute(drop_trigger.as_str())
337 .await
338 .map_err(|e| {
339 DBListenerError::DeletionError(format!(
340 "Failed to execute trigger deletion : {:#?}",
341 e
342 ))
343 })?;
344
345 let drop_function = format!(
347 r#"
348 DO $$
349 BEGIN
350 IF EXISTS (
351 SELECT 1
352 FROM pg_proc
353 WHERE proname = '{function_name}'
354 ) THEN
355 DROP FUNCTION {function_name}();
356 END IF;
357 END $$;
358 "#,
359 function_name = function_name
360 );
361
362 self.pool
363 .execute(drop_function.as_str())
364 .await
365 .map_err(|e| {
366 DBListenerError::DeletionError(format!(
367 "Failed to execute trigger deletion : {:#?}",
368 e
369 ))
370 })?;
371
372 info!("Trigger and function removed for table: {}", table_name);
373 Ok(())
374 }
375}
376
377#[async_trait]
378impl DBListenerTrait for PostgresTableListener {
379 async fn start(
380 &self,
381 ) -> Result<(Arc<Mutex<Receiver<Value>>>, JoinHandle<()>), DBListenerError> {
382 self.create_trigger_and_function()
384 .await
385 .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
386
387 let listener = PgListener::connect_with(&self.pool)
388 .await
389 .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
390
391 let listener = Arc::new(Mutex::new(listener));
392
393 {
394 let mut locked_listener = listener.lock().await;
395 locked_listener
396 .listen(&self.pg_column_updates_name)
397 .await
398 .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
399 }
400
401 info!(
402 "Listening for column update notifications on {}",
403 self.table_name
404 );
405
406 let listener_clone = Arc::clone(&listener);
407 let sender_clone = self.sender.clone();
408
409 let handle = tokio::spawn(async move {
410 info!("Listener spawned and waiting for notifications");
411
412 loop {
413 let mut locked_listener = listener_clone.lock().await;
414 match locked_listener.recv().await {
415 Ok(notification) => {
416 if let Ok(payload) = serde_json::from_str::<Value>(¬ification.payload())
417 {
418 if let Err(e) = sender_clone.send(payload).await {
419 error!("Failed to send payload to channel: {:?}", e);
420 }
421 }
422 }
423 Err(e) => {
424 error!("Listener encountered an error: {:?}", e);
425 break;
426 }
427 }
428 }
429 });
430
431 Ok((Arc::clone(&self.receiver), handle))
432 }
433
434 async fn stop(&self) -> Result<(), DBListenerError> {
435 self.drop_trigger_and_function()
436 .await
437 .map_err(|e| DBListenerError::ListenerError(e.to_string()))?;
438 Ok(())
439 }
440}
441
442#[derive(Debug, Deserialize, Serialize)]
443pub struct PgNotify {
444 pub operation: Option<String>,
445 pub table: String,
446
447 #[serde(deserialize_with = "stringify")]
448 pub id: String, pub old_value: Option<String>,
451 pub new_value: Option<String>,
452 pub timestamp: String,
453}