use std::time::Duration;
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;
pub(crate) const MAX_CONSECUTIVE_SAME_ERROR: u32 = 3;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum StreamPhase {
InitialSync,
Supervisor,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum ErrorKey {
ClientCreation,
StreamConnect,
UpdateRejected,
NoIdentityIssued,
}
pub(crate) struct ErrorTracker {
last_error_kind: Option<ErrorKey>,
consecutive_same_error: u32,
max_consecutive: u32,
}
impl ErrorTracker {
pub(crate) const fn new(max_consecutive: u32) -> Self {
Self {
last_error_kind: None,
consecutive_same_error: 0,
max_consecutive,
}
}
pub(crate) fn record_error(&mut self, error_kind: ErrorKey) -> bool {
let should_warn = self.last_error_kind != Some(error_kind)
|| self.consecutive_same_error < self.max_consecutive;
if self.last_error_kind == Some(error_kind) {
self.consecutive_same_error += 1;
} else {
self.consecutive_same_error = 1;
self.last_error_kind = Some(error_kind);
}
should_warn
}
pub(crate) const fn reset(&mut self) {
self.consecutive_same_error = 0;
self.last_error_kind = None;
}
pub(crate) const fn consecutive_count(&self) -> u32 {
self.consecutive_same_error
}
pub(crate) const fn last_error_kind(&self) -> Option<ErrorKey> {
self.last_error_kind
}
}
pub(crate) async fn sleep_or_cancel(token: &CancellationToken, dur: Duration) -> bool {
tokio::select! {
() = token.cancelled() => true,
() = sleep(dur) => false,
}
}
pub(crate) fn next_backoff(current: Duration, max: Duration) -> Duration {
let cur = u64::try_from(current.as_millis()).unwrap_or(u64::MAX);
let max = u64::try_from(max.as_millis()).unwrap_or(u64::MAX);
let base = (cur.saturating_mul(2)).min(max);
if base == 0 {
return Duration::from_millis(0);
}
let jitter = base / 10;
let add = if jitter > 0 {
fastrand::u64(0..=jitter)
} else {
0
};
let jitter_base = base.saturating_sub(jitter);
Duration::from_millis(jitter_base.saturating_add(add))
}
pub(crate) fn next_backoff_for_no_identity(current: Duration, max: Duration) -> Duration {
const MIN_BACKOFF_MS: u64 = 1000; const DEFAULT_MAX_BACKOFF_MS: u64 = 10000;
let max_ms = u64::try_from(max.as_millis()).unwrap_or(u64::MAX);
let effective_max = max_ms.min(DEFAULT_MAX_BACKOFF_MS);
let current_with_min = current.max(Duration::from_millis(MIN_BACKOFF_MS));
next_backoff(current_with_min, Duration::from_millis(effective_max))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn next_backoff_at_max_preserves_jitter() {
let max = Duration::from_secs(30);
let lo = max.saturating_sub(max / 10);
for _ in 0..100 {
let result = next_backoff(max, max);
assert!(
result >= lo && result <= max,
"expected backoff in [{lo:?}, {max:?}], got {result:?}"
);
}
let mut results = std::collections::HashSet::new();
for _ in 0..100 {
results.insert(next_backoff(max, max).as_millis());
}
assert!(
results.len() > 1,
"expected jitter to produce varying results, got {results:?}"
);
}
#[test]
fn no_identity_backoff_starts_at_minimum_1s() {
let result =
next_backoff_for_no_identity(Duration::from_millis(100), Duration::from_secs(30));
assert!(
result >= Duration::from_millis(1800),
"expected >= 1800ms (2s - 10% jitter), got {}ms",
result.as_millis()
);
}
#[test]
fn no_identity_backoff_respects_default_10s_cap() {
let result = next_backoff_for_no_identity(Duration::from_secs(8), Duration::from_secs(60));
assert!(
result <= Duration::from_secs(11),
"expected <= 11s (10s + jitter), got {}ms",
result.as_millis()
);
}
#[test]
fn no_identity_backoff_respects_user_max_below_default() {
let result = next_backoff_for_no_identity(Duration::from_secs(2), Duration::from_secs(3));
assert!(
result <= Duration::from_millis(3300),
"expected <= 3.3s (3s + jitter), got {}ms",
result.as_millis()
);
}
#[test]
fn no_identity_backoff_grows_exponentially() {
let first = next_backoff_for_no_identity(Duration::from_secs(1), Duration::from_secs(30));
let second = next_backoff_for_no_identity(first, Duration::from_secs(30));
assert!(
second > first,
"expected growth: first={}ms, second={}ms",
first.as_millis(),
second.as_millis()
);
}
}