use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use sqlx::PgPool;
use sqlx::postgres::PgListener;
use tokio::sync::{broadcast, mpsc};
use tokio::task::JoinHandle;
use tracing::{debug, warn};
type ChannelMap = Arc<Mutex<HashMap<String, broadcast::Sender<StreamEventId>>>>;
#[derive(Clone, Debug)]
pub enum StreamEventId {
Frame { ts_ms: i64, seq: i32 },
Reconnect,
}
pub struct StreamNotifier {
channels: ChannelMap,
command_tx: mpsc::UnboundedSender<NotifierCommand>,
_listener_task: JoinHandle<()>,
}
enum NotifierCommand {
Listen(String, tokio::sync::oneshot::Sender<()>),
}
impl StreamNotifier {
pub fn spawn(pool: PgPool) -> Arc<Self> {
let channels: ChannelMap = Arc::new(Mutex::new(HashMap::new()));
let (command_tx, command_rx) = mpsc::unbounded_channel();
let task_channels = channels.clone();
let listener_task = tokio::spawn(async move {
run_listener(pool, task_channels, command_rx).await;
});
Arc::new(Self {
channels,
command_tx,
_listener_task: listener_task,
})
}
pub fn placeholder() -> Arc<Self> {
let channels: ChannelMap = Arc::new(Mutex::new(HashMap::new()));
let (command_tx, _command_rx) = mpsc::unbounded_channel();
let listener_task = tokio::spawn(async move {});
Arc::new(Self {
channels,
command_tx,
_listener_task: listener_task,
})
}
pub async fn subscribe(&self, channel: &str) -> broadcast::Receiver<StreamEventId> {
let (ack_tx, ack_rx) = tokio::sync::oneshot::channel();
let (needs_listen, sender) = {
let mut guard = self.channels.lock().expect("stream notifier mutex");
if let Some(existing) = guard.get(channel) {
(false, existing.clone())
} else {
let (tx, _rx) = broadcast::channel(64);
guard.insert(channel.to_owned(), tx.clone());
(true, tx)
}
};
if needs_listen {
let _ = self
.command_tx
.send(NotifierCommand::Listen(channel.to_owned(), ack_tx));
let _ = ack_rx.await;
}
sender.subscribe()
}
pub fn active_channel_count(&self) -> usize {
self.channels.lock().expect("stream notifier mutex").len()
}
}
async fn run_listener(
pool: PgPool,
channels: ChannelMap,
mut command_rx: mpsc::UnboundedReceiver<NotifierCommand>,
) {
let mut backoff_ms: u64 = 100;
loop {
let mut listener = match PgListener::connect_with(&pool).await {
Ok(l) => l,
Err(e) => {
warn!(error = %e, "stream notifier: connect failed, backing off");
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(1_000);
continue;
}
};
backoff_ms = 100;
let snapshot: Vec<(String, broadcast::Sender<StreamEventId>)> = {
let guard = channels.lock().expect("stream notifier mutex");
guard.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
};
for (ch, _) in &snapshot {
if let Err(e) = listener.listen(ch).await {
warn!(channel = %ch, error = %e, "LISTEN failed during re-subscribe");
}
}
for (_, sender) in &snapshot {
let _ = sender.send(StreamEventId::Reconnect);
}
loop {
tokio::select! {
maybe_cmd = command_rx.recv() => {
match maybe_cmd {
Some(NotifierCommand::Listen(ch, ack)) => {
if let Err(e) = listener.listen(&ch).await {
warn!(channel = %ch, error = %e, "LISTEN failed");
let _ = ack.send(());
break;
}
let _ = ack.send(());
}
None => return,
}
}
notify = listener.recv() => {
match notify {
Ok(n) => {
let channel = n.channel().to_owned();
let event = parse_payload(n.payload());
let sender_opt = channels
.lock()
.expect("stream notifier mutex")
.get(&channel)
.cloned();
if let Some(sender) = sender_opt {
let _ = sender.send(event);
}
}
Err(e) => {
warn!(error = %e, "PgListener recv error; reconnecting");
break;
}
}
}
}
}
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
}
fn parse_payload(raw: &str) -> StreamEventId {
if let Some((ms, sq)) = raw.split_once('-')
&& let (Ok(ts), Ok(seq)) = (ms.parse::<i64>(), sq.parse::<i32>())
{
debug!(ts_ms = ts, seq, "stream notify");
return StreamEventId::Frame { ts_ms: ts, seq };
}
StreamEventId::Reconnect
}
pub fn channel_name(execution_uuid: &uuid::Uuid, attempt_index: u32) -> String {
format!("ff_stream_{execution_uuid}_{attempt_index}")
}