1#[cfg(feature = "chrono")]
84use chrono::{DateTime, SecondsFormat, Utc};
85#[cfg(not(feature = "chrono"))]
86use std::time::SystemTime;
87
88use {
89 futures::{StreamExt, stream},
90 std::{
91 collections::BTreeMap,
92 fmt::{self, Display},
93 str::FromStr,
94 sync::{Arc, RwLock},
95 },
96 tokio::{
97 io::{AsyncRead, AsyncWrite},
98 task::JoinHandle,
99 },
100 tokio_postgres::{
101 AsyncMessage, Client as PGClient, Connection as PGConnection, Notification, error::DbError,
102 },
103};
104
105pub type NotifyCallbacks =
107 Arc<RwLock<BTreeMap<String, Vec<Box<dyn for<'a> Fn(&'a PGNotify) + Send + Sync + 'static>>>>>;
108
109pub type RaiseCallbacks =
111 Arc<RwLock<Vec<Box<dyn for<'a> Fn(&'a PGRaise) + Send + Sync + 'static>>>>;
112
113pub struct PGNotifier {
117 pub client: PGClient,
118 listen_handle: JoinHandle<()>,
119 log: Arc<RwLock<Option<Vec<PGRaise>>>>,
120 raise_callbacks: RaiseCallbacks,
121 notify_callbacks: NotifyCallbacks,
122}
123
124impl PGNotifier {
125 pub fn spawn<S, T>(client: PGClient, mut conn: PGConnection<S, T>) -> Self
129 where
130 S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
131 T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
132 {
133 let log = Arc::new(RwLock::new(Some(Vec::default())));
134 let notify_callbacks: NotifyCallbacks = Arc::new(RwLock::new(BTreeMap::new()));
135 let raise_callbacks: RaiseCallbacks = Arc::new(RwLock::new(Vec::new()));
136
137 let listen_handle = {
139 let log = log.clone();
141 let notify_callbacks = notify_callbacks.clone();
142 let raise_callbacks = raise_callbacks.clone();
143
144 tokio::spawn(async move {
145 let mut stream =
147 stream::poll_fn(move |cx| conn.poll_message(cx).map_err(|e| panic!("{}", e)));
148
149 while let Some(msg) = stream.next().await {
150 match msg {
151 Ok(AsyncMessage::Notice(raise)) => {
152 Self::handle_raise(&raise_callbacks, &log, raise)
153 }
154 Ok(AsyncMessage::Notification(notice)) => {
155 Self::handle_notify(¬ify_callbacks, notice)
156 }
157 _ => {
158 #[cfg(feature = "tracing")]
159 tracing::error!("connection to the server was closed");
160 #[cfg(not(feature = "tracing"))]
161 eprintln!("connection to the server was closed");
162 break;
163 }
164 }
165 }
166 })
167 };
168
169 Self {
170 client,
171 listen_handle,
172 log,
173 notify_callbacks,
174 raise_callbacks,
175 }
176 }
177
178 fn handle_notify(callbacks: &NotifyCallbacks, note: Notification) {
182 let notice = PGNotify::new(note.channel(), note.payload());
183 if let Ok(guard) = callbacks.read() {
184 if let Some(cbs) = guard.get(note.channel()) {
185 for callback in cbs.iter() {
186 callback(¬ice);
187 }
188 }
189 }
190 }
191
192 fn handle_raise(
196 callbacks: &RaiseCallbacks,
197 log: &Arc<RwLock<Option<Vec<PGRaise>>>>,
198 raise: DbError,
199 ) {
200 let log_item = PGRaise {
201 #[cfg(feature = "chrono")]
202 timestamp: Utc::now(),
203 #[cfg(not(feature = "chrono"))]
204 timestamp: SystemTime::now(),
205 level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
206 message: raise.message().into(),
207 };
208
209 if let Ok(guard) = callbacks.read() {
210 for callback in guard.iter() {
211 callback(&log_item);
212 }
213 }
214
215 if let Ok(mut guard) = log.write() {
216 guard.as_mut().map(|log| log.push(log_item));
217 }
218 }
219
220 pub async fn subscribe_notify<F>(
228 &mut self,
229 channel: impl Into<String>,
230 callback: F,
231 ) -> Result<(), tokio_postgres::Error>
232 where
233 F: Fn(&PGNotify) + Send + Sync + 'static,
234 {
235 let channel = channel.into();
237 self.client
238 .execute(&format!("LISTEN {}", &channel), &[])
239 .await?;
240
241 if let Ok(mut guard) = self.notify_callbacks.write() {
243 guard.entry(channel).or_default().push(Box::new(callback));
244 }
245
246 Ok(())
247 }
248
249 pub fn subscribe_raise(&mut self, callback: impl Fn(&PGRaise) + Send + Sync + 'static) {
256 if let Ok(mut guard) = self.raise_callbacks.write() {
257 guard.push(Box::new(callback));
258 }
259 }
260
261 pub fn capture_log(&self) -> Option<Vec<PGRaise>> {
271 if let Ok(mut guard) = self.log.write() {
272 let captured = guard.take();
273 *guard = Some(Vec::default());
274 captured
275 } else {
276 None
277 }
278 }
279
280 pub async fn with_captured_log<F, T>(
289 &self,
290 f: F,
291 ) -> Result<(T, Vec<PGRaise>), tokio_postgres::Error>
292 where
293 F: AsyncFnOnce(&PGClient) -> Result<T, tokio_postgres::Error>,
294 {
295 self.capture_log(); let result = f(&self.client).await?;
297 let log = self.capture_log().unwrap_or_default();
298 Ok((result, log))
299 }
300}
301
302impl Drop for PGNotifier {
303 fn drop(&mut self) {
304 self.listen_handle.abort();
305 }
306}
307
308#[derive(Debug, Clone)]
312#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
313pub struct PGNotify {
314 pub channel: String,
315 pub payload: String,
316}
317
318impl PGNotify {
319 pub fn new(channel: impl Into<String>, payload: impl Into<String>) -> Self {
320 Self {
321 channel: channel.into(),
322 payload: payload.into(),
323 }
324 }
325}
326
327#[derive(Debug, Clone)]
331#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
332pub struct PGRaise {
333 #[cfg(feature = "chrono")]
334 pub timestamp: DateTime<Utc>,
335 #[cfg(not(feature = "chrono"))]
336 pub timestamp: SystemTime,
337 pub level: PGRaiseLevel,
338 pub message: String,
339}
340
341impl From<DbError> for PGRaise {
342 #[cfg(feature = "chrono")]
343 fn from(raise: DbError) -> Self {
344 PGRaise {
345 timestamp: Utc::now(),
346 level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
347 message: raise.message().into(),
348 }
349 }
350
351 #[cfg(not(feature = "chrono"))]
352 fn from(raise: DbError) -> Self {
353 PGRaise {
354 timestamp: SystemTime::now(),
355 level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
356 message: raise.message().into(),
357 }
358 }
359}
360
361impl Display for PGRaise {
362 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363 #[cfg(feature = "chrono")]
364 let ts = self.timestamp.to_rfc3339_opts(SecondsFormat::Millis, true);
365
366 #[cfg(not(feature = "chrono"))]
367 let ts = {
368 let duration = self
369 .timestamp
370 .duration_since(SystemTime::UNIX_EPOCH)
371 .unwrap();
372 let millis = duration.as_millis();
373 format!("{}", millis)
374 };
375
376 write!(f, "{}{:>8}: {}", &ts, &self.level.as_ref(), self.message)
377 }
378}
379
380#[derive(Debug, Clone, Copy)]
381#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
382#[cfg_attr(any(feature = "serde", test), serde(rename_all = "UPPERCASE"))]
383pub enum PGRaiseLevel {
384 Debug,
385 Log,
386 Info,
387 Notice,
388 Warning,
389 Error,
390 Fatal,
391 Panic,
392}
393
394impl AsRef<str> for PGRaiseLevel {
395 fn as_ref(&self) -> &str {
396 use PGRaiseLevel::*;
397 match self {
398 Debug => "DEBUG",
399 Log => "LOG",
400 Info => "INFO",
401 Notice => "NOTICE",
402 Warning => "WARNING",
403 Error => "ERROR",
404 Fatal => "FATAL",
405 Panic => "PANIC",
406 }
407 }
408}
409
410impl FromStr for PGRaiseLevel {
411 type Err = ();
412 fn from_str(s: &str) -> Result<Self, Self::Err> {
413 match s {
414 "DEBUG" => Ok(PGRaiseLevel::Debug),
415 "LOG" => Ok(PGRaiseLevel::Log),
416 "INFO" => Ok(PGRaiseLevel::Info),
417 "NOTICE" => Ok(PGRaiseLevel::Notice),
418 "WARNING" => Ok(PGRaiseLevel::Warning),
419 "ERROR" => Ok(PGRaiseLevel::Error),
420 "FATAL" => Ok(PGRaiseLevel::Fatal),
421 "PANIC" => Ok(PGRaiseLevel::Panic),
422 _ => Err(()),
423 }
424 }
425}
426
427#[cfg(test)]
428mod tests {
429
430 use super::{PGNotifier, PGNotify};
431 use insta::*;
432 use std::sync::{Arc, RwLock};
433 use testcontainers::{ImageExt, runners::AsyncRunner};
434 use testcontainers_modules::postgres::Postgres;
435
436 #[tokio::test]
437 async fn test_integration() {
438 let pg_server = Postgres::default()
444 .with_tag("16.4")
445 .start()
446 .await
447 .expect("could not start postgres server");
448
449 let database_url = format!(
450 "postgres://postgres:postgres@{}:{}/postgres",
451 pg_server.get_host().await.unwrap(),
452 pg_server.get_host_port_ipv4(5432).await.unwrap()
453 );
454
455 let (client, conn) = tokio_postgres::connect(&database_url, tokio_postgres::NoTls)
460 .await
461 .expect("could not connect to postgres server");
462
463 let mut notifier = PGNotifier::spawn(client, conn);
464
465 let notices = Arc::new(RwLock::new(Vec::new()));
470 let notices_clone = notices.clone();
471
472 notifier
473 .subscribe_notify("test", move |notify: &PGNotify| {
474 if let Ok(mut guard) = notices_clone.write() {
475 guard.push(notify.clone());
476 }
477 })
478 .await
479 .expect("could not subscribe to notifications");
480
481 let (_, execution_log) = notifier
482 .with_captured_log(async |client| {
483 client
484 .batch_execute(
485 r#"
486 set client_min_messages to 'debug';
487 do $$
488 begin
489 raise debug 'this is a DEBUG notification';
490 notify test, 'test#1';
491 raise log 'this is a LOG notification';
492 notify test, 'test#2';
493 raise info 'this is a INFO notification';
494 notify test, 'test#3';
495 raise notice 'this is a NOTICE notification';
496 notify test, 'test#4';
497 raise warning 'this is a WARNING notification';
498 notify test, 'test#5';
499 end;
500 $$;
501 "#,
502 )
503 .await
504 })
505 .await
506 .expect("could not execute queries on postgres");
507
508 assert_json_snapshot!("raise-notices", &execution_log, {
509 "[].timestamp" => "<timestamp>"
510 });
511
512 let guard = notices.read().expect("could not read notices");
513 let raise_notices = guard.clone();
514 assert_json_snapshot!("listen/notify", &raise_notices);
515 }
516}