use sqlx::postgres::PgListener;
use sqlx::PgPool;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::{oneshot, Notify};
use tokio::time::sleep_until;
use tracing::{debug, error, info, warn};
#[cfg(feature = "test-fault-injection")]
use crate::fault_injection::FaultInjector;
#[derive(Debug, Clone)]
pub struct LongPollConfig {
pub enabled: bool,
pub notifier_poll_interval: Duration,
pub timer_grace_period: Duration,
}
impl Default for LongPollConfig {
fn default() -> Self {
Self {
enabled: true,
notifier_poll_interval: Duration::from_secs(60),
timer_grace_period: Duration::from_millis(100),
}
}
}
struct RefreshResult {
orch_timers: Vec<i64>, worker_timers: Vec<i64>, }
pub struct Notifier {
pg_listener: PgListener,
pool: PgPool,
schema_name: String,
orch_heap: BinaryHeap<Reverse<Instant>>,
worker_heap: BinaryHeap<Reverse<Instant>>,
orch_notify: Arc<Notify>,
worker_notify: Arc<Notify>,
next_refresh: Instant,
pending_refresh: Option<oneshot::Receiver<RefreshResult>>,
config: LongPollConfig,
#[cfg(feature = "test-fault-injection")]
fault_injector: Option<Arc<FaultInjector>>,
}
impl Notifier {
pub async fn new(
pool: PgPool,
schema_name: String,
orch_notify: Arc<Notify>,
worker_notify: Arc<Notify>,
config: LongPollConfig,
) -> Result<Self, sqlx::Error> {
Self::new_internal(
pool,
schema_name,
orch_notify,
worker_notify,
config,
#[cfg(feature = "test-fault-injection")]
None,
)
.await
}
#[cfg(feature = "test-fault-injection")]
pub async fn new_with_fault_injection(
pool: PgPool,
schema_name: String,
orch_notify: Arc<Notify>,
worker_notify: Arc<Notify>,
config: LongPollConfig,
fault_injector: Arc<FaultInjector>,
) -> Result<Self, sqlx::Error> {
Self::new_internal(
pool,
schema_name,
orch_notify,
worker_notify,
config,
Some(fault_injector),
)
.await
}
async fn new_internal(
pool: PgPool,
schema_name: String,
orch_notify: Arc<Notify>,
worker_notify: Arc<Notify>,
config: LongPollConfig,
#[cfg(feature = "test-fault-injection")] fault_injector: Option<Arc<FaultInjector>>,
) -> Result<Self, sqlx::Error> {
let pg_listener = PgListener::connect_with(&pool).await?;
let mut notifier = Self {
pg_listener,
pool,
schema_name,
orch_heap: BinaryHeap::new(),
worker_heap: BinaryHeap::new(),
orch_notify,
worker_notify,
next_refresh: Instant::now(), pending_refresh: None,
config,
#[cfg(feature = "test-fault-injection")]
fault_injector,
};
notifier.subscribe_channels().await?;
info!(
target = "duroxide::providers::postgres::notifier",
schema = %notifier.schema_name,
"Notifier started, listening for NOTIFY events"
);
Ok(notifier)
}
async fn subscribe_channels(&mut self) -> Result<(), sqlx::Error> {
let orch_channel = format!("{}_orch_work", self.schema_name);
let worker_channel = format!("{}_worker_work", self.schema_name);
self.pg_listener.listen(&orch_channel).await?;
self.pg_listener.listen(&worker_channel).await?;
debug!(
target = "duroxide::providers::postgres::notifier",
orch_channel = %orch_channel,
worker_channel = %worker_channel,
"Subscribed to NOTIFY channels"
);
Ok(())
}
pub async fn run(&mut self) {
loop {
#[cfg(feature = "test-fault-injection")]
if let Some(ref fi) = self.fault_injector {
if fi.should_notifier_panic() {
panic!("Fault injection: notifier panic triggered");
}
if fi.should_reconnect() {
warn!(
target = "duroxide::providers::postgres::notifier",
"Fault injection: forcing reconnect"
);
self.handle_reconnect().await;
continue;
}
}
let next_timer = self.earliest_timer();
let refresh_in_progress = self.pending_refresh.is_some();
let next_wake = if refresh_in_progress {
next_timer.unwrap_or_else(|| Instant::now() + Duration::from_secs(60))
} else {
match next_timer {
Some(t) => t.min(self.next_refresh),
None => self.next_refresh,
}
};
tokio::select! {
result = self.pg_listener.recv() => {
match result {
Ok(notification) => {
self.handle_notify(notification);
}
Err(e) => {
warn!(
target = "duroxide::providers::postgres::notifier",
error = %e,
"LISTEN connection error, reconnecting..."
);
self.handle_reconnect().await;
}
}
}
_ = sleep_until(next_wake.into()) => {
self.pop_and_wake_expired_timers();
self.maybe_start_refresh();
}
Some(result) = async {
match &mut self.pending_refresh {
Some(rx) => rx.await.ok(),
None => std::future::pending().await,
}
} => {
self.pending_refresh = None;
self.handle_refresh_result(result);
}
}
}
}
fn earliest_timer(&self) -> Option<Instant> {
let orch = self.orch_heap.peek().map(|r| r.0);
let worker = self.worker_heap.peek().map(|r| r.0);
match (orch, worker) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
}
}
fn handle_notify(&mut self, notification: sqlx::postgres::PgNotification) {
let now_ms = current_epoch_ms();
let window_end_ms = now_ms + self.config.notifier_poll_interval.as_millis() as i64;
let now_instant = Instant::now();
let is_orch = notification.channel().ends_with("_orch_work");
let action = parse_notify_action(
notification.payload(),
now_ms,
window_end_ms,
self.config.timer_grace_period,
now_instant,
);
match action {
NotifyAction::WakeNow => {
debug!(
target = "duroxide::providers::postgres::notifier",
channel = %notification.channel(),
payload = %notification.payload(),
"Immediate work, waking dispatchers"
);
self.wake_dispatchers(is_orch);
}
NotifyAction::AddTimer { fire_at } => {
debug!(
target = "duroxide::providers::postgres::notifier",
channel = %notification.channel(),
payload = %notification.payload(),
"Future timer, adding to heap"
);
if is_orch {
self.orch_heap.push(Reverse(fire_at));
} else {
self.worker_heap.push(Reverse(fire_at));
}
}
NotifyAction::Ignore => {
debug!(
target = "duroxide::providers::postgres::notifier",
channel = %notification.channel(),
payload = %notification.payload(),
"Timer beyond window, ignoring"
);
}
}
}
fn wake_dispatchers(&self, is_orch: bool) {
if is_orch {
self.orch_notify.notify_one();
} else {
self.worker_notify.notify_waiters();
}
}
fn pop_and_wake_expired_timers(&mut self) {
let now = Instant::now();
while let Some(Reverse(fire_at)) = self.orch_heap.peek() {
if *fire_at <= now {
self.orch_heap.pop();
self.orch_notify.notify_one();
} else {
break;
}
}
while let Some(Reverse(fire_at)) = self.worker_heap.peek() {
if *fire_at <= now {
self.worker_heap.pop();
self.worker_notify.notify_waiters();
} else {
break;
}
}
}
fn maybe_start_refresh(&mut self) {
if self.pending_refresh.is_some() || Instant::now() < self.next_refresh {
return;
}
let (tx, rx) = oneshot::channel();
self.pending_refresh = Some(rx);
let pool = self.pool.clone();
let schema = self.schema_name.clone();
let now_ms = current_epoch_ms();
let window_end_ms = now_ms + self.config.notifier_poll_interval.as_millis() as i64;
#[cfg(feature = "test-fault-injection")]
let fault_injector = self.fault_injector.clone();
debug!(
target = "duroxide::providers::postgres::notifier",
now_ms = now_ms,
window_end_ms = window_end_ms,
"Starting refresh query"
);
tokio::spawn(async move {
#[cfg(feature = "test-fault-injection")]
if let Some(ref fi) = fault_injector {
let delay = fi.get_refresh_delay();
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
}
#[cfg(feature = "test-fault-injection")]
if let Some(ref fi) = fault_injector {
if fi.should_refresh_error() {
warn!(
target = "duroxide::providers::postgres::notifier",
"Fault injection: simulating refresh error"
);
let _ = tx.send(RefreshResult {
orch_timers: Vec::new(),
worker_timers: Vec::new(),
});
return;
}
}
let orch_timers = sqlx::query_scalar::<_, i64>(&format!(
"SELECT (EXTRACT(EPOCH FROM visible_at) * 1000)::BIGINT
FROM {schema}.orchestrator_queue
WHERE (EXTRACT(EPOCH FROM visible_at) * 1000)::BIGINT > $1
AND (EXTRACT(EPOCH FROM visible_at) * 1000)::BIGINT <= $2
AND lock_token IS NULL"
))
.bind(now_ms)
.bind(window_end_ms)
.fetch_all(&pool)
.await
.unwrap_or_default();
let worker_timers: Vec<i64> = Vec::new();
let _ = tx.send(RefreshResult {
orch_timers,
worker_timers,
});
});
}
fn handle_refresh_result(&mut self, result: RefreshResult) {
let now_ms = current_epoch_ms();
let now_instant = Instant::now();
debug!(
target = "duroxide::providers::postgres::notifier",
orch_count = result.orch_timers.len(),
worker_count = result.worker_timers.len(),
"Refresh query completed"
);
for fire_at in timers_from_refresh(
&result.orch_timers,
now_ms,
self.config.timer_grace_period,
now_instant,
) {
self.orch_heap.push(Reverse(fire_at));
}
for fire_at in timers_from_refresh(
&result.worker_timers,
now_ms,
self.config.timer_grace_period,
now_instant,
) {
self.worker_heap.push(Reverse(fire_at));
}
self.next_refresh = Instant::now() + self.config.notifier_poll_interval;
}
async fn handle_reconnect(&mut self) {
tokio::time::sleep(Duration::from_secs(1)).await;
match PgListener::connect_with(&self.pool).await {
Ok(listener) => {
self.pg_listener = listener;
if self.subscribe_channels().await.is_ok() {
info!(
target = "duroxide::providers::postgres::notifier",
"Reconnected to PostgreSQL LISTEN"
);
self.orch_notify.notify_one();
self.worker_notify.notify_waiters();
self.next_refresh = Instant::now();
}
}
Err(e) => {
error!(
target = "duroxide::providers::postgres::notifier",
error = %e,
"Failed to reconnect, will retry on next loop iteration"
);
}
}
}
}
fn current_epoch_ms() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NotifyAction {
WakeNow,
AddTimer { fire_at: Instant },
Ignore,
}
pub fn parse_notify_action(
payload: &str,
now_ms: i64,
window_end_ms: i64,
grace_period: Duration,
now_instant: Instant,
) -> NotifyAction {
let visible_at_ms: i64 = payload.parse().unwrap_or(0);
if visible_at_ms <= now_ms {
NotifyAction::WakeNow
} else if visible_at_ms <= window_end_ms {
let delay_ms = (visible_at_ms - now_ms) + grace_period.as_millis() as i64;
let fire_at = now_instant + Duration::from_millis(delay_ms as u64);
NotifyAction::AddTimer { fire_at }
} else {
NotifyAction::Ignore
}
}
pub fn timers_from_refresh(
visible_at_times: &[i64],
now_ms: i64,
grace_period: Duration,
now_instant: Instant,
) -> Vec<Instant> {
visible_at_times
.iter()
.filter_map(|&visible_at_ms| {
let delay_ms = (visible_at_ms - now_ms) + grace_period.as_millis() as i64;
if delay_ms > 0 {
Some(now_instant + Duration::from_millis(delay_ms as u64))
} else {
None
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_current_epoch_ms() {
let ms = current_epoch_ms();
assert!(ms > 1_577_836_800_000); assert!(ms < 2_524_608_000_000); }
#[test]
fn test_longpoll_config_default() {
let config = LongPollConfig::default();
assert!(config.enabled);
assert_eq!(config.notifier_poll_interval, Duration::from_secs(60));
assert_eq!(config.timer_grace_period, Duration::from_millis(100));
}
#[test]
fn notify_immediate_work_wakes_dispatchers() {
let now_ms = 1_700_000_000_000i64;
let window_end_ms = now_ms + 60_000;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let action = parse_notify_action(
&now_ms.to_string(),
now_ms,
window_end_ms,
grace,
now_instant,
);
assert_eq!(action, NotifyAction::WakeNow);
}
#[test]
fn notify_past_visible_at_wakes_immediately() {
let now_ms = 1_700_000_000_000i64;
let past_ms = now_ms - 5_000; let window_end_ms = now_ms + 60_000;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let action = parse_notify_action(
&past_ms.to_string(),
now_ms,
window_end_ms,
grace,
now_instant,
);
assert_eq!(action, NotifyAction::WakeNow);
}
#[test]
fn notify_future_timer_adds_to_heap() {
let now_ms = 1_700_000_000_000i64;
let future_ms = now_ms + 30_000; let window_end_ms = now_ms + 60_000;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let action = parse_notify_action(
&future_ms.to_string(),
now_ms,
window_end_ms,
grace,
now_instant,
);
match action {
NotifyAction::AddTimer { fire_at } => {
let expected_delay = Duration::from_millis(30_100);
let actual_delay = fire_at.duration_since(now_instant);
assert!(
actual_delay >= expected_delay - Duration::from_millis(10)
&& actual_delay <= expected_delay + Duration::from_millis(10),
"Expected delay ~30.1s, got {actual_delay:?}"
);
}
other => panic!("Expected AddTimer, got {other:?}"),
}
}
#[test]
fn notify_beyond_window_ignored() {
let now_ms = 1_700_000_000_000i64;
let far_future_ms = now_ms + 90_000; let window_end_ms = now_ms + 60_000;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let action = parse_notify_action(
&far_future_ms.to_string(),
now_ms,
window_end_ms,
grace,
now_instant,
);
assert_eq!(action, NotifyAction::Ignore);
}
#[test]
fn notify_invalid_payload_treated_as_immediate() {
let now_ms = 1_700_000_000_000i64;
let window_end_ms = now_ms + 60_000;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let action = parse_notify_action("garbage", now_ms, window_end_ms, grace, now_instant);
assert_eq!(action, NotifyAction::WakeNow);
}
#[test]
fn notify_empty_payload_treated_as_immediate() {
let now_ms = 1_700_000_000_000i64;
let window_end_ms = now_ms + 60_000;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let action = parse_notify_action("", now_ms, window_end_ms, grace, now_instant);
assert_eq!(action, NotifyAction::WakeNow);
}
#[test]
fn timer_fires_at_visible_at_plus_grace() {
let now_ms = 1_700_000_000_000i64;
let visible_at_ms = now_ms + 10_000; let window_end_ms = now_ms + 60_000;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let action = parse_notify_action(
&visible_at_ms.to_string(),
now_ms,
window_end_ms,
grace,
now_instant,
);
match action {
NotifyAction::AddTimer { fire_at } => {
let delay = fire_at.duration_since(now_instant);
let expected = Duration::from_millis(10_100);
assert!(
delay >= expected - Duration::from_millis(5)
&& delay <= expected + Duration::from_millis(5),
"Timer should fire at visible_at + grace, got {delay:?}"
);
}
other => panic!("Expected AddTimer, got {other:?}"),
}
}
#[test]
fn timer_heap_ordering() {
let mut heap: BinaryHeap<Reverse<Instant>> = BinaryHeap::new();
let now = Instant::now();
let t1 = now + Duration::from_secs(10);
let t2 = now + Duration::from_secs(5);
let t3 = now + Duration::from_secs(15);
heap.push(Reverse(t1));
heap.push(Reverse(t2));
heap.push(Reverse(t3));
assert_eq!(heap.pop().unwrap().0, t2);
assert_eq!(heap.pop().unwrap().0, t1);
assert_eq!(heap.pop().unwrap().0, t3);
}
#[test]
fn expired_timers_popped_in_batch() {
let mut heap: BinaryHeap<Reverse<Instant>> = BinaryHeap::new();
let past = Instant::now() - Duration::from_secs(1);
heap.push(Reverse(past - Duration::from_millis(100)));
heap.push(Reverse(past - Duration::from_millis(200)));
heap.push(Reverse(past - Duration::from_millis(300)));
let now = Instant::now();
let mut fired = 0;
while let Some(Reverse(fire_at)) = heap.peek() {
if *fire_at <= now {
heap.pop();
fired += 1;
} else {
break;
}
}
assert_eq!(fired, 3);
assert!(heap.is_empty());
}
#[test]
fn timer_does_not_fire_early() {
let mut heap: BinaryHeap<Reverse<Instant>> = BinaryHeap::new();
let now = Instant::now();
let future = now + Duration::from_secs(10);
heap.push(Reverse(future));
if let Some(Reverse(fire_at)) = heap.peek() {
assert!(*fire_at > now, "Timer should not fire early");
}
}
#[test]
fn refresh_adds_timers_to_heap() {
let now_ms = 1_700_000_000_000i64;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let timers = vec![
now_ms + 10_000, now_ms + 30_000, ];
let result = timers_from_refresh(&timers, now_ms, grace, now_instant);
assert_eq!(result.len(), 2);
let delay1 = result[0].duration_since(now_instant);
assert!(delay1 >= Duration::from_millis(10_000));
assert!(delay1 <= Duration::from_millis(10_200));
}
#[test]
fn refresh_skips_already_passed_timers() {
let now_ms = 1_700_000_000_000i64;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let timers = vec![
now_ms - 5_000, now_ms + 100, ];
let result = timers_from_refresh(&timers, now_ms, grace, now_instant);
assert_eq!(result.len(), 1);
}
#[test]
fn refresh_with_empty_result() {
let now_ms = 1_700_000_000_000i64;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let timers: Vec<i64> = vec![];
let result = timers_from_refresh(&timers, now_ms, grace, now_instant);
assert!(result.is_empty());
}
#[test]
fn refresh_timer_includes_grace_period() {
let now_ms = 1_700_000_000_000i64;
let grace = Duration::from_millis(500); let now_instant = Instant::now();
let timers = vec![now_ms + 10_000];
let result = timers_from_refresh(&timers, now_ms, grace, now_instant);
assert_eq!(result.len(), 1);
let delay = result[0].duration_since(now_instant);
assert!(delay >= Duration::from_millis(10_400));
assert!(delay <= Duration::from_millis(10_600));
}
#[test]
fn refresh_boundary_timer_at_exactly_now() {
let now_ms = 1_700_000_000_000i64;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let timers = vec![now_ms];
let result = timers_from_refresh(&timers, now_ms, grace, now_instant);
assert_eq!(result.len(), 1);
let delay = result[0].duration_since(now_instant);
assert!(delay >= Duration::from_millis(90));
assert!(delay <= Duration::from_millis(110));
}
#[test]
fn notify_at_window_boundary_included() {
let now_ms = 1_700_000_000_000i64;
let window_end_ms = now_ms + 60_000;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let action = parse_notify_action(
&window_end_ms.to_string(),
now_ms,
window_end_ms,
grace,
now_instant,
);
match action {
NotifyAction::AddTimer { .. } => {}
other => panic!("Expected AddTimer at window boundary, got {other:?}"),
}
}
#[test]
fn notify_just_past_window_boundary_ignored() {
let now_ms = 1_700_000_000_000i64;
let window_end_ms = now_ms + 60_000;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let action = parse_notify_action(
&(window_end_ms + 1).to_string(),
now_ms,
window_end_ms,
grace,
now_instant,
);
assert_eq!(action, NotifyAction::Ignore);
}
#[test]
fn notify_negative_timestamp_treated_as_immediate() {
let now_ms = 1_700_000_000_000i64;
let window_end_ms = now_ms + 60_000;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let action = parse_notify_action("-1000", now_ms, window_end_ms, grace, now_instant);
assert_eq!(action, NotifyAction::WakeNow);
}
#[test]
fn notify_zero_timestamp_treated_as_immediate() {
let now_ms = 1_700_000_000_000i64;
let window_end_ms = now_ms + 60_000;
let grace = Duration::from_millis(100);
let now_instant = Instant::now();
let action = parse_notify_action("0", now_ms, window_end_ms, grace, now_instant);
assert_eq!(action, NotifyAction::WakeNow);
}
}