anti_web/
lib.rs

1use eyre::{Context, Result};
2use hdrhistogram::{sync::SyncHistogram, Histogram};
3use hyper::client::HttpConnector;
4use hyper::http::{self, Request as HttpRequest};
5use hyper::{Body, Client as HyperClient};
6use reqwest::blocking::{Client, Response};
7use reqwest::header::HeaderMap;
8use reqwest::{Method, Version};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::io::{Read, Write};
12use std::net::TcpStream;
13use std::sync::Mutex;
14use std::time::Duration;
15
16const PROGRESS_TICK: Duration = Duration::from_millis(20);
17
18#[derive(Debug, Serialize, Deserialize, PartialEq)]
19pub struct FetchResponse {
20    pub status: u16,
21    pub headers: HashMap<String, String>,
22    pub body: Option<String>,
23    pub version: String,
24}
25
26pub struct StreamingResponse {
27    pub status: u16,
28    pub headers: HashMap<String, String>,
29    pub response: Response,
30    pub version: String,
31}
32
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum HttpVersion {
35    Http09,
36    Http10,
37    Http11,
38    Http2,
39}
40
41fn headers_to_map(headers: &HeaderMap) -> HashMap<String, String> {
42    headers
43        .iter()
44        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
45        .collect()
46}
47
48fn version_to_string(v: Version) -> String {
49    match v {
50        Version::HTTP_09 => "HTTP/0.9".into(),
51        Version::HTTP_10 => "HTTP/1.0".into(),
52        Version::HTTP_11 => "HTTP/1.1".into(),
53        Version::HTTP_2 => "HTTP/2".into(),
54        Version::HTTP_3 => "HTTP/3".into(),
55        _ => "HTTP/?".into(),
56    }
57}
58
59fn fetch_http09(url: &str) -> Result<FetchResponse> {
60    let parsed = reqwest::Url::parse(url).context("invalid url")?;
61    if parsed.scheme() != "http" {
62        eyre::bail!("HTTP/0.9 only supports http scheme");
63    }
64    let host = parsed
65        .host_str()
66        .ok_or_else(|| eyre::eyre!("missing host"))?;
67    let port = parsed.port_or_known_default().unwrap_or(80);
68    let mut stream = TcpStream::connect((host, port)).context("connect failed")?;
69    let path = if parsed.path().is_empty() {
70        "/"
71    } else {
72        parsed.path()
73    };
74    let request = format!("GET {}\r\n", path);
75    stream
76        .write_all(request.as_bytes())
77        .context("write failed")?;
78    let mut buf = Vec::new();
79    stream.read_to_end(&mut buf).context("read failed")?;
80    let body = String::from_utf8_lossy(&buf).to_string();
81    Ok(FetchResponse {
82        status: 200,
83        headers: HashMap::new(),
84        body: Some(body),
85        version: "HTTP/0.9".into(),
86    })
87}
88
89pub fn fetch(
90    url: &str,
91    method: &str,
92    headers: HashMap<String, String>,
93    body: Option<String>,
94    version: Option<HttpVersion>,
95    keep_alive: bool,
96) -> Result<FetchResponse> {
97    if matches!(version, Some(HttpVersion::Http09)) {
98        return fetch_http09(url);
99    }
100    // validate url
101    let _ = reqwest::Url::parse(url).context("invalid url")?;
102    let method = Method::from_bytes(method.as_bytes()).context("invalid HTTP method")?;
103    let mut builder = Client::builder();
104    if let Some(v) = version {
105        builder = match v {
106            HttpVersion::Http09 | HttpVersion::Http10 | HttpVersion::Http11 => builder.http1_only(),
107            HttpVersion::Http2 => builder.http2_prior_knowledge(),
108        };
109    }
110    let client = builder.build().context("failed to build client")?;
111    let mut req = client.request(method.clone(), url);
112    if let Some(v) = version {
113        req = req.version(match v {
114            HttpVersion::Http09 => Version::HTTP_09,
115            HttpVersion::Http10 => Version::HTTP_10,
116            HttpVersion::Http11 => Version::HTTP_11,
117            HttpVersion::Http2 => Version::HTTP_2,
118        });
119    }
120    for (k, v) in &headers {
121        req = req.header(k, v);
122    }
123    match version {
124        Some(HttpVersion::Http10) => {
125            if keep_alive {
126                req = req.header("Connection", "keep-alive");
127            }
128        }
129        Some(HttpVersion::Http11) => {
130            if !keep_alive {
131                req = req.header("Connection", "close");
132            }
133        }
134        _ => {}
135    }
136    if let Some(b) = body.clone() {
137        req = req.body(b);
138    }
139    let resp = req.send().context("request failed")?;
140    let status = resp.status().as_u16();
141    let headers = headers_to_map(resp.headers());
142    let version = version_to_string(resp.version());
143    let body = if method == Method::HEAD {
144        None
145    } else {
146        Some(resp.text()?)
147    };
148    Ok(FetchResponse {
149        status,
150        headers,
151        body,
152        version,
153    })
154}
155
156pub fn fetch_stream(
157    url: &str,
158    method: &str,
159    headers: HashMap<String, String>,
160    body: Option<String>,
161    version: Option<HttpVersion>,
162    keep_alive: bool,
163) -> Result<StreamingResponse> {
164    if matches!(version, Some(HttpVersion::Http09)) {
165        eyre::bail!("streaming not supported for HTTP/0.9");
166    }
167    let _ = reqwest::Url::parse(url).context("invalid url")?;
168    let method = Method::from_bytes(method.as_bytes()).context("invalid HTTP method")?;
169    let mut builder = Client::builder();
170    if let Some(v) = version {
171        builder = match v {
172            HttpVersion::Http09 | HttpVersion::Http10 | HttpVersion::Http11 => builder.http1_only(),
173            HttpVersion::Http2 => builder.http2_prior_knowledge(),
174        };
175    }
176    let client = builder.build().context("failed to build client")?;
177    let mut req = client.request(method.clone(), url);
178    if let Some(v) = version {
179        req = req.version(match v {
180            HttpVersion::Http09 => Version::HTTP_09,
181            HttpVersion::Http10 => Version::HTTP_10,
182            HttpVersion::Http11 => Version::HTTP_11,
183            HttpVersion::Http2 => Version::HTTP_2,
184        });
185    }
186    for (k, v) in &headers {
187        req = req.header(k, v);
188    }
189    match version {
190        Some(HttpVersion::Http10) => {
191            if keep_alive {
192                req = req.header("Connection", "keep-alive");
193            }
194        }
195        Some(HttpVersion::Http11) => {
196            if !keep_alive {
197                req = req.header("Connection", "close");
198            }
199        }
200        _ => {}
201    }
202    if let Some(b) = body.clone() {
203        req = req.body(b);
204    }
205    let resp = req.send().context("request failed")?;
206    let status = resp.status().as_u16();
207    let headers = headers_to_map(resp.headers());
208    let version = version_to_string(resp.version());
209    Ok(StreamingResponse {
210        status,
211        headers,
212        response: resp,
213        version,
214    })
215}
216
217pub fn to_json(resp: &FetchResponse) -> Result<String> {
218    Ok(serde_json::to_string_pretty(resp)?)
219}
220
221#[derive(Clone)]
222pub struct LoadOptions {
223    pub url: String,
224    pub method: String,
225    pub headers: HashMap<String, String>,
226    pub body: Option<String>,
227    pub version: Option<HttpVersion>,
228    pub keep_alive: bool,
229    pub requests: u32,
230    pub connections: usize,
231    pub http2_parallel: usize,
232    pub duration: Option<std::time::Duration>,
233    pub wait_after_deadline: bool,
234    pub qps: Option<u64>,
235    pub show_progress: bool,
236}
237
238#[derive(Debug, Serialize)]
239pub struct LoadResult {
240    pub total: usize,
241    pub successes: usize,
242    pub errors: usize,
243    pub status_counts: [usize; 5],
244    pub duration_secs: f64,
245    pub bytes: u64,
246    pub fastest: f64,
247    pub slowest: f64,
248    pub average: f64,
249    pub p95: f64,
250    pub p99: f64,
251}
252
253pub async fn load_test(opts: LoadOptions) -> Result<LoadResult> {
254    use std::sync::{
255        atomic::{AtomicU64, AtomicUsize, Ordering},
256        Arc,
257    };
258    use tokio::time::{interval, sleep, Duration, Instant};
259
260    let mut connector = HttpConnector::new();
261    connector.set_nodelay(true);
262    if opts.keep_alive {
263        connector.set_keepalive(Some(Duration::from_secs(60)));
264    }
265
266    let mut builder = HyperClient::builder();
267    builder.pool_max_idle_per_host(opts.connections);
268    builder.pool_idle_timeout(Duration::from_secs(60));
269    builder.http1_writev(true);
270    if matches!(opts.version, Some(HttpVersion::Http2)) {
271        builder.http2_only(true);
272        builder.http2_adaptive_window(true);
273    }
274    let client: HyperClient<_, Body> = builder.build(connector);
275    let uri: http::Uri = opts.url.parse()?;
276
277    let total = Arc::new(AtomicUsize::new(0));
278    let successes = Arc::new(AtomicUsize::new(0));
279    let errors = Arc::new(AtomicUsize::new(0));
280    let bytes_total = Arc::new(AtomicU64::new(0));
281    let histogram = Arc::new(Mutex::new(SyncHistogram::<u64>::from(
282        Histogram::<u64>::new(3).unwrap(),
283    )));
284    let status_counts = Arc::new([
285        AtomicUsize::new(0), // 1xx
286        AtomicUsize::new(0), // 2xx
287        AtomicUsize::new(0), // 3xx
288        AtomicUsize::new(0), // 4xx
289        AtomicUsize::new(0), // 5xx
290    ]);
291    let start = Instant::now();
292
293    let method: http::Method = opts.method.parse()?;
294    let headers_vec: Arc<Vec<(String, String)>> = Arc::new(
295        opts.headers
296            .iter()
297            .map(|(k, v)| (k.clone(), v.clone()))
298            .collect(),
299    );
300    let body_arc: Arc<Vec<u8>> = Arc::new(opts.body.clone().unwrap_or_default().into_bytes());
301
302    let metrics_total = total.clone();
303    let metrics_successes = successes.clone();
304    let metrics_duration = opts.duration;
305    let metrics_requests = opts.requests;
306    let metrics_start = start.clone();
307    let metrics_errors = errors.clone();
308    let metrics_status = status_counts.clone();
309
310    let pb = if !opts.show_progress {
311        None
312    } else if opts.duration.is_some() {
313        let pb = indicatif::ProgressBar::new_spinner();
314        pb.enable_steady_tick(PROGRESS_TICK);
315        Some(pb)
316    } else {
317        let pb = indicatif::ProgressBar::new(opts.requests as u64);
318        pb.set_style(
319            indicatif::ProgressStyle::with_template(
320                "{spinner:.green} {elapsed_precise} [{bar:40.cyan/blue}] {pos}/{len} rps:{msg}",
321            )
322            .unwrap(),
323        );
324        pb.enable_steady_tick(PROGRESS_TICK);
325        Some(pb)
326    };
327
328    let metrics_handle = if let Some(pb) = pb.clone() {
329        Some(tokio::spawn(async move {
330            let mut ticker = interval(PROGRESS_TICK);
331            loop {
332                ticker.tick().await;
333                let elapsed = metrics_start.elapsed().as_secs_f64();
334                let total = metrics_total.load(Ordering::SeqCst);
335                let success = metrics_successes.load(Ordering::SeqCst);
336                let err = metrics_errors.load(Ordering::SeqCst);
337                let s1 = metrics_status[0].load(Ordering::SeqCst);
338                let s2 = metrics_status[1].load(Ordering::SeqCst);
339                let s3 = metrics_status[2].load(Ordering::SeqCst);
340                let s4 = metrics_status[3].load(Ordering::SeqCst);
341                let s5 = metrics_status[4].load(Ordering::SeqCst);
342                pb.set_position(success as u64);
343                pb.set_message(format!(
344                    "{:.1} 1xx:{} 2xx:{} 3xx:{} 4xx:{} 5xx:{} err:{}",
345                    success as f64 / elapsed.max(0.0001),
346                    s1,
347                    s2,
348                    s3,
349                    s4,
350                    s5,
351                    err
352                ));
353                if let Some(d) = metrics_duration {
354                    if elapsed >= d.as_secs_f64() {
355                        break;
356                    }
357                } else if total as u32 >= metrics_requests {
358                    break;
359                }
360            }
361            pb.finish_and_clear();
362        }))
363    } else {
364        None
365    };
366
367    let workers = if matches!(opts.version, Some(HttpVersion::Http2)) {
368        opts.connections * opts.http2_parallel
369    } else {
370        opts.connections
371    };
372
373    let mut handles = Vec::new();
374    for _ in 0..workers {
375        // Cloning the client shares the underlying connection pool,
376        // which is cheaper than building a new one per worker.
377        let client = client.clone();
378        let opts = opts.clone();
379        let total = total.clone();
380        let successes = successes.clone();
381        let errors = errors.clone();
382        let status_counts = status_counts.clone();
383        let headers_vec = headers_vec.clone();
384        let body_arc = body_arc.clone();
385        let bytes_total = bytes_total.clone();
386        let uri = uri.clone();
387        let method = method.clone();
388        let mut recorder = {
389            let hist = histogram.lock().unwrap();
390            hist.recorder()
391        };
392        handles.push(tokio::spawn(async move {
393            loop {
394                let current = total.fetch_add(1, Ordering::SeqCst);
395                if let Some(dur) = opts.duration {
396                    if start.elapsed() >= dur {
397                        if !opts.wait_after_deadline {
398                            break;
399                        }
400                    }
401                } else if current as u32 >= opts.requests {
402                    break;
403                }
404
405                if let Some(qps) = opts.qps {
406                    if qps > 0 {
407                        sleep(std::time::Duration::from_secs_f64(1.0 / qps as f64)).await;
408                    }
409                }
410
411                let mut req_builder = HttpRequest::builder()
412                    .method(method.clone())
413                    .uri(uri.clone());
414                if let Some(v) = opts.version {
415                    req_builder = req_builder.version(match v {
416                        HttpVersion::Http09 => http::Version::HTTP_09,
417                        HttpVersion::Http10 => http::Version::HTTP_10,
418                        HttpVersion::Http11 => http::Version::HTTP_11,
419                        HttpVersion::Http2 => http::Version::HTTP_2,
420                    });
421                }
422                for (k, v) in headers_vec.iter() {
423                    req_builder = req_builder.header(k, v);
424                }
425                match opts.version {
426                    Some(HttpVersion::Http10) => {
427                        if opts.keep_alive {
428                            req_builder = req_builder.header("Connection", "keep-alive");
429                        }
430                    }
431                    Some(HttpVersion::Http11) => {
432                        if !opts.keep_alive {
433                            req_builder = req_builder.header("Connection", "close");
434                        }
435                    }
436                    _ => {}
437                }
438                let body = if body_arc.is_empty() {
439                    Body::empty()
440                } else {
441                    Body::from((*body_arc).clone())
442                };
443                let req = req_builder.body(body).expect("build request");
444                let start_req = Instant::now();
445                match client.request(req).await {
446                    Ok(mut resp) => {
447                        let status = resp.status().as_u16();
448                        successes.fetch_add(1, Ordering::SeqCst);
449                        use hyper::body::HttpBody;
450                        let mut body_bytes = 0u64;
451                        while let Some(chunk) = resp.body_mut().data().await {
452                            match chunk {
453                                Ok(c) => body_bytes += c.len() as u64,
454                                Err(_) => break,
455                            }
456                        }
457                        bytes_total.fetch_add(body_bytes, Ordering::SeqCst);
458                        let latency = start_req.elapsed().as_micros() as u64;
459                        let _ = recorder.record(latency);
460                        if status < 200 {
461                            status_counts[0].fetch_add(1, Ordering::SeqCst);
462                        } else if status < 300 {
463                            status_counts[1].fetch_add(1, Ordering::SeqCst);
464                        } else if status < 400 {
465                            status_counts[2].fetch_add(1, Ordering::SeqCst);
466                        } else if status < 500 {
467                            status_counts[3].fetch_add(1, Ordering::SeqCst);
468                        } else {
469                            status_counts[4].fetch_add(1, Ordering::SeqCst);
470                        }
471                    }
472                    Err(_) => {
473                        errors.fetch_add(1, Ordering::SeqCst);
474                    }
475                }
476                if let Some(dur) = opts.duration {
477                    if start.elapsed() >= dur {
478                        if !opts.wait_after_deadline {
479                            break;
480                        }
481                    }
482                } else if total.load(Ordering::SeqCst) as u32 >= opts.requests {
483                    break;
484                }
485            }
486        }));
487    }
488    for h in handles {
489        let _ = h.await;
490    }
491    if let Some(handle) = metrics_handle {
492        let _ = handle.await;
493    }
494    let total_val = total.load(Ordering::SeqCst);
495    let success_val = successes.load(Ordering::SeqCst);
496    let error_val = errors.load(Ordering::SeqCst);
497    let status_vals = [
498        status_counts[0].load(Ordering::SeqCst),
499        status_counts[1].load(Ordering::SeqCst),
500        status_counts[2].load(Ordering::SeqCst),
501        status_counts[3].load(Ordering::SeqCst),
502        status_counts[4].load(Ordering::SeqCst),
503    ];
504    let mut hist = histogram.lock().unwrap();
505    hist.refresh();
506    let bytes = bytes_total.load(Ordering::SeqCst);
507    Ok(LoadResult {
508        total: total_val,
509        successes: success_val,
510        errors: error_val,
511        status_counts: status_vals,
512        duration_secs: start.elapsed().as_secs_f64(),
513        bytes,
514        fastest: hist.min() as f64 / 1_000_000.0,
515        slowest: hist.max() as f64 / 1_000_000.0,
516        average: hist.mean() / 1_000_000.0,
517        p95: hist.value_at_quantile(0.95) as f64 / 1_000_000.0,
518        p99: hist.value_at_quantile(0.99) as f64 / 1_000_000.0,
519    })
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use std::collections::HashMap;
526    use std::convert::Infallible;
527    use hyper::{Body, Response, Server};
528    use hyper::service::{make_service_fn, service_fn};
529    use tokio::task;
530
531    #[test]
532    fn test_invalid_url() {
533        assert!(fetch("ht!tp://bad", "GET", HashMap::new(), None, None, true).is_err());
534    }
535
536    #[test]
537    fn test_json_format() {
538        let mut headers = HashMap::new();
539        headers.insert("Content-Type".to_string(), "text/plain".to_string());
540        let resp = FetchResponse {
541            status: 200,
542            headers,
543            body: Some("hello".into()),
544            version: "HTTP/1.1".into(),
545        };
546        let json = to_json(&resp).unwrap();
547        let parsed: FetchResponse = serde_json::from_str(&json).unwrap();
548        assert_eq!(parsed, resp);
549    }
550
551    #[tokio::test]
552    async fn test_fetch_stream_http11() {
553        let make_svc = make_service_fn(|_conn| async {
554            Ok::<_, Infallible>(service_fn(|_req| async {
555                Ok::<_, Infallible>(Response::new(Body::from("hello")))
556            }))
557        });
558
559        let builder = Server::bind(&([127, 0, 0, 1], 0).into());
560        let addr = builder.local_addr();
561        let server = builder.serve(make_svc);
562        let handle = tokio::spawn(server);
563
564        let url = format!("http://{}", addr);
565        let (status, version, body) = task::spawn_blocking(move || {
566            let resp = fetch_stream(&url, "GET", HashMap::new(), None, Some(HttpVersion::Http11), true).unwrap();
567            let text = resp.response.text().unwrap();
568            (resp.status, resp.version, text)
569        })
570        .await
571        .unwrap();
572
573        assert_eq!(status, 200);
574        assert_eq!(version, "HTTP/1.1");
575        assert_eq!(body, "hello");
576
577        handle.abort();
578    }
579
580    #[tokio::test]
581    async fn test_fetch_stream_http2() {
582        let make_svc = make_service_fn(|_conn| async {
583            Ok::<_, Infallible>(service_fn(|_req| async {
584                Ok::<_, Infallible>(Response::new(Body::from("world")))
585            }))
586        });
587
588        let builder = Server::bind(&([127, 0, 0, 1], 0).into()).http2_only(true);
589        let addr = builder.local_addr();
590        let server = builder.serve(make_svc);
591        let handle = tokio::spawn(server);
592
593        let url = format!("http://{}", addr);
594        let (status, version, body) = task::spawn_blocking(move || {
595            let resp = fetch_stream(&url, "GET", HashMap::new(), None, Some(HttpVersion::Http2), true).unwrap();
596            let text = resp.response.text().unwrap();
597            (resp.status, resp.version, text)
598        })
599        .await
600        .unwrap();
601
602        assert_eq!(status, 200);
603        assert_eq!(version, "HTTP/2");
604        assert_eq!(body, "world");
605
606        handle.abort();
607    }
608
609    #[tokio::test]
610    async fn test_load_test_basic() {
611        let make_svc = make_service_fn(|_conn| async {
612            Ok::<_, Infallible>(service_fn(|_req| async {
613                Ok::<_, Infallible>(Response::new(Body::from("ok")))
614            }))
615        });
616
617        let builder = Server::bind(&([127, 0, 0, 1], 0).into());
618        let addr = builder.local_addr();
619        let server = builder.serve(make_svc);
620        let handle = tokio::spawn(server);
621
622        let url = format!("http://{}", addr);
623        let opts = LoadOptions {
624            url,
625            method: "GET".into(),
626            headers: HashMap::new(),
627            body: None,
628            version: Some(HttpVersion::Http11),
629            keep_alive: false,
630            requests: 5,
631            connections: 1,
632            http2_parallel: 1,
633            duration: None,
634            wait_after_deadline: false,
635            qps: None,
636            show_progress: false,
637        };
638
639        let result = load_test(opts).await.unwrap();
640
641        assert_eq!(result.total, 5);
642        assert_eq!(result.successes, 5);
643        assert_eq!(result.errors, 0);
644        assert_eq!(result.status_counts[1], 5);
645
646        handle.abort();
647    }
648}