forge_runtime/pg/notify_bus.rs
1//! Process-wide PostgreSQL LISTEN/NOTIFY multiplexer.
2//!
3//! Reduces O(N) `PgListener` connections (one per subsystem) to a single
4//! connection that fans out received notifications to per-channel broadcast
5//! subscribers. Subsystems call [`PgNotifyBus::subscribe`] to obtain a
6//! `broadcast::Receiver<String>` for their channel instead of managing their
7//! own `PgListener` lifecycle and reconnection logic.
8//!
9//! # Reconnection
10//!
11//! The bus owns the only LISTEN connection. When that connection drops, the
12//! run loop reconnects with exponential backoff (500 ms to 30 s) and
13//! re-issues LISTEN on every registered channel. Subscribers see no
14//! interruption other than a brief gap in notifications (the same gap the
15//! old per-subsystem listeners had, except now there is exactly one
16//! reconnect path to maintain).
17//!
18//! # Payload semantics
19//!
20//! The bus forwards the raw `notification.payload()` string. Channels that
21//! use structured JSON payloads (typed via [`NotifyChannel`]) decode on the
22//! subscriber side, just as before. Channels with custom string formats
23//! (e.g. `forge_changes` with its `v1:table:OP:...` wire format) pass
24//! through unmodified.
25
26use std::collections::HashMap;
27use std::sync::Arc;
28use std::time::Duration;
29
30use tokio::sync::{broadcast, watch};
31
32/// Per-channel broadcast buffer size. Subscribers that fall behind by more
33/// than this many messages will see `RecvError::Lagged` and can decide
34/// whether to catch up or resync.
35const CHANNEL_BUFFER_SIZE: usize = 256;
36
37/// Initial reconnection delay after a `PgListener` disconnect.
38const INITIAL_BACKOFF: Duration = Duration::from_millis(500);
39
40/// Maximum reconnection delay.
41const MAX_BACKOFF: Duration = Duration::from_secs(30);
42
43/// Process-wide PostgreSQL LISTEN multiplexer.
44///
45/// Create one at runtime startup, register all channels up front, then call
46/// [`run`](Self::run) in a background task. Subsystems obtain a
47/// `broadcast::Receiver<String>` via [`subscribe`](Self::subscribe).
48pub struct PgNotifyBus {
49 pool: sqlx::PgPool,
50 /// Channel name -> broadcast sender. Populated at construction time;
51 /// immutable afterwards.
52 senders: Arc<HashMap<String, broadcast::Sender<String>>>,
53 /// Ticks once per successful (re)connect. The initial connect publishes
54 /// generation `1`; subsequent reconnects publish `2`, `3`, ... so a
55 /// subscriber that snapshots the value at startup can detect a reconnect
56 /// strictly later than its own start without being woken by the boot
57 /// connect. Subscribers that need replay-on-reconnect call
58 /// [`subscribe_reconnects`](Self::subscribe_reconnects) and react to
59 /// changes whose value is greater than the snapshot they observed when
60 /// they subscribed.
61 reconnect_tx: watch::Sender<u64>,
62}
63
64impl PgNotifyBus {
65 /// Create a new notify bus for the given channels.
66 ///
67 /// Each channel gets a `broadcast::channel(256)` so subscribers can lag
68 /// slightly without losing messages. The bus does not start listening
69 /// until [`run`](Self::run) is called.
70 pub fn new(pool: sqlx::PgPool, channels: &[&str]) -> Self {
71 let mut senders = HashMap::with_capacity(channels.len());
72 for &ch in channels {
73 let (tx, _) = broadcast::channel(CHANNEL_BUFFER_SIZE);
74 senders.insert(ch.to_string(), tx);
75 }
76 // Generation starts at 0 so the first successful connect (which
77 // publishes 1) is distinguishable from "never connected" and the
78 // value a subscriber sees the moment they call
79 // `subscribe_reconnects()` doesn't already look like a reconnect.
80 let (reconnect_tx, _) = watch::channel(0u64);
81 Self {
82 pool,
83 senders: Arc::new(senders),
84 reconnect_tx,
85 }
86 }
87
88 /// Subscribe to notifications on `channel`.
89 ///
90 /// Returns `None` if `channel` was not registered at construction time.
91 /// The returned receiver yields the raw NOTIFY payload string.
92 pub fn subscribe(&self, channel: &str) -> Option<broadcast::Receiver<String>> {
93 self.senders.get(channel).map(|tx| tx.subscribe())
94 }
95
96 /// Returns the set of channel names this bus is configured for.
97 pub fn channels(&self) -> Vec<&str> {
98 self.senders.keys().map(|s| s.as_str()).collect()
99 }
100
101 /// Subscribe to reconnect events.
102 ///
103 /// The returned receiver's value is bumped once per successful (re)connect
104 /// of the underlying `PgListener`. The initial connect publishes `1`;
105 /// subsequent reconnects publish `2`, `3`, etc. Subscribers that want to
106 /// trigger gap recovery on reconnect should:
107 ///
108 /// 1. Call `subscribe_reconnects()` and snapshot the current generation
109 /// via `*rx.borrow()`.
110 /// 2. In their main loop, `select!` on `rx.changed()` alongside their
111 /// payload `recv()`.
112 /// 3. On a change, compare `*rx.borrow()` to the snapshot and only treat
113 /// it as a reconnect if it is strictly greater than the snapshot —
114 /// this filters the first-boot connect for subscribers that attach
115 /// before `run()` succeeds.
116 pub fn subscribe_reconnects(&self) -> watch::Receiver<u64> {
117 self.reconnect_tx.subscribe()
118 }
119
120 /// Run the listener loop until `shutdown` fires.
121 ///
122 /// This must be spawned as a background task. It owns the single
123 /// `PgListener` connection, reconnects on failure, and fans out every
124 /// received notification to the matching broadcast channel.
125 pub async fn run(&self, shutdown: tokio::sync::watch::Receiver<bool>) {
126 let channel_names: Vec<String> = self.senders.keys().cloned().collect();
127 let mut backoff = INITIAL_BACKOFF;
128 let mut shutdown = shutdown;
129
130 loop {
131 let listener = match self.connect_and_listen(&channel_names).await {
132 Ok(l) => {
133 backoff = INITIAL_BACKOFF;
134 // Bump the reconnect generation. Subscribers compare the
135 // observed value against the snapshot they took at
136 // subscribe-time, so the first connect (generation 1)
137 // only fires gap recovery for late subscribers — which
138 // is the safe behaviour anyway since a late subscriber
139 // could have missed events before attaching.
140 self.reconnect_tx.send_modify(|g| *g = g.saturating_add(1));
141 l
142 }
143 Err(e) => {
144 tracing::warn!(error = %e, "PgNotifyBus: connect/listen failed, retrying");
145 tokio::select! {
146 biased;
147 _ = shutdown.changed() => {
148 if *shutdown.borrow() {
149 tracing::debug!("PgNotifyBus: shutdown during reconnect");
150 return;
151 }
152 }
153 _ = tokio::time::sleep(backoff) => {}
154 }
155 backoff = (backoff * 2).min(MAX_BACKOFF);
156 continue;
157 }
158 };
159
160 tracing::info!(
161 channels = ?channel_names,
162 "PgNotifyBus: listening on {} channel(s)",
163 channel_names.len(),
164 );
165
166 if self.recv_loop(listener, &mut shutdown).await {
167 tracing::debug!("PgNotifyBus: shutting down");
168 return;
169 }
170
171 tracing::warn!("PgNotifyBus: connection lost, reconnecting");
172 }
173 }
174
175 /// Connect a `PgListener` and LISTEN on every channel. Returns the
176 /// ready listener or the first error encountered.
177 async fn connect_and_listen(
178 &self,
179 channels: &[String],
180 ) -> Result<sqlx::postgres::PgListener, sqlx::Error> {
181 let mut listener = sqlx::postgres::PgListener::connect_with(&self.pool).await?;
182 for ch in channels {
183 listener.listen(ch).await?;
184 }
185 Ok(listener)
186 }
187
188 /// Receive notifications and fan out. Returns `true` if shutdown was
189 /// requested, `false` if the connection broke.
190 async fn recv_loop(
191 &self,
192 mut listener: sqlx::postgres::PgListener,
193 shutdown: &mut tokio::sync::watch::Receiver<bool>,
194 ) -> bool {
195 loop {
196 tokio::select! {
197 biased;
198 _ = shutdown.changed() => {
199 if *shutdown.borrow() {
200 return true;
201 }
202 }
203 notification = listener.recv() => {
204 match notification {
205 Ok(n) => {
206 let channel = n.channel();
207 let payload = n.payload().to_string();
208 if let Some(tx) = self.senders.get(channel) {
209 // Ignore send errors — they mean no active receivers.
210 let _ = tx.send(payload);
211 } else {
212 tracing::debug!(
213 channel = channel,
214 "PgNotifyBus: notification on unregistered channel, ignoring",
215 );
216 }
217 }
218 Err(e) => {
219 tracing::warn!(error = %e, "PgNotifyBus: recv error");
220 return false;
221 }
222 }
223 }
224 }
225 }
226 }
227}
228
229#[cfg(test)]
230#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
231mod tests {
232 use super::*;
233
234 fn make_bus(channels: &[&str]) -> PgNotifyBus {
235 let pool = sqlx::PgPool::connect_lazy("postgres://localhost/test").unwrap();
236 PgNotifyBus::new(pool, channels)
237 }
238
239 #[tokio::test]
240 async fn subscribe_returns_receiver_for_registered_channel() {
241 let bus = make_bus(&["forge_changes", "forge_jobs_available"]);
242 assert!(bus.subscribe("forge_changes").is_some());
243 assert!(bus.subscribe("forge_jobs_available").is_some());
244 }
245
246 #[tokio::test]
247 async fn subscribe_returns_none_for_unknown_channel() {
248 let bus = make_bus(&["forge_changes"]);
249 assert!(bus.subscribe("forge_nonexistent").is_none());
250 }
251
252 #[tokio::test]
253 async fn channels_returns_all_registered_names() {
254 let bus = make_bus(&[
255 "forge_changes",
256 "forge_jobs_available",
257 "forge_workflow_wakeup",
258 ]);
259 let mut names = bus.channels();
260 names.sort();
261 assert_eq!(
262 names,
263 vec![
264 "forge_changes",
265 "forge_jobs_available",
266 "forge_workflow_wakeup"
267 ],
268 );
269 }
270
271 #[tokio::test]
272 async fn fan_out_delivers_to_all_subscribers() {
273 let bus = make_bus(&["test_channel"]);
274 let mut rx1 = bus.subscribe("test_channel").unwrap();
275 let mut rx2 = bus.subscribe("test_channel").unwrap();
276
277 // Simulate what the recv loop does: send directly on the broadcast.
278 let tx = bus.senders.get("test_channel").unwrap();
279 tx.send("hello".to_string()).unwrap();
280
281 assert_eq!(rx1.recv().await.unwrap(), "hello");
282 assert_eq!(rx2.recv().await.unwrap(), "hello");
283 }
284
285 #[tokio::test]
286 async fn send_without_subscribers_does_not_error() {
287 let bus = make_bus(&["test_channel"]);
288 let tx = bus.senders.get("test_channel").unwrap();
289 // No subscribers — send returns Err but the bus must not care.
290 let _ = tx.send("orphan".to_string());
291 }
292
293 #[tokio::test]
294 async fn empty_channels_list_produces_empty_bus() {
295 let bus = make_bus(&[]);
296 assert!(bus.channels().is_empty());
297 assert!(bus.subscribe("anything").is_none());
298 }
299
300 #[tokio::test]
301 async fn reconnect_subscriber_starts_at_zero_and_observes_ticks() {
302 // The reconnect generation starts at 0 so a freshly-attached
303 // subscriber can distinguish the boot-time first connect from a
304 // genuine reconnect by snapshotting the value at subscribe time.
305 // Every successful (re)connect bumps the generation by one — we
306 // simulate that by directly calling `send_modify` the way `run()`
307 // does, since the real connect path requires a live PG backend.
308 let bus = make_bus(&["test_channel"]);
309 let mut rx = bus.subscribe_reconnects();
310 assert_eq!(*rx.borrow(), 0, "fresh bus starts at generation 0");
311
312 // First connect.
313 bus.reconnect_tx.send_modify(|g| *g = g.saturating_add(1));
314 rx.changed().await.unwrap();
315 assert_eq!(*rx.borrow(), 1, "first connect publishes generation 1");
316
317 // Reconnect.
318 bus.reconnect_tx.send_modify(|g| *g = g.saturating_add(1));
319 rx.changed().await.unwrap();
320 assert_eq!(*rx.borrow(), 2, "reconnect publishes generation 2");
321 }
322}