inquisitor_core/
lib.rs

1use hdrhistogram::Histogram;
2use reqwest::ClientBuilder;
3use std::collections::HashMap;
4use std::io::Read;
5use std::path::PathBuf;
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::sync::Arc;
8use std::time::SystemTime;
9use tokio::sync::Mutex;
10
11pub mod error;
12pub use error::{InquisitorError, Result};
13
14pub mod config;
15pub use config::{Config, Method};
16
17pub mod time;
18use time::Microseconds;
19
20/// Default maximum number of HTTP connections used
21pub const MAX_CONNS: usize = 12;
22
23/// Set a handler to gracefully handle CTRL+C
24fn setup_ctrl_c_handler(should_exit: Arc<AtomicBool>) -> Result<()> {
25    ctrlc::set_handler(move || {
26        let previously_set = should_exit.fetch_or(true, Ordering::SeqCst);
27
28        if previously_set {
29            std::process::exit(130);
30        }
31    })?;
32
33    Ok(())
34}
35
36/// Get the certificate if specified
37fn get_ca_cert(path: &PathBuf) -> Result<reqwest::Certificate> {
38    let mut buf = Vec::new();
39    std::fs::File::open(path)
40        .map_err(|e| InquisitorError::CouldNotOpenFile(e, path.to_owned()))?
41        .read_to_end(&mut buf)
42        .map_err(|e| InquisitorError::CouldNotReadFile(e, path.to_owned()))?;
43
44    reqwest::Certificate::from_pem(&buf)
45        .map_err(|e| InquisitorError::CouldNotConvertCert(e, path.to_owned()))
46}
47
48/// Run load tests with the given configuration
49pub fn run<C: Into<Config>>(config: C) -> Result<()> {
50    let config: Config = config.into();
51
52    let should_exit = Arc::new(AtomicBool::new(false));
53    setup_ctrl_c_handler(should_exit.clone())?;
54
55    let (iterations, duration) = config.iterations_and_duration();
56
57    let mut headers = HashMap::new();
58    for header in config.header {
59        if let Some((k, v)) = header.split_once(':') {
60            headers.insert(k.to_string(), v.to_string());
61        }
62    }
63
64    // histogram of response times, recorded in microseconds
65    let times = Arc::new(Mutex::new(Histogram::<u64>::new_with_max(
66        1_000_000_000_000,
67        3, // significant digits
68    )?));
69
70    let passes = Arc::new(AtomicUsize::new(0));
71    let errors = Arc::new(AtomicUsize::new(0));
72
73    let failed_regex = match config.failed_body {
74        Some(regex) => Some(regex::Regex::new(&regex)?),
75        None => None,
76    };
77
78    let request_body = Box::leak(Box::new(config.request_body)) as &Option<_>;
79
80    let mut handles = Vec::new();
81
82    let rt = tokio::runtime::Builder::new_multi_thread()
83        .enable_io()
84        .enable_time()
85        .build()?;
86
87    let cert = if let Some(cert_path) = config.ca_cert.as_ref() {
88        Some(get_ca_cert(cert_path)?)
89    } else {
90        None
91    };
92
93    let test_start_time = SystemTime::now();
94
95    for _ in 0..config.connections {
96        let mut client = ClientBuilder::new().danger_accept_invalid_certs(config.insecure);
97
98        if let Some(cert) = cert.clone() {
99            client = client.add_root_certificate(cert);
100        }
101
102        let client = client
103            .build()
104            .map_err(InquisitorError::FailedToCreateClient)?;
105
106        let passes = passes.clone();
107        let errors = errors.clone();
108        let url = config.url.clone();
109        let headers = headers.clone();
110        let failed_regex = failed_regex.clone();
111        let times = times.clone();
112        let should_exit = should_exit.clone();
113
114        let task = rt.spawn(async move {
115            let mut total = passes.load(Ordering::Relaxed) + errors.load(Ordering::Relaxed);
116            let mut total_elapsed = test_start_time.elapsed()?.as_micros() as u64;
117
118            while total < iterations && total_elapsed < duration {
119                if should_exit.load(Ordering::Relaxed) {
120                    break;
121                }
122
123                let mut builder = match config.method {
124                    Method::Get => client.get(&url),
125                    Method::Post => client.post(&url),
126                };
127
128                if let Some(body) = request_body.as_deref() {
129                    builder = builder.body(body);
130                }
131
132                for (k, v) in &headers {
133                    builder = builder.header(k, v);
134                }
135
136                let req_start_time = SystemTime::now();
137                let response = builder.send().await;
138                let elapsed = req_start_time.elapsed()?.as_micros() as u64;
139                times.lock().await.record(elapsed)?;
140
141                match (response, failed_regex.as_ref()) {
142                    (Ok(res), _) if res.status().is_success() && failed_regex.is_none() => {
143                        passes.fetch_add(1, Ordering::SeqCst);
144                        if config.print_response {
145                            println!(
146                                "Response successful. Content: {}",
147                                res.text()
148                                    .await
149                                    .map_err(InquisitorError::FailedToReadResponseBody)?
150                            );
151                        }
152                    }
153                    (Ok(res), Some(failed_regex)) if res.status().is_success() => {
154                        let body = res
155                            .text()
156                            .await
157                            .map_err(InquisitorError::FailedToReadResponseBody)?;
158
159                        if failed_regex.is_match(&body) {
160                            if !config.hide_errors {
161                                eprintln!("Response is 200 but body indicates an error: {}", body);
162                            }
163                            errors.fetch_add(1, Ordering::SeqCst);
164                        } else {
165                            passes.fetch_add(1, Ordering::SeqCst);
166
167                            if config.print_response {
168                                println!("Response successful. Contents: {}", body);
169                            }
170                        }
171                    }
172                    (Ok(res), _) => {
173                        if !config.hide_errors {
174                            eprintln!("Response is not 200. Status code: {}", res.status());
175                        }
176                        errors.fetch_add(1, Ordering::SeqCst);
177                    }
178                    (Err(e), _) => {
179                        if !config.hide_errors {
180                            eprintln!("Request failed: {}", e);
181                        }
182                        errors.fetch_add(1, Ordering::SeqCst);
183                    }
184                };
185
186                total = passes.load(Ordering::Relaxed) + errors.load(Ordering::Relaxed);
187                total_elapsed = test_start_time.elapsed()?.as_micros() as u64;
188            }
189
190            Result::<()>::Ok(())
191        });
192
193        handles.push(task);
194    }
195
196    let times = rt.block_on(async {
197        futures::future::join_all(handles).await;
198        Result::<Histogram<u64>>::Ok(
199            Arc::try_unwrap(times)
200                .map_err(|_| InquisitorError::FailedToUnwrapArc)?
201                .into_inner(),
202        )
203    })?;
204
205    let elapsed_us = test_start_time
206        .elapsed()
207        .map_err(InquisitorError::FailedToGetTimeInterval)?
208        .as_micros() as f64;
209    print_results(
210        times,
211        elapsed_us,
212        errors.load(Ordering::Relaxed),
213        passes.load(Ordering::Relaxed),
214    );
215
216    Ok(())
217}
218
219fn print_results(times: Histogram<u64>, elapsed_us: f64, errors: usize, passes: usize) {
220    let iterations = passes + errors;
221    let rps = (iterations as f64 / (elapsed_us / 1_000_000.0)) as usize;
222
223    println!("total time: {}", Microseconds(elapsed_us));
224    print!("errors: {}/{}", errors, iterations);
225
226    if errors > 0 {
227        println!(" ({:.2}%)", (errors as f64 / iterations as f64) * 100.0);
228    } else {
229        println!();
230    }
231    println!("throughput: {} req./s", rps,);
232
233    println!(
234        "response times:\n\tmean\t{}\n\tst.dev\t{}\n\tmin\t{}\n\tmax\t{}",
235        Microseconds(times.mean()),
236        Microseconds(times.stdev()),
237        Microseconds(times.min() as f64),
238        Microseconds(times.max() as f64),
239    );
240
241    println!(
242        "latencies:\n\t50%\t{}\n\t75%\t{}\n\t90%\t{}\n\t95%\t{}\n\t99%\t{}\n\t99.9%\t{}",
243        Microseconds(times.value_at_quantile(0.5) as f64),
244        Microseconds(times.value_at_quantile(0.75) as f64),
245        Microseconds(times.value_at_quantile(0.9) as f64),
246        Microseconds(times.value_at_quantile(0.95) as f64),
247        Microseconds(times.value_at_quantile(0.99) as f64),
248        Microseconds(times.value_at_quantile(0.999) as f64),
249    );
250}