1use std::path::PathBuf;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use tokio::sync::{Semaphore, mpsc};
6use tokio_util::sync::CancellationToken;
7use tracing::Instrument;
8use tracing::info_span;
9
10use crate::command::{Body, Command};
11use crate::config::secret::SensitiveString;
12use crate::http::{Request, RequestConfig, RequestResult};
13use crate::load_curve::LoadCurve;
14use crate::load_curve::executor::{CurveExecutor, CurveExecutorParams};
15use crate::monitoring::SpanName;
16use crate::request_template::Template;
17use crate::response_template::ResponseTemplate;
18use crate::response_template::extractor;
19use crate::response_template::field::TrackedField;
20use crate::response_template::stats::ResponseStats;
21use crate::sampling::{ReservoirAction, SamplingParams, SamplingState};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum RunMode {
28 Fixed,
30 Curve,
32}
33
34pub struct RunStats {
37 pub elapsed: Duration,
38 pub template_duration: Option<Duration>,
39 pub response_stats: Option<ResponseStats>,
40 pub results: Vec<RequestResult>,
41 pub mode: RunMode,
42 pub curve_duration: Option<Duration>,
44 pub curve_stages: Option<Vec<crate::load_curve::Stage>>,
47 pub total_requests: usize,
49 pub total_failures: usize,
51 pub sample_rate: f64,
53 pub min_sample_rate: f64,
55}
56
57pub struct RequestSpec {
61 pub host: String,
62 pub method: crate::command::HttpMethod,
63 pub body: Option<Body>,
64 pub template_path: Option<PathBuf>,
65 pub response_template_path: Option<PathBuf>,
66 pub headers: Vec<(String, SensitiveString)>,
68}
69
70pub struct SamplingConfig {
74 pub sample_threshold: usize,
75 pub result_buffer: usize,
76}
77
78pub enum ExecutionMode {
82 Fixed {
84 request_count: usize,
85 concurrency: usize,
86 },
87 Curve(LoadCurve),
89}
90
91pub struct RunCommand {
94 pub request: RequestSpec,
95 pub execution: ExecutionMode,
96 pub sampling: SamplingConfig,
97}
98
99impl Command for RunCommand {
100 async fn execute(self) -> Result<Option<RunStats>, Box<dyn std::error::Error>> {
101 match self.execution {
102 ExecutionMode::Fixed {
103 request_count,
104 concurrency,
105 } => execute_fixed(self.request, self.sampling, request_count, concurrency).await,
106 ExecutionMode::Curve(curve) => execute_curve(self.request, self.sampling, curve).await,
107 }
108 }
109}
110
111fn resolve_tracked_fields(
114 path: Option<PathBuf>,
115) -> Result<Option<Arc<Vec<TrackedField>>>, Box<dyn std::error::Error>> {
116 path.map(|p| {
117 ResponseTemplate::parse(&p)
118 .map(|rt| Arc::new(rt.fields))
119 .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
120 })
121 .transpose()
122}
123
124fn build_request_config(
125 host: String,
126 method: crate::command::HttpMethod,
127 body: Option<Body>,
128 tracked_fields: Option<Arc<Vec<TrackedField>>>,
129 headers: Vec<(String, SensitiveString)>,
130) -> Arc<RequestConfig> {
131 Arc::new(RequestConfig {
132 client: reqwest::Client::new(),
133 host: Arc::new(host),
134 method,
135 body: Arc::new(body),
136 tracked_fields,
137 headers: Arc::new(headers),
138 })
139}
140
141fn compute_response_stats(
142 results: &[RequestResult],
143 tracked_fields: &Option<Arc<Vec<TrackedField>>>,
144) -> Option<ResponseStats> {
145 tracked_fields.as_ref().map(|fields| {
146 let mut rs = ResponseStats::new();
147 for result in results {
148 if let Some(ref body_str) = result.response_body
149 && let Ok(body_val) = serde_json::from_str(body_str)
150 {
151 rs.record(extractor::extract(&body_val, fields));
152 }
153 }
154 rs
155 })
156}
157
158async fn execute_fixed(
162 request_spec: RequestSpec,
163 sampling: SamplingConfig,
164 total: usize,
165 concurrency: usize,
166) -> Result<Option<RunStats>, Box<dyn std::error::Error>> {
167 let RequestSpec {
168 host,
169 method,
170 body,
171 template_path,
172 response_template_path,
173 headers,
174 } = request_spec;
175
176 let gen_start = Instant::now();
178 let all_bodies: Option<Vec<String>> = template_path
179 .map(|path| {
180 let template = Template::parse(&path)?;
181 let bodies = template.pre_generate(total);
182 Ok::<Vec<String>, Box<dyn std::error::Error>>(bodies)
183 })
184 .transpose()?;
185 let template_duration = all_bodies.as_ref().map(|_| gen_start.elapsed());
186
187 let tracked_fields = resolve_tracked_fields(response_template_path)?;
188 let request = build_request_config(host, method, body, tracked_fields, headers);
189
190 let token = CancellationToken::new();
191 let cancel = token.clone();
192 tokio::spawn(async move {
193 tokio::signal::ctrl_c()
194 .await
195 .expect("failed to listen for ctrl_c");
196 eprintln!("\nShutdown signal received — waiting for in-flight requests to finish...");
197 cancel.cancel();
198 });
199
200 let started_at = Instant::now();
201
202 let sample_threshold = sampling.sample_threshold;
203 let result_buffer = sampling.result_buffer;
204
205 let plain_headers: Arc<Vec<(String, String)>> = Arc::new(
207 request
208 .headers
209 .iter()
210 .map(|(k, v)| (k.clone(), v.to_string()))
211 .collect(),
212 );
213
214 let (all_results, sampling_state) = async {
215 let sem = Arc::new(Semaphore::new(concurrency));
216 let (tx, mut rx) = mpsc::channel::<RequestResult>(concurrency);
217
218 for i in 0..total {
219 let resolved = request.resolve_body(all_bodies.as_ref().map(|bs| bs[i].clone()));
220
221 let client = request.client.clone();
222 let url = request.host.as_str().to_string();
223 let method = request.method;
224 let capture_body = request.tracked_fields.is_some();
225 let headers = Arc::clone(&plain_headers);
226 let tx = tx.clone();
227
228 tokio::select! {
229 _ = token.cancelled() => break,
230 permit = sem.clone().acquire_owned() => {
231 let permit = permit.unwrap();
232 tokio::spawn(async move {
233 let _permit = permit;
234 let mut req = Request::new(client, url, method);
235 if let Some((content, content_type)) = resolved {
236 req = req.body(content, content_type);
237 }
238 if capture_body {
239 req = req.read_response();
240 }
241 if !headers.is_empty() {
242 req = req.headers((*headers).clone());
243 }
244 let _ = tx.send(req.execute().await).await;
245 });
246 }
247 }
248 }
249
250 drop(tx);
252
253 let mut sampling_state = SamplingState::new(SamplingParams {
254 vu_threshold: sample_threshold,
255 reservoir_size: result_buffer,
256 });
257 sampling_state.set_active_vus(concurrency);
259
260 let mut results: Vec<RequestResult> = Vec::with_capacity(total.min(result_buffer));
261 while let Some(result) = rx.recv().await {
262 sampling_state.record_request(result.success);
263 if sampling_state.should_collect() {
264 match sampling_state.reservoir_slot(results.len()) {
265 ReservoirAction::Push => results.push(result),
266 ReservoirAction::Replace(idx) => results[idx] = result,
267 ReservoirAction::Discard => {}
268 }
269 }
270 }
271 (results, sampling_state)
272 }
273 .instrument(info_span!(SpanName::REQUESTS, total))
274 .await;
275
276 let response_stats = compute_response_stats(&all_results, &request.tracked_fields);
277
278 Ok(Some(RunStats {
279 elapsed: started_at.elapsed(),
280 template_duration,
281 response_stats,
282 results: all_results,
283 mode: RunMode::Fixed,
284 curve_duration: None,
285 curve_stages: None,
286 total_requests: sampling_state.total_requests(),
287 total_failures: sampling_state.total_failures(),
288 sample_rate: sampling_state.sample_rate(),
289 min_sample_rate: sampling_state.min_sample_rate(),
290 }))
291}
292
293async fn execute_curve(
297 request_spec: RequestSpec,
298 sampling: SamplingConfig,
299 curve: LoadCurve,
300) -> Result<Option<RunStats>, Box<dyn std::error::Error>> {
301 let RequestSpec {
302 host,
303 method,
304 body,
305 template_path,
306 response_template_path,
307 headers,
308 } = request_spec;
309 let curve_duration = curve.total_duration();
310 let curve_stages = curve.stages.clone();
311
312 let template: Option<Arc<Template>> = template_path
314 .map(|path| Template::parse(&path).map(Arc::new))
315 .transpose()
316 .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
317
318 let tracked_fields = resolve_tracked_fields(response_template_path)?;
319 let request_config = build_request_config(host, method, body, tracked_fields, headers);
320
321 let cancellation_token = CancellationToken::new();
322 let cancel = cancellation_token.clone();
323 tokio::spawn(async move {
324 tokio::signal::ctrl_c()
325 .await
326 .expect("failed to listen for ctrl_c");
327 eprintln!("\nShutdown signal received — cancelling curve execution...");
328 cancel.cancel();
329 });
330
331 let started_at = Instant::now();
332
333 let executor = CurveExecutor::new(CurveExecutorParams {
334 curve,
335 request_config: Arc::clone(&request_config),
336 template,
337 cancellation_token,
338 sampling: SamplingParams {
339 vu_threshold: sampling.sample_threshold,
340 reservoir_size: sampling.result_buffer,
341 },
342 });
343
344 let curve_result = executor.execute().await;
345
346 let response_stats =
347 compute_response_stats(&curve_result.results, &request_config.tracked_fields);
348
349 Ok(Some(RunStats {
350 elapsed: started_at.elapsed(),
351 template_duration: None,
352 response_stats,
353 results: curve_result.results,
354 mode: RunMode::Curve,
355 curve_duration: Some(curve_duration),
356 curve_stages: Some(curve_stages),
357 total_requests: curve_result.total_requests,
358 total_failures: curve_result.total_failures,
359 sample_rate: curve_result.sample_rate,
360 min_sample_rate: curve_result.min_sample_rate,
361 }))
362}
363
364#[cfg(test)]
367mod tests {
368 use std::time::Duration;
369
370 use super::{RunMode, RunStats};
371 use crate::load_curve::{RampType, Stage};
372
373 fn make_stats_fixed() -> RunStats {
374 RunStats {
375 elapsed: Duration::from_secs(1),
376 template_duration: None,
377 response_stats: None,
378 results: vec![],
379 mode: RunMode::Fixed,
380 curve_duration: None,
381 curve_stages: None,
382 total_requests: 10,
383 total_failures: 0,
384 sample_rate: 1.0,
385 min_sample_rate: 1.0,
386 }
387 }
388
389 fn make_stats_curve(stages: Vec<Stage>) -> RunStats {
390 RunStats {
391 elapsed: Duration::from_secs(10),
392 template_duration: None,
393 response_stats: None,
394 results: vec![],
395 mode: RunMode::Curve,
396 curve_duration: Some(Duration::from_secs(10)),
397 curve_stages: Some(stages),
398 total_requests: 100,
399 total_failures: 2,
400 sample_rate: 1.0,
401 min_sample_rate: 1.0,
402 }
403 }
404
405 #[test]
408 fn curve_stages_none_for_fixed_mode() {
409 let stats = make_stats_fixed();
410 assert!(
411 stats.curve_stages.is_none(),
412 "fixed-mode RunStats must have curve_stages == None"
413 );
414 }
415
416 #[test]
419 fn curve_stages_some_for_curve_mode() {
420 let stages = vec![
421 Stage {
422 duration: Duration::from_secs(5),
423 target_vus: 50,
424 ramp: RampType::Linear,
425 },
426 Stage {
427 duration: Duration::from_secs(5),
428 target_vus: 100,
429 ramp: RampType::Step,
430 },
431 ];
432 let stats = make_stats_curve(stages.clone());
433
434 let stored = stats
435 .curve_stages
436 .expect("curve_stages must be Some in curve mode");
437 assert_eq!(stored.len(), 2);
438 assert_eq!(stored[0].target_vus, 50);
439 assert_eq!(stored[0].ramp, RampType::Linear);
440 assert_eq!(stored[1].target_vus, 100);
441 assert_eq!(stored[1].ramp, RampType::Step);
442 }
443
444 #[test]
447 fn curve_stages_count_matches_original() {
448 let stages: Vec<Stage> = (0..5)
449 .map(|i| Stage {
450 duration: Duration::from_secs(10),
451 target_vus: (i + 1) * 20,
452 ramp: RampType::Linear,
453 })
454 .collect();
455 let count = stages.len();
456 let stats = make_stats_curve(stages);
457 assert_eq!(
458 stats.curve_stages.unwrap().len(),
459 count,
460 "stored stage count must equal original stage count"
461 );
462 }
463
464 #[test]
467 fn run_mode_fixed_variant() {
468 let stats = make_stats_fixed();
469 assert_eq!(stats.mode, RunMode::Fixed);
470 }
471
472 #[test]
475 fn run_mode_curve_variant() {
476 let stages = vec![Stage {
477 duration: Duration::from_secs(5),
478 target_vus: 10,
479 ramp: RampType::Linear,
480 }];
481 let stats = make_stats_curve(stages);
482 assert_eq!(stats.mode, RunMode::Curve);
483 }
484}