use crate::util::backoff::{ExponentialBackoff, TrackedBackoff};
use std::time::Duration;
use super::Location;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionFailureReason {
RoutingFailed,
Timeout,
Rejected,
#[allow(dead_code)]
NatPunchFailed,
#[allow(dead_code)]
HandshakeError,
#[allow(dead_code)]
TransientError,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct LocationBucket(u8);
impl LocationBucket {
fn from_location(loc: Location) -> Self {
let bucket = (loc.as_f64() * 256.0).min(255.0) as u8;
Self(bucket)
}
}
#[derive(Debug)]
pub struct ConnectionBackoff {
inner: TrackedBackoff<LocationBucket>,
}
impl Default for ConnectionBackoff {
fn default() -> Self {
Self::new()
}
}
impl ConnectionBackoff {
const DEFAULT_BASE_INTERVAL: Duration = Duration::from_secs(30);
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(600);
const DEFAULT_MAX_ENTRIES: usize = 256;
const TIMEOUT_FAILURE_ESCALATION: u32 = 2;
pub fn new() -> Self {
let config =
ExponentialBackoff::new(Self::DEFAULT_BASE_INTERVAL, Self::DEFAULT_MAX_BACKOFF);
Self {
inner: TrackedBackoff::new(config, Self::DEFAULT_MAX_ENTRIES),
}
}
#[cfg(test)]
pub fn with_config(base_interval: Duration, max_backoff: Duration, max_entries: usize) -> Self {
let config = ExponentialBackoff::new(base_interval, max_backoff);
Self {
inner: TrackedBackoff::new(config, max_entries),
}
}
pub fn is_in_backoff(&self, target: Location) -> bool {
let bucket = LocationBucket::from_location(target);
self.inner.is_in_backoff(&bucket)
}
#[allow(dead_code)] pub fn record_failure(&mut self, target: Location) {
self.record_failure_with_reason(target, ConnectionFailureReason::RoutingFailed);
}
pub fn record_failure_with_reason(
&mut self,
target: Location,
reason: ConnectionFailureReason,
) {
let bucket = LocationBucket::from_location(target);
let failures_before = self.inner.failure_count(&bucket);
let num_failures_to_record = match reason {
ConnectionFailureReason::Timeout => {
Self::TIMEOUT_FAILURE_ESCALATION
}
ConnectionFailureReason::TransientError => {
1
}
ConnectionFailureReason::RoutingFailed
| ConnectionFailureReason::Rejected
| ConnectionFailureReason::NatPunchFailed
| ConnectionFailureReason::HandshakeError => {
1
}
};
for _ in 0..num_failures_to_record {
self.inner.record_failure(bucket);
}
let actual_failures_after = self.inner.failure_count(&bucket);
let backoff = self
.inner
.config()
.delay_for_failures(actual_failures_after);
tracing::debug!(
bucket = bucket.0,
failures_before = failures_before,
failures_after = actual_failures_after,
reason = ?reason,
backoff_secs = backoff.as_secs(),
"Connection target in backoff (with reason)"
);
}
pub fn record_success(&mut self, target: Location) {
let bucket = LocationBucket::from_location(target);
if self.inner.failure_count(&bucket) > 0 {
tracing::debug!(bucket = bucket.0, "Connection target backoff cleared");
}
self.inner.record_success(&bucket);
}
pub fn cleanup_expired(&mut self) {
self.inner.cleanup_expired();
}
pub fn clear(&mut self) {
self.inner.clear();
}
#[cfg(test)]
fn failure_count(&self, target: Location) -> u32 {
let bucket = LocationBucket::from_location(target);
self.inner.failure_count(&bucket)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backoff_not_in_backoff_initially() {
let backoff = ConnectionBackoff::new();
let loc = Location::new(0.5);
assert!(!backoff.is_in_backoff(loc));
}
#[test]
fn test_backoff_after_failure() {
let mut backoff = ConnectionBackoff::new();
let loc = Location::new(0.5);
backoff.record_failure(loc);
assert!(backoff.is_in_backoff(loc));
}
#[test]
fn test_backoff_cleared_on_success() {
let mut backoff = ConnectionBackoff::new();
let loc = Location::new(0.5);
backoff.record_failure(loc);
assert!(backoff.is_in_backoff(loc));
backoff.record_success(loc);
assert!(!backoff.is_in_backoff(loc));
}
#[test]
fn test_exponential_backoff_calculation() {
let config = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(300));
assert_eq!(config.delay_for_failures(1), Duration::from_secs(1));
assert_eq!(config.delay_for_failures(2), Duration::from_secs(2));
assert_eq!(config.delay_for_failures(3), Duration::from_secs(4));
assert_eq!(config.delay_for_failures(4), Duration::from_secs(8));
}
#[test]
fn test_backoff_capped_at_max() {
let config = ExponentialBackoff::new(Duration::from_secs(10), Duration::from_secs(60));
assert_eq!(config.delay_for_failures(10), Duration::from_secs(60));
assert_eq!(config.delay_for_failures(20), Duration::from_secs(60));
}
#[test]
fn test_nearby_locations_share_bucket() {
let mut backoff = ConnectionBackoff::new();
let loc1 = Location::new(0.500);
let loc2 = Location::new(0.501);
backoff.record_failure(loc1);
assert!(backoff.is_in_backoff(loc2));
}
#[test]
fn test_distant_locations_different_buckets() {
let mut backoff = ConnectionBackoff::new();
let loc1 = Location::new(0.1);
let loc2 = Location::new(0.9);
backoff.record_failure(loc1);
assert!(!backoff.is_in_backoff(loc2));
}
#[test]
fn test_eviction_when_max_entries_exceeded() {
let mut backoff = ConnectionBackoff::with_config(
Duration::from_secs(5),
Duration::from_secs(300),
10, );
for i in 0..20 {
let loc = Location::new(i as f64 / 256.0);
backoff.record_failure(loc);
}
assert!(backoff.inner.len() <= 10);
}
#[test]
fn test_consecutive_failures_increase_backoff() {
let mut backoff =
ConnectionBackoff::with_config(Duration::from_secs(1), Duration::from_secs(300), 256);
let loc = Location::new(0.5);
backoff.record_failure(loc);
assert_eq!(backoff.failure_count(loc), 1);
backoff.record_failure(loc);
assert_eq!(backoff.failure_count(loc), 2);
}
#[test]
fn test_failure_reason_timeout_escalates_faster() {
let mut backoff =
ConnectionBackoff::with_config(Duration::from_secs(1), Duration::from_secs(300), 256);
let loc1 = Location::new(0.5);
let loc2 = Location::new(0.6);
backoff.record_failure_with_reason(loc1, ConnectionFailureReason::Timeout);
backoff.record_failure_with_reason(loc2, ConnectionFailureReason::RoutingFailed);
assert!(backoff.is_in_backoff(loc1));
assert!(backoff.is_in_backoff(loc2));
assert_eq!(
backoff.failure_count(loc1),
2,
"Timeout should escalate to 2 failures"
);
assert_eq!(
backoff.failure_count(loc2),
1,
"Routing failure should stay at 1 failure"
);
let timeout_backoff = backoff.inner.config().delay_for_failures(2);
let routing_backoff = backoff.inner.config().delay_for_failures(1);
assert!(
timeout_backoff > routing_backoff,
"Timeout backoff should be longer"
);
}
#[test]
fn test_failure_reason_transient_error_minimal_backoff() {
let mut backoff =
ConnectionBackoff::with_config(Duration::from_secs(1), Duration::from_secs(300), 256);
let loc = Location::new(0.5);
backoff.record_failure_with_reason(loc, ConnectionFailureReason::TransientError);
assert_eq!(backoff.failure_count(loc), 1);
assert!(backoff.is_in_backoff(loc));
}
#[test]
fn test_failure_reason_rejected_normal_backoff() {
let mut backoff =
ConnectionBackoff::with_config(Duration::from_secs(1), Duration::from_secs(300), 256);
let loc = Location::new(0.5);
backoff.record_failure_with_reason(loc, ConnectionFailureReason::Rejected);
assert_eq!(backoff.failure_count(loc), 1);
assert!(backoff.is_in_backoff(loc));
}
#[test]
fn test_all_failure_reasons_recorded() {
let mut backoff =
ConnectionBackoff::with_config(Duration::from_secs(1), Duration::from_secs(300), 256);
let reasons = [
ConnectionFailureReason::RoutingFailed,
ConnectionFailureReason::Timeout,
ConnectionFailureReason::Rejected,
ConnectionFailureReason::NatPunchFailed,
ConnectionFailureReason::HandshakeError,
ConnectionFailureReason::TransientError,
];
for (i, reason) in reasons.iter().enumerate() {
let loc = Location::new(i as f64 / 256.0);
backoff.record_failure_with_reason(loc, *reason);
assert!(
backoff.is_in_backoff(loc),
"Location should be in backoff after {:?}",
reason
);
}
}
#[test]
fn test_success_clears_all_failure_reasons() {
let mut backoff =
ConnectionBackoff::with_config(Duration::from_secs(1), Duration::from_secs(300), 256);
let loc = Location::new(0.5);
backoff.record_failure_with_reason(loc, ConnectionFailureReason::Timeout);
assert!(backoff.is_in_backoff(loc));
backoff.record_success(loc);
assert!(!backoff.is_in_backoff(loc));
}
#[test]
fn test_clear_removes_all_backoff_state() {
let mut backoff =
ConnectionBackoff::with_config(Duration::from_secs(1), Duration::from_secs(300), 256);
let loc1 = Location::new(0.3);
let loc2 = Location::new(0.7);
backoff.record_failure(loc1);
backoff.record_failure(loc2);
assert!(backoff.is_in_backoff(loc1));
assert!(backoff.is_in_backoff(loc2));
backoff.clear();
assert!(!backoff.is_in_backoff(loc1));
assert!(!backoff.is_in_backoff(loc2));
assert_eq!(backoff.inner.len(), 0);
}
}