1use crate::error::{DbError, Result};
45use serde::{Deserialize, Serialize};
46use sqlx::{
47 postgres::{PgListener, PgNotification},
48 PgPool,
49};
50use std::collections::HashMap;
51use std::sync::Arc;
52use tokio::sync::{Mutex, RwLock};
53use tracing::{debug, error, info, warn};
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct DatabaseEvent {
58 pub channel: String,
60 pub event_type: String,
62 pub payload: serde_json::Value,
64 pub correlation_id: Option<String>,
66 pub timestamp: String,
68}
69
70impl DatabaseEvent {
71 pub fn new(
73 channel: impl Into<String>,
74 event_type: impl Into<String>,
75 payload: serde_json::Value,
76 ) -> Self {
77 Self {
78 channel: channel.into(),
79 event_type: event_type.into(),
80 payload,
81 correlation_id: None,
82 timestamp: chrono::Utc::now().to_rfc3339(),
83 }
84 }
85
86 pub fn with_correlation_id(mut self, correlation_id: impl Into<String>) -> Self {
88 self.correlation_id = Some(correlation_id.into());
89 self
90 }
91
92 pub fn from_notification(notification: &PgNotification) -> Result<Self> {
94 let payload_str = notification.payload();
95 serde_json::from_str(payload_str)
96 .map_err(|e| DbError::Validation(format!("Failed to parse event payload: {}", e)))
97 }
98
99 pub fn to_json_string(&self) -> Result<String> {
101 serde_json::to_string(self)
102 .map_err(|e| DbError::Validation(format!("Failed to serialize event: {}", e)))
103 }
104}
105
106pub type EventHandlerFn = Arc<dyn Fn(DatabaseEvent) + Send + Sync>;
108
109pub struct EventListener {
111 pool: PgPool,
113 listener: Arc<Mutex<PgListener>>,
115 handlers: Arc<RwLock<HashMap<String, Vec<EventHandlerFn>>>>,
117 subscriptions: Arc<RwLock<Vec<String>>>,
119 config: ListenerConfig,
121}
122
123#[derive(Debug, Clone)]
125pub struct ListenerConfig {
126 pub max_reconnect_attempts: usize,
128 pub reconnect_delay_ms: u64,
130 pub event_buffer_size: usize,
132}
133
134impl Default for ListenerConfig {
135 fn default() -> Self {
136 Self {
137 max_reconnect_attempts: 0, reconnect_delay_ms: 1000, event_buffer_size: 1000,
140 }
141 }
142}
143
144impl EventListener {
145 pub async fn new(pool: PgPool) -> Result<Self> {
147 Self::new_with_config(pool, ListenerConfig::default()).await
148 }
149
150 pub async fn new_with_config(pool: PgPool, config: ListenerConfig) -> Result<Self> {
152 let listener = PgListener::connect_with(&pool).await?;
153
154 info!("Created PostgreSQL event listener");
155
156 Ok(Self {
157 pool,
158 listener: Arc::new(Mutex::new(listener)),
159 handlers: Arc::new(RwLock::new(HashMap::new())),
160 subscriptions: Arc::new(RwLock::new(Vec::new())),
161 config,
162 })
163 }
164
165 pub async fn subscribe(&self, channel: &str) -> Result<()> {
167 let mut listener = self.listener.lock().await;
168 listener.listen(channel).await?;
169
170 let mut subs = self.subscriptions.write().await;
171 if !subs.contains(&channel.to_string()) {
172 subs.push(channel.to_string());
173 }
174
175 info!("Subscribed to channel: {}", channel);
176 Ok(())
177 }
178
179 pub async fn unsubscribe(&self, channel: &str) -> Result<()> {
181 let mut listener = self.listener.lock().await;
182 listener.unlisten(channel).await?;
183
184 let mut subs = self.subscriptions.write().await;
185 subs.retain(|c| c != channel);
186
187 info!("Unsubscribed from channel: {}", channel);
188 Ok(())
189 }
190
191 pub async fn on_event<F>(&self, channel: &str, handler: F)
193 where
194 F: Fn(DatabaseEvent) + Send + Sync + 'static,
195 {
196 let mut handlers = self.handlers.write().await;
197 handlers
198 .entry(channel.to_string())
199 .or_insert_with(Vec::new)
200 .push(Arc::new(handler));
201
202 debug!("Registered event handler for channel: {}", channel);
203 }
204
205 pub async fn listen(&self) -> Result<()> {
207 info!("Started listening for database events");
208
209 loop {
210 let notification = {
211 let mut listener = self.listener.lock().await;
212 match listener.try_recv().await {
213 Ok(Some(notif)) => notif,
214 Ok(None) => {
215 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
217 continue;
218 }
219 Err(e) => {
220 error!("Error receiving notification: {}", e);
221 if let Err(e) = self.reconnect().await {
223 error!("Failed to reconnect: {}", e);
224 }
225 continue;
226 }
227 }
228 };
229
230 self.handle_notification(notification).await;
231 }
232 }
233
234 async fn handle_notification(&self, notification: PgNotification) {
236 let channel = notification.channel();
237
238 let event = match DatabaseEvent::from_notification(¬ification) {
240 Ok(event) => event,
241 Err(e) => {
242 warn!("Failed to parse event from channel {}: {}", channel, e);
243 return;
244 }
245 };
246
247 debug!(
248 "Received event: channel={}, type={}",
249 event.channel, event.event_type
250 );
251
252 let handlers = self.handlers.read().await;
254 if let Some(channel_handlers) = handlers.get(channel) {
255 for handler in channel_handlers {
256 handler(event.clone());
257 }
258 }
259 }
260
261 async fn reconnect(&self) -> Result<()> {
263 warn!("Attempting to reconnect to PostgreSQL...");
264
265 let mut attempts = 0;
266 loop {
267 if self.config.max_reconnect_attempts > 0
268 && attempts >= self.config.max_reconnect_attempts
269 {
270 return Err(DbError::Connection(format!(
271 "Failed to reconnect after {} attempts",
272 attempts
273 )));
274 }
275
276 if attempts > 0 {
278 tokio::time::sleep(tokio::time::Duration::from_millis(
279 self.config.reconnect_delay_ms,
280 ))
281 .await;
282 }
283
284 match PgListener::connect_with(&self.pool).await {
286 Ok(new_listener) => {
287 let mut listener = self.listener.lock().await;
288 *listener = new_listener;
289
290 let subscriptions = self.subscriptions.read().await;
292 for channel in subscriptions.iter() {
293 if let Err(e) = listener.listen(channel).await {
294 error!("Failed to re-subscribe to {}: {}", channel, e);
295 }
296 }
297
298 info!("Successfully reconnected and re-subscribed");
299 return Ok(());
300 }
301 Err(e) => {
302 error!("Reconnection attempt {} failed: {}", attempts + 1, e);
303 attempts += 1;
304 }
305 }
306 }
307 }
308
309 pub async fn notify(&self, event: &DatabaseEvent) -> Result<()> {
311 let payload = event.to_json_string()?;
312
313 sqlx::query(&format!("NOTIFY {}, $1", event.channel))
314 .bind(&payload)
315 .execute(&self.pool)
316 .await?;
317
318 debug!("Published event to channel: {}", event.channel);
319 Ok(())
320 }
321
322 pub async fn get_subscriptions(&self) -> Vec<String> {
324 self.subscriptions.read().await.clone()
325 }
326
327 pub async fn get_handler_counts(&self) -> HashMap<String, usize> {
329 let handlers = self.handlers.read().await;
330 handlers
331 .iter()
332 .map(|(channel, handlers)| (channel.clone(), handlers.len()))
333 .collect()
334 }
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct NotificationStats {
340 pub notifications_received: u64,
342 pub notifications_sent: u64,
344 pub active_subscriptions: usize,
346 pub handlers_per_channel: HashMap<String, usize>,
348}
349
350pub trait NotificationTriggers {
352 fn create_insert_trigger_sql(table: &str, channel: &str) -> String;
354
355 fn create_update_trigger_sql(table: &str, channel: &str) -> String;
357
358 fn create_delete_trigger_sql(table: &str, channel: &str) -> String;
360}
361
362pub struct PostgresNotificationTriggers;
363
364impl NotificationTriggers for PostgresNotificationTriggers {
365 fn create_insert_trigger_sql(table: &str, channel: &str) -> String {
366 format!(
367 r#"
368CREATE OR REPLACE FUNCTION notify_{table}_insert()
369RETURNS TRIGGER AS $$
370BEGIN
371 PERFORM pg_notify(
372 '{channel}',
373 json_build_object(
374 'channel', '{channel}',
375 'event_type', '{table}.created',
376 'payload', row_to_json(NEW),
377 'timestamp', to_char(NOW(), 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
378 )::text
379 );
380 RETURN NEW;
381END;
382$$ LANGUAGE plpgsql;
383
384CREATE TRIGGER {table}_insert_notify
385AFTER INSERT ON {table}
386FOR EACH ROW
387EXECUTE FUNCTION notify_{table}_insert();
388 "#,
389 table = table,
390 channel = channel
391 )
392 }
393
394 fn create_update_trigger_sql(table: &str, channel: &str) -> String {
395 format!(
396 r#"
397CREATE OR REPLACE FUNCTION notify_{table}_update()
398RETURNS TRIGGER AS $$
399BEGIN
400 PERFORM pg_notify(
401 '{channel}',
402 json_build_object(
403 'channel', '{channel}',
404 'event_type', '{table}.updated',
405 'payload', json_build_object('old', row_to_json(OLD), 'new', row_to_json(NEW)),
406 'timestamp', to_char(NOW(), 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
407 )::text
408 );
409 RETURN NEW;
410END;
411$$ LANGUAGE plpgsql;
412
413CREATE TRIGGER {table}_update_notify
414AFTER UPDATE ON {table}
415FOR EACH ROW
416EXECUTE FUNCTION notify_{table}_update();
417 "#,
418 table = table,
419 channel = channel
420 )
421 }
422
423 fn create_delete_trigger_sql(table: &str, channel: &str) -> String {
424 format!(
425 r#"
426CREATE OR REPLACE FUNCTION notify_{table}_delete()
427RETURNS TRIGGER AS $$
428BEGIN
429 PERFORM pg_notify(
430 '{channel}',
431 json_build_object(
432 'channel', '{channel}',
433 'event_type', '{table}.deleted',
434 'payload', row_to_json(OLD),
435 'timestamp', to_char(NOW(), 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
436 )::text
437 );
438 RETURN OLD;
439END;
440$$ LANGUAGE plpgsql;
441
442CREATE TRIGGER {table}_delete_notify
443AFTER DELETE ON {table}
444FOR EACH ROW
445EXECUTE FUNCTION notify_{table}_delete();
446 "#,
447 table = table,
448 channel = channel
449 )
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test]
458 fn test_database_event_creation() {
459 let payload = serde_json::json!({"user_id": "123", "action": "login"});
460 let event = DatabaseEvent::new("user_events", "user.login", payload.clone());
461
462 assert_eq!(event.channel, "user_events");
463 assert_eq!(event.event_type, "user.login");
464 assert_eq!(event.payload, payload);
465 assert!(event.correlation_id.is_none());
466 assert!(!event.timestamp.is_empty());
467 }
468
469 #[test]
470 fn test_database_event_with_correlation_id() {
471 let payload = serde_json::json!({"order_id": "456"});
472 let event = DatabaseEvent::new("order_events", "order.created", payload)
473 .with_correlation_id("trace-123-456");
474
475 assert_eq!(event.correlation_id, Some("trace-123-456".to_string()));
476 }
477
478 #[test]
479 fn test_database_event_serialization() {
480 let payload = serde_json::json!({"token_id": "789", "amount": 1000});
481 let event = DatabaseEvent::new("token_events", "token.transfer", payload);
482
483 let json_str = event.to_json_string().unwrap();
484 assert!(json_str.contains("token_events"));
485 assert!(json_str.contains("token.transfer"));
486 assert!(json_str.contains("token_id"));
487 }
488
489 #[test]
490 fn test_database_event_deserialization() {
491 let json_str = r#"{
492 "channel": "trade_events",
493 "event_type": "trade.executed",
494 "payload": {"trade_id": "999", "price": 50000},
495 "correlation_id": null,
496 "timestamp": "2024-01-01T12:00:00Z"
497 }"#;
498
499 let event: DatabaseEvent = serde_json::from_str(json_str).unwrap();
500 assert_eq!(event.channel, "trade_events");
501 assert_eq!(event.event_type, "trade.executed");
502 assert_eq!(event.payload["trade_id"], "999");
503 }
504
505 #[test]
506 fn test_listener_config_default() {
507 let config = ListenerConfig::default();
508 assert_eq!(config.max_reconnect_attempts, 0); assert_eq!(config.reconnect_delay_ms, 1000);
510 assert_eq!(config.event_buffer_size, 1000);
511 }
512
513 #[test]
514 fn test_listener_config_custom() {
515 let config = ListenerConfig {
516 max_reconnect_attempts: 5,
517 reconnect_delay_ms: 2000,
518 event_buffer_size: 500,
519 };
520
521 assert_eq!(config.max_reconnect_attempts, 5);
522 assert_eq!(config.reconnect_delay_ms, 2000);
523 assert_eq!(config.event_buffer_size, 500);
524 }
525
526 #[test]
527 fn test_create_insert_trigger_sql() {
528 let sql = PostgresNotificationTriggers::create_insert_trigger_sql("users", "user_events");
529
530 assert!(sql.contains("CREATE OR REPLACE FUNCTION notify_users_insert()"));
531 assert!(sql.contains("pg_notify"));
532 assert!(sql.contains("user_events"));
533 assert!(sql.contains("users.created"));
534 assert!(sql.contains("CREATE TRIGGER users_insert_notify"));
535 }
536
537 #[test]
538 fn test_create_update_trigger_sql() {
539 let sql = PostgresNotificationTriggers::create_update_trigger_sql("tokens", "token_events");
540
541 assert!(sql.contains("CREATE OR REPLACE FUNCTION notify_tokens_update()"));
542 assert!(sql.contains("pg_notify"));
543 assert!(sql.contains("token_events"));
544 assert!(sql.contains("tokens.updated"));
545 assert!(sql.contains("CREATE TRIGGER tokens_update_notify"));
546 }
547
548 #[test]
549 fn test_create_delete_trigger_sql() {
550 let sql = PostgresNotificationTriggers::create_delete_trigger_sql("orders", "order_events");
551
552 assert!(sql.contains("CREATE OR REPLACE FUNCTION notify_orders_delete()"));
553 assert!(sql.contains("pg_notify"));
554 assert!(sql.contains("order_events"));
555 assert!(sql.contains("orders.deleted"));
556 assert!(sql.contains("CREATE TRIGGER orders_delete_notify"));
557 }
558
559 #[test]
560 fn test_notification_stats_serialization() {
561 let stats = NotificationStats {
562 notifications_received: 1000,
563 notifications_sent: 500,
564 active_subscriptions: 5,
565 handlers_per_channel: [("users".to_string(), 3), ("tokens".to_string(), 2)]
566 .iter()
567 .cloned()
568 .collect(),
569 };
570
571 let json = serde_json::to_string(&stats).unwrap();
572 assert!(json.contains("notifications_received"));
573 assert!(json.contains("1000"));
574 assert!(json.contains("handlers_per_channel"));
575 }
576
577 #[test]
578 fn test_event_payload_types() {
579 let string_payload = serde_json::json!("simple string");
581 let event1 = DatabaseEvent::new("test", "test.string", string_payload);
582 assert!(event1.to_json_string().is_ok());
583
584 let object_payload = serde_json::json!({"key": "value", "number": 42});
585 let event2 = DatabaseEvent::new("test", "test.object", object_payload);
586 assert!(event2.to_json_string().is_ok());
587
588 let array_payload = serde_json::json!([1, 2, 3, 4, 5]);
589 let event3 = DatabaseEvent::new("test", "test.array", array_payload);
590 assert!(event3.to_json_string().is_ok());
591 }
592
593 #[test]
594 fn test_event_timestamp_format() {
595 let event = DatabaseEvent::new("test", "test.event", serde_json::json!({"test": true}));
596
597 assert!(chrono::DateTime::parse_from_rfc3339(&event.timestamp).is_ok());
599 }
600}