1use hdrhistogram::Histogram;
2use reqwest::ClientBuilder;
3use std::collections::HashMap;
4use std::io::Read;
5use std::path::PathBuf;
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::sync::Arc;
8use std::time::SystemTime;
9use tokio::sync::Mutex;
10
11pub mod error;
12pub use error::{InquisitorError, Result};
13
14pub mod config;
15pub use config::{Config, Method};
16
17pub mod time;
18use time::Microseconds;
19
20pub const MAX_CONNS: usize = 12;
22
23fn setup_ctrl_c_handler(should_exit: Arc<AtomicBool>) -> Result<()> {
25 ctrlc::set_handler(move || {
26 let previously_set = should_exit.fetch_or(true, Ordering::SeqCst);
27
28 if previously_set {
29 std::process::exit(130);
30 }
31 })?;
32
33 Ok(())
34}
35
36fn get_ca_cert(path: &PathBuf) -> Result<reqwest::Certificate> {
38 let mut buf = Vec::new();
39 std::fs::File::open(path)
40 .map_err(|e| InquisitorError::CouldNotOpenFile(e, path.to_owned()))?
41 .read_to_end(&mut buf)
42 .map_err(|e| InquisitorError::CouldNotReadFile(e, path.to_owned()))?;
43
44 reqwest::Certificate::from_pem(&buf)
45 .map_err(|e| InquisitorError::CouldNotConvertCert(e, path.to_owned()))
46}
47
48pub fn run<C: Into<Config>>(config: C) -> Result<()> {
50 let config: Config = config.into();
51
52 let should_exit = Arc::new(AtomicBool::new(false));
53 setup_ctrl_c_handler(should_exit.clone())?;
54
55 let (iterations, duration) = config.iterations_and_duration();
56
57 let mut headers = HashMap::new();
58 for header in config.header {
59 if let Some((k, v)) = header.split_once(':') {
60 headers.insert(k.to_string(), v.to_string());
61 }
62 }
63
64 let times = Arc::new(Mutex::new(Histogram::<u64>::new_with_max(
66 1_000_000_000_000,
67 3, )?));
69
70 let passes = Arc::new(AtomicUsize::new(0));
71 let errors = Arc::new(AtomicUsize::new(0));
72
73 let failed_regex = match config.failed_body {
74 Some(regex) => Some(regex::Regex::new(®ex)?),
75 None => None,
76 };
77
78 let request_body = Box::leak(Box::new(config.request_body)) as &Option<_>;
79
80 let mut handles = Vec::new();
81
82 let rt = tokio::runtime::Builder::new_multi_thread()
83 .enable_io()
84 .enable_time()
85 .build()?;
86
87 let cert = if let Some(cert_path) = config.ca_cert.as_ref() {
88 Some(get_ca_cert(cert_path)?)
89 } else {
90 None
91 };
92
93 let test_start_time = SystemTime::now();
94
95 for _ in 0..config.connections {
96 let mut client = ClientBuilder::new().danger_accept_invalid_certs(config.insecure);
97
98 if let Some(cert) = cert.clone() {
99 client = client.add_root_certificate(cert);
100 }
101
102 let client = client
103 .build()
104 .map_err(InquisitorError::FailedToCreateClient)?;
105
106 let passes = passes.clone();
107 let errors = errors.clone();
108 let url = config.url.clone();
109 let headers = headers.clone();
110 let failed_regex = failed_regex.clone();
111 let times = times.clone();
112 let should_exit = should_exit.clone();
113
114 let task = rt.spawn(async move {
115 let mut total = passes.load(Ordering::Relaxed) + errors.load(Ordering::Relaxed);
116 let mut total_elapsed = test_start_time.elapsed()?.as_micros() as u64;
117
118 while total < iterations && total_elapsed < duration {
119 if should_exit.load(Ordering::Relaxed) {
120 break;
121 }
122
123 let mut builder = match config.method {
124 Method::Get => client.get(&url),
125 Method::Post => client.post(&url),
126 };
127
128 if let Some(body) = request_body.as_deref() {
129 builder = builder.body(body);
130 }
131
132 for (k, v) in &headers {
133 builder = builder.header(k, v);
134 }
135
136 let req_start_time = SystemTime::now();
137 let response = builder.send().await;
138 let elapsed = req_start_time.elapsed()?.as_micros() as u64;
139 times.lock().await.record(elapsed)?;
140
141 match (response, failed_regex.as_ref()) {
142 (Ok(res), _) if res.status().is_success() && failed_regex.is_none() => {
143 passes.fetch_add(1, Ordering::SeqCst);
144 if config.print_response {
145 println!(
146 "Response successful. Content: {}",
147 res.text()
148 .await
149 .map_err(InquisitorError::FailedToReadResponseBody)?
150 );
151 }
152 }
153 (Ok(res), Some(failed_regex)) if res.status().is_success() => {
154 let body = res
155 .text()
156 .await
157 .map_err(InquisitorError::FailedToReadResponseBody)?;
158
159 if failed_regex.is_match(&body) {
160 if !config.hide_errors {
161 eprintln!("Response is 200 but body indicates an error: {}", body);
162 }
163 errors.fetch_add(1, Ordering::SeqCst);
164 } else {
165 passes.fetch_add(1, Ordering::SeqCst);
166
167 if config.print_response {
168 println!("Response successful. Contents: {}", body);
169 }
170 }
171 }
172 (Ok(res), _) => {
173 if !config.hide_errors {
174 eprintln!("Response is not 200. Status code: {}", res.status());
175 }
176 errors.fetch_add(1, Ordering::SeqCst);
177 }
178 (Err(e), _) => {
179 if !config.hide_errors {
180 eprintln!("Request failed: {}", e);
181 }
182 errors.fetch_add(1, Ordering::SeqCst);
183 }
184 };
185
186 total = passes.load(Ordering::Relaxed) + errors.load(Ordering::Relaxed);
187 total_elapsed = test_start_time.elapsed()?.as_micros() as u64;
188 }
189
190 Result::<()>::Ok(())
191 });
192
193 handles.push(task);
194 }
195
196 let times = rt.block_on(async {
197 futures::future::join_all(handles).await;
198 Result::<Histogram<u64>>::Ok(
199 Arc::try_unwrap(times)
200 .map_err(|_| InquisitorError::FailedToUnwrapArc)?
201 .into_inner(),
202 )
203 })?;
204
205 let elapsed_us = test_start_time
206 .elapsed()
207 .map_err(InquisitorError::FailedToGetTimeInterval)?
208 .as_micros() as f64;
209 print_results(
210 times,
211 elapsed_us,
212 errors.load(Ordering::Relaxed),
213 passes.load(Ordering::Relaxed),
214 );
215
216 Ok(())
217}
218
219fn print_results(times: Histogram<u64>, elapsed_us: f64, errors: usize, passes: usize) {
220 let iterations = passes + errors;
221 let rps = (iterations as f64 / (elapsed_us / 1_000_000.0)) as usize;
222
223 println!("total time: {}", Microseconds(elapsed_us));
224 print!("errors: {}/{}", errors, iterations);
225
226 if errors > 0 {
227 println!(" ({:.2}%)", (errors as f64 / iterations as f64) * 100.0);
228 } else {
229 println!();
230 }
231 println!("throughput: {} req./s", rps,);
232
233 println!(
234 "response times:\n\tmean\t{}\n\tst.dev\t{}\n\tmin\t{}\n\tmax\t{}",
235 Microseconds(times.mean()),
236 Microseconds(times.stdev()),
237 Microseconds(times.min() as f64),
238 Microseconds(times.max() as f64),
239 );
240
241 println!(
242 "latencies:\n\t50%\t{}\n\t75%\t{}\n\t90%\t{}\n\t95%\t{}\n\t99%\t{}\n\t99.9%\t{}",
243 Microseconds(times.value_at_quantile(0.5) as f64),
244 Microseconds(times.value_at_quantile(0.75) as f64),
245 Microseconds(times.value_at_quantile(0.9) as f64),
246 Microseconds(times.value_at_quantile(0.95) as f64),
247 Microseconds(times.value_at_quantile(0.99) as f64),
248 Microseconds(times.value_at_quantile(0.999) as f64),
249 );
250}