1#[cfg(feature = "chrono")]
84use chrono::{DateTime, SecondsFormat, Utc};
85#[cfg(not(feature = "chrono"))]
86use std::time::SystemTime;
87
88use {
89 futures::{StreamExt, stream},
90 std::{
91 collections::BTreeMap,
92 fmt::{self, Display},
93 str::FromStr,
94 sync::{Arc, RwLock},
95 },
96 tokio::{
97 io::{AsyncRead, AsyncWrite},
98 task::JoinHandle,
99 },
100 tokio_postgres::{
101 AsyncMessage, Client as PGClient, Connection as PGConnection, Error as PGError,
102 Notification, error::DbError,
103 },
104};
105
106pub type PGResult<T> = Result<T, PGError>;
108
109pub type NotifyCallbacks =
111 Arc<RwLock<BTreeMap<String, Vec<Box<dyn for<'a> Fn(&'a PGNotify) + Send + Sync + 'static>>>>>;
112
113pub type RaiseCallbacks =
115 Arc<RwLock<Vec<Box<dyn for<'a> Fn(&'a PGRaise) + Send + Sync + 'static>>>>;
116
117#[allow(unused)]
125pub struct PGRobustNotifier<F> {
126 notify_callbacks: NotifyCallbacks,
127 raise_callbacks: RaiseCallbacks,
128 subscriptions: Vec<JoinHandle<()>>,
129 connect: F,
130 inner: PGNotifier,
131}
132
133impl<F, S, T> PGRobustNotifier<F>
134where
135 F: AsyncFn() -> PGResult<(PGClient, PGConnection<S, T>)>,
136 S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
137 T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
138{
139 pub async fn new(connect: F) -> PGResult<Self> {
140 let (client, conn) = connect().await?;
142 let inner = PGNotifier::spawn(client, conn);
143 let notify_callbacks = inner.notify_callbacks.clone();
144 let raise_callbacks = inner.raise_callbacks.clone();
145
146 Ok(Self {
147 notify_callbacks,
148 raise_callbacks,
149 subscriptions: vec![],
150 connect,
151 inner,
152 })
153 }
154
155 async fn reconnect(&mut self) -> PGResult<()> {
159 let (client, conn) = (self.connect)().await?;
160 self.inner =
161 PGNotifier::respawn(client, conn, &self.notify_callbacks, &self.raise_callbacks)
162 .await?;
163 Ok(())
164 }
165
166 pub async fn client(&mut self) -> PGResult<&PGClient> {
180 if let Err(e) = self.inner.client.execute("/* PING */", &[]).await
181 && e.is_closed()
182 {
183 let mut k = 1;
185 let mut attempts = 1;
186
187 loop {
188 tracing::info!("Connection is closed. Reconnect attempt #{}", attempts);
189 attempts += 1;
190
191 match self.reconnect().await {
192 Ok(_) => {
193 break;
194 }
195 Err(e) if e.is_closed() => {
196 k *= std::cmp::min(k, 60);
197 let t = k + rand::random_range(0..k);
198 tokio::time::sleep(tokio::time::Duration::from_secs(t)).await;
199 }
200 Err(e) => return Err(e),
201 }
202 }
203 }
204
205 Ok(&self.inner.client)
206 }
207
208 pub async fn subscribe_notify<CB>(
210 &mut self,
211 channel: impl Into<String>,
212 callback: CB,
213 ) -> PGResult<()>
214 where
215 CB: Fn(&PGNotify) + Send + Sync + 'static,
216 {
217 self.inner.subscribe_notify(channel, callback).await
218 }
219
220 pub async fn subscribe_raise(&mut self, callback: impl Fn(&PGRaise) + Send + Sync + 'static) {
222 self.inner.subscribe_raise(callback)
223 }
224
225 pub async fn capture_log(&mut self) -> Option<Vec<PGRaise>> {
227 self.inner.capture_log()
228 }
229
230 pub async fn with_captured_log<CB, Data>(&mut self, f: CB) -> PGResult<(Data, Vec<PGRaise>)>
232 where
233 CB: AsyncFnOnce(&PGClient) -> PGResult<Data>,
234 {
235 self.inner.with_captured_log(f).await
236 }
237}
238
239pub struct PGNotifier {
243 pub client: PGClient,
244 listen_handle: JoinHandle<()>,
245 log: Arc<RwLock<Option<Vec<PGRaise>>>>,
246 raise_callbacks: RaiseCallbacks,
247 notify_callbacks: NotifyCallbacks,
248}
249
250impl PGNotifier {
251 pub fn spawn<S, T>(client: PGClient, mut conn: PGConnection<S, T>) -> Self
255 where
256 S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
257 T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
258 {
259 let log = Arc::new(RwLock::new(Some(Vec::default())));
260 let notify_callbacks: NotifyCallbacks = Arc::new(RwLock::new(BTreeMap::new()));
261 let raise_callbacks: RaiseCallbacks = Arc::new(RwLock::new(Vec::new()));
262
263 let listen_handle = {
265 let log = log.clone();
267 let notify_callbacks = notify_callbacks.clone();
268 let raise_callbacks = raise_callbacks.clone();
269
270 tokio::spawn(async move {
271 let mut stream =
273 stream::poll_fn(move |cx| conn.poll_message(cx).map_err(|e| panic!("{}", e)));
274
275 while let Some(msg) = stream.next().await {
276 match msg {
277 Ok(AsyncMessage::Notice(raise)) => {
278 Self::handle_raise(&raise_callbacks, &log, raise)
279 }
280 Ok(AsyncMessage::Notification(notice)) => {
281 Self::handle_notify(¬ify_callbacks, notice)
282 }
283 _ => {
284 #[cfg(feature = "tracing")]
285 tracing::error!("connection to the server was closed");
286 #[cfg(not(feature = "tracing"))]
287 eprintln!("connection to the server was closed");
288 break;
289 }
290 }
291 }
292 })
293 };
294
295 Self {
296 client,
297 listen_handle,
298 log,
299 notify_callbacks,
300 raise_callbacks,
301 }
302 }
303
304 pub async fn respawn<S, T>(
308 client: PGClient,
309 conn: PGConnection<S, T>,
310 notify_callbacks: &NotifyCallbacks,
311 raise_callbacks: &RaiseCallbacks,
312 ) -> PGResult<Self>
313 where
314 S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
315 T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
316 {
317 let mut notifier = Self::spawn(client, conn);
318 notifier.notify_callbacks = notify_callbacks.clone();
319 notifier.raise_callbacks = raise_callbacks.clone();
320
321 if let Ok(guard) = notify_callbacks.read() {
322 let sql = guard
323 .keys()
324 .map(|channel| format!("LISTEN {}", channel))
325 .collect::<Vec<_>>()
326 .join(";\n");
327 notifier.client.batch_execute(&sql).await?;
328 }
329
330 Ok(notifier)
331 }
332
333 fn handle_notify(callbacks: &NotifyCallbacks, note: Notification) {
337 let notice = PGNotify::new(note.channel(), note.payload());
338 if let Ok(guard) = callbacks.read() {
339 if let Some(cbs) = guard.get(note.channel()) {
340 for callback in cbs.iter() {
341 callback(¬ice);
342 }
343 }
344 }
345 }
346
347 fn handle_raise(
351 callbacks: &RaiseCallbacks,
352 log: &Arc<RwLock<Option<Vec<PGRaise>>>>,
353 raise: DbError,
354 ) {
355 let log_item = PGRaise {
356 #[cfg(feature = "chrono")]
357 timestamp: Utc::now(),
358 #[cfg(not(feature = "chrono"))]
359 timestamp: SystemTime::now(),
360 level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
361 message: raise.message().into(),
362 };
363
364 if let Ok(guard) = callbacks.read() {
365 for callback in guard.iter() {
366 callback(&log_item);
367 }
368 }
369
370 if let Ok(mut guard) = log.write() {
371 guard.as_mut().map(|log| log.push(log_item));
372 }
373 }
374
375 pub async fn subscribe_notify<F>(
383 &mut self,
384 channel: impl Into<String>,
385 callback: F,
386 ) -> PGResult<()>
387 where
388 F: Fn(&PGNotify) + Send + Sync + 'static,
389 {
390 let channel = channel.into();
392 self.client
393 .execute(&format!("LISTEN {}", &channel), &[])
394 .await?;
395
396 if let Ok(mut guard) = self.notify_callbacks.write() {
398 guard.entry(channel).or_default().push(Box::new(callback));
399 }
400
401 Ok(())
402 }
403
404 pub fn subscribe_raise(&mut self, callback: impl Fn(&PGRaise) + Send + Sync + 'static) {
411 if let Ok(mut guard) = self.raise_callbacks.write() {
412 guard.push(Box::new(callback));
413 }
414 }
415
416 pub fn capture_log(&self) -> Option<Vec<PGRaise>> {
426 if let Ok(mut guard) = self.log.write() {
427 let captured = guard.take();
428 *guard = Some(Vec::default());
429 captured
430 } else {
431 None
432 }
433 }
434
435 pub async fn with_captured_log<F, T>(&self, f: F) -> PGResult<(T, Vec<PGRaise>)>
444 where
445 F: AsyncFnOnce(&PGClient) -> PGResult<T>,
446 {
447 self.capture_log(); let result = f(&self.client).await?;
449 let log = self.capture_log().unwrap_or_default();
450 Ok((result, log))
451 }
452}
453
454impl Drop for PGNotifier {
455 fn drop(&mut self) {
456 self.listen_handle.abort();
457 }
458}
459
460#[derive(Debug, Clone)]
464#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
465pub struct PGNotify {
466 pub channel: String,
467 pub payload: String,
468}
469
470impl PGNotify {
471 pub fn new(channel: impl Into<String>, payload: impl Into<String>) -> Self {
472 Self {
473 channel: channel.into(),
474 payload: payload.into(),
475 }
476 }
477}
478
479#[derive(Debug, Clone)]
483#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
484pub struct PGRaise {
485 #[cfg(feature = "chrono")]
486 pub timestamp: DateTime<Utc>,
487 #[cfg(not(feature = "chrono"))]
488 pub timestamp: SystemTime,
489 pub level: PGRaiseLevel,
490 pub message: String,
491}
492
493impl From<DbError> for PGRaise {
494 #[cfg(feature = "chrono")]
495 fn from(raise: DbError) -> Self {
496 PGRaise {
497 timestamp: Utc::now(),
498 level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
499 message: raise.message().into(),
500 }
501 }
502
503 #[cfg(not(feature = "chrono"))]
504 fn from(raise: DbError) -> Self {
505 PGRaise {
506 timestamp: SystemTime::now(),
507 level: PGRaiseLevel::from_str(raise.severity()).unwrap_or(PGRaiseLevel::Error),
508 message: raise.message().into(),
509 }
510 }
511}
512
513impl Display for PGRaise {
514 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
515 #[cfg(feature = "chrono")]
516 let ts = self.timestamp.to_rfc3339_opts(SecondsFormat::Millis, true);
517
518 #[cfg(not(feature = "chrono"))]
519 let ts = {
520 let duration = self
521 .timestamp
522 .duration_since(SystemTime::UNIX_EPOCH)
523 .unwrap();
524 let millis = duration.as_millis();
525 format!("{}", millis)
526 };
527
528 write!(f, "{}{:>8}: {}", &ts, &self.level.as_ref(), self.message)
529 }
530}
531
532#[derive(Debug, Clone, Copy)]
533#[cfg_attr(any(feature = "serde", test), derive(serde::Serialize))]
534#[cfg_attr(any(feature = "serde", test), serde(rename_all = "UPPERCASE"))]
535pub enum PGRaiseLevel {
536 Debug,
537 Log,
538 Info,
539 Notice,
540 Warning,
541 Error,
542 Fatal,
543 Panic,
544}
545
546impl AsRef<str> for PGRaiseLevel {
547 fn as_ref(&self) -> &str {
548 use PGRaiseLevel::*;
549 match self {
550 Debug => "DEBUG",
551 Log => "LOG",
552 Info => "INFO",
553 Notice => "NOTICE",
554 Warning => "WARNING",
555 Error => "ERROR",
556 Fatal => "FATAL",
557 Panic => "PANIC",
558 }
559 }
560}
561
562impl FromStr for PGRaiseLevel {
563 type Err = ();
564 fn from_str(s: &str) -> Result<Self, Self::Err> {
565 match s {
566 "DEBUG" => Ok(PGRaiseLevel::Debug),
567 "LOG" => Ok(PGRaiseLevel::Log),
568 "INFO" => Ok(PGRaiseLevel::Info),
569 "NOTICE" => Ok(PGRaiseLevel::Notice),
570 "WARNING" => Ok(PGRaiseLevel::Warning),
571 "ERROR" => Ok(PGRaiseLevel::Error),
572 "FATAL" => Ok(PGRaiseLevel::Fatal),
573 "PANIC" => Ok(PGRaiseLevel::Panic),
574 _ => Err(()),
575 }
576 }
577}
578
579#[cfg(test)]
580mod tests {
581
582 use super::{PGClient, PGNotifier, PGNotify, PGRobustNotifier};
583 use insta::*;
584 use std::sync::{
585 Arc, RwLock,
586 atomic::{AtomicI32, Ordering},
587 };
588 use testcontainers::{ImageExt, runners::AsyncRunner};
589 use testcontainers_modules::postgres::Postgres;
590
591 #[tokio::test]
592 async fn test_integration() {
593 let pg_server = Postgres::default()
599 .with_tag("16.4")
600 .start()
601 .await
602 .expect("could not start postgres server");
603
604 let database_url = format!(
605 "postgres://postgres:postgres@{}:{}/postgres",
606 pg_server.get_host().await.unwrap(),
607 pg_server.get_host_port_ipv4(5432).await.unwrap()
608 );
609
610 let (client, conn) = tokio_postgres::connect(&database_url, tokio_postgres::NoTls)
615 .await
616 .expect("could not connect to postgres server");
617
618 let mut notifier = PGNotifier::spawn(client, conn);
619
620 let notices = Arc::new(RwLock::new(Vec::new()));
625 let notices_clone = notices.clone();
626
627 notifier
628 .subscribe_notify("test", move |notify: &PGNotify| {
629 if let Ok(mut guard) = notices_clone.write() {
630 guard.push(notify.clone());
631 }
632 })
633 .await
634 .expect("could not subscribe to notifications");
635
636 let (_, execution_log) = notifier
637 .with_captured_log(async |client| {
638 client
639 .batch_execute(
640 r#"
641 set client_min_messages to 'debug';
642 do $$
643 begin
644 raise debug 'this is a DEBUG notification';
645 notify test, 'test#1';
646 raise log 'this is a LOG notification';
647 notify test, 'test#2';
648 raise info 'this is a INFO notification';
649 notify test, 'test#3';
650 raise notice 'this is a NOTICE notification';
651 notify test, 'test#4';
652 raise warning 'this is a WARNING notification';
653 notify test, 'test#5';
654 end;
655 $$;
656 "#,
657 )
658 .await
659 })
660 .await
661 .expect("could not execute queries on postgres");
662
663 assert_json_snapshot!("raise-notices", &execution_log, {
664 "[].timestamp" => "<timestamp>"
665 });
666
667 let guard = notices.read().expect("could not read notices");
668 let raise_notices = guard.clone();
669 assert_json_snapshot!("listen/notify", &raise_notices);
670
671 let counter = Arc::new(AtomicI32::new(0));
676 let (client, conn) = tokio_postgres::connect(&database_url, tokio_postgres::NoTls)
677 .await
678 .expect("could not connect to postgres server");
679 let admin = PGNotifier::spawn(client, conn);
680
681 let database_url = database_url.to_string();
682 let counter_clone = counter.clone();
683 let mut notifier = PGRobustNotifier::new(async move || {
684 counter_clone.fetch_add(1, Ordering::Relaxed);
685 tokio_postgres::connect(&database_url, tokio_postgres::NoTls).await
686 })
687 .await
688 .expect("could not connect to postgres server");
689
690 let client: &PGClient = notifier.client().await.expect("could not get client");
691 assert!(client.execute("select 1", &[]).await.is_ok());
692
693 admin
694 .client
695 .execute(
696 r#"
697 SELECT pg_terminate_backend(pg_stat_activity.pid)
698 FROM pg_stat_activity
699 WHERE pid <> pg_backend_pid();
700 "#,
701 &[],
702 )
703 .await
704 .expect("could kill other connections");
705
706 let client: &PGClient = notifier.client().await.expect("could not get client");
707 assert!(client.execute("select 1", &[]).await.is_ok());
708 assert!(counter.load(Ordering::Relaxed) == 2);
709 }
710}