use std::time::Duration;
use async_nats::connection::State;
use async_nats::jetstream;
use rand::Rng;
use tracing::{debug, info, warn};
use crate::staleness::Tracker;
const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
const MAX_BACKOFF: Duration = Duration::from_secs(300);
const JITTER_FRACTION: f64 = 0.25;
pub async fn wait_for_kv(
js: &jetstream::Context,
client: &async_nats::Client,
tracker: &Tracker,
bucket: &str,
label: &'static str,
) -> jetstream::kv::Store {
let mut backoff = INITIAL_BACKOFF;
let mut consecutive_failures: u32 = 0;
loop {
gate_on_connection(client, tracker, label, "kv").await;
match js.get_key_value(bucket).await {
Ok(store) => {
if consecutive_failures > 0 {
info!(
label,
kind = "kv",
resource = bucket,
consecutive_failures,
"nats_retry: kv recovered",
);
}
debug!(label, bucket, "nats_retry: kv ready");
return store;
}
Err(e) => {
consecutive_failures = consecutive_failures.saturating_add(1);
log_failure(consecutive_failures, backoff, "kv", label, bucket, &e);
sleep_or_wake(backoff, tracker).await;
backoff = next_backoff(backoff);
}
}
}
}
pub async fn wait_for_stream(
js: &jetstream::Context,
client: &async_nats::Client,
tracker: &Tracker,
name: &str,
label: &'static str,
) -> jetstream::stream::Stream {
let mut backoff = INITIAL_BACKOFF;
let mut consecutive_failures: u32 = 0;
loop {
gate_on_connection(client, tracker, label, "stream").await;
match js.get_stream(name).await {
Ok(s) => {
if consecutive_failures > 0 {
info!(
label,
kind = "stream",
resource = name,
consecutive_failures,
"nats_retry: stream recovered",
);
}
debug!(label, stream = name, "nats_retry: stream ready");
return s;
}
Err(e) => {
consecutive_failures = consecutive_failures.saturating_add(1);
log_failure(consecutive_failures, backoff, "stream", label, name, &e);
sleep_or_wake(backoff, tracker).await;
backoff = next_backoff(backoff);
}
}
}
}
pub async fn wait_for_consumer<C>(
stream: &jetstream::stream::Stream,
client: &async_nats::Client,
tracker: &Tracker,
name: &str,
label: &'static str,
config: C,
) -> jetstream::consumer::Consumer<C>
where
C: jetstream::consumer::IntoConsumerConfig + jetstream::consumer::FromConsumer + Clone,
{
let mut backoff = INITIAL_BACKOFF;
let mut consecutive_failures: u32 = 0;
loop {
gate_on_connection(client, tracker, label, "consumer").await;
match stream.get_or_create_consumer(name, config.clone()).await {
Ok(c) => {
if consecutive_failures > 0 {
info!(
label,
kind = "consumer",
resource = name,
consecutive_failures,
"nats_retry: consumer recovered",
);
}
debug!(label, consumer = name, "nats_retry: consumer ready");
return c;
}
Err(e) => {
consecutive_failures = consecutive_failures.saturating_add(1);
log_failure(consecutive_failures, backoff, "consumer", label, name, &e);
sleep_or_wake(backoff, tracker).await;
backoff = next_backoff(backoff);
}
}
}
}
pub async fn reopen_pause() {
tokio::time::sleep(Duration::from_secs(1)).await;
}
fn log_failure(
consecutive: u32,
backoff: Duration,
kind: &'static str,
label: &'static str,
resource: &str,
error: &dyn std::fmt::Display,
) {
if consecutive == 1 {
warn!(
label,
kind,
resource,
error = %error,
backoff_secs = backoff.as_secs(),
"nats_retry: unavailable, retrying after backoff",
);
} else {
debug!(
label,
kind,
resource,
error = %error,
backoff_secs = backoff.as_secs(),
consecutive_failures = consecutive,
"nats_retry: still unavailable, retrying after backoff",
);
}
}
async fn gate_on_connection(
client: &async_nats::Client,
tracker: &Tracker,
label: &'static str,
kind: &'static str,
) {
if client.connection_state() == State::Connected {
return;
}
debug!(
label,
kind, "nats_retry: client not connected, waiting on Notify"
);
let _ = tokio::time::timeout(Duration::from_secs(5), tracker.wait_connected()).await;
}
async fn sleep_or_wake(d: Duration, tracker: &Tracker) {
tokio::select! {
biased;
_ = tracker.wait_connected() => {}
_ = tokio::time::sleep(jitter(d)) => {}
}
}
fn next_backoff(current: Duration) -> Duration {
let doubled = current.saturating_mul(2);
if doubled > MAX_BACKOFF {
MAX_BACKOFF
} else {
doubled
}
}
fn jitter(d: Duration) -> Duration {
let factor = 1.0 + rand::rng().random_range(-JITTER_FRACTION..=JITTER_FRACTION);
d.mul_f64(factor)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn next_backoff_doubles_then_caps() {
let mut d = INITIAL_BACKOFF;
let expected_secs = [1u64, 2, 4, 8, 16, 32, 64, 128, 256, 300, 300, 300];
let mut observed = Vec::new();
for _ in 0..expected_secs.len() {
observed.push(d.as_secs());
d = next_backoff(d);
}
assert_eq!(observed, expected_secs);
}
#[test]
fn jitter_stays_within_band() {
let base = Duration::from_secs(10);
for _ in 0..1000 {
let j = jitter(base);
let secs = j.as_secs_f64();
assert!(
(7.499..=12.501).contains(&secs),
"jitter sample {secs:.3}s outside ±25% band of 10s",
);
}
}
#[tokio::test(start_paused = true)]
async fn sleep_or_wake_returns_early_on_connected_event() {
let tracker = Tracker::new();
let cb = tracker.on_event();
let tracker_for_waiter = tracker.clone();
let waiter = tokio::spawn(async move {
sleep_or_wake(Duration::from_secs(60), &tracker_for_waiter).await;
});
tokio::task::yield_now().await;
tokio::task::yield_now().await;
let _ = cb(async_nats::Event::Connected).await;
waiter.await.expect("waiter task panicked");
}
#[tokio::test(start_paused = true)]
async fn sleep_or_wake_ignores_non_connected_events() {
let tracker = Tracker::new();
let cb = tracker.on_event();
let tracker_for_waiter = tracker.clone();
let mut waiter = tokio::spawn(async move {
sleep_or_wake(Duration::from_secs(60), &tracker_for_waiter).await;
});
tokio::task::yield_now().await;
tokio::task::yield_now().await;
let _ = cb(async_nats::Event::Disconnected).await;
for _ in 0..5 {
tokio::task::yield_now().await;
}
match futures::poll!(&mut waiter) {
std::task::Poll::Pending => { }
std::task::Poll::Ready(r) => {
panic!("waiter woke on Disconnected event (incorrectly): {r:?}");
}
}
waiter.abort();
}
}