postrust_graphql/subscription/
broker.rs1use futures::stream::{Stream, StreamExt};
7use sqlx::postgres::PgListener;
8use sqlx::PgPool;
9use std::collections::HashMap;
10use std::pin::Pin;
11use std::sync::Arc;
12use tokio::sync::broadcast;
13use tokio::sync::RwLock;
14use tracing::{debug, error, info, warn};
15
16const DEFAULT_CHANNEL_CAPACITY: usize = 256;
18
19#[derive(Debug, Clone)]
21pub struct PgNotification {
22 pub channel: String,
24 pub payload: String,
26 pub process_id: u32,
28}
29
30pub struct NotifyBroker {
32 pool: PgPool,
34 channels: Arc<RwLock<HashMap<String, broadcast::Sender<PgNotification>>>>,
36 channel_capacity: usize,
38 running: Arc<RwLock<bool>>,
40}
41
42impl NotifyBroker {
43 pub fn new(pool: PgPool) -> Self {
45 Self {
46 pool,
47 channels: Arc::new(RwLock::new(HashMap::new())),
48 channel_capacity: DEFAULT_CHANNEL_CAPACITY,
49 running: Arc::new(RwLock::new(false)),
50 }
51 }
52
53 pub fn with_capacity(pool: PgPool, capacity: usize) -> Self {
55 Self {
56 pool,
57 channels: Arc::new(RwLock::new(HashMap::new())),
58 channel_capacity: capacity,
59 running: Arc::new(RwLock::new(false)),
60 }
61 }
62
63 pub async fn start(&self, listen_channels: Vec<String>) -> Result<(), BrokerError> {
68 {
70 let running = self.running.read().await;
71 if *running {
72 return Err(BrokerError::AlreadyRunning);
73 }
74 }
75
76 {
78 let mut running = self.running.write().await;
79 *running = true;
80 }
81
82 {
84 let mut channels = self.channels.write().await;
85 for channel_name in &listen_channels {
86 if !channels.contains_key(channel_name) {
87 let (tx, _) = broadcast::channel(self.channel_capacity);
88 channels.insert(channel_name.clone(), tx);
89 }
90 }
91 }
92
93 let mut listener = PgListener::connect_with(&self.pool)
95 .await
96 .map_err(BrokerError::Database)?;
97
98 for channel in &listen_channels {
100 listener
101 .listen(channel)
102 .await
103 .map_err(BrokerError::Database)?;
104 info!("Listening on PostgreSQL channel: {}", channel);
105 }
106
107 let channels = Arc::clone(&self.channels);
109 let running = Arc::clone(&self.running);
110
111 tokio::spawn(async move {
113 loop {
114 {
116 let is_running = running.read().await;
117 if !*is_running {
118 info!("Broker stopped, exiting listener loop");
119 break;
120 }
121 }
122
123 match listener.try_recv().await {
124 Ok(Some(notification)) => {
125 let pg_notification = PgNotification {
126 channel: notification.channel().to_string(),
127 payload: notification.payload().to_string(),
128 process_id: notification.process_id() as u32,
129 };
130
131 debug!(
132 "Received notification on channel '{}': {}",
133 pg_notification.channel,
134 &pg_notification.payload[..pg_notification.payload.len().min(100)]
135 );
136
137 let channels_read = channels.read().await;
139 if let Some(sender) = channels_read.get(&pg_notification.channel) {
140 let _ = sender.send(pg_notification);
142 }
143 }
144 Ok(None) => {
145 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
147 }
148 Err(e) => {
149 error!("Error receiving notification: {:?}", e);
150 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
152 }
153 }
154 }
155 });
156
157 Ok(())
158 }
159
160 pub async fn stop(&self) {
162 let mut running = self.running.write().await;
163 *running = false;
164 info!("Broker stop requested");
165 }
166
167 pub async fn subscribe(
171 &self,
172 channel: &str,
173 ) -> Result<Pin<Box<dyn Stream<Item = PgNotification> + Send>>, BrokerError> {
174 let channels = self.channels.read().await;
175
176 let sender = channels
177 .get(channel)
178 .ok_or_else(|| BrokerError::ChannelNotFound(channel.to_string()))?;
179
180 let receiver = sender.subscribe();
181
182 let stream = tokio_stream::wrappers::BroadcastStream::new(receiver).filter_map(|result| {
184 futures::future::ready(result.ok())
185 });
186
187 Ok(Box::pin(stream))
188 }
189
190 pub async fn subscribe_or_create(
195 &self,
196 channel: &str,
197 ) -> Pin<Box<dyn Stream<Item = PgNotification> + Send>> {
198 {
200 let channels = self.channels.read().await;
201 if let Some(sender) = channels.get(channel) {
202 let receiver = sender.subscribe();
203 let stream = tokio_stream::wrappers::BroadcastStream::new(receiver)
204 .filter_map(|result| futures::future::ready(result.ok()));
205 return Box::pin(stream);
206 }
207 }
208
209 {
211 let mut channels = self.channels.write().await;
212 if !channels.contains_key(channel) {
214 let (tx, _) = broadcast::channel(self.channel_capacity);
215 channels.insert(channel.to_string(), tx);
216 }
217 }
218
219 let channels = self.channels.read().await;
221 let sender = channels.get(channel).expect("just created");
222 let receiver = sender.subscribe();
223 let stream = tokio_stream::wrappers::BroadcastStream::new(receiver)
224 .filter_map(|result| futures::future::ready(result.ok()));
225 Box::pin(stream)
226 }
227
228 pub async fn listen_channel(&self, channel: &str) -> Result<(), BrokerError> {
230 let mut listener = PgListener::connect_with(&self.pool)
232 .await
233 .map_err(BrokerError::Database)?;
234
235 listener
236 .listen(channel)
237 .await
238 .map_err(BrokerError::Database)?;
239
240 {
242 let mut channels = self.channels.write().await;
243 if !channels.contains_key(channel) {
244 let (tx, _) = broadcast::channel(self.channel_capacity);
245 channels.insert(channel.to_string(), tx);
246 }
247 }
248
249 let channels = Arc::clone(&self.channels);
250 let running = Arc::clone(&self.running);
251 let channel_name = channel.to_string();
252
253 tokio::spawn(async move {
255 info!("Started dynamic listener for channel: {}", channel_name);
256
257 loop {
258 {
259 let is_running = running.read().await;
260 if !*is_running {
261 break;
262 }
263 }
264
265 match listener.try_recv().await {
266 Ok(Some(notification)) => {
267 let pg_notification = PgNotification {
268 channel: notification.channel().to_string(),
269 payload: notification.payload().to_string(),
270 process_id: notification.process_id() as u32,
271 };
272
273 let channels_read = channels.read().await;
274 if let Some(sender) = channels_read.get(&pg_notification.channel) {
275 let _ = sender.send(pg_notification);
276 }
277 }
278 Ok(None) => {
279 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
280 }
281 Err(e) => {
282 warn!("Error on channel {}: {:?}", channel_name, e);
283 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
284 }
285 }
286 }
287
288 info!("Stopped dynamic listener for channel: {}", channel_name);
289 });
290
291 Ok(())
292 }
293
294 pub async fn is_running(&self) -> bool {
296 *self.running.read().await
297 }
298
299 pub async fn channel_count(&self) -> usize {
301 self.channels.read().await.len()
302 }
303}
304
305#[derive(Debug, thiserror::Error)]
307pub enum BrokerError {
308 #[error("Database error: {0}")]
309 Database(#[from] sqlx::Error),
310
311 #[error("Channel not found: {0}")]
312 ChannelNotFound(String),
313
314 #[error("Broker is already running")]
315 AlreadyRunning,
316}
317
318pub fn table_channel_name(schema: &str, table: &str) -> String {
320 format!("postrust_{}_{}", schema, table)
321}
322
323pub fn create_notify_trigger_sql(schema: &str, table: &str) -> String {
325 let channel = table_channel_name(schema, table);
326 let trigger_name = format!("postrust_notify_{}_{}", schema, table);
327 let function_name = format!("postrust_notify_{}_{}_fn", schema, table);
328
329 format!(
330 r#"
331-- Create notification function
332CREATE OR REPLACE FUNCTION {schema}.{function_name}()
333RETURNS TRIGGER AS $$
334DECLARE
335 payload jsonb;
336BEGIN
337 IF TG_OP = 'DELETE' THEN
338 payload := jsonb_build_object(
339 'operation', 'DELETE',
340 'table', TG_TABLE_NAME,
341 'schema', TG_TABLE_SCHEMA,
342 'old', row_to_json(OLD)
343 );
344 ELSIF TG_OP = 'UPDATE' THEN
345 payload := jsonb_build_object(
346 'operation', 'UPDATE',
347 'table', TG_TABLE_NAME,
348 'schema', TG_TABLE_SCHEMA,
349 'old', row_to_json(OLD),
350 'new', row_to_json(NEW)
351 );
352 ELSIF TG_OP = 'INSERT' THEN
353 payload := jsonb_build_object(
354 'operation', 'INSERT',
355 'table', TG_TABLE_NAME,
356 'schema', TG_TABLE_SCHEMA,
357 'new', row_to_json(NEW)
358 );
359 END IF;
360
361 PERFORM pg_notify('{channel}', payload::text);
362
363 RETURN COALESCE(NEW, OLD);
364END;
365$$ LANGUAGE plpgsql;
366
367-- Create trigger
368DROP TRIGGER IF EXISTS {trigger_name} ON {schema}.{table};
369CREATE TRIGGER {trigger_name}
370 AFTER INSERT OR UPDATE OR DELETE ON {schema}.{table}
371 FOR EACH ROW
372 EXECUTE FUNCTION {schema}.{function_name}();
373"#,
374 schema = schema,
375 table = table,
376 channel = channel,
377 function_name = function_name,
378 trigger_name = trigger_name
379 )
380}
381
382pub fn drop_notify_trigger_sql(schema: &str, table: &str) -> String {
384 let trigger_name = format!("postrust_notify_{}_{}", schema, table);
385 let function_name = format!("postrust_notify_{}_{}_fn", schema, table);
386
387 format!(
388 r#"
389DROP TRIGGER IF EXISTS {trigger_name} ON {schema}.{table};
390DROP FUNCTION IF EXISTS {schema}.{function_name}();
391"#,
392 schema = schema,
393 table = table,
394 trigger_name = trigger_name,
395 function_name = function_name
396 )
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_table_channel_name() {
405 assert_eq!(
406 table_channel_name("public", "users"),
407 "postrust_public_users"
408 );
409 assert_eq!(
410 table_channel_name("api", "orders"),
411 "postrust_api_orders"
412 );
413 }
414
415 #[test]
416 fn test_create_notify_trigger_sql() {
417 let sql = create_notify_trigger_sql("public", "users");
418 assert!(sql.contains("CREATE OR REPLACE FUNCTION"));
419 assert!(sql.contains("postrust_notify_public_users_fn"));
420 assert!(sql.contains("CREATE TRIGGER"));
421 assert!(sql.contains("pg_notify"));
422 assert!(sql.contains("postrust_public_users"));
423 }
424
425 #[test]
426 fn test_drop_notify_trigger_sql() {
427 let sql = drop_notify_trigger_sql("public", "users");
428 assert!(sql.contains("DROP TRIGGER IF EXISTS"));
429 assert!(sql.contains("DROP FUNCTION IF EXISTS"));
430 }
431}