1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::cli::{Args, Method};
6use crate::error::Error;
7
8#[derive(Debug, Clone)]
10pub enum StopCondition {
11 Requests(usize),
13 Duration(Duration),
15 Infinite,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum HttpMethod {
22 Get,
23 Post,
24 Put,
25 Delete,
26 Patch,
27 Head,
28 Options,
29}
30
31impl From<Method> for HttpMethod {
32 fn from(m: Method) -> Self {
33 match m {
34 Method::Get => HttpMethod::Get,
35 Method::Post => HttpMethod::Post,
36 Method::Put => HttpMethod::Put,
37 Method::Delete => HttpMethod::Delete,
38 Method::Patch => HttpMethod::Patch,
39 Method::Head => HttpMethod::Head,
40 Method::Options => HttpMethod::Options,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct RequestConfig {
48 pub url: String,
49 pub method: HttpMethod,
50 pub headers: HashMap<String, String>,
51 pub body: Option<String>,
52}
53
54#[derive(Debug, Clone, Copy)]
56pub struct RequestContext {
57 pub worker_id: usize,
58 pub request_number: usize,
59}
60
61#[derive(Debug, Clone, Copy)]
63pub struct RateContext {
64 pub elapsed: Duration,
66 pub total_requests: usize,
68 pub successful_requests: usize,
70 pub failed_requests: usize,
72 pub current_rate: f64,
74}
75
76pub type RequestGenerator = Arc<dyn Fn(RequestContext) -> RequestConfig + Send + Sync>;
78
79pub type RateFunction = Arc<dyn Fn(RateContext) -> f64 + Send + Sync>;
81
82#[derive(Debug, Clone, Copy)]
84pub struct BeforeRequestContext {
85 pub worker_id: usize,
87 pub request_number: usize,
89 pub elapsed: Duration,
91 pub total_requests: usize,
93 pub successful_requests: usize,
95 pub failed_requests: usize,
97}
98
99#[derive(Debug, Clone, Copy)]
101pub struct AfterRequestContext {
102 pub worker_id: usize,
104 pub request_number: usize,
106 pub elapsed: Duration,
108 pub total_requests: usize,
110 pub successful_requests: usize,
112 pub failed_requests: usize,
114 pub latency: Duration,
116 pub status: Option<u16>,
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum HookAction {
123 Continue,
125 Abort,
127 Retry,
129}
130
131pub type BeforeRequestHook = Arc<dyn Fn(BeforeRequestContext) -> HookAction + Send + Sync>;
133
134pub type AfterRequestHook = Arc<dyn Fn(AfterRequestContext) -> HookAction + Send + Sync>;
136
137#[derive(Clone)]
139pub enum RequestSource {
140 Static(RequestConfig),
142 Dynamic(RequestGenerator),
144}
145
146#[derive(Clone)]
148pub struct BenchConfig {
149 pub request_source: RequestSource,
150 pub concurrency: usize,
151 pub stop_condition: StopCondition,
152 pub timeout: Duration,
153 pub rate: Option<u64>,
154 pub rate_fn: Option<RateFunction>,
155 pub before_request_hooks: Vec<BeforeRequestHook>,
156 pub after_request_hooks: Vec<AfterRequestHook>,
157 pub max_retries: usize,
158}
159
160impl BenchConfig {
161 pub fn from_args(args: Args) -> Result<Self, Error> {
163 let stop_condition = match (args.requests, args.duration) {
164 (Some(n), None) => StopCondition::Requests(n),
165 (None, Some(d)) => StopCondition::Duration(parse_duration(&d)?),
166 (None, None) => StopCondition::Infinite,
167 (Some(_), Some(_)) => unreachable!("clap prevents this"),
168 };
169
170 let headers = parse_headers(&args.headers)?;
171
172 let request_config = RequestConfig {
173 url: args.url,
174 method: args.method.into(),
175 headers,
176 body: args.body,
177 };
178
179 Ok(BenchConfig {
180 request_source: RequestSource::Static(request_config),
181 concurrency: args.concurrency,
182 stop_condition,
183 timeout: Duration::from_secs(args.timeout),
184 rate: args.rate,
185 rate_fn: None,
186 before_request_hooks: Vec::new(),
187 after_request_hooks: Vec::new(),
188 max_retries: 3,
189 })
190 }
191}
192
193fn parse_duration(s: &str) -> Result<Duration, Error> {
195 let s = s.trim();
196
197 if let Some(ms) = s.strip_suffix("ms") {
198 let ms: u64 = ms.parse().map_err(|_| Error::InvalidDuration(s.to_string()))?;
199 return Ok(Duration::from_millis(ms));
200 }
201
202 if let Some(secs) = s.strip_suffix('s') {
203 let secs: u64 = secs.parse().map_err(|_| Error::InvalidDuration(s.to_string()))?;
204 return Ok(Duration::from_secs(secs));
205 }
206
207 if let Some(mins) = s.strip_suffix('m') {
208 let mins: u64 = mins.parse().map_err(|_| Error::InvalidDuration(s.to_string()))?;
209 return Ok(Duration::from_secs(mins * 60));
210 }
211
212 if let Ok(secs) = s.parse::<u64>() {
214 return Ok(Duration::from_secs(secs));
215 }
216
217 Err(Error::InvalidDuration(s.to_string()))
218}
219
220fn parse_headers(headers: &[String]) -> Result<HashMap<String, String>, Error> {
222 let mut map = HashMap::new();
223
224 for h in headers {
225 let (key, value) = h
226 .split_once(':')
227 .ok_or_else(|| Error::InvalidHeader(h.clone()))?;
228
229 map.insert(key.trim().to_string(), value.trim().to_string());
230 }
231
232 Ok(map)
233}