use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectionDisruptionConfig {
#[serde(default)]
pub drop_after_requests: Option<u64>,
#[serde(default)]
pub drop_mid_frame: bool,
#[serde(default)]
pub rst_after_partial: bool,
#[serde(default = "default_partial_bytes")]
pub partial_byte_count: usize,
#[serde(default)]
pub hold_open_timeout: Option<Duration>,
#[serde(default)]
pub periodic_drop: Option<u64>,
#[serde(default)]
pub close_delay: Option<Duration>,
#[serde(default)]
pub use_rst: bool,
}
fn default_partial_bytes() -> usize {
4
}
impl Default for ConnectionDisruptionConfig {
fn default() -> Self {
Self {
drop_after_requests: None,
drop_mid_frame: false,
rst_after_partial: false,
partial_byte_count: default_partial_bytes(),
hold_open_timeout: None,
periodic_drop: None,
close_delay: None,
use_rst: false,
}
}
}
impl ConnectionDisruptionConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_drop_after_requests(mut self, n: u64) -> Self {
self.drop_after_requests = Some(n);
self
}
pub fn with_drop_mid_frame(mut self) -> Self {
self.drop_mid_frame = true;
self
}
pub fn with_rst_after_partial(mut self, byte_count: usize) -> Self {
self.rst_after_partial = true;
self.partial_byte_count = byte_count;
self
}
pub fn with_hold_open_timeout(mut self, duration: Duration) -> Self {
self.hold_open_timeout = Some(duration);
self
}
pub fn with_periodic_drop(mut self, n: u64) -> Self {
self.periodic_drop = Some(n);
self
}
pub fn with_rst(mut self) -> Self {
self.use_rst = true;
self
}
pub fn with_close_delay(mut self, delay: Duration) -> Self {
self.close_delay = Some(delay);
self
}
}
pub struct ConnectionDisruptionState {
config: ConnectionDisruptionConfig,
request_count: AtomicU64,
holding_open: AtomicBool,
}
impl ConnectionDisruptionState {
pub fn new(config: ConnectionDisruptionConfig) -> Self {
Self {
config,
request_count: AtomicU64::new(0),
holding_open: AtomicBool::new(false),
}
}
pub fn record_request(&self) -> DisruptionAction {
let count = self.request_count.fetch_add(1, Ordering::Relaxed) + 1;
if let Some(threshold) = self.config.drop_after_requests {
if count >= threshold {
return self.build_disconnect_action();
}
}
if let Some(period) = self.config.periodic_drop {
if period > 0 && count % period == 0 {
return self.build_disconnect_action();
}
}
if let Some(hold_duration) = self.config.hold_open_timeout {
if count >= self.config.drop_after_requests.unwrap_or(u64::MAX) {
return DisruptionAction::HoldOpen {
duration: hold_duration,
};
}
}
DisruptionAction::None
}
fn build_disconnect_action(&self) -> DisruptionAction {
if self.config.rst_after_partial {
DisruptionAction::RstAfterPartial {
byte_count: self.config.partial_byte_count,
close_delay: self.config.close_delay,
use_rst: self.config.use_rst,
}
} else if self.config.drop_mid_frame {
DisruptionAction::DropMidFrame {
close_delay: self.config.close_delay,
use_rst: self.config.use_rst,
}
} else {
DisruptionAction::Disconnect {
close_delay: self.config.close_delay,
use_rst: self.config.use_rst,
}
}
}
pub fn request_count(&self) -> u64 {
self.request_count.load(Ordering::Acquire)
}
pub fn is_holding_open(&self) -> bool {
self.holding_open.load(Ordering::Acquire)
}
pub fn set_holding_open(&self, holding: bool) {
self.holding_open.store(holding, Ordering::Release);
}
pub fn reset(&self) {
self.request_count.store(0, Ordering::Release);
self.holding_open.store(false, Ordering::Release);
}
}
#[derive(Debug, Clone)]
pub enum DisruptionAction {
None,
Disconnect {
close_delay: Option<Duration>,
use_rst: bool,
},
DropMidFrame {
close_delay: Option<Duration>,
use_rst: bool,
},
RstAfterPartial {
byte_count: usize,
close_delay: Option<Duration>,
use_rst: bool,
},
HoldOpen {
duration: Duration,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = ConnectionDisruptionConfig::default();
assert!(config.drop_after_requests.is_none());
assert!(!config.drop_mid_frame);
assert!(!config.rst_after_partial);
assert!(config.periodic_drop.is_none());
}
#[test]
fn test_drop_after_requests() {
let config = ConnectionDisruptionConfig::new().with_drop_after_requests(3);
let state = ConnectionDisruptionState::new(config);
assert!(matches!(state.record_request(), DisruptionAction::None));
assert!(matches!(state.record_request(), DisruptionAction::None));
assert!(matches!(
state.record_request(),
DisruptionAction::Disconnect { .. }
));
}
#[test]
fn test_periodic_drop() {
let config = ConnectionDisruptionConfig::new().with_periodic_drop(2);
let state = ConnectionDisruptionState::new(config);
assert!(matches!(state.record_request(), DisruptionAction::None)); assert!(matches!(
state.record_request(),
DisruptionAction::Disconnect { .. }
)); assert!(matches!(state.record_request(), DisruptionAction::None)); assert!(matches!(
state.record_request(),
DisruptionAction::Disconnect { .. }
)); }
#[test]
fn test_rst_after_partial() {
let config = ConnectionDisruptionConfig::new()
.with_drop_after_requests(1)
.with_rst_after_partial(6);
let state = ConnectionDisruptionState::new(config);
match state.record_request() {
DisruptionAction::RstAfterPartial { byte_count, .. } => {
assert_eq!(byte_count, 6);
}
other => panic!("Expected RstAfterPartial, got {:?}", other),
}
}
#[test]
fn test_drop_mid_frame() {
let config = ConnectionDisruptionConfig::new()
.with_drop_after_requests(1)
.with_drop_mid_frame();
let state = ConnectionDisruptionState::new(config);
assert!(matches!(
state.record_request(),
DisruptionAction::DropMidFrame { .. }
));
}
#[test]
fn test_config_builder() {
let config = ConnectionDisruptionConfig::new()
.with_drop_after_requests(5)
.with_rst()
.with_close_delay(Duration::from_millis(100));
assert_eq!(config.drop_after_requests, Some(5));
assert!(config.use_rst);
assert_eq!(config.close_delay, Some(Duration::from_millis(100)));
}
#[test]
fn test_state_reset() {
let config = ConnectionDisruptionConfig::new().with_drop_after_requests(2);
let state = ConnectionDisruptionState::new(config);
state.record_request(); assert_eq!(state.request_count(), 1);
state.reset();
assert_eq!(state.request_count(), 0);
assert!(matches!(state.record_request(), DisruptionAction::None)); assert!(matches!(
state.record_request(),
DisruptionAction::Disconnect { .. }
)); }
#[test]
fn test_no_disruption_by_default() {
let config = ConnectionDisruptionConfig::default();
let state = ConnectionDisruptionState::new(config);
for _ in 0..100 {
assert!(matches!(state.record_request(), DisruptionAction::None));
}
}
#[test]
fn test_use_rst_flag() {
let config = ConnectionDisruptionConfig::new()
.with_drop_after_requests(1)
.with_rst();
let state = ConnectionDisruptionState::new(config);
match state.record_request() {
DisruptionAction::Disconnect { use_rst, .. } => {
assert!(use_rst);
}
other => panic!("Expected Disconnect, got {:?}", other),
}
}
}