1use eyre::{Context, Result};
2use hdrhistogram::{sync::SyncHistogram, Histogram};
3use hyper::client::HttpConnector;
4use hyper::http::{self, Request as HttpRequest};
5use hyper::{Body, Client as HyperClient};
6use reqwest::blocking::{Client, Response};
7use reqwest::header::HeaderMap;
8use reqwest::{Method, Version};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::io::{Read, Write};
12use std::net::TcpStream;
13use std::sync::Mutex;
14use std::time::Duration;
15
16const PROGRESS_TICK: Duration = Duration::from_millis(20);
17
18#[derive(Debug, Serialize, Deserialize, PartialEq)]
19pub struct FetchResponse {
20 pub status: u16,
21 pub headers: HashMap<String, String>,
22 pub body: Option<String>,
23 pub version: String,
24}
25
26pub struct StreamingResponse {
27 pub status: u16,
28 pub headers: HashMap<String, String>,
29 pub response: Response,
30 pub version: String,
31}
32
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum HttpVersion {
35 Http09,
36 Http10,
37 Http11,
38 Http2,
39}
40
41fn headers_to_map(headers: &HeaderMap) -> HashMap<String, String> {
42 headers
43 .iter()
44 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
45 .collect()
46}
47
48fn version_to_string(v: Version) -> String {
49 match v {
50 Version::HTTP_09 => "HTTP/0.9".into(),
51 Version::HTTP_10 => "HTTP/1.0".into(),
52 Version::HTTP_11 => "HTTP/1.1".into(),
53 Version::HTTP_2 => "HTTP/2".into(),
54 Version::HTTP_3 => "HTTP/3".into(),
55 _ => "HTTP/?".into(),
56 }
57}
58
59fn fetch_http09(url: &str) -> Result<FetchResponse> {
60 let parsed = reqwest::Url::parse(url).context("invalid url")?;
61 if parsed.scheme() != "http" {
62 eyre::bail!("HTTP/0.9 only supports http scheme");
63 }
64 let host = parsed
65 .host_str()
66 .ok_or_else(|| eyre::eyre!("missing host"))?;
67 let port = parsed.port_or_known_default().unwrap_or(80);
68 let mut stream = TcpStream::connect((host, port)).context("connect failed")?;
69 let path = if parsed.path().is_empty() {
70 "/"
71 } else {
72 parsed.path()
73 };
74 let request = format!("GET {}\r\n", path);
75 stream
76 .write_all(request.as_bytes())
77 .context("write failed")?;
78 let mut buf = Vec::new();
79 stream.read_to_end(&mut buf).context("read failed")?;
80 let body = String::from_utf8_lossy(&buf).to_string();
81 Ok(FetchResponse {
82 status: 200,
83 headers: HashMap::new(),
84 body: Some(body),
85 version: "HTTP/0.9".into(),
86 })
87}
88
89pub fn fetch(
90 url: &str,
91 method: &str,
92 headers: HashMap<String, String>,
93 body: Option<String>,
94 version: Option<HttpVersion>,
95 keep_alive: bool,
96) -> Result<FetchResponse> {
97 if matches!(version, Some(HttpVersion::Http09)) {
98 return fetch_http09(url);
99 }
100 let _ = reqwest::Url::parse(url).context("invalid url")?;
102 let method = Method::from_bytes(method.as_bytes()).context("invalid HTTP method")?;
103 let mut builder = Client::builder();
104 if let Some(v) = version {
105 builder = match v {
106 HttpVersion::Http09 | HttpVersion::Http10 | HttpVersion::Http11 => builder.http1_only(),
107 HttpVersion::Http2 => builder.http2_prior_knowledge(),
108 };
109 }
110 let client = builder.build().context("failed to build client")?;
111 let mut req = client.request(method.clone(), url);
112 if let Some(v) = version {
113 req = req.version(match v {
114 HttpVersion::Http09 => Version::HTTP_09,
115 HttpVersion::Http10 => Version::HTTP_10,
116 HttpVersion::Http11 => Version::HTTP_11,
117 HttpVersion::Http2 => Version::HTTP_2,
118 });
119 }
120 for (k, v) in &headers {
121 req = req.header(k, v);
122 }
123 match version {
124 Some(HttpVersion::Http10) => {
125 if keep_alive {
126 req = req.header("Connection", "keep-alive");
127 }
128 }
129 Some(HttpVersion::Http11) => {
130 if !keep_alive {
131 req = req.header("Connection", "close");
132 }
133 }
134 _ => {}
135 }
136 if let Some(b) = body.clone() {
137 req = req.body(b);
138 }
139 let resp = req.send().context("request failed")?;
140 let status = resp.status().as_u16();
141 let headers = headers_to_map(resp.headers());
142 let version = version_to_string(resp.version());
143 let body = if method == Method::HEAD {
144 None
145 } else {
146 Some(resp.text()?)
147 };
148 Ok(FetchResponse {
149 status,
150 headers,
151 body,
152 version,
153 })
154}
155
156pub fn fetch_stream(
157 url: &str,
158 method: &str,
159 headers: HashMap<String, String>,
160 body: Option<String>,
161 version: Option<HttpVersion>,
162 keep_alive: bool,
163) -> Result<StreamingResponse> {
164 if matches!(version, Some(HttpVersion::Http09)) {
165 eyre::bail!("streaming not supported for HTTP/0.9");
166 }
167 let _ = reqwest::Url::parse(url).context("invalid url")?;
168 let method = Method::from_bytes(method.as_bytes()).context("invalid HTTP method")?;
169 let mut builder = Client::builder();
170 if let Some(v) = version {
171 builder = match v {
172 HttpVersion::Http09 | HttpVersion::Http10 | HttpVersion::Http11 => builder.http1_only(),
173 HttpVersion::Http2 => builder.http2_prior_knowledge(),
174 };
175 }
176 let client = builder.build().context("failed to build client")?;
177 let mut req = client.request(method.clone(), url);
178 if let Some(v) = version {
179 req = req.version(match v {
180 HttpVersion::Http09 => Version::HTTP_09,
181 HttpVersion::Http10 => Version::HTTP_10,
182 HttpVersion::Http11 => Version::HTTP_11,
183 HttpVersion::Http2 => Version::HTTP_2,
184 });
185 }
186 for (k, v) in &headers {
187 req = req.header(k, v);
188 }
189 match version {
190 Some(HttpVersion::Http10) => {
191 if keep_alive {
192 req = req.header("Connection", "keep-alive");
193 }
194 }
195 Some(HttpVersion::Http11) => {
196 if !keep_alive {
197 req = req.header("Connection", "close");
198 }
199 }
200 _ => {}
201 }
202 if let Some(b) = body.clone() {
203 req = req.body(b);
204 }
205 let resp = req.send().context("request failed")?;
206 let status = resp.status().as_u16();
207 let headers = headers_to_map(resp.headers());
208 let version = version_to_string(resp.version());
209 Ok(StreamingResponse {
210 status,
211 headers,
212 response: resp,
213 version,
214 })
215}
216
217pub fn to_json(resp: &FetchResponse) -> Result<String> {
218 Ok(serde_json::to_string_pretty(resp)?)
219}
220
221#[derive(Clone)]
222pub struct LoadOptions {
223 pub url: String,
224 pub method: String,
225 pub headers: HashMap<String, String>,
226 pub body: Option<String>,
227 pub version: Option<HttpVersion>,
228 pub keep_alive: bool,
229 pub requests: u32,
230 pub connections: usize,
231 pub http2_parallel: usize,
232 pub duration: Option<std::time::Duration>,
233 pub wait_after_deadline: bool,
234 pub qps: Option<u64>,
235 pub show_progress: bool,
236}
237
238#[derive(Debug, Serialize)]
239pub struct LoadResult {
240 pub total: usize,
241 pub successes: usize,
242 pub errors: usize,
243 pub status_counts: [usize; 5],
244 pub duration_secs: f64,
245 pub bytes: u64,
246 pub fastest: f64,
247 pub slowest: f64,
248 pub average: f64,
249 pub p95: f64,
250 pub p99: f64,
251}
252
253pub async fn load_test(opts: LoadOptions) -> Result<LoadResult> {
254 use std::sync::{
255 atomic::{AtomicU64, AtomicUsize, Ordering},
256 Arc,
257 };
258 use tokio::time::{interval, sleep, Duration, Instant};
259
260 let mut connector = HttpConnector::new();
261 connector.set_nodelay(true);
262 if opts.keep_alive {
263 connector.set_keepalive(Some(Duration::from_secs(60)));
264 }
265
266 let mut builder = HyperClient::builder();
267 builder.pool_max_idle_per_host(opts.connections);
268 builder.pool_idle_timeout(Duration::from_secs(60));
269 builder.http1_writev(true);
270 if matches!(opts.version, Some(HttpVersion::Http2)) {
271 builder.http2_only(true);
272 builder.http2_adaptive_window(true);
273 }
274 let client: HyperClient<_, Body> = builder.build(connector);
275 let uri: http::Uri = opts.url.parse()?;
276
277 let total = Arc::new(AtomicUsize::new(0));
278 let successes = Arc::new(AtomicUsize::new(0));
279 let errors = Arc::new(AtomicUsize::new(0));
280 let bytes_total = Arc::new(AtomicU64::new(0));
281 let histogram = Arc::new(Mutex::new(SyncHistogram::<u64>::from(
282 Histogram::<u64>::new(3).unwrap(),
283 )));
284 let status_counts = Arc::new([
285 AtomicUsize::new(0), AtomicUsize::new(0), AtomicUsize::new(0), AtomicUsize::new(0), AtomicUsize::new(0), ]);
291 let start = Instant::now();
292
293 let method: http::Method = opts.method.parse()?;
294 let headers_vec: Arc<Vec<(String, String)>> = Arc::new(
295 opts.headers
296 .iter()
297 .map(|(k, v)| (k.clone(), v.clone()))
298 .collect(),
299 );
300 let body_arc: Arc<Vec<u8>> = Arc::new(opts.body.clone().unwrap_or_default().into_bytes());
301
302 let metrics_total = total.clone();
303 let metrics_successes = successes.clone();
304 let metrics_duration = opts.duration;
305 let metrics_requests = opts.requests;
306 let metrics_start = start.clone();
307 let metrics_errors = errors.clone();
308 let metrics_status = status_counts.clone();
309
310 let pb = if !opts.show_progress {
311 None
312 } else if opts.duration.is_some() {
313 let pb = indicatif::ProgressBar::new_spinner();
314 pb.enable_steady_tick(PROGRESS_TICK);
315 Some(pb)
316 } else {
317 let pb = indicatif::ProgressBar::new(opts.requests as u64);
318 pb.set_style(
319 indicatif::ProgressStyle::with_template(
320 "{spinner:.green} {elapsed_precise} [{bar:40.cyan/blue}] {pos}/{len} rps:{msg}",
321 )
322 .unwrap(),
323 );
324 pb.enable_steady_tick(PROGRESS_TICK);
325 Some(pb)
326 };
327
328 let metrics_handle = if let Some(pb) = pb.clone() {
329 Some(tokio::spawn(async move {
330 let mut ticker = interval(PROGRESS_TICK);
331 loop {
332 ticker.tick().await;
333 let elapsed = metrics_start.elapsed().as_secs_f64();
334 let total = metrics_total.load(Ordering::SeqCst);
335 let success = metrics_successes.load(Ordering::SeqCst);
336 let err = metrics_errors.load(Ordering::SeqCst);
337 let s1 = metrics_status[0].load(Ordering::SeqCst);
338 let s2 = metrics_status[1].load(Ordering::SeqCst);
339 let s3 = metrics_status[2].load(Ordering::SeqCst);
340 let s4 = metrics_status[3].load(Ordering::SeqCst);
341 let s5 = metrics_status[4].load(Ordering::SeqCst);
342 pb.set_position(success as u64);
343 pb.set_message(format!(
344 "{:.1} 1xx:{} 2xx:{} 3xx:{} 4xx:{} 5xx:{} err:{}",
345 success as f64 / elapsed.max(0.0001),
346 s1,
347 s2,
348 s3,
349 s4,
350 s5,
351 err
352 ));
353 if let Some(d) = metrics_duration {
354 if elapsed >= d.as_secs_f64() {
355 break;
356 }
357 } else if total as u32 >= metrics_requests {
358 break;
359 }
360 }
361 pb.finish_and_clear();
362 }))
363 } else {
364 None
365 };
366
367 let workers = if matches!(opts.version, Some(HttpVersion::Http2)) {
368 opts.connections * opts.http2_parallel
369 } else {
370 opts.connections
371 };
372
373 let mut handles = Vec::new();
374 for _ in 0..workers {
375 let client = client.clone();
378 let opts = opts.clone();
379 let total = total.clone();
380 let successes = successes.clone();
381 let errors = errors.clone();
382 let status_counts = status_counts.clone();
383 let headers_vec = headers_vec.clone();
384 let body_arc = body_arc.clone();
385 let bytes_total = bytes_total.clone();
386 let uri = uri.clone();
387 let method = method.clone();
388 let mut recorder = {
389 let hist = histogram.lock().unwrap();
390 hist.recorder()
391 };
392 handles.push(tokio::spawn(async move {
393 loop {
394 let current = total.fetch_add(1, Ordering::SeqCst);
395 if let Some(dur) = opts.duration {
396 if start.elapsed() >= dur {
397 if !opts.wait_after_deadline {
398 break;
399 }
400 }
401 } else if current as u32 >= opts.requests {
402 break;
403 }
404
405 if let Some(qps) = opts.qps {
406 if qps > 0 {
407 sleep(std::time::Duration::from_secs_f64(1.0 / qps as f64)).await;
408 }
409 }
410
411 let mut req_builder = HttpRequest::builder()
412 .method(method.clone())
413 .uri(uri.clone());
414 if let Some(v) = opts.version {
415 req_builder = req_builder.version(match v {
416 HttpVersion::Http09 => http::Version::HTTP_09,
417 HttpVersion::Http10 => http::Version::HTTP_10,
418 HttpVersion::Http11 => http::Version::HTTP_11,
419 HttpVersion::Http2 => http::Version::HTTP_2,
420 });
421 }
422 for (k, v) in headers_vec.iter() {
423 req_builder = req_builder.header(k, v);
424 }
425 match opts.version {
426 Some(HttpVersion::Http10) => {
427 if opts.keep_alive {
428 req_builder = req_builder.header("Connection", "keep-alive");
429 }
430 }
431 Some(HttpVersion::Http11) => {
432 if !opts.keep_alive {
433 req_builder = req_builder.header("Connection", "close");
434 }
435 }
436 _ => {}
437 }
438 let body = if body_arc.is_empty() {
439 Body::empty()
440 } else {
441 Body::from((*body_arc).clone())
442 };
443 let req = req_builder.body(body).expect("build request");
444 let start_req = Instant::now();
445 match client.request(req).await {
446 Ok(mut resp) => {
447 let status = resp.status().as_u16();
448 successes.fetch_add(1, Ordering::SeqCst);
449 use hyper::body::HttpBody;
450 let mut body_bytes = 0u64;
451 while let Some(chunk) = resp.body_mut().data().await {
452 match chunk {
453 Ok(c) => body_bytes += c.len() as u64,
454 Err(_) => break,
455 }
456 }
457 bytes_total.fetch_add(body_bytes, Ordering::SeqCst);
458 let latency = start_req.elapsed().as_micros() as u64;
459 let _ = recorder.record(latency);
460 if status < 200 {
461 status_counts[0].fetch_add(1, Ordering::SeqCst);
462 } else if status < 300 {
463 status_counts[1].fetch_add(1, Ordering::SeqCst);
464 } else if status < 400 {
465 status_counts[2].fetch_add(1, Ordering::SeqCst);
466 } else if status < 500 {
467 status_counts[3].fetch_add(1, Ordering::SeqCst);
468 } else {
469 status_counts[4].fetch_add(1, Ordering::SeqCst);
470 }
471 }
472 Err(_) => {
473 errors.fetch_add(1, Ordering::SeqCst);
474 }
475 }
476 if let Some(dur) = opts.duration {
477 if start.elapsed() >= dur {
478 if !opts.wait_after_deadline {
479 break;
480 }
481 }
482 } else if total.load(Ordering::SeqCst) as u32 >= opts.requests {
483 break;
484 }
485 }
486 }));
487 }
488 for h in handles {
489 let _ = h.await;
490 }
491 if let Some(handle) = metrics_handle {
492 let _ = handle.await;
493 }
494 let total_val = total.load(Ordering::SeqCst);
495 let success_val = successes.load(Ordering::SeqCst);
496 let error_val = errors.load(Ordering::SeqCst);
497 let status_vals = [
498 status_counts[0].load(Ordering::SeqCst),
499 status_counts[1].load(Ordering::SeqCst),
500 status_counts[2].load(Ordering::SeqCst),
501 status_counts[3].load(Ordering::SeqCst),
502 status_counts[4].load(Ordering::SeqCst),
503 ];
504 let mut hist = histogram.lock().unwrap();
505 hist.refresh();
506 let bytes = bytes_total.load(Ordering::SeqCst);
507 Ok(LoadResult {
508 total: total_val,
509 successes: success_val,
510 errors: error_val,
511 status_counts: status_vals,
512 duration_secs: start.elapsed().as_secs_f64(),
513 bytes,
514 fastest: hist.min() as f64 / 1_000_000.0,
515 slowest: hist.max() as f64 / 1_000_000.0,
516 average: hist.mean() / 1_000_000.0,
517 p95: hist.value_at_quantile(0.95) as f64 / 1_000_000.0,
518 p99: hist.value_at_quantile(0.99) as f64 / 1_000_000.0,
519 })
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use std::collections::HashMap;
526 use std::convert::Infallible;
527 use hyper::{Body, Response, Server};
528 use hyper::service::{make_service_fn, service_fn};
529 use tokio::task;
530
531 #[test]
532 fn test_invalid_url() {
533 assert!(fetch("ht!tp://bad", "GET", HashMap::new(), None, None, true).is_err());
534 }
535
536 #[test]
537 fn test_json_format() {
538 let mut headers = HashMap::new();
539 headers.insert("Content-Type".to_string(), "text/plain".to_string());
540 let resp = FetchResponse {
541 status: 200,
542 headers,
543 body: Some("hello".into()),
544 version: "HTTP/1.1".into(),
545 };
546 let json = to_json(&resp).unwrap();
547 let parsed: FetchResponse = serde_json::from_str(&json).unwrap();
548 assert_eq!(parsed, resp);
549 }
550
551 #[tokio::test]
552 async fn test_fetch_stream_http11() {
553 let make_svc = make_service_fn(|_conn| async {
554 Ok::<_, Infallible>(service_fn(|_req| async {
555 Ok::<_, Infallible>(Response::new(Body::from("hello")))
556 }))
557 });
558
559 let builder = Server::bind(&([127, 0, 0, 1], 0).into());
560 let addr = builder.local_addr();
561 let server = builder.serve(make_svc);
562 let handle = tokio::spawn(server);
563
564 let url = format!("http://{}", addr);
565 let (status, version, body) = task::spawn_blocking(move || {
566 let resp = fetch_stream(&url, "GET", HashMap::new(), None, Some(HttpVersion::Http11), true).unwrap();
567 let text = resp.response.text().unwrap();
568 (resp.status, resp.version, text)
569 })
570 .await
571 .unwrap();
572
573 assert_eq!(status, 200);
574 assert_eq!(version, "HTTP/1.1");
575 assert_eq!(body, "hello");
576
577 handle.abort();
578 }
579
580 #[tokio::test]
581 async fn test_fetch_stream_http2() {
582 let make_svc = make_service_fn(|_conn| async {
583 Ok::<_, Infallible>(service_fn(|_req| async {
584 Ok::<_, Infallible>(Response::new(Body::from("world")))
585 }))
586 });
587
588 let builder = Server::bind(&([127, 0, 0, 1], 0).into()).http2_only(true);
589 let addr = builder.local_addr();
590 let server = builder.serve(make_svc);
591 let handle = tokio::spawn(server);
592
593 let url = format!("http://{}", addr);
594 let (status, version, body) = task::spawn_blocking(move || {
595 let resp = fetch_stream(&url, "GET", HashMap::new(), None, Some(HttpVersion::Http2), true).unwrap();
596 let text = resp.response.text().unwrap();
597 (resp.status, resp.version, text)
598 })
599 .await
600 .unwrap();
601
602 assert_eq!(status, 200);
603 assert_eq!(version, "HTTP/2");
604 assert_eq!(body, "world");
605
606 handle.abort();
607 }
608
609 #[tokio::test]
610 async fn test_load_test_basic() {
611 let make_svc = make_service_fn(|_conn| async {
612 Ok::<_, Infallible>(service_fn(|_req| async {
613 Ok::<_, Infallible>(Response::new(Body::from("ok")))
614 }))
615 });
616
617 let builder = Server::bind(&([127, 0, 0, 1], 0).into());
618 let addr = builder.local_addr();
619 let server = builder.serve(make_svc);
620 let handle = tokio::spawn(server);
621
622 let url = format!("http://{}", addr);
623 let opts = LoadOptions {
624 url,
625 method: "GET".into(),
626 headers: HashMap::new(),
627 body: None,
628 version: Some(HttpVersion::Http11),
629 keep_alive: false,
630 requests: 5,
631 connections: 1,
632 http2_parallel: 1,
633 duration: None,
634 wait_after_deadline: false,
635 qps: None,
636 show_progress: false,
637 };
638
639 let result = load_test(opts).await.unwrap();
640
641 assert_eq!(result.total, 5);
642 assert_eq!(result.successes, 5);
643 assert_eq!(result.errors, 0);
644 assert_eq!(result.status_counts[1], 5);
645
646 handle.abort();
647 }
648}