use std::{
collections::HashMap,
io::ErrorKind,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
},
time::Duration,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::Mutex as TokioMutex,
task::JoinHandle,
time::sleep,
};
use tokio_netem::{
delayer::DynamicDuration, io::NetEmWriteExt, probability::DynamicProbability,
slicer::DynamicSize, terminator::Terminator, throttler::DynamicRate,
};
pub struct TcpDisconnectProxy {
base_url: String,
paused: Arc<AtomicBool>,
impairments: Arc<ProxyImpairments>,
active_connections: Arc<TokioMutex<HashMap<u64, JoinHandle<()>>>>,
accept_task: JoinHandle<()>,
}
struct ProxyImpairments {
blackholed: AtomicBool,
latency_ms: AtomicU64,
stall_every_n_chunks: AtomicU64,
stall_duration_ms: AtomicU64,
netem_write_delay: Arc<DynamicDuration>,
netem_write_rate: Arc<DynamicRate>,
netem_write_slice_size: Arc<DynamicSize>,
netem_termination_probability: Arc<DynamicProbability>,
}
impl ProxyImpairments {
fn new() -> Self {
Self {
blackholed: AtomicBool::new(false),
latency_ms: AtomicU64::new(0),
stall_every_n_chunks: AtomicU64::new(0),
stall_duration_ms: AtomicU64::new(0),
netem_write_delay: DynamicDuration::new(Duration::ZERO),
netem_write_rate: DynamicRate::new(0),
netem_write_slice_size: DynamicSize::new(0),
netem_termination_probability: DynamicProbability::new(0.0)
.expect("zero termination probability should be valid"),
}
}
async fn wait_before_forwarding(&self, chunk_index: u64) {
while self.blackholed.load(Ordering::SeqCst) {
sleep(Duration::from_millis(25)).await;
}
let latency_ms = self.latency_ms.load(Ordering::SeqCst);
if latency_ms > 0 {
sleep(Duration::from_millis(latency_ms)).await;
}
let stall_every_n_chunks = self.stall_every_n_chunks.load(Ordering::SeqCst);
let stall_duration_ms = self.stall_duration_ms.load(Ordering::SeqCst);
if stall_every_n_chunks > 0
&& stall_duration_ms > 0
&& chunk_index % stall_every_n_chunks == 0
{
sleep(Duration::from_millis(stall_duration_ms)).await;
}
}
}
impl TcpDisconnectProxy {
pub async fn start(target_base_url: &str) -> Self {
let target_addr = extract_host_port(target_base_url);
let listener =
bind_loopback_listener().await.expect("proxy should bind to an ephemeral port");
let bind_addr = listener.local_addr().expect("proxy should have a local addr");
let paused = Arc::new(AtomicBool::new(false));
let impairments = Arc::new(ProxyImpairments::new());
let active_connections = Arc::new(TokioMutex::new(HashMap::new()));
let next_id = Arc::new(AtomicU64::new(1));
let paused_clone = paused.clone();
let impairments_clone = impairments.clone();
let active_clone = active_connections.clone();
let next_id_clone = next_id.clone();
let accept_task = tokio::spawn(async move {
while let Ok((mut inbound, _peer)) = listener.accept().await {
if paused_clone.load(Ordering::SeqCst) {
let _ = inbound.shutdown().await;
drop(inbound);
continue;
}
let id = next_id_clone.fetch_add(1, Ordering::SeqCst);
let target_addr = target_addr.clone();
let active_for_task = active_clone.clone();
let impairments_for_task = impairments_clone.clone();
let task = tokio::spawn(async move {
if let Ok(mut outbound) = TcpStream::connect(&target_addr).await {
let (inbound_reader, inbound_writer) = inbound.split();
let (outbound_reader, outbound_writer) = outbound.split();
let _ = tokio::try_join!(
relay_with_impairments(
inbound_reader,
outbound_writer,
impairments_for_task.clone(),
),
relay_with_impairments(
outbound_reader,
inbound_writer,
impairments_for_task,
),
);
}
active_for_task.lock().await.remove(&id);
});
active_clone.lock().await.insert(id, task);
}
});
Self {
base_url: format!("http://{}", bind_addr),
paused,
impairments,
active_connections,
accept_task,
}
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn pause(&self) {
self.paused.store(true, Ordering::SeqCst);
}
pub fn resume(&self) {
self.paused.store(false, Ordering::SeqCst);
}
pub fn blackhole(&self) {
self.impairments.blackholed.store(true, Ordering::SeqCst);
}
pub fn restore_traffic(&self) {
self.impairments.blackholed.store(false, Ordering::SeqCst);
}
pub fn set_latency(&self, latency: Duration) {
self.impairments.latency_ms.store(latency.as_millis() as u64, Ordering::SeqCst);
}
pub fn clear_latency(&self) {
self.impairments.latency_ms.store(0, Ordering::SeqCst);
}
pub fn set_chunk_stall_pattern(&self, every_n_chunks: u64, stall_duration: Duration) {
self.impairments.stall_every_n_chunks.store(every_n_chunks, Ordering::SeqCst);
self.impairments
.stall_duration_ms
.store(stall_duration.as_millis() as u64, Ordering::SeqCst);
}
pub fn clear_chunk_stall_pattern(&self) {
self.impairments.stall_every_n_chunks.store(0, Ordering::SeqCst);
self.impairments.stall_duration_ms.store(0, Ordering::SeqCst);
}
pub fn set_netem_write_delay(&self, delay: Duration) {
self.impairments.netem_write_delay.set(delay);
}
pub fn clear_netem_write_delay(&self) {
self.impairments.netem_write_delay.set(Duration::ZERO);
}
pub fn set_netem_write_rate(&self, bytes_per_second: usize) {
self.impairments.netem_write_rate.set(bytes_per_second);
}
pub fn clear_netem_write_rate(&self) {
self.impairments.netem_write_rate.set(0);
}
pub fn set_netem_write_slice_size(&self, size: usize) {
self.impairments.netem_write_slice_size.set(size);
}
pub fn clear_netem_write_slice_size(&self) {
self.impairments.netem_write_slice_size.set(0);
}
pub fn set_netem_termination_probability(&self, probability: f64) {
self.impairments
.netem_termination_probability
.set(probability)
.expect("termination probability should be between 0.0 and 1.0");
}
pub fn clear_netem_termination_probability(&self) {
self.impairments
.netem_termination_probability
.set(0.0)
.expect("zero termination probability should be valid");
}
pub async fn drop_active_connections(&self) {
let mut active = self.active_connections.lock().await;
for (_id, task) in active.drain() {
task.abort();
}
}
pub async fn active_count(&self) -> usize {
self.active_connections.lock().await.len()
}
pub async fn wait_for_active_connections(
&self,
min_count: usize,
timeout_dur: Duration,
) -> bool {
let start = std::time::Instant::now();
loop {
if self.active_connections.lock().await.len() >= min_count {
return true;
}
if start.elapsed() >= timeout_dur {
return false;
}
sleep(Duration::from_millis(50)).await;
}
}
pub async fn simulate_server_down(&self) {
self.pause();
self.drop_active_connections().await;
}
pub fn simulate_server_up(&self) {
self.resume();
}
pub async fn shutdown(self) {
self.accept_task.abort();
self.drop_active_connections().await;
}
}
async fn bind_loopback_listener() -> std::io::Result<TcpListener> {
let mut last_error = None;
for _ in 0..20 {
match TcpListener::bind("127.0.0.1:0").await {
Ok(listener) => return Ok(listener),
Err(err)
if matches!(err.kind(), ErrorKind::AddrNotAvailable | ErrorKind::AddrInUse) =>
{
last_error = Some(err);
sleep(Duration::from_millis(50)).await;
},
Err(err) => return Err(err),
}
}
Err(last_error.unwrap_or_else(|| {
std::io::Error::new(
ErrorKind::AddrNotAvailable,
"failed to bind loopback listener after retries",
)
}))
}
async fn relay_with_impairments(
mut reader: tokio::net::tcp::ReadHalf<'_>,
writer: tokio::net::tcp::WriteHalf<'_>,
impairments: Arc<ProxyImpairments>,
) -> std::io::Result<()> {
let mut buffer = [0_u8; 16 * 1024];
let mut chunk_index = 0_u64;
let mut writer = Terminator::new(
writer
.delay_writes_dyn(impairments.netem_write_delay.clone())
.throttle_writes_dyn(impairments.netem_write_rate.clone())
.slice_writes_dyn(impairments.netem_write_slice_size.clone()),
impairments.netem_termination_probability.clone(),
);
loop {
let read = reader.read(&mut buffer).await?;
if read == 0 {
writer.shutdown().await?;
return Ok(());
}
chunk_index = chunk_index.saturating_add(1);
impairments.wait_before_forwarding(chunk_index).await;
writer.write_all(&buffer[..read]).await?;
writer.flush().await?;
}
}
fn extract_host_port(base_url: &str) -> String {
base_url
.trim_start_matches("http://")
.trim_start_matches("https://")
.split('/')
.next()
.unwrap_or("127.0.0.1:2900")
.to_string()
}