Skip to main content

forge_runtime/pg/
notify.rs

1//! Typed wrapper around `pg_notify` and `LISTEN`.
2//!
3//! Doctrine: every consumer that ships a typed JSON payload over PostgreSQL
4//! NOTIFY goes through [`NotifyChannel`]. Custom string-format channels (the
5//! `forge_changes` `v1:table:OP:row_id[#seq]` payload, for example) stay as
6//! they are because their wire shape is part of the public schema and can
7//! evolve independently of any Rust type.
8//!
9//! # Why a typed channel
10//!
11//! Without it, every site reinvents:
12//!
13//! - JSON serialization (`serde_json::to_string`) and deserialization on the
14//!   listening side.
15//! - The 8 KiB PostgreSQL `NOTIFY` payload limit. Hit that limit at runtime
16//!   and the publish fails silently from the application's perspective; the
17//!   trigger raises `ERROR:  payload string too long` and the wrapping
18//!   transaction rolls back.
19//! - `PgListener::connect_with` plus `listen` plus a `recv()` loop with
20//!   parse-and-skip handling.
21//!
22//! [`NotifyChannel`] centralises these. It enforces a 7 KiB ceiling on the
23//! serialized payload to leave headroom for PG framing, returns a typed
24//! `ForgeError::InvalidArgument` when the caller exceeds it, and exposes a
25//! `Stream<Item = T>` for subscribers so the listener loop is no longer
26//! every consumer's problem.
27//!
28//! # When the payload is too big
29//!
30//! For records that don't fit, write the full row to `forge_change_log` and
31//! publish only the row id over the channel. Subscribers fetch the body from
32//! the log when they receive the notification. Helpers for the change-log
33//! side live in [`crate::pg::change_log`] (forthcoming).
34
35use std::marker::PhantomData;
36
37use futures_util::stream::{Stream, StreamExt};
38use serde::Serialize;
39use serde::de::DeserializeOwned;
40use sqlx::PgExecutor;
41use sqlx::postgres::PgListener;
42
43use forge_core::error::{ForgeError, Result};
44
45/// Maximum serialized JSON payload bytes. PostgreSQL caps `NOTIFY` payloads
46/// at 8000 bytes; we reserve ~1 KiB for PG framing, channel name, and the
47/// `pg_notify` SQL wrapper.
48pub const MAX_PAYLOAD_BYTES: usize = 7 * 1024;
49
50/// Typed handle to a single PostgreSQL `NOTIFY` channel.
51///
52/// `T` is the JSON payload shape. Construct one per channel as a `const`-ish
53/// value (`name` is `&'static str`) and reuse it everywhere that channel is
54/// touched, so publish and subscribe sites can never disagree on the shape.
55pub struct NotifyChannel<T> {
56    name: &'static str,
57    _marker: PhantomData<fn(T) -> T>,
58}
59
60impl<T> NotifyChannel<T> {
61    /// Create a typed channel handle. `name` is the PostgreSQL channel
62    /// identifier passed to `pg_notify` and `LISTEN`; it must be a valid
63    /// SQL identifier (the framework uses snake_case `forge_*` names).
64    pub const fn new(name: &'static str) -> Self {
65        Self {
66            name,
67            _marker: PhantomData,
68        }
69    }
70
71    /// PostgreSQL channel name.
72    pub const fn name(&self) -> &'static str {
73        self.name
74    }
75}
76
77impl<T> NotifyChannel<T>
78where
79    T: Serialize,
80{
81    /// Publish `payload` on this channel.
82    ///
83    /// Errors:
84    /// - `ForgeError::Serialization` if `serde_json::to_string(payload)` fails.
85    /// - `ForgeError::InvalidArgument` if the serialized payload exceeds
86    ///   [`MAX_PAYLOAD_BYTES`]. Use the change-log fallback for larger bodies.
87    /// - `ForgeError::Database` if the underlying `SELECT pg_notify(...)`
88    ///   fails (transaction rolled back, connection dropped, etc.).
89    pub async fn publish<'e, E>(&self, executor: E, payload: &T) -> Result<()>
90    where
91        E: PgExecutor<'e>,
92    {
93        let body =
94            serde_json::to_string(payload).map_err(|e| ForgeError::Serialization(e.to_string()))?;
95        if body.len() > MAX_PAYLOAD_BYTES {
96            return Err(ForgeError::InvalidArgument(format!(
97                "NotifyChannel `{}` payload is {} bytes, exceeds {} byte limit; \
98                 write the body to forge_change_log and emit only the row id",
99                self.name,
100                body.len(),
101                MAX_PAYLOAD_BYTES,
102            )));
103        }
104        crate::observability::record_notify_payload_bytes(self.name, body.len());
105        sqlx::query!("SELECT pg_notify($1, $2)", self.name, &body)
106            .execute(executor)
107            .await
108            .map_err(ForgeError::Database)?;
109        Ok(())
110    }
111}
112
113impl<T> NotifyChannel<T>
114where
115    T: DeserializeOwned + Send + 'static,
116{
117    /// Subscribe to this channel and return a stream of decoded payloads.
118    ///
119    /// `listener` is consumed; the caller surrenders the connection to the
120    /// stream for the duration of the subscription. Notifications whose
121    /// payload fails JSON decoding are logged and skipped, so a malformed
122    /// publish from one peer cannot tear down a long-running subscriber.
123    /// Errors from the underlying `recv` (connection dropped, etc.) end the
124    /// stream; the caller decides whether to reconnect.
125    pub async fn subscribe(&self, mut listener: PgListener) -> Result<impl Stream<Item = T>> {
126        listener
127            .listen(self.name)
128            .await
129            .map_err(ForgeError::Database)?;
130        let channel_name = self.name;
131        let raw = listener.into_stream();
132        let stream = raw
133            .take_while(|res| {
134                let cont = res.is_ok();
135                async move { cont }
136            })
137            .filter_map(move |res| async move {
138                let notification = match res {
139                    Ok(n) => n,
140                    Err(_) => return None,
141                };
142                match serde_json::from_str::<T>(notification.payload()) {
143                    Ok(value) => Some(value),
144                    Err(e) => {
145                        tracing::debug!(
146                            channel = channel_name,
147                            error = %e,
148                            payload = notification.payload(),
149                            "NotifyChannel: dropping malformed payload",
150                        );
151                        None
152                    }
153                }
154            });
155        Ok(stream)
156    }
157}
158
159#[cfg(test)]
160#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
161mod unit_tests {
162    use super::*;
163
164    #[derive(serde::Serialize)]
165    struct Tiny {
166        v: u32,
167    }
168
169    #[test]
170    fn channel_constructor_records_name() {
171        const CH: NotifyChannel<Tiny> = NotifyChannel::new("forge_test_channel");
172        assert_eq!(CH.name(), "forge_test_channel");
173    }
174
175    #[test]
176    fn max_payload_bytes_stays_below_pg_notify_ceiling() {
177        // PG caps NOTIFY at 8000 bytes. The constant must stay strictly under
178        // that with at least some framing headroom — verify we didn't silently
179        // bump it to or past the PG ceiling.
180        const _: () = assert!(MAX_PAYLOAD_BYTES < 8000);
181        const _: () = assert!(MAX_PAYLOAD_BYTES == 7 * 1024);
182    }
183
184    #[test]
185    fn channel_handle_is_zero_sized() {
186        // PhantomData<fn(T) -> T> + a single &'static str pointer should keep
187        // the channel handle as small as a pointer. Verifies we didn't
188        // accidentally grow the struct.
189        use std::mem::size_of;
190        assert_eq!(size_of::<NotifyChannel<Tiny>>(), size_of::<&'static str>());
191    }
192}
193
194#[cfg(all(test, feature = "testcontainers"))]
195#[allow(
196    clippy::unwrap_used,
197    clippy::indexing_slicing,
198    clippy::panic,
199    clippy::disallowed_methods
200)]
201mod integration_tests {
202    use super::*;
203    use forge_core::testing::{IsolatedTestDb, TestDatabase};
204    use serde::Deserialize;
205    use sqlx::postgres::PgListener;
206    use std::time::Duration;
207
208    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209    struct Wakeup {
210        id: i64,
211        kind: String,
212    }
213
214    async fn setup_db(test_name: &str) -> IsolatedTestDb {
215        let base = TestDatabase::from_env()
216            .await
217            .expect("Failed to create test database");
218        base.isolated(test_name)
219            .await
220            .expect("Failed to create isolated db")
221    }
222
223    #[tokio::test]
224    async fn publish_then_subscribe_round_trip() {
225        let db = setup_db("notify_round_trip").await;
226        let channel: NotifyChannel<Wakeup> = NotifyChannel::new("forge_test_notify_round_trip");
227
228        let listener = PgListener::connect_with(db.pool()).await.unwrap();
229        let mut stream = Box::pin(channel.subscribe(listener).await.unwrap());
230
231        // Publish must happen on a separate connection so the listener
232        // (a different backend) actually sees the NOTIFY.
233        let payload = Wakeup {
234            id: 42,
235            kind: "test".into(),
236        };
237        channel.publish(db.pool(), &payload).await.unwrap();
238
239        let received = tokio::time::timeout(Duration::from_secs(5), stream.next())
240            .await
241            .expect("stream did not yield within 5s")
242            .expect("stream ended before yielding");
243        assert_eq!(received, payload);
244    }
245
246    #[tokio::test]
247    async fn publish_rejects_oversize_payload() {
248        let db = setup_db("notify_oversize").await;
249        let channel: NotifyChannel<String> = NotifyChannel::new("forge_test_notify_oversize");
250
251        let big = "x".repeat(MAX_PAYLOAD_BYTES + 1);
252        let err = channel.publish(db.pool(), &big).await.unwrap_err();
253        assert!(matches!(err, ForgeError::InvalidArgument(_)));
254        let msg = err.to_string();
255        assert!(
256            msg.contains("forge_change_log"),
257            "error should hint at the change-log fallback, got: {msg}",
258        );
259    }
260
261    #[tokio::test]
262    async fn subscribe_skips_malformed_payloads() {
263        let db = setup_db("notify_malformed").await;
264        // Subscriber expects {id, kind}; we will publish a non-JSON string and
265        // then a real payload, and assert only the real one comes through.
266        let channel: NotifyChannel<Wakeup> = NotifyChannel::new("forge_test_notify_malformed");
267        let listener = PgListener::connect_with(db.pool()).await.unwrap();
268        let mut stream = Box::pin(channel.subscribe(listener).await.unwrap());
269
270        // Publish raw string via SQL (bypasses NotifyChannel's typed publish)
271        sqlx::query("SELECT pg_notify($1, $2)")
272            .bind("forge_test_notify_malformed")
273            .bind("not-json")
274            .execute(db.pool())
275            .await
276            .unwrap();
277
278        // Then a real payload through the typed publisher.
279        let payload = Wakeup {
280            id: 7,
281            kind: "ok".into(),
282        };
283        channel.publish(db.pool(), &payload).await.unwrap();
284
285        let received = tokio::time::timeout(Duration::from_secs(5), stream.next())
286            .await
287            .expect("stream did not yield within 5s")
288            .expect("stream ended before yielding");
289        assert_eq!(received, payload);
290    }
291}