Skip to main content

forge_runtime/pg/
notify_bus.rs

1//! Process-wide PostgreSQL LISTEN/NOTIFY multiplexer.
2//!
3//! Reduces O(N) `PgListener` connections (one per subsystem) to a single
4//! connection that fans out received notifications to per-channel broadcast
5//! subscribers. Subsystems call [`PgNotifyBus::subscribe`] to obtain a
6//! `broadcast::Receiver<String>` for their channel instead of managing their
7//! own `PgListener` lifecycle and reconnection logic.
8//!
9//! # Reconnection
10//!
11//! The bus owns the only LISTEN connection. When that connection drops, the
12//! run loop reconnects with exponential backoff (500 ms to 30 s) and
13//! re-issues LISTEN on every registered channel. Subscribers see no
14//! interruption other than a brief gap in notifications (the same gap the
15//! old per-subsystem listeners had, except now there is exactly one
16//! reconnect path to maintain).
17//!
18//! # Payload semantics
19//!
20//! The bus forwards the raw `notification.payload()` string. Channels that
21//! use structured JSON payloads (typed via [`NotifyChannel`]) decode on the
22//! subscriber side, just as before. Channels with custom string formats
23//! (e.g. `forge_changes` with its `v1:table:OP:...` wire format) pass
24//! through unmodified.
25
26use std::collections::HashMap;
27use std::sync::Arc;
28use std::time::Duration;
29
30use tokio::sync::{broadcast, watch};
31
32/// Per-channel broadcast buffer size. Subscribers that fall behind by more
33/// than this many messages will see `RecvError::Lagged` and can decide
34/// whether to catch up or resync.
35const CHANNEL_BUFFER_SIZE: usize = 256;
36
37/// Initial reconnection delay after a `PgListener` disconnect.
38const INITIAL_BACKOFF: Duration = Duration::from_millis(500);
39
40/// Maximum reconnection delay.
41const MAX_BACKOFF: Duration = Duration::from_secs(30);
42
43/// Process-wide PostgreSQL LISTEN multiplexer.
44///
45/// Create one at runtime startup, register all channels up front, then call
46/// [`run`](Self::run) in a background task. Subsystems obtain a
47/// `broadcast::Receiver<String>` via [`subscribe`](Self::subscribe).
48pub struct PgNotifyBus {
49    pool: sqlx::PgPool,
50    /// Channel name -> broadcast sender. Populated at construction time;
51    /// immutable afterwards.
52    senders: Arc<HashMap<String, broadcast::Sender<String>>>,
53    /// Ticks once per successful (re)connect. The initial connect publishes
54    /// generation `1`; subsequent reconnects publish `2`, `3`, ... so a
55    /// subscriber that snapshots the value at startup can detect a reconnect
56    /// strictly later than its own start without being woken by the boot
57    /// connect. Subscribers that need replay-on-reconnect call
58    /// [`subscribe_reconnects`](Self::subscribe_reconnects) and react to
59    /// changes whose value is greater than the snapshot they observed when
60    /// they subscribed.
61    reconnect_tx: watch::Sender<u64>,
62}
63
64impl PgNotifyBus {
65    /// Create a new notify bus for the given channels.
66    ///
67    /// Each channel gets a `broadcast::channel(256)` so subscribers can lag
68    /// slightly without losing messages. The bus does not start listening
69    /// until [`run`](Self::run) is called.
70    pub fn new(pool: sqlx::PgPool, channels: &[&str]) -> Self {
71        let mut senders = HashMap::with_capacity(channels.len());
72        for &ch in channels {
73            let (tx, _) = broadcast::channel(CHANNEL_BUFFER_SIZE);
74            senders.insert(ch.to_string(), tx);
75        }
76        // Generation starts at 0 so the first successful connect (which
77        // publishes 1) is distinguishable from "never connected" and the
78        // value a subscriber sees the moment they call
79        // `subscribe_reconnects()` doesn't already look like a reconnect.
80        let (reconnect_tx, _) = watch::channel(0u64);
81        Self {
82            pool,
83            senders: Arc::new(senders),
84            reconnect_tx,
85        }
86    }
87
88    /// Subscribe to notifications on `channel`.
89    ///
90    /// Returns `None` if `channel` was not registered at construction time.
91    /// The returned receiver yields the raw NOTIFY payload string.
92    pub fn subscribe(&self, channel: &str) -> Option<broadcast::Receiver<String>> {
93        self.senders.get(channel).map(|tx| tx.subscribe())
94    }
95
96    /// Returns the set of channel names this bus is configured for.
97    pub fn channels(&self) -> Vec<&str> {
98        self.senders.keys().map(|s| s.as_str()).collect()
99    }
100
101    /// Subscribe to reconnect events.
102    ///
103    /// The returned receiver's value is bumped once per successful (re)connect
104    /// of the underlying `PgListener`. The initial connect publishes `1`;
105    /// subsequent reconnects publish `2`, `3`, etc. Subscribers that want to
106    /// trigger gap recovery on reconnect should:
107    ///
108    /// 1. Call `subscribe_reconnects()` and snapshot the current generation
109    ///    via `*rx.borrow()`.
110    /// 2. In their main loop, `select!` on `rx.changed()` alongside their
111    ///    payload `recv()`.
112    /// 3. On a change, compare `*rx.borrow()` to the snapshot and only treat
113    ///    it as a reconnect if it is strictly greater than the snapshot —
114    ///    this filters the first-boot connect for subscribers that attach
115    ///    before `run()` succeeds.
116    pub fn subscribe_reconnects(&self) -> watch::Receiver<u64> {
117        self.reconnect_tx.subscribe()
118    }
119
120    /// Run the listener loop until `shutdown` fires.
121    ///
122    /// This must be spawned as a background task. It owns the single
123    /// `PgListener` connection, reconnects on failure, and fans out every
124    /// received notification to the matching broadcast channel.
125    pub async fn run(&self, shutdown: tokio::sync::watch::Receiver<bool>) {
126        let channel_names: Vec<String> = self.senders.keys().cloned().collect();
127        let mut backoff = INITIAL_BACKOFF;
128        let mut shutdown = shutdown;
129
130        loop {
131            let listener = match self.connect_and_listen(&channel_names).await {
132                Ok(l) => {
133                    backoff = INITIAL_BACKOFF;
134                    // Bump the reconnect generation. Subscribers compare the
135                    // observed value against the snapshot they took at
136                    // subscribe-time, so the first connect (generation 1)
137                    // only fires gap recovery for late subscribers — which
138                    // is the safe behaviour anyway since a late subscriber
139                    // could have missed events before attaching.
140                    self.reconnect_tx.send_modify(|g| *g = g.saturating_add(1));
141                    l
142                }
143                Err(e) => {
144                    tracing::warn!(error = %e, "PgNotifyBus: connect/listen failed, retrying");
145                    tokio::select! {
146                        biased;
147                        _ = shutdown.changed() => {
148                            if *shutdown.borrow() {
149                                tracing::debug!("PgNotifyBus: shutdown during reconnect");
150                                return;
151                            }
152                        }
153                        _ = tokio::time::sleep(backoff) => {}
154                    }
155                    backoff = (backoff * 2).min(MAX_BACKOFF);
156                    continue;
157                }
158            };
159
160            tracing::info!(
161                channels = ?channel_names,
162                "PgNotifyBus: listening on {} channel(s)",
163                channel_names.len(),
164            );
165
166            if self.recv_loop(listener, &mut shutdown).await {
167                tracing::debug!("PgNotifyBus: shutting down");
168                return;
169            }
170
171            tracing::warn!("PgNotifyBus: connection lost, reconnecting");
172        }
173    }
174
175    /// Connect a `PgListener` and LISTEN on every channel. Returns the
176    /// ready listener or the first error encountered.
177    async fn connect_and_listen(
178        &self,
179        channels: &[String],
180    ) -> Result<sqlx::postgres::PgListener, sqlx::Error> {
181        let mut listener = sqlx::postgres::PgListener::connect_with(&self.pool).await?;
182        for ch in channels {
183            listener.listen(ch).await?;
184        }
185        Ok(listener)
186    }
187
188    /// Receive notifications and fan out. Returns `true` if shutdown was
189    /// requested, `false` if the connection broke.
190    async fn recv_loop(
191        &self,
192        mut listener: sqlx::postgres::PgListener,
193        shutdown: &mut tokio::sync::watch::Receiver<bool>,
194    ) -> bool {
195        loop {
196            tokio::select! {
197                biased;
198                _ = shutdown.changed() => {
199                    if *shutdown.borrow() {
200                        return true;
201                    }
202                }
203                notification = listener.recv() => {
204                    match notification {
205                        Ok(n) => {
206                            let channel = n.channel();
207                            let payload = n.payload().to_string();
208                            if let Some(tx) = self.senders.get(channel) {
209                                // Ignore send errors — they mean no active receivers.
210                                let _ = tx.send(payload);
211                            } else {
212                                tracing::debug!(
213                                    channel = channel,
214                                    "PgNotifyBus: notification on unregistered channel, ignoring",
215                                );
216                            }
217                        }
218                        Err(e) => {
219                            tracing::warn!(error = %e, "PgNotifyBus: recv error");
220                            return false;
221                        }
222                    }
223                }
224            }
225        }
226    }
227}
228
229#[cfg(test)]
230#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
231mod tests {
232    use super::*;
233
234    fn make_bus(channels: &[&str]) -> PgNotifyBus {
235        let pool = sqlx::PgPool::connect_lazy("postgres://localhost/test").unwrap();
236        PgNotifyBus::new(pool, channels)
237    }
238
239    #[tokio::test]
240    async fn subscribe_returns_receiver_for_registered_channel() {
241        let bus = make_bus(&["forge_changes", "forge_jobs_available"]);
242        assert!(bus.subscribe("forge_changes").is_some());
243        assert!(bus.subscribe("forge_jobs_available").is_some());
244    }
245
246    #[tokio::test]
247    async fn subscribe_returns_none_for_unknown_channel() {
248        let bus = make_bus(&["forge_changes"]);
249        assert!(bus.subscribe("forge_nonexistent").is_none());
250    }
251
252    #[tokio::test]
253    async fn channels_returns_all_registered_names() {
254        let bus = make_bus(&[
255            "forge_changes",
256            "forge_jobs_available",
257            "forge_workflow_wakeup",
258        ]);
259        let mut names = bus.channels();
260        names.sort();
261        assert_eq!(
262            names,
263            vec![
264                "forge_changes",
265                "forge_jobs_available",
266                "forge_workflow_wakeup"
267            ],
268        );
269    }
270
271    #[tokio::test]
272    async fn fan_out_delivers_to_all_subscribers() {
273        let bus = make_bus(&["test_channel"]);
274        let mut rx1 = bus.subscribe("test_channel").unwrap();
275        let mut rx2 = bus.subscribe("test_channel").unwrap();
276
277        // Simulate what the recv loop does: send directly on the broadcast.
278        let tx = bus.senders.get("test_channel").unwrap();
279        tx.send("hello".to_string()).unwrap();
280
281        assert_eq!(rx1.recv().await.unwrap(), "hello");
282        assert_eq!(rx2.recv().await.unwrap(), "hello");
283    }
284
285    #[tokio::test]
286    async fn send_without_subscribers_does_not_error() {
287        let bus = make_bus(&["test_channel"]);
288        let tx = bus.senders.get("test_channel").unwrap();
289        // No subscribers — send returns Err but the bus must not care.
290        let _ = tx.send("orphan".to_string());
291    }
292
293    #[tokio::test]
294    async fn empty_channels_list_produces_empty_bus() {
295        let bus = make_bus(&[]);
296        assert!(bus.channels().is_empty());
297        assert!(bus.subscribe("anything").is_none());
298    }
299
300    #[tokio::test]
301    async fn reconnect_subscriber_starts_at_zero_and_observes_ticks() {
302        // The reconnect generation starts at 0 so a freshly-attached
303        // subscriber can distinguish the boot-time first connect from a
304        // genuine reconnect by snapshotting the value at subscribe time.
305        // Every successful (re)connect bumps the generation by one — we
306        // simulate that by directly calling `send_modify` the way `run()`
307        // does, since the real connect path requires a live PG backend.
308        let bus = make_bus(&["test_channel"]);
309        let mut rx = bus.subscribe_reconnects();
310        assert_eq!(*rx.borrow(), 0, "fresh bus starts at generation 0");
311
312        // First connect.
313        bus.reconnect_tx.send_modify(|g| *g = g.saturating_add(1));
314        rx.changed().await.unwrap();
315        assert_eq!(*rx.borrow(), 1, "first connect publishes generation 1");
316
317        // Reconnect.
318        bus.reconnect_tx.send_modify(|g| *g = g.saturating_add(1));
319        rx.changed().await.unwrap();
320        assert_eq!(*rx.borrow(), 2, "reconnect publishes generation 2");
321    }
322}