use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use indicatif::ProgressBar;
use crate::client::HttpClient;
use crate::config::{
AfterRequestHook, BeforeRequestHook, BenchConfig, HttpMethod, RateContext, RateFunction,
RequestConfig, RequestContext, RequestGenerator, RequestSource, StopCondition,
};
use crate::error::{Error, Result};
use crate::executor::Executor;
use crate::metrics::BenchmarkResults;
pub struct BenchmarkBuilder {
url: Option<String>,
method: Option<HttpMethod>,
concurrency: usize,
stop_condition: StopCondition,
headers: HashMap<String, String>,
body: Option<String>,
timeout: Duration,
rate: Option<u64>,
rate_fn: Option<RateFunction>,
request_fn: Option<RequestGenerator>,
before_request_hooks: Vec<BeforeRequestHook>,
after_request_hooks: Vec<AfterRequestHook>,
max_retries: usize,
show_progress: bool,
insecure: bool,
}
impl BenchmarkBuilder {
pub fn new() -> Self {
BenchmarkBuilder {
url: None,
method: None,
concurrency: 10,
stop_condition: StopCondition::Infinite,
headers: HashMap::new(),
body: None,
timeout: Duration::from_secs(30),
rate: None,
rate_fn: None,
request_fn: None,
before_request_hooks: Vec::new(),
after_request_hooks: Vec::new(),
max_retries: 3,
show_progress: false,
insecure: false,
}
}
pub fn url(mut self, url: &str) -> Self {
self.url = Some(url.to_string());
self
}
pub fn method(mut self, method: HttpMethod) -> Self {
self.method = Some(method);
self
}
pub fn concurrency(mut self, n: usize) -> Self {
self.concurrency = n;
self
}
pub fn duration(mut self, d: Duration) -> Self {
self.stop_condition = StopCondition::Duration(d);
self
}
pub fn requests(mut self, n: usize) -> Self {
self.stop_condition = StopCondition::Requests(n);
self
}
pub fn rate(mut self, rps: u64) -> Self {
self.rate = Some(rps);
self
}
pub fn rate_fn<F>(mut self, f: F) -> Self
where
F: Fn(RateContext) -> f64 + Send + Sync + 'static,
{
self.rate_fn = Some(Arc::new(f));
self
}
pub fn header(mut self, key: &str, value: &str) -> Self {
self.headers.insert(key.to_string(), value.to_string());
self
}
pub fn body(mut self, body: &str) -> Self {
self.body = Some(body.to_string());
self
}
pub fn timeout(mut self, d: Duration) -> Self {
self.timeout = d;
self
}
pub fn request_fn<F>(mut self, f: F) -> Self
where
F: Fn(RequestContext) -> RequestConfig + Send + Sync + 'static,
{
self.request_fn = Some(Arc::new(f));
self
}
pub fn before_request<F>(mut self, f: F) -> Self
where
F: Fn(crate::config::BeforeRequestContext) -> crate::config::HookAction
+ Send
+ Sync
+ 'static,
{
self.before_request_hooks.push(Arc::new(f));
self
}
pub fn after_request<F>(mut self, f: F) -> Self
where
F: Fn(crate::config::AfterRequestContext) -> crate::config::HookAction
+ Send
+ Sync
+ 'static,
{
self.after_request_hooks.push(Arc::new(f));
self
}
pub fn max_retries(mut self, n: usize) -> Self {
self.max_retries = n;
self
}
pub fn show_progress(mut self, show: bool) -> Self {
self.show_progress = show;
self
}
pub fn insecure(mut self, insecure: bool) -> Self {
self.insecure = insecure;
self
}
pub fn build(self) -> Result<Benchmark> {
if self.rate.is_some() && self.rate_fn.is_some() {
return Err(Error::InvalidConfig(
"Cannot use both rate() and rate_fn()".to_string(),
));
}
let request_source = match (self.url, self.request_fn) {
(Some(_), Some(_)) => {
return Err(Error::InvalidConfig(
"Cannot use both url() and request_fn()".to_string(),
));
}
(None, None) => {
return Err(Error::InvalidConfig(
"Must provide either url() or request_fn()".to_string(),
));
}
(Some(url), None) => {
let request_config = RequestConfig {
url,
method: self.method.unwrap_or(HttpMethod::Get),
headers: self.headers,
body: self.body.map(Bytes::from),
};
RequestSource::Static(request_config)
}
(None, Some(generator)) => {
if self.method.is_some() {
return Err(Error::InvalidConfig(
"Cannot use method() with request_fn()".to_string(),
));
}
if !self.headers.is_empty() {
return Err(Error::InvalidConfig(
"Cannot use header() with request_fn()".to_string(),
));
}
if self.body.is_some() {
return Err(Error::InvalidConfig(
"Cannot use body() with request_fn()".to_string(),
));
}
RequestSource::Dynamic(generator)
}
};
let config = BenchConfig {
request_source,
concurrency: self.concurrency,
stop_condition: self.stop_condition,
timeout: self.timeout,
rate: self.rate,
rate_fn: self.rate_fn,
before_request_hooks: self.before_request_hooks,
after_request_hooks: self.after_request_hooks,
max_retries: self.max_retries,
progress_fn: None,
insecure: self.insecure,
};
let (config, progress_bar) = if self.show_progress {
let (c, pb) = config.with_progress();
(c, Some(pb))
} else {
(config, None)
};
Ok(Benchmark {
config,
progress_bar,
})
}
}
impl Default for BenchmarkBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct Benchmark {
config: BenchConfig,
progress_bar: Option<Arc<ProgressBar>>,
}
impl Benchmark {
pub fn builder() -> BenchmarkBuilder {
BenchmarkBuilder::new()
}
pub async fn run(self) -> Result<BenchmarkResults> {
let client = HttpClient::new(
self.config.timeout,
self.config.concurrency,
self.config.insecure,
)?;
let executor = Executor::new(client, self.config);
let results = executor.run().await?;
if let Some(pb) = self.progress_bar {
pb.finish_and_clear();
}
Ok(results)
}
}