forge_runtime/pg/
notify.rs1use std::marker::PhantomData;
36
37use futures_util::stream::{Stream, StreamExt};
38use serde::Serialize;
39use serde::de::DeserializeOwned;
40use sqlx::PgExecutor;
41use sqlx::postgres::PgListener;
42
43use forge_core::error::{ForgeError, Result};
44
45pub const MAX_PAYLOAD_BYTES: usize = 7 * 1024;
49
50pub struct NotifyChannel<T> {
56 name: &'static str,
57 _marker: PhantomData<fn(T) -> T>,
58}
59
60impl<T> NotifyChannel<T> {
61 pub const fn new(name: &'static str) -> Self {
65 Self {
66 name,
67 _marker: PhantomData,
68 }
69 }
70
71 pub const fn name(&self) -> &'static str {
73 self.name
74 }
75}
76
77impl<T> NotifyChannel<T>
78where
79 T: Serialize,
80{
81 pub async fn publish<'e, E>(&self, executor: E, payload: &T) -> Result<()>
90 where
91 E: PgExecutor<'e>,
92 {
93 let body =
94 serde_json::to_string(payload).map_err(|e| ForgeError::Serialization(e.to_string()))?;
95 if body.len() > MAX_PAYLOAD_BYTES {
96 return Err(ForgeError::InvalidArgument(format!(
97 "NotifyChannel `{}` payload is {} bytes, exceeds {} byte limit; \
98 write the body to forge_change_log and emit only the row id",
99 self.name,
100 body.len(),
101 MAX_PAYLOAD_BYTES,
102 )));
103 }
104 crate::observability::record_notify_payload_bytes(self.name, body.len());
105 sqlx::query!("SELECT pg_notify($1, $2)", self.name, &body)
106 .execute(executor)
107 .await
108 .map_err(ForgeError::Database)?;
109 Ok(())
110 }
111}
112
113impl<T> NotifyChannel<T>
114where
115 T: DeserializeOwned + Send + 'static,
116{
117 pub async fn subscribe(&self, mut listener: PgListener) -> Result<impl Stream<Item = T>> {
126 listener
127 .listen(self.name)
128 .await
129 .map_err(ForgeError::Database)?;
130 let channel_name = self.name;
131 let raw = listener.into_stream();
132 let stream = raw
133 .take_while(|res| {
134 let cont = res.is_ok();
135 async move { cont }
136 })
137 .filter_map(move |res| async move {
138 let notification = match res {
139 Ok(n) => n,
140 Err(_) => return None,
141 };
142 match serde_json::from_str::<T>(notification.payload()) {
143 Ok(value) => Some(value),
144 Err(e) => {
145 tracing::debug!(
146 channel = channel_name,
147 error = %e,
148 payload = notification.payload(),
149 "NotifyChannel: dropping malformed payload",
150 );
151 None
152 }
153 }
154 });
155 Ok(stream)
156 }
157}
158
159#[cfg(test)]
160#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
161mod unit_tests {
162 use super::*;
163
164 #[derive(serde::Serialize)]
165 struct Tiny {
166 v: u32,
167 }
168
169 #[test]
170 fn channel_constructor_records_name() {
171 const CH: NotifyChannel<Tiny> = NotifyChannel::new("forge_test_channel");
172 assert_eq!(CH.name(), "forge_test_channel");
173 }
174
175 #[test]
176 fn max_payload_bytes_stays_below_pg_notify_ceiling() {
177 const _: () = assert!(MAX_PAYLOAD_BYTES < 8000);
181 const _: () = assert!(MAX_PAYLOAD_BYTES == 7 * 1024);
182 }
183
184 #[test]
185 fn channel_handle_is_zero_sized() {
186 use std::mem::size_of;
190 assert_eq!(size_of::<NotifyChannel<Tiny>>(), size_of::<&'static str>());
191 }
192}
193
194#[cfg(all(test, feature = "testcontainers"))]
195#[allow(
196 clippy::unwrap_used,
197 clippy::indexing_slicing,
198 clippy::panic,
199 clippy::disallowed_methods
200)]
201mod integration_tests {
202 use super::*;
203 use forge_core::testing::{IsolatedTestDb, TestDatabase};
204 use serde::Deserialize;
205 use sqlx::postgres::PgListener;
206 use std::time::Duration;
207
208 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209 struct Wakeup {
210 id: i64,
211 kind: String,
212 }
213
214 async fn setup_db(test_name: &str) -> IsolatedTestDb {
215 let base = TestDatabase::from_env()
216 .await
217 .expect("Failed to create test database");
218 base.isolated(test_name)
219 .await
220 .expect("Failed to create isolated db")
221 }
222
223 #[tokio::test]
224 async fn publish_then_subscribe_round_trip() {
225 let db = setup_db("notify_round_trip").await;
226 let channel: NotifyChannel<Wakeup> = NotifyChannel::new("forge_test_notify_round_trip");
227
228 let listener = PgListener::connect_with(db.pool()).await.unwrap();
229 let mut stream = Box::pin(channel.subscribe(listener).await.unwrap());
230
231 let payload = Wakeup {
234 id: 42,
235 kind: "test".into(),
236 };
237 channel.publish(db.pool(), &payload).await.unwrap();
238
239 let received = tokio::time::timeout(Duration::from_secs(5), stream.next())
240 .await
241 .expect("stream did not yield within 5s")
242 .expect("stream ended before yielding");
243 assert_eq!(received, payload);
244 }
245
246 #[tokio::test]
247 async fn publish_rejects_oversize_payload() {
248 let db = setup_db("notify_oversize").await;
249 let channel: NotifyChannel<String> = NotifyChannel::new("forge_test_notify_oversize");
250
251 let big = "x".repeat(MAX_PAYLOAD_BYTES + 1);
252 let err = channel.publish(db.pool(), &big).await.unwrap_err();
253 assert!(matches!(err, ForgeError::InvalidArgument(_)));
254 let msg = err.to_string();
255 assert!(
256 msg.contains("forge_change_log"),
257 "error should hint at the change-log fallback, got: {msg}",
258 );
259 }
260
261 #[tokio::test]
262 async fn subscribe_skips_malformed_payloads() {
263 let db = setup_db("notify_malformed").await;
264 let channel: NotifyChannel<Wakeup> = NotifyChannel::new("forge_test_notify_malformed");
267 let listener = PgListener::connect_with(db.pool()).await.unwrap();
268 let mut stream = Box::pin(channel.subscribe(listener).await.unwrap());
269
270 sqlx::query("SELECT pg_notify($1, $2)")
272 .bind("forge_test_notify_malformed")
273 .bind("not-json")
274 .execute(db.pool())
275 .await
276 .unwrap();
277
278 let payload = Wakeup {
280 id: 7,
281 kind: "ok".into(),
282 };
283 channel.publish(db.pool(), &payload).await.unwrap();
284
285 let received = tokio::time::timeout(Duration::from_secs(5), stream.next())
286 .await
287 .expect("stream did not yield within 5s")
288 .expect("stream ended before yielding");
289 assert_eq!(received, payload);
290 }
291}