use async_stream::stream;
use futures::StreamExt;
use std::{
collections::HashMap,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
pub struct ChunkedBenchConfig {
pub target_url: String,
pub method: reqwest::Method,
pub concurrency: u32,
pub duration: Duration,
pub chunk_size_bytes: usize,
pub total_size_bytes: usize,
pub chunk_interval_ms: u64,
pub headers: HashMap<String, String>,
pub skip_tls_verify: bool,
}
#[derive(Debug, Clone)]
pub struct ChunkedBenchResult {
pub total_requests: u64,
pub successful: u64,
pub failed: u64,
pub bytes_sent: u64,
pub elapsed: Duration,
pub req_per_sec: f64,
pub latencies_ms: Vec<u64>,
pub avg_latency_ms: f64,
pub p50_ms: u64,
pub p95_ms: u64,
pub p99_ms: u64,
pub status_counts: HashMap<u16, u64>,
}
pub async fn run(cfg: ChunkedBenchConfig) -> anyhow::Result<ChunkedBenchResult> {
if cfg.chunk_size_bytes == 0 {
anyhow::bail!("chunk_size_bytes must be > 0");
}
if cfg.total_size_bytes == 0 {
anyhow::bail!("total_size_bytes must be > 0");
}
if cfg.concurrency == 0 {
anyhow::bail!("concurrency must be >= 1");
}
let client = reqwest::Client::builder()
.danger_accept_invalid_certs(cfg.skip_tls_verify)
.build()?;
let total_requests = Arc::new(AtomicU64::new(0));
let successful = Arc::new(AtomicU64::new(0));
let failed = Arc::new(AtomicU64::new(0));
let bytes_sent = Arc::new(AtomicU64::new(0));
let latencies: Arc<Mutex<Vec<u64>>> = Arc::new(Mutex::new(Vec::with_capacity(8192)));
let status_counts: Arc<Mutex<HashMap<u16, u64>>> = Arc::new(Mutex::new(HashMap::new()));
let deadline = Instant::now() + cfg.duration;
let started_at = Instant::now();
let mut workers = Vec::with_capacity(cfg.concurrency as usize);
for _ in 0..cfg.concurrency {
let cfg = cfg.clone();
let client = client.clone();
let total_requests = total_requests.clone();
let successful = successful.clone();
let failed = failed.clone();
let bytes_sent = bytes_sent.clone();
let latencies = latencies.clone();
let status_counts = status_counts.clone();
workers.push(tokio::spawn(async move {
while Instant::now() < deadline {
let req_started = Instant::now();
match send_one_chunked_request(&client, &cfg).await {
Ok(status) => {
successful.fetch_add(1, Ordering::Relaxed);
bytes_sent.fetch_add(cfg.total_size_bytes as u64, Ordering::Relaxed);
let elapsed_ms = req_started.elapsed().as_millis() as u64;
latencies.lock().await.push(elapsed_ms);
*status_counts.lock().await.entry(status).or_insert(0) += 1;
}
Err(_e) => {
failed.fetch_add(1, Ordering::Relaxed);
}
}
total_requests.fetch_add(1, Ordering::Relaxed);
}
}));
}
for w in workers {
let _ = w.await;
}
let elapsed = started_at.elapsed();
let total = total_requests.load(Ordering::Relaxed);
let mut samples: Vec<u64> = {
let mut g = latencies.lock().await;
std::mem::take(&mut *g)
};
let final_status_counts: HashMap<u16, u64> = {
let mut g = status_counts.lock().await;
std::mem::take(&mut *g)
};
samples.sort_unstable();
let avg = if samples.is_empty() {
0.0
} else {
samples.iter().copied().sum::<u64>() as f64 / samples.len() as f64
};
let p = |q: f64| -> u64 {
if samples.is_empty() {
return 0;
}
let idx = ((samples.len() as f64 - 1.0) * q).round() as usize;
samples[idx]
};
Ok(ChunkedBenchResult {
total_requests: total,
successful: successful.load(Ordering::Relaxed),
failed: failed.load(Ordering::Relaxed),
bytes_sent: bytes_sent.load(Ordering::Relaxed),
elapsed,
req_per_sec: if elapsed.as_secs_f64() > 0.0 {
total as f64 / elapsed.as_secs_f64()
} else {
0.0
},
avg_latency_ms: avg,
p50_ms: p(0.50),
p95_ms: p(0.95),
p99_ms: p(0.99),
latencies_ms: samples,
status_counts: final_status_counts,
})
}
async fn send_one_chunked_request(
client: &reqwest::Client,
cfg: &ChunkedBenchConfig,
) -> anyhow::Result<u16> {
let chunk_size = cfg.chunk_size_bytes;
let total = cfg.total_size_bytes;
let interval_ms = cfg.chunk_interval_ms;
let body_stream = stream! {
let mut sent: usize = 0;
let payload = vec![b'X'; chunk_size];
while sent < total {
let next = std::cmp::min(chunk_size, total - sent);
let chunk = payload[..next].to_vec();
sent += next;
if interval_ms > 0 && sent < total {
tokio::time::sleep(Duration::from_millis(interval_ms)).await;
}
yield Ok::<_, std::io::Error>(chunk);
}
};
let body = reqwest::Body::wrap_stream(body_stream.boxed());
let mut req = client.request(cfg.method.clone(), &cfg.target_url).body(body);
for (k, v) in &cfg.headers {
req = req.header(k, v);
}
let resp = req.send().await?;
Ok(resp.status().as_u16())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn rejects_zero_concurrency() {
let cfg = ChunkedBenchConfig {
target_url: "http://127.0.0.1:1".into(),
method: reqwest::Method::POST,
concurrency: 0,
duration: Duration::from_millis(10),
chunk_size_bytes: 1024,
total_size_bytes: 4096,
chunk_interval_ms: 0,
headers: HashMap::new(),
skip_tls_verify: false,
};
assert!(run(cfg).await.is_err());
}
#[tokio::test]
async fn rejects_zero_chunk_size() {
let cfg = ChunkedBenchConfig {
target_url: "http://127.0.0.1:1".into(),
method: reqwest::Method::POST,
concurrency: 1,
duration: Duration::from_millis(10),
chunk_size_bytes: 0,
total_size_bytes: 4096,
chunk_interval_ms: 0,
headers: HashMap::new(),
skip_tls_verify: false,
};
assert!(run(cfg).await.is_err());
}
}