httpress/
config.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::cli::{Args, Method};
6use crate::error::Error;
7
8/// Defines when the benchmark should stop
9#[derive(Debug, Clone)]
10pub enum StopCondition {
11    /// Stop after N requests
12    Requests(usize),
13    /// Stop after duration
14    Duration(Duration),
15    /// Run until interrupted (Ctrl+C)
16    Infinite,
17}
18
19/// HTTP method for requests
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum HttpMethod {
22    Get,
23    Post,
24    Put,
25    Delete,
26    Patch,
27    Head,
28    Options,
29}
30
31impl From<Method> for HttpMethod {
32    fn from(m: Method) -> Self {
33        match m {
34            Method::Get => HttpMethod::Get,
35            Method::Post => HttpMethod::Post,
36            Method::Put => HttpMethod::Put,
37            Method::Delete => HttpMethod::Delete,
38            Method::Patch => HttpMethod::Patch,
39            Method::Head => HttpMethod::Head,
40            Method::Options => HttpMethod::Options,
41        }
42    }
43}
44
45/// Configuration for a single HTTP request
46#[derive(Debug, Clone)]
47pub struct RequestConfig {
48    pub url: String,
49    pub method: HttpMethod,
50    pub headers: HashMap<String, String>,
51    pub body: Option<String>,
52}
53
54/// Context passed to request generator functions
55#[derive(Debug, Clone, Copy)]
56pub struct RequestContext {
57    pub worker_id: usize,
58    pub request_number: usize,
59}
60
61/// Context passed to rate generator functions
62#[derive(Debug, Clone, Copy)]
63pub struct RateContext {
64    /// Time elapsed since benchmark start
65    pub elapsed: Duration,
66    /// Total requests completed (success + failure)
67    pub total_requests: usize,
68    /// Successful requests (2xx status codes)
69    pub successful_requests: usize,
70    /// Failed requests (non-2xx or errors)
71    pub failed_requests: usize,
72    /// Current configured rate (for reference)
73    pub current_rate: f64,
74}
75
76/// Type alias for request generator function
77pub type RequestGenerator = Arc<dyn Fn(RequestContext) -> RequestConfig + Send + Sync>;
78
79/// Type alias for rate generator function
80pub type RateFunction = Arc<dyn Fn(RateContext) -> f64 + Send + Sync>;
81
82/// Context passed to before_request hook functions
83#[derive(Debug, Clone, Copy)]
84pub struct BeforeRequestContext {
85    /// ID of the worker executing this request
86    pub worker_id: usize,
87    /// Sequential number of this request for this worker
88    pub request_number: usize,
89    /// Time elapsed since benchmark start
90    pub elapsed: Duration,
91    /// Total requests completed so far (success + failure)
92    pub total_requests: usize,
93    /// Successful requests so far (2xx status codes)
94    pub successful_requests: usize,
95    /// Failed requests so far (non-2xx or errors)
96    pub failed_requests: usize,
97}
98
99/// Context passed to after_request hook functions
100#[derive(Debug, Clone, Copy)]
101pub struct AfterRequestContext {
102    /// ID of the worker that executed this request
103    pub worker_id: usize,
104    /// Sequential number of this request for this worker
105    pub request_number: usize,
106    /// Time elapsed since benchmark start
107    pub elapsed: Duration,
108    /// Total requests completed so far (success + failure)
109    pub total_requests: usize,
110    /// Successful requests so far (2xx status codes)
111    pub successful_requests: usize,
112    /// Failed requests so far (non-2xx or errors)
113    pub failed_requests: usize,
114    /// Time taken for this request
115    pub latency: Duration,
116    /// HTTP status code (None if request failed)
117    pub status: Option<u16>,
118}
119
120/// Action returned by hook functions to control request execution
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum HookAction {
123    /// Continue with normal execution
124    Continue,
125    /// Abort this request (counts as failed, but doesn't stop benchmark)
126    Abort,
127    /// Retry this request (up to max_retries limit)
128    Retry,
129}
130
131/// Type alias for before_request hook function
132pub type BeforeRequestHook = Arc<dyn Fn(BeforeRequestContext) -> HookAction + Send + Sync>;
133
134/// Type alias for after_request hook function
135pub type AfterRequestHook = Arc<dyn Fn(AfterRequestContext) -> HookAction + Send + Sync>;
136
137/// Source of request configuration - either static or dynamically generated
138#[derive(Clone)]
139pub enum RequestSource {
140    /// Static configuration used for all requests
141    Static(RequestConfig),
142    /// Dynamic generator function called for each request
143    Dynamic(RequestGenerator),
144}
145
146/// Benchmark configuration
147#[derive(Clone)]
148pub struct BenchConfig {
149    pub request_source: RequestSource,
150    pub concurrency: usize,
151    pub stop_condition: StopCondition,
152    pub timeout: Duration,
153    pub rate: Option<u64>,
154    pub rate_fn: Option<RateFunction>,
155    pub before_request_hooks: Vec<BeforeRequestHook>,
156    pub after_request_hooks: Vec<AfterRequestHook>,
157    pub max_retries: usize,
158}
159
160impl BenchConfig {
161    /// Create config from CLI arguments
162    pub fn from_args(args: Args) -> Result<Self, Error> {
163        let stop_condition = match (args.requests, args.duration) {
164            (Some(n), None) => StopCondition::Requests(n),
165            (None, Some(d)) => StopCondition::Duration(parse_duration(&d)?),
166            (None, None) => StopCondition::Infinite,
167            (Some(_), Some(_)) => unreachable!("clap prevents this"),
168        };
169
170        let headers = parse_headers(&args.headers)?;
171
172        let request_config = RequestConfig {
173            url: args.url,
174            method: args.method.into(),
175            headers,
176            body: args.body,
177        };
178
179        Ok(BenchConfig {
180            request_source: RequestSource::Static(request_config),
181            concurrency: args.concurrency,
182            stop_condition,
183            timeout: Duration::from_secs(args.timeout),
184            rate: args.rate,
185            rate_fn: None,
186            before_request_hooks: Vec::new(),
187            after_request_hooks: Vec::new(),
188            max_retries: 3,
189        })
190    }
191}
192
193/// Parse duration string like "10s", "1m", "500ms"
194fn parse_duration(s: &str) -> Result<Duration, Error> {
195    let s = s.trim();
196
197    if let Some(ms) = s.strip_suffix("ms") {
198        let ms: u64 = ms.parse().map_err(|_| Error::InvalidDuration(s.to_string()))?;
199        return Ok(Duration::from_millis(ms));
200    }
201
202    if let Some(secs) = s.strip_suffix('s') {
203        let secs: u64 = secs.parse().map_err(|_| Error::InvalidDuration(s.to_string()))?;
204        return Ok(Duration::from_secs(secs));
205    }
206
207    if let Some(mins) = s.strip_suffix('m') {
208        let mins: u64 = mins.parse().map_err(|_| Error::InvalidDuration(s.to_string()))?;
209        return Ok(Duration::from_secs(mins * 60));
210    }
211
212    // Try parsing as plain seconds
213    if let Ok(secs) = s.parse::<u64>() {
214        return Ok(Duration::from_secs(secs));
215    }
216
217    Err(Error::InvalidDuration(s.to_string()))
218}
219
220/// Parse header strings like "Content-Type: application/json"
221fn parse_headers(headers: &[String]) -> Result<HashMap<String, String>, Error> {
222    let mut map = HashMap::new();
223
224    for h in headers {
225        let (key, value) = h
226            .split_once(':')
227            .ok_or_else(|| Error::InvalidHeader(h.clone()))?;
228
229        map.insert(key.trim().to_string(), value.trim().to_string());
230    }
231
232    Ok(map)
233}