kaccy_db/
event_notifications.rs

1//! Database event notifications using PostgreSQL LISTEN/NOTIFY
2//!
3//! This module provides real-time event notification capabilities using PostgreSQL's
4//! LISTEN/NOTIFY mechanism. It enables event-driven architectures, cache invalidation,
5//! and cross-service communication through the database.
6//!
7//! # Features
8//!
9//! - Real-time event notifications from database triggers
10//! - Type-safe event payloads with JSON serialization
11//! - Automatic reconnection and subscription recovery
12//! - Multiple concurrent listeners with different channels
13//! - Handler registration for specific event types
14//! - Graceful shutdown and cleanup
15//!
16//! # Example
17//!
18//! ```rust,no_run
19//! use kaccy_db::event_notifications::{EventListener, DatabaseEvent};
20//! use sqlx::PgPool;
21//!
22//! #[tokio::main]
23//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
24//!     let pool = PgPool::connect("postgresql://localhost/kaccy").await?;
25//!
26//!     // Create event listener
27//!     let listener = EventListener::new(pool.clone()).await?;
28//!
29//!     // Subscribe to a channel
30//!     listener.subscribe("user_events").await?;
31//!
32//!     // Register event handler
33//!     listener.on_event("user_events", |event| {
34//!         println!("Received event: {:?}", event);
35//!     }).await;
36//!
37//!     // Listen for events (runs until stopped)
38//!     listener.listen().await?;
39//!
40//!     Ok(())
41//! }
42//! ```
43
44use crate::error::{DbError, Result};
45use serde::{Deserialize, Serialize};
46use sqlx::{
47    postgres::{PgListener, PgNotification},
48    PgPool,
49};
50use std::collections::HashMap;
51use std::sync::Arc;
52use tokio::sync::{Mutex, RwLock};
53use tracing::{debug, error, info, warn};
54
55/// Database event with typed payload
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct DatabaseEvent {
58    /// Event channel name
59    pub channel: String,
60    /// Event type (e.g., "user.created", "token.updated")
61    pub event_type: String,
62    /// Event payload as JSON
63    pub payload: serde_json::Value,
64    /// Optional correlation ID for distributed tracing
65    pub correlation_id: Option<String>,
66    /// Event timestamp (ISO 8601)
67    pub timestamp: String,
68}
69
70impl DatabaseEvent {
71    /// Create a new database event
72    pub fn new(
73        channel: impl Into<String>,
74        event_type: impl Into<String>,
75        payload: serde_json::Value,
76    ) -> Self {
77        Self {
78            channel: channel.into(),
79            event_type: event_type.into(),
80            payload,
81            correlation_id: None,
82            timestamp: chrono::Utc::now().to_rfc3339(),
83        }
84    }
85
86    /// Set correlation ID for distributed tracing
87    pub fn with_correlation_id(mut self, correlation_id: impl Into<String>) -> Self {
88        self.correlation_id = Some(correlation_id.into());
89        self
90    }
91
92    /// Parse event from PostgreSQL notification
93    pub fn from_notification(notification: &PgNotification) -> Result<Self> {
94        let payload_str = notification.payload();
95        serde_json::from_str(payload_str)
96            .map_err(|e| DbError::Validation(format!("Failed to parse event payload: {}", e)))
97    }
98
99    /// Serialize event to JSON string for NOTIFY
100    pub fn to_json_string(&self) -> Result<String> {
101        serde_json::to_string(self)
102            .map_err(|e| DbError::Validation(format!("Failed to serialize event: {}", e)))
103    }
104}
105
106/// Event handler function type
107pub type EventHandlerFn = Arc<dyn Fn(DatabaseEvent) + Send + Sync>;
108
109/// Event listener for PostgreSQL LISTEN/NOTIFY
110pub struct EventListener {
111    /// PostgreSQL connection pool for sending notifications
112    pool: PgPool,
113    /// PostgreSQL listener for receiving notifications
114    listener: Arc<Mutex<PgListener>>,
115    /// Registered event handlers per channel
116    handlers: Arc<RwLock<HashMap<String, Vec<EventHandlerFn>>>>,
117    /// Active subscriptions
118    subscriptions: Arc<RwLock<Vec<String>>>,
119    /// Listener configuration
120    config: ListenerConfig,
121}
122
123/// Configuration for event listener
124#[derive(Debug, Clone)]
125pub struct ListenerConfig {
126    /// Maximum reconnection attempts (0 = infinite)
127    pub max_reconnect_attempts: usize,
128    /// Delay between reconnection attempts in milliseconds
129    pub reconnect_delay_ms: u64,
130    /// Buffer size for event queue
131    pub event_buffer_size: usize,
132}
133
134impl Default for ListenerConfig {
135    fn default() -> Self {
136        Self {
137            max_reconnect_attempts: 0, // Infinite retries
138            reconnect_delay_ms: 1000,  // 1 second
139            event_buffer_size: 1000,
140        }
141    }
142}
143
144impl EventListener {
145    /// Create a new event listener
146    pub async fn new(pool: PgPool) -> Result<Self> {
147        Self::new_with_config(pool, ListenerConfig::default()).await
148    }
149
150    /// Create a new event listener with custom configuration
151    pub async fn new_with_config(pool: PgPool, config: ListenerConfig) -> Result<Self> {
152        let listener = PgListener::connect_with(&pool).await?;
153
154        info!("Created PostgreSQL event listener");
155
156        Ok(Self {
157            pool,
158            listener: Arc::new(Mutex::new(listener)),
159            handlers: Arc::new(RwLock::new(HashMap::new())),
160            subscriptions: Arc::new(RwLock::new(Vec::new())),
161            config,
162        })
163    }
164
165    /// Subscribe to a notification channel
166    pub async fn subscribe(&self, channel: &str) -> Result<()> {
167        let mut listener = self.listener.lock().await;
168        listener.listen(channel).await?;
169
170        let mut subs = self.subscriptions.write().await;
171        if !subs.contains(&channel.to_string()) {
172            subs.push(channel.to_string());
173        }
174
175        info!("Subscribed to channel: {}", channel);
176        Ok(())
177    }
178
179    /// Unsubscribe from a notification channel
180    pub async fn unsubscribe(&self, channel: &str) -> Result<()> {
181        let mut listener = self.listener.lock().await;
182        listener.unlisten(channel).await?;
183
184        let mut subs = self.subscriptions.write().await;
185        subs.retain(|c| c != channel);
186
187        info!("Unsubscribed from channel: {}", channel);
188        Ok(())
189    }
190
191    /// Register an event handler for a specific channel
192    pub async fn on_event<F>(&self, channel: &str, handler: F)
193    where
194        F: Fn(DatabaseEvent) + Send + Sync + 'static,
195    {
196        let mut handlers = self.handlers.write().await;
197        handlers
198            .entry(channel.to_string())
199            .or_insert_with(Vec::new)
200            .push(Arc::new(handler));
201
202        debug!("Registered event handler for channel: {}", channel);
203    }
204
205    /// Start listening for events (blocking until stopped)
206    pub async fn listen(&self) -> Result<()> {
207        info!("Started listening for database events");
208
209        loop {
210            let notification = {
211                let mut listener = self.listener.lock().await;
212                match listener.try_recv().await {
213                    Ok(Some(notif)) => notif,
214                    Ok(None) => {
215                        // No notification available, wait a bit
216                        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
217                        continue;
218                    }
219                    Err(e) => {
220                        error!("Error receiving notification: {}", e);
221                        // Attempt reconnection
222                        if let Err(e) = self.reconnect().await {
223                            error!("Failed to reconnect: {}", e);
224                        }
225                        continue;
226                    }
227                }
228            };
229
230            self.handle_notification(notification).await;
231        }
232    }
233
234    /// Handle a received notification
235    async fn handle_notification(&self, notification: PgNotification) {
236        let channel = notification.channel();
237
238        // Parse event
239        let event = match DatabaseEvent::from_notification(&notification) {
240            Ok(event) => event,
241            Err(e) => {
242                warn!("Failed to parse event from channel {}: {}", channel, e);
243                return;
244            }
245        };
246
247        debug!(
248            "Received event: channel={}, type={}",
249            event.channel, event.event_type
250        );
251
252        // Call registered handlers
253        let handlers = self.handlers.read().await;
254        if let Some(channel_handlers) = handlers.get(channel) {
255            for handler in channel_handlers {
256                handler(event.clone());
257            }
258        }
259    }
260
261    /// Reconnect and re-subscribe to all channels
262    async fn reconnect(&self) -> Result<()> {
263        warn!("Attempting to reconnect to PostgreSQL...");
264
265        let mut attempts = 0;
266        loop {
267            if self.config.max_reconnect_attempts > 0
268                && attempts >= self.config.max_reconnect_attempts
269            {
270                return Err(DbError::Connection(format!(
271                    "Failed to reconnect after {} attempts",
272                    attempts
273                )));
274            }
275
276            // Wait before retrying
277            if attempts > 0 {
278                tokio::time::sleep(tokio::time::Duration::from_millis(
279                    self.config.reconnect_delay_ms,
280                ))
281                .await;
282            }
283
284            // Attempt to create new listener
285            match PgListener::connect_with(&self.pool).await {
286                Ok(new_listener) => {
287                    let mut listener = self.listener.lock().await;
288                    *listener = new_listener;
289
290                    // Re-subscribe to all channels
291                    let subscriptions = self.subscriptions.read().await;
292                    for channel in subscriptions.iter() {
293                        if let Err(e) = listener.listen(channel).await {
294                            error!("Failed to re-subscribe to {}: {}", channel, e);
295                        }
296                    }
297
298                    info!("Successfully reconnected and re-subscribed");
299                    return Ok(());
300                }
301                Err(e) => {
302                    error!("Reconnection attempt {} failed: {}", attempts + 1, e);
303                    attempts += 1;
304                }
305            }
306        }
307    }
308
309    /// Publish an event to a channel
310    pub async fn notify(&self, event: &DatabaseEvent) -> Result<()> {
311        let payload = event.to_json_string()?;
312
313        sqlx::query(&format!("NOTIFY {}, $1", event.channel))
314            .bind(&payload)
315            .execute(&self.pool)
316            .await?;
317
318        debug!("Published event to channel: {}", event.channel);
319        Ok(())
320    }
321
322    /// Get list of active subscriptions
323    pub async fn get_subscriptions(&self) -> Vec<String> {
324        self.subscriptions.read().await.clone()
325    }
326
327    /// Get number of registered handlers per channel
328    pub async fn get_handler_counts(&self) -> HashMap<String, usize> {
329        let handlers = self.handlers.read().await;
330        handlers
331            .iter()
332            .map(|(channel, handlers)| (channel.clone(), handlers.len()))
333            .collect()
334    }
335}
336
337/// Event notification statistics
338#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct NotificationStats {
340    /// Total notifications received
341    pub notifications_received: u64,
342    /// Total notifications sent
343    pub notifications_sent: u64,
344    /// Active subscriptions
345    pub active_subscriptions: usize,
346    /// Registered handlers count per channel
347    pub handlers_per_channel: HashMap<String, usize>,
348}
349
350/// Helper trait for creating database triggers that send notifications
351pub trait NotificationTriggers {
352    /// Create a trigger that sends a notification on INSERT
353    fn create_insert_trigger_sql(table: &str, channel: &str) -> String;
354
355    /// Create a trigger that sends a notification on UPDATE
356    fn create_update_trigger_sql(table: &str, channel: &str) -> String;
357
358    /// Create a trigger that sends a notification on DELETE
359    fn create_delete_trigger_sql(table: &str, channel: &str) -> String;
360}
361
362pub struct PostgresNotificationTriggers;
363
364impl NotificationTriggers for PostgresNotificationTriggers {
365    fn create_insert_trigger_sql(table: &str, channel: &str) -> String {
366        format!(
367            r#"
368CREATE OR REPLACE FUNCTION notify_{table}_insert()
369RETURNS TRIGGER AS $$
370BEGIN
371    PERFORM pg_notify(
372        '{channel}',
373        json_build_object(
374            'channel', '{channel}',
375            'event_type', '{table}.created',
376            'payload', row_to_json(NEW),
377            'timestamp', to_char(NOW(), 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
378        )::text
379    );
380    RETURN NEW;
381END;
382$$ LANGUAGE plpgsql;
383
384CREATE TRIGGER {table}_insert_notify
385AFTER INSERT ON {table}
386FOR EACH ROW
387EXECUTE FUNCTION notify_{table}_insert();
388            "#,
389            table = table,
390            channel = channel
391        )
392    }
393
394    fn create_update_trigger_sql(table: &str, channel: &str) -> String {
395        format!(
396            r#"
397CREATE OR REPLACE FUNCTION notify_{table}_update()
398RETURNS TRIGGER AS $$
399BEGIN
400    PERFORM pg_notify(
401        '{channel}',
402        json_build_object(
403            'channel', '{channel}',
404            'event_type', '{table}.updated',
405            'payload', json_build_object('old', row_to_json(OLD), 'new', row_to_json(NEW)),
406            'timestamp', to_char(NOW(), 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
407        )::text
408    );
409    RETURN NEW;
410END;
411$$ LANGUAGE plpgsql;
412
413CREATE TRIGGER {table}_update_notify
414AFTER UPDATE ON {table}
415FOR EACH ROW
416EXECUTE FUNCTION notify_{table}_update();
417            "#,
418            table = table,
419            channel = channel
420        )
421    }
422
423    fn create_delete_trigger_sql(table: &str, channel: &str) -> String {
424        format!(
425            r#"
426CREATE OR REPLACE FUNCTION notify_{table}_delete()
427RETURNS TRIGGER AS $$
428BEGIN
429    PERFORM pg_notify(
430        '{channel}',
431        json_build_object(
432            'channel', '{channel}',
433            'event_type', '{table}.deleted',
434            'payload', row_to_json(OLD),
435            'timestamp', to_char(NOW(), 'YYYY-MM-DD"T"HH24:MI:SS"Z"')
436        )::text
437    );
438    RETURN OLD;
439END;
440$$ LANGUAGE plpgsql;
441
442CREATE TRIGGER {table}_delete_notify
443AFTER DELETE ON {table}
444FOR EACH ROW
445EXECUTE FUNCTION notify_{table}_delete();
446            "#,
447            table = table,
448            channel = channel
449        )
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[test]
458    fn test_database_event_creation() {
459        let payload = serde_json::json!({"user_id": "123", "action": "login"});
460        let event = DatabaseEvent::new("user_events", "user.login", payload.clone());
461
462        assert_eq!(event.channel, "user_events");
463        assert_eq!(event.event_type, "user.login");
464        assert_eq!(event.payload, payload);
465        assert!(event.correlation_id.is_none());
466        assert!(!event.timestamp.is_empty());
467    }
468
469    #[test]
470    fn test_database_event_with_correlation_id() {
471        let payload = serde_json::json!({"order_id": "456"});
472        let event = DatabaseEvent::new("order_events", "order.created", payload)
473            .with_correlation_id("trace-123-456");
474
475        assert_eq!(event.correlation_id, Some("trace-123-456".to_string()));
476    }
477
478    #[test]
479    fn test_database_event_serialization() {
480        let payload = serde_json::json!({"token_id": "789", "amount": 1000});
481        let event = DatabaseEvent::new("token_events", "token.transfer", payload);
482
483        let json_str = event.to_json_string().unwrap();
484        assert!(json_str.contains("token_events"));
485        assert!(json_str.contains("token.transfer"));
486        assert!(json_str.contains("token_id"));
487    }
488
489    #[test]
490    fn test_database_event_deserialization() {
491        let json_str = r#"{
492            "channel": "trade_events",
493            "event_type": "trade.executed",
494            "payload": {"trade_id": "999", "price": 50000},
495            "correlation_id": null,
496            "timestamp": "2024-01-01T12:00:00Z"
497        }"#;
498
499        let event: DatabaseEvent = serde_json::from_str(json_str).unwrap();
500        assert_eq!(event.channel, "trade_events");
501        assert_eq!(event.event_type, "trade.executed");
502        assert_eq!(event.payload["trade_id"], "999");
503    }
504
505    #[test]
506    fn test_listener_config_default() {
507        let config = ListenerConfig::default();
508        assert_eq!(config.max_reconnect_attempts, 0); // Infinite
509        assert_eq!(config.reconnect_delay_ms, 1000);
510        assert_eq!(config.event_buffer_size, 1000);
511    }
512
513    #[test]
514    fn test_listener_config_custom() {
515        let config = ListenerConfig {
516            max_reconnect_attempts: 5,
517            reconnect_delay_ms: 2000,
518            event_buffer_size: 500,
519        };
520
521        assert_eq!(config.max_reconnect_attempts, 5);
522        assert_eq!(config.reconnect_delay_ms, 2000);
523        assert_eq!(config.event_buffer_size, 500);
524    }
525
526    #[test]
527    fn test_create_insert_trigger_sql() {
528        let sql = PostgresNotificationTriggers::create_insert_trigger_sql("users", "user_events");
529
530        assert!(sql.contains("CREATE OR REPLACE FUNCTION notify_users_insert()"));
531        assert!(sql.contains("pg_notify"));
532        assert!(sql.contains("user_events"));
533        assert!(sql.contains("users.created"));
534        assert!(sql.contains("CREATE TRIGGER users_insert_notify"));
535    }
536
537    #[test]
538    fn test_create_update_trigger_sql() {
539        let sql = PostgresNotificationTriggers::create_update_trigger_sql("tokens", "token_events");
540
541        assert!(sql.contains("CREATE OR REPLACE FUNCTION notify_tokens_update()"));
542        assert!(sql.contains("pg_notify"));
543        assert!(sql.contains("token_events"));
544        assert!(sql.contains("tokens.updated"));
545        assert!(sql.contains("CREATE TRIGGER tokens_update_notify"));
546    }
547
548    #[test]
549    fn test_create_delete_trigger_sql() {
550        let sql = PostgresNotificationTriggers::create_delete_trigger_sql("orders", "order_events");
551
552        assert!(sql.contains("CREATE OR REPLACE FUNCTION notify_orders_delete()"));
553        assert!(sql.contains("pg_notify"));
554        assert!(sql.contains("order_events"));
555        assert!(sql.contains("orders.deleted"));
556        assert!(sql.contains("CREATE TRIGGER orders_delete_notify"));
557    }
558
559    #[test]
560    fn test_notification_stats_serialization() {
561        let stats = NotificationStats {
562            notifications_received: 1000,
563            notifications_sent: 500,
564            active_subscriptions: 5,
565            handlers_per_channel: [("users".to_string(), 3), ("tokens".to_string(), 2)]
566                .iter()
567                .cloned()
568                .collect(),
569        };
570
571        let json = serde_json::to_string(&stats).unwrap();
572        assert!(json.contains("notifications_received"));
573        assert!(json.contains("1000"));
574        assert!(json.contains("handlers_per_channel"));
575    }
576
577    #[test]
578    fn test_event_payload_types() {
579        // Test with different payload types
580        let string_payload = serde_json::json!("simple string");
581        let event1 = DatabaseEvent::new("test", "test.string", string_payload);
582        assert!(event1.to_json_string().is_ok());
583
584        let object_payload = serde_json::json!({"key": "value", "number": 42});
585        let event2 = DatabaseEvent::new("test", "test.object", object_payload);
586        assert!(event2.to_json_string().is_ok());
587
588        let array_payload = serde_json::json!([1, 2, 3, 4, 5]);
589        let event3 = DatabaseEvent::new("test", "test.array", array_payload);
590        assert!(event3.to_json_string().is_ok());
591    }
592
593    #[test]
594    fn test_event_timestamp_format() {
595        let event = DatabaseEvent::new("test", "test.event", serde_json::json!({"test": true}));
596
597        // Verify timestamp is in RFC3339 format
598        assert!(chrono::DateTime::parse_from_rfc3339(&event.timestamp).is_ok());
599    }
600}