postrust_graphql/subscription/
broker.rs

1//! PostgreSQL NOTIFY message broker for GraphQL subscriptions.
2//!
3//! This module provides a broker that listens to PostgreSQL NOTIFY events
4//! and broadcasts them to GraphQL subscription clients.
5
6use futures::stream::{Stream, StreamExt};
7use sqlx::postgres::PgListener;
8use sqlx::PgPool;
9use std::collections::HashMap;
10use std::pin::Pin;
11use std::sync::Arc;
12use tokio::sync::broadcast;
13use tokio::sync::RwLock;
14use tracing::{debug, error, info, warn};
15
16/// Default channel capacity for broadcast channels
17const DEFAULT_CHANNEL_CAPACITY: usize = 256;
18
19/// A notification from PostgreSQL
20#[derive(Debug, Clone)]
21pub struct PgNotification {
22    /// The channel name (table name or custom channel)
23    pub channel: String,
24    /// The payload (usually JSON)
25    pub payload: String,
26    /// Process ID that sent the notification
27    pub process_id: u32,
28}
29
30/// Message broker that distributes PostgreSQL NOTIFY events to subscribers.
31pub struct NotifyBroker {
32    /// Database connection pool
33    pool: PgPool,
34    /// Channel senders keyed by channel name
35    channels: Arc<RwLock<HashMap<String, broadcast::Sender<PgNotification>>>>,
36    /// Capacity for new broadcast channels
37    channel_capacity: usize,
38    /// Whether the broker is running
39    running: Arc<RwLock<bool>>,
40}
41
42impl NotifyBroker {
43    /// Create a new notification broker.
44    pub fn new(pool: PgPool) -> Self {
45        Self {
46            pool,
47            channels: Arc::new(RwLock::new(HashMap::new())),
48            channel_capacity: DEFAULT_CHANNEL_CAPACITY,
49            running: Arc::new(RwLock::new(false)),
50        }
51    }
52
53    /// Create a new notification broker with custom channel capacity.
54    pub fn with_capacity(pool: PgPool, capacity: usize) -> Self {
55        Self {
56            pool,
57            channels: Arc::new(RwLock::new(HashMap::new())),
58            channel_capacity: capacity,
59            running: Arc::new(RwLock::new(false)),
60        }
61    }
62
63    /// Start listening for notifications on the given channels.
64    ///
65    /// This spawns a background task that listens for PostgreSQL NOTIFY events
66    /// and broadcasts them to all subscribers.
67    pub async fn start(&self, listen_channels: Vec<String>) -> Result<(), BrokerError> {
68        // Check if already running
69        {
70            let running = self.running.read().await;
71            if *running {
72                return Err(BrokerError::AlreadyRunning);
73            }
74        }
75
76        // Mark as running
77        {
78            let mut running = self.running.write().await;
79            *running = true;
80        }
81
82        // Create channels for each listen channel
83        {
84            let mut channels = self.channels.write().await;
85            for channel_name in &listen_channels {
86                if !channels.contains_key(channel_name) {
87                    let (tx, _) = broadcast::channel(self.channel_capacity);
88                    channels.insert(channel_name.clone(), tx);
89                }
90            }
91        }
92
93        // Create listener
94        let mut listener = PgListener::connect_with(&self.pool)
95            .await
96            .map_err(BrokerError::Database)?;
97
98        // Subscribe to all channels
99        for channel in &listen_channels {
100            listener
101                .listen(channel)
102                .await
103                .map_err(BrokerError::Database)?;
104            info!("Listening on PostgreSQL channel: {}", channel);
105        }
106
107        // Clone for the spawned task
108        let channels = Arc::clone(&self.channels);
109        let running = Arc::clone(&self.running);
110
111        // Spawn listener task
112        tokio::spawn(async move {
113            loop {
114                // Check if we should stop
115                {
116                    let is_running = running.read().await;
117                    if !*is_running {
118                        info!("Broker stopped, exiting listener loop");
119                        break;
120                    }
121                }
122
123                match listener.try_recv().await {
124                    Ok(Some(notification)) => {
125                        let pg_notification = PgNotification {
126                            channel: notification.channel().to_string(),
127                            payload: notification.payload().to_string(),
128                            process_id: notification.process_id() as u32,
129                        };
130
131                        debug!(
132                            "Received notification on channel '{}': {}",
133                            pg_notification.channel,
134                            &pg_notification.payload[..pg_notification.payload.len().min(100)]
135                        );
136
137                        // Broadcast to subscribers
138                        let channels_read = channels.read().await;
139                        if let Some(sender) = channels_read.get(&pg_notification.channel) {
140                            // Ignore send errors - means no active receivers
141                            let _ = sender.send(pg_notification);
142                        }
143                    }
144                    Ok(None) => {
145                        // No notification available, continue
146                        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
147                    }
148                    Err(e) => {
149                        error!("Error receiving notification: {:?}", e);
150                        // Try to reconnect after a delay
151                        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
152                    }
153                }
154            }
155        });
156
157        Ok(())
158    }
159
160    /// Stop the broker.
161    pub async fn stop(&self) {
162        let mut running = self.running.write().await;
163        *running = false;
164        info!("Broker stop requested");
165    }
166
167    /// Subscribe to notifications for a specific channel.
168    ///
169    /// Returns a stream of notifications for the given channel.
170    pub async fn subscribe(
171        &self,
172        channel: &str,
173    ) -> Result<Pin<Box<dyn Stream<Item = PgNotification> + Send>>, BrokerError> {
174        let channels = self.channels.read().await;
175
176        let sender = channels
177            .get(channel)
178            .ok_or_else(|| BrokerError::ChannelNotFound(channel.to_string()))?;
179
180        let receiver = sender.subscribe();
181
182        // Convert broadcast receiver to stream
183        let stream = tokio_stream::wrappers::BroadcastStream::new(receiver).filter_map(|result| {
184            futures::future::ready(result.ok())
185        });
186
187        Ok(Box::pin(stream))
188    }
189
190    /// Subscribe to a channel, creating it if it doesn't exist.
191    ///
192    /// Note: This only creates a broadcast channel. You must also call
193    /// `listen_channel` to start receiving PostgreSQL notifications.
194    pub async fn subscribe_or_create(
195        &self,
196        channel: &str,
197    ) -> Pin<Box<dyn Stream<Item = PgNotification> + Send>> {
198        // First try to get existing channel
199        {
200            let channels = self.channels.read().await;
201            if let Some(sender) = channels.get(channel) {
202                let receiver = sender.subscribe();
203                let stream = tokio_stream::wrappers::BroadcastStream::new(receiver)
204                    .filter_map(|result| futures::future::ready(result.ok()));
205                return Box::pin(stream);
206            }
207        }
208
209        // Create new channel
210        {
211            let mut channels = self.channels.write().await;
212            // Double-check after acquiring write lock
213            if !channels.contains_key(channel) {
214                let (tx, _) = broadcast::channel(self.channel_capacity);
215                channels.insert(channel.to_string(), tx);
216            }
217        }
218
219        // Now subscribe
220        let channels = self.channels.read().await;
221        let sender = channels.get(channel).expect("just created");
222        let receiver = sender.subscribe();
223        let stream = tokio_stream::wrappers::BroadcastStream::new(receiver)
224            .filter_map(|result| futures::future::ready(result.ok()));
225        Box::pin(stream)
226    }
227
228    /// Add a new channel to listen on dynamically.
229    pub async fn listen_channel(&self, channel: &str) -> Result<(), BrokerError> {
230        // Create a new listener for this channel
231        let mut listener = PgListener::connect_with(&self.pool)
232            .await
233            .map_err(BrokerError::Database)?;
234
235        listener
236            .listen(channel)
237            .await
238            .map_err(BrokerError::Database)?;
239
240        // Ensure broadcast channel exists
241        {
242            let mut channels = self.channels.write().await;
243            if !channels.contains_key(channel) {
244                let (tx, _) = broadcast::channel(self.channel_capacity);
245                channels.insert(channel.to_string(), tx);
246            }
247        }
248
249        let channels = Arc::clone(&self.channels);
250        let running = Arc::clone(&self.running);
251        let channel_name = channel.to_string();
252
253        // Spawn a listener for this channel
254        tokio::spawn(async move {
255            info!("Started dynamic listener for channel: {}", channel_name);
256
257            loop {
258                {
259                    let is_running = running.read().await;
260                    if !*is_running {
261                        break;
262                    }
263                }
264
265                match listener.try_recv().await {
266                    Ok(Some(notification)) => {
267                        let pg_notification = PgNotification {
268                            channel: notification.channel().to_string(),
269                            payload: notification.payload().to_string(),
270                            process_id: notification.process_id() as u32,
271                        };
272
273                        let channels_read = channels.read().await;
274                        if let Some(sender) = channels_read.get(&pg_notification.channel) {
275                            let _ = sender.send(pg_notification);
276                        }
277                    }
278                    Ok(None) => {
279                        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
280                    }
281                    Err(e) => {
282                        warn!("Error on channel {}: {:?}", channel_name, e);
283                        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
284                    }
285                }
286            }
287
288            info!("Stopped dynamic listener for channel: {}", channel_name);
289        });
290
291        Ok(())
292    }
293
294    /// Check if the broker is currently running.
295    pub async fn is_running(&self) -> bool {
296        *self.running.read().await
297    }
298
299    /// Get the number of active channels.
300    pub async fn channel_count(&self) -> usize {
301        self.channels.read().await.len()
302    }
303}
304
305/// Errors that can occur in the broker.
306#[derive(Debug, thiserror::Error)]
307pub enum BrokerError {
308    #[error("Database error: {0}")]
309    Database(#[from] sqlx::Error),
310
311    #[error("Channel not found: {0}")]
312    ChannelNotFound(String),
313
314    #[error("Broker is already running")]
315    AlreadyRunning,
316}
317
318/// Generate a channel name for table change notifications.
319pub fn table_channel_name(schema: &str, table: &str) -> String {
320    format!("postrust_{}_{}", schema, table)
321}
322
323/// Generate SQL to create a notification trigger for a table.
324pub fn create_notify_trigger_sql(schema: &str, table: &str) -> String {
325    let channel = table_channel_name(schema, table);
326    let trigger_name = format!("postrust_notify_{}_{}", schema, table);
327    let function_name = format!("postrust_notify_{}_{}_fn", schema, table);
328
329    format!(
330        r#"
331-- Create notification function
332CREATE OR REPLACE FUNCTION {schema}.{function_name}()
333RETURNS TRIGGER AS $$
334DECLARE
335    payload jsonb;
336BEGIN
337    IF TG_OP = 'DELETE' THEN
338        payload := jsonb_build_object(
339            'operation', 'DELETE',
340            'table', TG_TABLE_NAME,
341            'schema', TG_TABLE_SCHEMA,
342            'old', row_to_json(OLD)
343        );
344    ELSIF TG_OP = 'UPDATE' THEN
345        payload := jsonb_build_object(
346            'operation', 'UPDATE',
347            'table', TG_TABLE_NAME,
348            'schema', TG_TABLE_SCHEMA,
349            'old', row_to_json(OLD),
350            'new', row_to_json(NEW)
351        );
352    ELSIF TG_OP = 'INSERT' THEN
353        payload := jsonb_build_object(
354            'operation', 'INSERT',
355            'table', TG_TABLE_NAME,
356            'schema', TG_TABLE_SCHEMA,
357            'new', row_to_json(NEW)
358        );
359    END IF;
360
361    PERFORM pg_notify('{channel}', payload::text);
362
363    RETURN COALESCE(NEW, OLD);
364END;
365$$ LANGUAGE plpgsql;
366
367-- Create trigger
368DROP TRIGGER IF EXISTS {trigger_name} ON {schema}.{table};
369CREATE TRIGGER {trigger_name}
370    AFTER INSERT OR UPDATE OR DELETE ON {schema}.{table}
371    FOR EACH ROW
372    EXECUTE FUNCTION {schema}.{function_name}();
373"#,
374        schema = schema,
375        table = table,
376        channel = channel,
377        function_name = function_name,
378        trigger_name = trigger_name
379    )
380}
381
382/// Generate SQL to drop a notification trigger for a table.
383pub fn drop_notify_trigger_sql(schema: &str, table: &str) -> String {
384    let trigger_name = format!("postrust_notify_{}_{}", schema, table);
385    let function_name = format!("postrust_notify_{}_{}_fn", schema, table);
386
387    format!(
388        r#"
389DROP TRIGGER IF EXISTS {trigger_name} ON {schema}.{table};
390DROP FUNCTION IF EXISTS {schema}.{function_name}();
391"#,
392        schema = schema,
393        table = table,
394        trigger_name = trigger_name,
395        function_name = function_name
396    )
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_table_channel_name() {
405        assert_eq!(
406            table_channel_name("public", "users"),
407            "postrust_public_users"
408        );
409        assert_eq!(
410            table_channel_name("api", "orders"),
411            "postrust_api_orders"
412        );
413    }
414
415    #[test]
416    fn test_create_notify_trigger_sql() {
417        let sql = create_notify_trigger_sql("public", "users");
418        assert!(sql.contains("CREATE OR REPLACE FUNCTION"));
419        assert!(sql.contains("postrust_notify_public_users_fn"));
420        assert!(sql.contains("CREATE TRIGGER"));
421        assert!(sql.contains("pg_notify"));
422        assert!(sql.contains("postrust_public_users"));
423    }
424
425    #[test]
426    fn test_drop_notify_trigger_sql() {
427        let sql = drop_notify_trigger_sql("public", "users");
428        assert!(sql.contains("DROP TRIGGER IF EXISTS"));
429        assert!(sql.contains("DROP FUNCTION IF EXISTS"));
430    }
431}