1use async_stream::stream;
35use futures::StreamExt;
36use std::{
37 collections::HashMap,
38 sync::{
39 atomic::{AtomicU64, Ordering},
40 Arc,
41 },
42 time::{Duration, Instant},
43};
44use tokio::sync::Mutex;
45
46#[derive(Debug, Clone)]
48pub struct ChunkedBenchConfig {
49 pub target_url: String,
51 pub method: reqwest::Method,
53 pub concurrency: u32,
55 pub duration: Duration,
57 pub chunk_size_bytes: usize,
59 pub total_size_bytes: usize,
61 pub chunk_interval_ms: u64,
63 pub headers: HashMap<String, String>,
66 pub skip_tls_verify: bool,
68}
69
70#[derive(Debug, Clone)]
75pub struct ErrorSample {
76 pub status: u16,
77 pub server_header: Option<String>,
80 pub body_excerpt: String,
82}
83
84#[derive(Debug, Clone)]
86pub struct ChunkedBenchResult {
87 pub total_requests: u64,
88 pub successful: u64,
89 pub failed: u64,
90 pub bytes_sent: u64,
91 pub elapsed: Duration,
92 pub req_per_sec: f64,
93 pub latencies_ms: Vec<u64>,
94 pub avg_latency_ms: f64,
95 pub p50_ms: u64,
96 pub p95_ms: u64,
97 pub p99_ms: u64,
98 pub status_counts: HashMap<u16, u64>,
99 pub error_samples: Vec<ErrorSample>,
102}
103
104const MAX_ERROR_SAMPLES: usize = 5;
106const ERROR_BODY_EXCERPT_BYTES: usize = 256;
108
109pub async fn run(cfg: ChunkedBenchConfig) -> anyhow::Result<ChunkedBenchResult> {
112 if cfg.chunk_size_bytes == 0 {
113 anyhow::bail!("chunk_size_bytes must be > 0");
114 }
115 if cfg.total_size_bytes == 0 {
116 anyhow::bail!("total_size_bytes must be > 0");
117 }
118 if cfg.concurrency == 0 {
119 anyhow::bail!("concurrency must be >= 1");
120 }
121
122 let client = reqwest::Client::builder()
123 .danger_accept_invalid_certs(cfg.skip_tls_verify)
124 .build()?;
125
126 let total_requests = Arc::new(AtomicU64::new(0));
127 let successful = Arc::new(AtomicU64::new(0));
128 let failed = Arc::new(AtomicU64::new(0));
129 let bytes_sent = Arc::new(AtomicU64::new(0));
130 let latencies: Arc<Mutex<Vec<u64>>> = Arc::new(Mutex::new(Vec::with_capacity(8192)));
131 let status_counts: Arc<Mutex<HashMap<u16, u64>>> = Arc::new(Mutex::new(HashMap::new()));
132 let error_samples: Arc<Mutex<Vec<ErrorSample>>> = Arc::new(Mutex::new(Vec::new()));
133
134 let deadline = Instant::now() + cfg.duration;
135 let started_at = Instant::now();
136
137 let mut workers = Vec::with_capacity(cfg.concurrency as usize);
138 for _ in 0..cfg.concurrency {
139 let cfg = cfg.clone();
140 let client = client.clone();
141 let total_requests = total_requests.clone();
142 let successful = successful.clone();
143 let failed = failed.clone();
144 let bytes_sent = bytes_sent.clone();
145 let latencies = latencies.clone();
146 let status_counts = status_counts.clone();
147 let error_samples = error_samples.clone();
148
149 workers.push(tokio::spawn(async move {
150 while Instant::now() < deadline {
151 let req_started = Instant::now();
152 match send_one_chunked_request(&client, &cfg).await {
153 Ok(SendResult { status, sample }) => {
154 successful.fetch_add(1, Ordering::Relaxed);
155 bytes_sent.fetch_add(cfg.total_size_bytes as u64, Ordering::Relaxed);
156 let elapsed_ms = req_started.elapsed().as_millis() as u64;
157 latencies.lock().await.push(elapsed_ms);
158 *status_counts.lock().await.entry(status).or_insert(0) += 1;
159 if let Some(s) = sample {
160 let mut g = error_samples.lock().await;
161 if g.len() < MAX_ERROR_SAMPLES {
162 g.push(s);
163 }
164 }
165 }
166 Err(_e) => {
167 failed.fetch_add(1, Ordering::Relaxed);
168 }
169 }
170 total_requests.fetch_add(1, Ordering::Relaxed);
171 }
172 }));
173 }
174
175 for w in workers {
176 let _ = w.await;
177 }
178
179 let elapsed = started_at.elapsed();
180 let total = total_requests.load(Ordering::Relaxed);
181 let mut samples: Vec<u64> = {
182 let mut g = latencies.lock().await;
183 std::mem::take(&mut *g)
184 };
185 let final_status_counts: HashMap<u16, u64> = {
186 let mut g = status_counts.lock().await;
187 std::mem::take(&mut *g)
188 };
189 let final_error_samples: Vec<ErrorSample> = {
190 let mut g = error_samples.lock().await;
191 std::mem::take(&mut *g)
192 };
193 samples.sort_unstable();
194 let avg = if samples.is_empty() {
195 0.0
196 } else {
197 samples.iter().copied().sum::<u64>() as f64 / samples.len() as f64
198 };
199 let p = |q: f64| -> u64 {
200 if samples.is_empty() {
201 return 0;
202 }
203 let idx = ((samples.len() as f64 - 1.0) * q).round() as usize;
204 samples[idx]
205 };
206
207 Ok(ChunkedBenchResult {
208 total_requests: total,
209 successful: successful.load(Ordering::Relaxed),
210 failed: failed.load(Ordering::Relaxed),
211 bytes_sent: bytes_sent.load(Ordering::Relaxed),
212 elapsed,
213 req_per_sec: if elapsed.as_secs_f64() > 0.0 {
214 total as f64 / elapsed.as_secs_f64()
215 } else {
216 0.0
217 },
218 avg_latency_ms: avg,
219 p50_ms: p(0.50),
220 p95_ms: p(0.95),
221 p99_ms: p(0.99),
222 latencies_ms: samples,
223 status_counts: final_status_counts,
224 error_samples: final_error_samples,
225 })
226}
227
228struct SendResult {
232 status: u16,
233 sample: Option<ErrorSample>,
234}
235
236async fn send_one_chunked_request(
237 client: &reqwest::Client,
238 cfg: &ChunkedBenchConfig,
239) -> anyhow::Result<SendResult> {
240 let chunk_size = cfg.chunk_size_bytes;
241 let total = cfg.total_size_bytes;
242 let interval_ms = cfg.chunk_interval_ms;
243
244 let body_stream = stream! {
248 let mut sent: usize = 0;
249 let payload = vec![b'X'; chunk_size];
250 while sent < total {
251 let next = std::cmp::min(chunk_size, total - sent);
252 let chunk = payload[..next].to_vec();
253 sent += next;
254 if interval_ms > 0 && sent < total {
255 tokio::time::sleep(Duration::from_millis(interval_ms)).await;
256 }
257 yield Ok::<_, std::io::Error>(chunk);
258 }
259 };
260
261 let body = reqwest::Body::wrap_stream(body_stream.boxed());
262
263 let mut req = client.request(cfg.method.clone(), &cfg.target_url).body(body);
264 for (k, v) in &cfg.headers {
265 req = req.header(k, v);
266 }
267 let resp = req.send().await?;
268 let status = resp.status().as_u16();
269
270 let sample = if !(200..300).contains(&status) {
275 let server_header = resp
276 .headers()
277 .get(reqwest::header::SERVER)
278 .and_then(|v| v.to_str().ok())
279 .map(str::to_owned);
280 let bytes = resp.bytes().await.unwrap_or_default();
281 let take = std::cmp::min(bytes.len(), ERROR_BODY_EXCERPT_BYTES);
282 let body_excerpt = String::from_utf8_lossy(&bytes[..take]).trim().to_owned();
283 Some(ErrorSample {
284 status,
285 server_header,
286 body_excerpt,
287 })
288 } else {
289 None
290 };
291
292 Ok(SendResult { status, sample })
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn error_sample_struct_holds_diagnostic_fields() {
301 let s = ErrorSample {
304 status: 503,
305 server_header: Some("nginx/1.21.0".into()),
306 body_excerpt: "upstream timed out".into(),
307 };
308 assert_eq!(s.status, 503);
309 assert_eq!(s.server_header.as_deref(), Some("nginx/1.21.0"));
310 assert_eq!(s.body_excerpt, "upstream timed out");
311 }
312
313 #[tokio::test]
314 async fn rejects_zero_concurrency() {
315 let cfg = ChunkedBenchConfig {
316 target_url: "http://127.0.0.1:1".into(),
317 method: reqwest::Method::POST,
318 concurrency: 0,
319 duration: Duration::from_millis(10),
320 chunk_size_bytes: 1024,
321 total_size_bytes: 4096,
322 chunk_interval_ms: 0,
323 headers: HashMap::new(),
324 skip_tls_verify: false,
325 };
326 assert!(run(cfg).await.is_err());
327 }
328
329 #[tokio::test]
330 async fn rejects_zero_chunk_size() {
331 let cfg = ChunkedBenchConfig {
332 target_url: "http://127.0.0.1:1".into(),
333 method: reqwest::Method::POST,
334 concurrency: 1,
335 duration: Duration::from_millis(10),
336 chunk_size_bytes: 0,
337 total_size_bytes: 4096,
338 chunk_interval_ms: 0,
339 headers: HashMap::new(),
340 skip_tls_verify: false,
341 };
342 assert!(run(cfg).await.is_err());
343 }
344}