use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use bytes::Bytes;
use futures_util::stream::{FuturesUnordered, StreamExt};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
pub struct BenchConfig {
pub uri: String,
pub method: String,
pub headers: Vec<(String, String)>,
pub body: Bytes,
pub num_requests: usize,
pub concurrency: usize,
pub duration: Option<Duration>,
pub timeout: Duration,
#[allow(dead_code)]
pub qps: f64,
pub disable_keepalive: bool,
}
pub struct Progress {
pub completed: AtomicU64,
pub errors: AtomicU64,
pub total_ns: AtomicU64,
}
impl Progress {
pub fn new() -> Self {
Self {
completed: AtomicU64::new(0),
errors: AtomicU64::new(0),
total_ns: AtomicU64::new(0),
}
}
}
pub struct WorkerResult {
pub latencies: Vec<u64>,
pub status_codes: Vec<u16>,
pub errors: u64,
pub bytes_recv: u64,
pub write: PhaseAcc,
pub wait: PhaseAcc,
pub read: PhaseAcc,
}
pub struct PhaseAcc {
pub sum: u64,
pub min: u64,
pub max: u64,
pub count: u64,
}
impl PhaseAcc {
pub fn new() -> Self {
Self { sum: 0, min: u64::MAX, max: 0, count: 0 }
}
#[inline]
pub fn record(&mut self, ns: u64) {
self.sum += ns;
if ns < self.min { self.min = ns; }
if ns > self.max { self.max = ns; }
self.count += 1;
}
}
pub struct RequestTimings {
pub write_ns: u64,
pub wait_ns: u64,
pub read_ns: u64,
}
fn parse_url(url: &str) -> (String, u16, String) {
let s = url.strip_prefix("http://").unwrap_or(url);
let (host_port, path) = match s.find('/') {
Some(i) => (&s[..i], s[i..].to_string()),
None => (s, "/".to_string()),
};
let (host, port) = match host_port.rfind(':') {
Some(i) => (&host_port[..i], host_port[i + 1..].parse().unwrap_or(80)),
None => (host_port, 80u16),
};
(host.to_string(), port, path)
}
fn build_raw_request(
method: &str, host: &str, port: u16, path: &str,
headers: &[(String, String)], body: &[u8], keepalive: bool,
) -> Bytes {
let mut buf = Vec::with_capacity(512 + body.len());
buf.extend_from_slice(method.as_bytes());
buf.extend_from_slice(b" ");
buf.extend_from_slice(path.as_bytes());
buf.extend_from_slice(b" HTTP/1.1\r\n");
buf.extend_from_slice(b"Host: ");
buf.extend_from_slice(host.as_bytes());
if port != 80 {
buf.push(b':');
buf.extend_from_slice(port.to_string().as_bytes());
}
buf.extend_from_slice(b"\r\n");
if !body.is_empty() {
buf.extend_from_slice(b"Content-Length: ");
buf.extend_from_slice(body.len().to_string().as_bytes());
buf.extend_from_slice(b"\r\n");
}
if keepalive {
buf.extend_from_slice(b"Connection: keep-alive\r\n");
} else {
buf.extend_from_slice(b"Connection: close\r\n");
}
for (k, v) in headers {
buf.extend_from_slice(k.as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(v.as_bytes());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
buf.extend_from_slice(body);
Bytes::from(buf)
}
pub async fn run(config: BenchConfig) -> (Vec<WorkerResult>, Duration) {
let progress = Arc::new(Progress::new());
let stop = Arc::new(AtomicBool::new(false));
let is_duration_mode = config.duration.is_some();
let total_display = if is_duration_mode { 0 } else { config.num_requests };
let prog = progress.clone();
let stop_r = stop.clone();
let render_handle = tokio::spawn(async move {
crate::report::render_progress(prog, total_display, is_duration_mode, stop_r).await;
});
if let Some(dur) = config.duration {
let s = stop.clone();
tokio::spawn(async move {
tokio::time::sleep(dur).await;
s.store(true, Ordering::Release);
});
}
let (host, port, path) = parse_url(&config.uri);
let addr: SocketAddr = tokio::net::lookup_host(format!("{}:{}", host, port))
.await
.expect("DNS lookup failed")
.next()
.expect("no addresses found");
let request_bytes = build_raw_request(
&config.method, &host, port, &path,
&config.headers, &config.body, !config.disable_keepalive,
);
const TARGET_CONNS: usize = 128;
let cpus = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
let max_workers = cpus * 2;
let num_workers = (config.concurrency / TARGET_CONNS)
.max(1)
.min(max_workers)
.min(config.concurrency);
let c = config.concurrency;
let connect_limit = c.min(256);
let sem = Arc::new(tokio::sync::Semaphore::new(connect_limit));
let mut connect_futs = FuturesUnordered::new();
for _ in 0..c {
let sem = sem.clone();
let timeout = config.timeout;
connect_futs.push(async move {
let _permit = sem.acquire().await.ok()?;
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => {
let _ = stream.set_nodelay(true);
Some(BufReader::with_capacity(32768, stream))
}
_ => None,
}
});
}
let mut all_conns: Vec<BufReader<TcpStream>> = Vec::with_capacity(c);
let mut connect_failures = 0u64;
while let Some(result) = connect_futs.next().await {
if let Some(conn) = result {
all_conns.push(conn);
} else {
connect_failures += 1;
}
}
drop(connect_futs);
if all_conns.is_empty() {
stop.store(true, Ordering::Release);
tokio::time::sleep(Duration::from_millis(150)).await;
let _ = render_handle.abort();
return (vec![WorkerResult {
latencies: vec![], status_codes: vec![],
errors: connect_failures,
bytes_recv: 0,
write: PhaseAcc::new(), wait: PhaseAcc::new(), read: PhaseAcc::new(),
}], Duration::ZERO);
}
let mut worker_conns: Vec<Vec<BufReader<TcpStream>>> =
(0..num_workers).map(|_| Vec::new()).collect();
for (i, conn) in all_conns.into_iter().enumerate() {
worker_conns[i % num_workers].push(conn);
}
let total_reqs = config.num_requests;
let reqs_base = if is_duration_mode { usize::MAX } else { total_reqs / num_workers };
let reqs_extra = if is_duration_mode { 0 } else { total_reqs % num_workers };
let start = Instant::now();
let mut handles = Vec::with_capacity(num_workers);
for i in 0..num_workers {
let nr = if is_duration_mode {
usize::MAX
} else {
reqs_base + if i < reqs_extra { 1 } else { 0 }
};
let conns = std::mem::take(&mut worker_conns[i]);
let rb = request_bytes.clone();
let progress = progress.clone();
let stop = stop.clone();
let timeout = config.timeout;
let keepalive = !config.disable_keepalive;
handles.push(tokio::spawn(async move {
core_worker(addr, rb, conns, nr, timeout, progress, stop, keepalive).await
}));
}
let mut results = Vec::with_capacity(num_workers);
for h in handles {
results.push(h.await.unwrap());
}
let elapsed = start.elapsed();
stop.store(true, Ordering::Release);
tokio::time::sleep(Duration::from_millis(150)).await;
let _ = render_handle.abort();
(results, elapsed)
}
async fn core_worker(
addr: SocketAddr,
request_bytes: Bytes,
pre_conns: Vec<BufReader<TcpStream>>,
num_requests: usize,
timeout: Duration,
progress: Arc<Progress>,
stop: Arc<AtomicBool>,
keepalive: bool,
) -> WorkerResult {
let cap = num_requests.min(100_000);
let mut latencies = Vec::with_capacity(cap);
let mut status_codes = Vec::with_capacity(cap);
let mut errors = 0u64;
let mut bytes_recv = 0u64;
let mut requests_sent = 0usize;
let mut write_acc = PhaseAcc::new();
let mut wait_acc = PhaseAcc::new();
let mut read_acc = PhaseAcc::new();
let mut in_flight: FuturesUnordered<_> = FuturesUnordered::new();
for conn in pre_conns {
if requests_sent >= num_requests || stop.load(Ordering::Relaxed) {
break;
}
requests_sent += 1;
in_flight.push(do_one_request(addr, Some(conn), request_bytes.clone(), timeout, keepalive));
}
while let Some((conn_opt, result)) = in_flight.next().await {
match result {
Ok((status, size, latency_ns, timings)) => {
latencies.push(latency_ns);
status_codes.push(status);
bytes_recv += size;
write_acc.record(timings.write_ns);
wait_acc.record(timings.wait_ns);
read_acc.record(timings.read_ns);
progress.total_ns.fetch_add(latency_ns, Ordering::Relaxed);
}
Err(_) => {
errors += 1;
}
}
progress.completed.fetch_add(1, Ordering::Relaxed);
if stop.load(Ordering::Relaxed) {
break;
}
if requests_sent < num_requests {
requests_sent += 1;
in_flight.push(do_one_request(
addr, conn_opt, request_bytes.clone(), timeout, keepalive,
));
}
}
WorkerResult {
latencies, status_codes, errors, bytes_recv,
write: write_acc, wait: wait_acc, read: read_acc,
}
}
async fn do_one_request(
addr: SocketAddr,
existing_conn: Option<BufReader<TcpStream>>,
request_bytes: Bytes,
timeout: Duration,
keepalive: bool,
) -> (Option<BufReader<TcpStream>>, Result<(u16, u64, u64, RequestTimings), ()>) {
let t0 = Instant::now();
let result = tokio::time::timeout(timeout, async {
let mut rdr = match existing_conn {
Some(c) => c,
None => {
let stream = TcpStream::connect(addr).await.map_err(|_| ())?;
let _ = stream.set_nodelay(true);
BufReader::with_capacity(32768, stream)
}
};
let tw = Instant::now();
rdr.get_mut().write_all(&request_bytes).await.map_err(|_| ())?;
let write_ns = tw.elapsed().as_nanos() as u64;
let (status, size, wait_ns, read_ns) = read_response_timed(&mut rdr).await?;
let latency = t0.elapsed().as_nanos() as u64;
let timings = RequestTimings { write_ns, wait_ns, read_ns };
let conn_out = if keepalive { Some(rdr) } else { None };
Ok::<_, ()>((conn_out, status, size, latency, timings))
})
.await;
match result {
Ok(Ok((conn, status, size, latency, timings))) => {
(conn, Ok((status, size, latency, timings)))
}
_ => (None, Err(())),
}
}
async fn read_response_timed(
reader: &mut BufReader<TcpStream>,
) -> Result<(u16, u64, u64, u64), ()> {
let tw = Instant::now();
let header_end;
loop {
let buf = reader.fill_buf().await.map_err(|_| ())?;
if buf.is_empty() {
return Err(());
}
match find_header_end(buf) {
Some(pos) => {
header_end = pos;
break;
}
None => {
let len = buf.len();
reader.consume(len);
}
}
}
let wait_ns = tw.elapsed().as_nanos() as u64;
let tr = Instant::now();
let buf = reader.fill_buf().await.map_err(|_| ())?;
let headers = &buf[..header_end];
if headers.len() < 12 {
return Err(());
}
let status: u16 = std::str::from_utf8(&headers[9..12])
.map_err(|_| ())?
.parse()
.map_err(|_| ())?;
let mut content_length: Option<usize> = None;
let mut is_chunked = false;
let mut pos = 0;
while pos < header_end {
let line_end = match headers[pos..].iter().position(|&b| b == b'\n') {
Some(i) => pos + i + 1,
None => header_end,
};
let line = &headers[pos..line_end];
if line.len() > 16 && starts_with_ci(line, b"content-length:") {
let val = std::str::from_utf8(&line[15..]).unwrap_or("").trim();
content_length = val.parse().ok();
} else if line.len() > 19 && starts_with_ci(line, b"transfer-encoding:") {
let val = std::str::from_utf8(&line[18..]).unwrap_or("").trim();
if val.eq_ignore_ascii_case("chunked") {
is_chunked = true;
}
}
pos = line_end;
}
let buf_len = buf.len();
let response_header_size = header_end + 4;
let body_already = buf_len - response_header_size;
let size = if let Some(cl) = content_length {
if body_already >= cl {
reader.consume(response_header_size + cl);
} else {
reader.consume(buf_len);
drain_exact(reader, cl - body_already).await?;
}
cl as u64
} else if is_chunked {
reader.consume(response_header_size);
drain_chunked(reader).await?
} else {
reader.consume(response_header_size);
0
};
let read_ns = tr.elapsed().as_nanos() as u64;
Ok((status, size, wait_ns, read_ns))
}
#[inline]
fn find_header_end(buf: &[u8]) -> Option<usize> {
if buf.len() < 4 { return None; }
for i in 0..buf.len() - 3 {
if buf[i] == b'\r' && buf[i + 1] == b'\n' && buf[i + 2] == b'\r' && buf[i + 3] == b'\n' {
return Some(i);
}
}
None
}
#[inline]
fn starts_with_ci(haystack: &[u8], needle: &[u8]) -> bool {
haystack.len() >= needle.len()
&& haystack[..needle.len()]
.iter()
.zip(needle)
.all(|(a, b)| a.to_ascii_lowercase() == *b)
}
async fn drain_exact(
reader: &mut BufReader<TcpStream>,
mut remaining: usize,
) -> Result<(), ()> {
while remaining > 0 {
let buf = reader.fill_buf().await.map_err(|_| ())?;
if buf.is_empty() { return Err(()); }
let take = remaining.min(buf.len());
reader.consume(take);
remaining -= take;
}
Ok(())
}
async fn drain_chunked(reader: &mut BufReader<TcpStream>) -> Result<u64, ()> {
let mut total = 0u64;
loop {
let chunk_size = read_chunk_size(reader).await?;
if chunk_size == 0 {
skip_line(reader).await?;
break;
}
total += chunk_size as u64;
drain_exact(reader, chunk_size + 2).await?; }
Ok(total)
}
async fn read_chunk_size(reader: &mut BufReader<TcpStream>) -> Result<usize, ()> {
let buf = reader.fill_buf().await.map_err(|_| ())?;
if buf.is_empty() { return Err(()); }
if let Some(nl) = buf.iter().position(|&b| b == b'\n') {
let hex = std::str::from_utf8(&buf[..nl])
.unwrap_or("0")
.trim_matches(|c: char| c == '\r' || c == '\n' || c == ' ')
.split(';')
.next()
.unwrap_or("0");
let size = usize::from_str_radix(hex, 16).unwrap_or(0);
reader.consume(nl + 1);
Ok(size)
} else {
let len = buf.len();
reader.consume(len);
Ok(0)
}
}
async fn skip_line(reader: &mut BufReader<TcpStream>) -> Result<(), ()> {
let buf = reader.fill_buf().await.map_err(|_| ())?;
if let Some(pos) = buf.iter().position(|&b| b == b'\n') {
reader.consume(pos + 1);
} else {
let len = buf.len();
reader.consume(len);
}
Ok(())
}