use crate::NodeAddr;
use crate::time::{Instant, instant_now};
use std::collections::HashMap;
use std::time::Duration;
const DEFAULT_BACKOFF_BASE_SECS: u64 = 0;
const DEFAULT_BACKOFF_MAX_SECS: u64 = 0;
const BACKOFF_MULTIPLIER: u64 = 2;
pub struct DiscoveryBackoff {
entries: HashMap<NodeAddr, BackoffEntry>,
base: Duration,
max: Duration,
}
struct BackoffEntry {
suppress_until: Instant,
failures: u32,
}
impl DiscoveryBackoff {
pub fn new() -> Self {
Self::with_params(DEFAULT_BACKOFF_BASE_SECS, DEFAULT_BACKOFF_MAX_SECS)
}
pub fn with_params(base_secs: u64, max_secs: u64) -> Self {
Self {
entries: HashMap::new(),
base: Duration::from_secs(base_secs),
max: Duration::from_secs(max_secs),
}
}
pub fn is_suppressed(&self, target: &NodeAddr) -> bool {
if let Some(entry) = self.entries.get(target) {
instant_now() < entry.suppress_until
} else {
false
}
}
pub fn record_failure(&mut self, target: &NodeAddr) {
let now = instant_now();
let failures = self.entries.get(target).map_or(0, |e| e.failures) + 1;
let backoff_secs = self
.base
.as_secs()
.saturating_mul(BACKOFF_MULTIPLIER.saturating_pow(failures.saturating_sub(1)));
let backoff = Duration::from_secs(backoff_secs.min(self.max.as_secs()));
self.entries.insert(
*target,
BackoffEntry {
suppress_until: now + backoff,
failures,
},
);
}
pub fn record_success(&mut self, target: &NodeAddr) {
self.entries.remove(target);
}
pub fn reset_all(&mut self) {
self.entries.clear();
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn entry_count(&self) -> usize {
self.entries.len()
}
pub fn failure_count(&self, target: &NodeAddr) -> u32 {
self.entries.get(target).map_or(0, |e| e.failures)
}
#[cfg(test)]
pub fn len(&self) -> usize {
self.entries.len()
}
}
impl Default for DiscoveryBackoff {
fn default() -> Self {
Self::new()
}
}
const DEFAULT_FORWARD_MIN_INTERVAL: Duration = Duration::from_secs(2);
const FORWARD_MAX_AGE: Duration = Duration::from_secs(60);
pub struct DiscoveryForwardRateLimiter {
last_forwarded: HashMap<NodeAddr, Instant>,
min_interval: Duration,
max_age: Duration,
}
impl DiscoveryForwardRateLimiter {
pub fn new() -> Self {
Self {
last_forwarded: HashMap::new(),
min_interval: DEFAULT_FORWARD_MIN_INTERVAL,
max_age: FORWARD_MAX_AGE,
}
}
pub fn with_interval(min_interval: Duration) -> Self {
Self {
last_forwarded: HashMap::new(),
min_interval,
max_age: FORWARD_MAX_AGE,
}
}
pub fn should_forward(&mut self, target: &NodeAddr) -> bool {
let now = instant_now();
if let Some(&last) = self.last_forwarded.get(target)
&& now.duration_since(last) < self.min_interval
{
return false;
}
self.last_forwarded.insert(*target, now);
self.cleanup(now);
true
}
#[cfg(test)]
pub fn set_interval(&mut self, interval: Duration) {
self.min_interval = interval;
}
fn cleanup(&mut self, now: Instant) {
self.last_forwarded
.retain(|_, &mut last| now.duration_since(last) < self.max_age);
}
#[cfg(test)]
pub fn len(&self) -> usize {
self.last_forwarded.len()
}
}
impl Default for DiscoveryForwardRateLimiter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
fn addr(val: u8) -> NodeAddr {
let mut bytes = [0u8; 16];
bytes[0] = val;
NodeAddr::from_bytes(bytes)
}
#[test]
fn test_backoff_not_suppressed_initially() {
let backoff = DiscoveryBackoff::new();
assert!(!backoff.is_suppressed(&addr(1)));
}
#[test]
fn test_backoff_suppressed_after_failure() {
let mut backoff = DiscoveryBackoff::with_params(30, 300);
backoff.record_failure(&addr(1));
assert!(backoff.is_suppressed(&addr(1)));
assert!(!backoff.is_suppressed(&addr(2)));
}
#[test]
fn test_backoff_cleared_on_success() {
let mut backoff = DiscoveryBackoff::with_params(30, 300);
backoff.record_failure(&addr(1));
assert!(backoff.is_suppressed(&addr(1)));
backoff.record_success(&addr(1));
assert!(!backoff.is_suppressed(&addr(1)));
}
#[test]
fn test_backoff_reset_all() {
let mut backoff = DiscoveryBackoff::new();
backoff.record_failure(&addr(1));
backoff.record_failure(&addr(2));
assert_eq!(backoff.len(), 2);
backoff.reset_all();
assert_eq!(backoff.len(), 0);
assert!(!backoff.is_suppressed(&addr(1)));
}
#[test]
fn test_backoff_exponential() {
let mut backoff = DiscoveryBackoff::with_params(1, 300);
backoff.record_failure(&addr(1));
assert_eq!(backoff.failure_count(&addr(1)), 1);
backoff.record_failure(&addr(1));
assert_eq!(backoff.failure_count(&addr(1)), 2);
backoff.record_failure(&addr(1));
assert_eq!(backoff.failure_count(&addr(1)), 3);
}
#[test]
fn test_backoff_expires() {
let mut backoff = DiscoveryBackoff::with_params(0, 0);
backoff.record_failure(&addr(1));
assert!(!backoff.is_suppressed(&addr(1)));
}
#[test]
fn test_backoff_capped() {
let mut backoff = DiscoveryBackoff::with_params(1, 10);
for _ in 0..20 {
backoff.record_failure(&addr(1));
}
let entry = backoff.entries.get(&addr(1)).unwrap();
let remaining = entry.suppress_until.duration_since(Instant::now());
assert!(remaining <= Duration::from_secs(11));
}
#[test]
fn test_forward_first_allowed() {
let mut limiter = DiscoveryForwardRateLimiter::new();
assert!(limiter.should_forward(&addr(1)));
}
#[test]
fn test_forward_rapid_rate_limited() {
let mut limiter = DiscoveryForwardRateLimiter::new();
assert!(limiter.should_forward(&addr(1)));
assert!(!limiter.should_forward(&addr(1)));
assert!(!limiter.should_forward(&addr(1)));
}
#[test]
fn test_forward_different_targets_independent() {
let mut limiter = DiscoveryForwardRateLimiter::new();
assert!(limiter.should_forward(&addr(1)));
assert!(limiter.should_forward(&addr(2)));
assert!(!limiter.should_forward(&addr(1)));
assert!(!limiter.should_forward(&addr(2)));
}
#[test]
fn test_forward_allowed_after_interval() {
let mut limiter = DiscoveryForwardRateLimiter::with_interval(Duration::from_millis(100));
assert!(limiter.should_forward(&addr(1)));
thread::sleep(Duration::from_millis(110));
assert!(limiter.should_forward(&addr(1)));
}
#[test]
fn test_forward_cleanup_removes_old() {
let mut limiter = DiscoveryForwardRateLimiter::new();
assert!(limiter.should_forward(&addr(1)));
assert!(limiter.should_forward(&addr(2)));
assert_eq!(limiter.len(), 2);
let future = Instant::now() + Duration::from_secs(61);
limiter.cleanup(future);
assert_eq!(limiter.len(), 0);
}
#[test]
fn test_forward_cleanup_preserves_recent() {
let mut limiter = DiscoveryForwardRateLimiter::new();
assert!(limiter.should_forward(&addr(1)));
assert_eq!(limiter.len(), 1);
limiter.cleanup(Instant::now());
assert_eq!(limiter.len(), 1);
}
}