use std::collections::HashMap;
use crate::ids::PeerId;
pub const DEFAULT_BASE_NS: u64 = 10_000_000;
pub const DEFAULT_MAX_DELAY_NS: u64 = 60_000_000_000;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct BackoffState {
pub attempts: u32,
pub last_attempt_ns: u64,
pub next_retry_ns: u64,
}
pub struct BackoffTable {
states: HashMap<PeerId, BackoffState>,
base_ns: u64,
max_delay_ns: u64,
}
impl Default for BackoffTable {
fn default() -> Self {
Self::new()
}
}
impl BackoffTable {
pub fn new() -> Self {
Self::with_schedule(DEFAULT_BASE_NS, DEFAULT_MAX_DELAY_NS)
}
pub fn with_schedule(base_ns: u64, max_delay_ns: u64) -> Self {
Self {
states: HashMap::new(),
base_ns: base_ns.max(1),
max_delay_ns: max_delay_ns.max(1),
}
}
pub fn record_failure(&mut self, peer: PeerId, now_ns: u64) {
let attempts = self
.states
.get(&peer)
.map(|s| s.attempts.saturating_add(1))
.unwrap_or(1);
let next_retry_ns = now_ns.saturating_add(self.delay_for(attempts));
self.states.insert(
peer,
BackoffState {
attempts,
last_attempt_ns: now_ns,
next_retry_ns,
},
);
}
pub fn record_remote_advisory(&mut self, peer: PeerId, now_ns: u64, min_backoff_ns: u64) {
let attempts = self
.states
.get(&peer)
.map(|s| s.attempts.saturating_add(1))
.unwrap_or(1);
let capped = min_backoff_ns.min(self.max_delay_ns);
let next_retry_ns = now_ns.saturating_add(capped);
self.states.insert(
peer,
BackoffState {
attempts,
last_attempt_ns: now_ns,
next_retry_ns,
},
);
}
pub fn record_success(&mut self, peer: PeerId) {
self.states.remove(&peer);
}
pub fn should_retry(&self, peer: PeerId, now_ns: u64) -> bool {
match self.states.get(&peer) {
None => true,
Some(state) => now_ns >= state.next_retry_ns,
}
}
pub fn state(&self, peer: PeerId) -> Option<BackoffState> {
self.states.get(&peer).copied()
}
pub fn len(&self) -> usize {
self.states.len()
}
pub fn is_empty(&self) -> bool {
self.states.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (PeerId, BackoffState)> + '_ {
self.states.iter().map(|(p, s)| (*p, *s))
}
pub fn restore_state(&mut self, peer: PeerId, state: BackoffState) {
self.states.insert(peer, state);
}
fn delay_for(&self, attempts: u32) -> u64 {
let shift = attempts.saturating_sub(1).min(63);
let factor = 1u64.checked_shl(shift).unwrap_or(u64::MAX);
self.base_ns.saturating_mul(factor).min(self.max_delay_ns)
}
}