use std::marker::PhantomData;
use futures_util::stream::{Stream, StreamExt};
use serde::Serialize;
use serde::de::DeserializeOwned;
use sqlx::PgExecutor;
use sqlx::postgres::PgListener;
use forge_core::error::{ForgeError, Result};
pub const MAX_PAYLOAD_BYTES: usize = 7 * 1024;
pub struct NotifyChannel<T> {
name: &'static str,
_marker: PhantomData<fn(T) -> T>,
}
impl<T> NotifyChannel<T> {
pub const fn new(name: &'static str) -> Self {
Self {
name,
_marker: PhantomData,
}
}
pub const fn name(&self) -> &'static str {
self.name
}
}
impl<T> NotifyChannel<T>
where
T: Serialize,
{
pub async fn publish<'e, E>(&self, executor: E, payload: &T) -> Result<()>
where
E: PgExecutor<'e>,
{
let body =
serde_json::to_string(payload).map_err(|e| ForgeError::Serialization(e.to_string()))?;
if body.len() > MAX_PAYLOAD_BYTES {
return Err(ForgeError::InvalidArgument(format!(
"NotifyChannel `{}` payload is {} bytes, exceeds {} byte limit; \
write the body to forge_change_log and emit only the row id",
self.name,
body.len(),
MAX_PAYLOAD_BYTES,
)));
}
crate::observability::record_notify_payload_bytes(self.name, body.len());
sqlx::query!("SELECT pg_notify($1, $2)", self.name, &body)
.execute(executor)
.await
.map_err(ForgeError::Database)?;
Ok(())
}
}
impl<T> NotifyChannel<T>
where
T: DeserializeOwned + Send + 'static,
{
pub async fn subscribe(&self, mut listener: PgListener) -> Result<impl Stream<Item = T>> {
listener
.listen(self.name)
.await
.map_err(ForgeError::Database)?;
let channel_name = self.name;
let raw = listener.into_stream();
let stream = raw
.take_while(|res| {
let cont = res.is_ok();
async move { cont }
})
.filter_map(move |res| async move {
let notification = match res {
Ok(n) => n,
Err(_) => return None,
};
match serde_json::from_str::<T>(notification.payload()) {
Ok(value) => Some(value),
Err(e) => {
tracing::debug!(
channel = channel_name,
error = %e,
payload = notification.payload(),
"NotifyChannel: dropping malformed payload",
);
None
}
}
});
Ok(stream)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod unit_tests {
use super::*;
#[derive(serde::Serialize)]
struct Tiny {
v: u32,
}
#[test]
fn channel_constructor_records_name() {
const CH: NotifyChannel<Tiny> = NotifyChannel::new("forge_test_channel");
assert_eq!(CH.name(), "forge_test_channel");
}
#[test]
fn max_payload_bytes_stays_below_pg_notify_ceiling() {
const _: () = assert!(MAX_PAYLOAD_BYTES < 8000);
const _: () = assert!(MAX_PAYLOAD_BYTES == 7 * 1024);
}
#[test]
fn channel_handle_is_zero_sized() {
use std::mem::size_of;
assert_eq!(size_of::<NotifyChannel<Tiny>>(), size_of::<&'static str>());
}
}
#[cfg(all(test, feature = "testcontainers"))]
#[allow(
clippy::unwrap_used,
clippy::indexing_slicing,
clippy::panic,
clippy::disallowed_methods
)]
mod integration_tests {
use super::*;
use forge_core::testing::{IsolatedTestDb, TestDatabase};
use serde::Deserialize;
use sqlx::postgres::PgListener;
use std::time::Duration;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct Wakeup {
id: i64,
kind: String,
}
async fn setup_db(test_name: &str) -> IsolatedTestDb {
let base = TestDatabase::from_env()
.await
.expect("Failed to create test database");
base.isolated(test_name)
.await
.expect("Failed to create isolated db")
}
#[tokio::test]
async fn publish_then_subscribe_round_trip() {
let db = setup_db("notify_round_trip").await;
let channel: NotifyChannel<Wakeup> = NotifyChannel::new("forge_test_notify_round_trip");
let listener = PgListener::connect_with(db.pool()).await.unwrap();
let mut stream = Box::pin(channel.subscribe(listener).await.unwrap());
let payload = Wakeup {
id: 42,
kind: "test".into(),
};
channel.publish(db.pool(), &payload).await.unwrap();
let received = tokio::time::timeout(Duration::from_secs(5), stream.next())
.await
.expect("stream did not yield within 5s")
.expect("stream ended before yielding");
assert_eq!(received, payload);
}
#[tokio::test]
async fn publish_rejects_oversize_payload() {
let db = setup_db("notify_oversize").await;
let channel: NotifyChannel<String> = NotifyChannel::new("forge_test_notify_oversize");
let big = "x".repeat(MAX_PAYLOAD_BYTES + 1);
let err = channel.publish(db.pool(), &big).await.unwrap_err();
assert!(matches!(err, ForgeError::InvalidArgument(_)));
let msg = err.to_string();
assert!(
msg.contains("forge_change_log"),
"error should hint at the change-log fallback, got: {msg}",
);
}
#[tokio::test]
async fn subscribe_skips_malformed_payloads() {
let db = setup_db("notify_malformed").await;
let channel: NotifyChannel<Wakeup> = NotifyChannel::new("forge_test_notify_malformed");
let listener = PgListener::connect_with(db.pool()).await.unwrap();
let mut stream = Box::pin(channel.subscribe(listener).await.unwrap());
sqlx::query("SELECT pg_notify($1, $2)")
.bind("forge_test_notify_malformed")
.bind("not-json")
.execute(db.pool())
.await
.unwrap();
let payload = Wakeup {
id: 7,
kind: "ok".into(),
};
channel.publish(db.pool(), &payload).await.unwrap();
let received = tokio::time::timeout(Duration::from_secs(5), stream.next())
.await
.expect("stream did not yield within 5s")
.expect("stream ended before yielding");
assert_eq!(received, payload);
}
}