Skip to main content

lmn_core/command/
run.rs

1use std::path::PathBuf;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use tokio::sync::{Semaphore, mpsc};
6use tokio_util::sync::CancellationToken;
7use tracing::Instrument;
8use tracing::info_span;
9
10use crate::command::{Body, Command};
11use crate::config::secret::SensitiveString;
12use crate::http::{Request, RequestConfig, RequestResult};
13use crate::load_curve::LoadCurve;
14use crate::load_curve::executor::{CurveExecutor, CurveExecutorParams};
15use crate::monitoring::SpanName;
16use crate::request_template::Template;
17use crate::response_template::ResponseTemplate;
18use crate::response_template::extractor;
19use crate::response_template::field::TrackedField;
20use crate::response_template::stats::ResponseStats;
21use crate::sampling::{ReservoirAction, SamplingParams, SamplingState};
22
23// ── RunMode ───────────────────────────────────────────────────────────────────
24
25/// Indicates which execution strategy produced the run results.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum RunMode {
28    /// Classic semaphore-based fixed-count mode.
29    Fixed,
30    /// Time-based dynamic VU mode driven by a `LoadCurve`.
31    Curve,
32}
33
34// ── RunStats ──────────────────────────────────────────────────────────────────
35
36pub struct RunStats {
37    pub elapsed: Duration,
38    pub template_duration: Option<Duration>,
39    pub response_stats: Option<ResponseStats>,
40    pub results: Vec<RequestResult>,
41    pub mode: RunMode,
42    /// Total curve duration (only meaningful when `mode == RunMode::Curve`).
43    pub curve_duration: Option<Duration>,
44    /// Curve stages captured from the `LoadCurve` after execution.
45    /// `Some` only when `mode == RunMode::Curve`.
46    pub curve_stages: Option<Vec<crate::load_curve::Stage>>,
47    /// Actual (unsampled) total request count.
48    pub total_requests: usize,
49    /// Actual (unsampled) failure count.
50    pub total_failures: usize,
51    /// Final VU-threshold sample rate at end of run (1.0 = no threshold sampling).
52    pub sample_rate: f64,
53    /// Lowest sample rate observed at any point during the run.
54    pub min_sample_rate: f64,
55}
56
57// ── RequestSpec ───────────────────────────────────────────────────────────────
58
59/// All request-level parameters for a run.
60pub struct RequestSpec {
61    pub host: String,
62    pub method: crate::command::HttpMethod,
63    pub body: Option<Body>,
64    pub template_path: Option<PathBuf>,
65    pub response_template_path: Option<PathBuf>,
66    /// Custom HTTP headers to send with every request in this run.
67    pub headers: Vec<(String, SensitiveString)>,
68}
69
70// ── SamplingConfig ────────────────────────────────────────────────────────────
71
72/// Sampling and reservoir parameters for a run.
73pub struct SamplingConfig {
74    pub sample_threshold: usize,
75    pub result_buffer: usize,
76}
77
78// ── ExecutionMode ─────────────────────────────────────────────────────────────
79
80/// Determines the execution strategy for a run.
81pub enum ExecutionMode {
82    /// Classic semaphore-based fixed-count execution.
83    Fixed {
84        request_count: usize,
85        concurrency: usize,
86    },
87    /// Time-based dynamic VU execution driven by a `LoadCurve`.
88    Curve(LoadCurve),
89}
90
91// ── RunCommand ────────────────────────────────────────────────────────────────
92
93pub struct RunCommand {
94    pub request: RequestSpec,
95    pub execution: ExecutionMode,
96    pub sampling: SamplingConfig,
97}
98
99impl Command for RunCommand {
100    async fn execute(self) -> Result<Option<RunStats>, Box<dyn std::error::Error>> {
101        match self.execution {
102            ExecutionMode::Fixed {
103                request_count,
104                concurrency,
105            } => execute_fixed(self.request, self.sampling, request_count, concurrency).await,
106            ExecutionMode::Curve(curve) => execute_curve(self.request, self.sampling, curve).await,
107        }
108    }
109}
110
111// ── Shared helpers ────────────────────────────────────────────────────────────
112
113fn resolve_tracked_fields(
114    path: Option<PathBuf>,
115) -> Result<Option<Arc<Vec<TrackedField>>>, Box<dyn std::error::Error>> {
116    path.map(|p| {
117        ResponseTemplate::parse(&p)
118            .map(|rt| Arc::new(rt.fields))
119            .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
120    })
121    .transpose()
122}
123
124fn build_request_config(
125    host: String,
126    method: crate::command::HttpMethod,
127    body: Option<Body>,
128    tracked_fields: Option<Arc<Vec<TrackedField>>>,
129    headers: Vec<(String, SensitiveString)>,
130) -> Arc<RequestConfig> {
131    Arc::new(RequestConfig {
132        client: reqwest::Client::new(),
133        host: Arc::new(host),
134        method,
135        body: Arc::new(body),
136        tracked_fields,
137        headers: Arc::new(headers),
138    })
139}
140
141fn compute_response_stats(
142    results: &[RequestResult],
143    tracked_fields: &Option<Arc<Vec<TrackedField>>>,
144) -> Option<ResponseStats> {
145    tracked_fields.as_ref().map(|fields| {
146        let mut rs = ResponseStats::new();
147        for result in results {
148            if let Some(ref body_str) = result.response_body
149                && let Ok(body_val) = serde_json::from_str(body_str)
150            {
151                rs.record(extractor::extract(&body_val, fields));
152            }
153        }
154        rs
155    })
156}
157
158// ── execute_fixed ─────────────────────────────────────────────────────────────
159
160/// Fixed-count semaphore-based execution path (original behaviour, unchanged).
161async fn execute_fixed(
162    request_spec: RequestSpec,
163    sampling: SamplingConfig,
164    total: usize,
165    concurrency: usize,
166) -> Result<Option<RunStats>, Box<dyn std::error::Error>> {
167    let RequestSpec {
168        host,
169        method,
170        body,
171        template_path,
172        response_template_path,
173        headers,
174    } = request_spec;
175
176    // Pre-generate all template bodies before any requests fire
177    let gen_start = Instant::now();
178    let all_bodies: Option<Vec<String>> = template_path
179        .map(|path| {
180            let template = Template::parse(&path)?;
181            let bodies = template.pre_generate(total);
182            Ok::<Vec<String>, Box<dyn std::error::Error>>(bodies)
183        })
184        .transpose()?;
185    let template_duration = all_bodies.as_ref().map(|_| gen_start.elapsed());
186
187    let tracked_fields = resolve_tracked_fields(response_template_path)?;
188    let request = build_request_config(host, method, body, tracked_fields, headers);
189
190    let token = CancellationToken::new();
191    let cancel = token.clone();
192    tokio::spawn(async move {
193        tokio::signal::ctrl_c()
194            .await
195            .expect("failed to listen for ctrl_c");
196        eprintln!("\nShutdown signal received — waiting for in-flight requests to finish...");
197        cancel.cancel();
198    });
199
200    let started_at = Instant::now();
201
202    let sample_threshold = sampling.sample_threshold;
203    let result_buffer = sampling.result_buffer;
204
205    // Pre-convert headers once before the hot loop to avoid per-request allocation.
206    let plain_headers: Arc<Vec<(String, String)>> = Arc::new(
207        request
208            .headers
209            .iter()
210            .map(|(k, v)| (k.clone(), v.to_string()))
211            .collect(),
212    );
213
214    let (all_results, sampling_state) = async {
215        let sem = Arc::new(Semaphore::new(concurrency));
216        let (tx, mut rx) = mpsc::channel::<RequestResult>(concurrency);
217
218        for i in 0..total {
219            let resolved = request.resolve_body(all_bodies.as_ref().map(|bs| bs[i].clone()));
220
221            let client = request.client.clone();
222            let url = request.host.as_str().to_string();
223            let method = request.method;
224            let capture_body = request.tracked_fields.is_some();
225            let headers = Arc::clone(&plain_headers);
226            let tx = tx.clone();
227
228            tokio::select! {
229                _ = token.cancelled() => break,
230                permit = sem.clone().acquire_owned() => {
231                    let permit = permit.unwrap();
232                    tokio::spawn(async move {
233                        let _permit = permit;
234                        let mut req = Request::new(client, url, method);
235                        if let Some((content, content_type)) = resolved {
236                            req = req.body(content, content_type);
237                        }
238                        if capture_body {
239                            req = req.read_response();
240                        }
241                        if !headers.is_empty() {
242                            req = req.headers((*headers).clone());
243                        }
244                        let _ = tx.send(req.execute().await).await;
245                    });
246                }
247            }
248        }
249
250        // Close the last sender — rx drains once all tasks have finished
251        drop(tx);
252
253        let mut sampling_state = SamplingState::new(SamplingParams {
254            vu_threshold: sample_threshold,
255            reservoir_size: result_buffer,
256        });
257        // In fixed mode the VU count is constant — set it once before draining.
258        sampling_state.set_active_vus(concurrency);
259
260        let mut results: Vec<RequestResult> = Vec::with_capacity(total.min(result_buffer));
261        while let Some(result) = rx.recv().await {
262            sampling_state.record_request(result.success);
263            if sampling_state.should_collect() {
264                match sampling_state.reservoir_slot(results.len()) {
265                    ReservoirAction::Push => results.push(result),
266                    ReservoirAction::Replace(idx) => results[idx] = result,
267                    ReservoirAction::Discard => {}
268                }
269            }
270        }
271        (results, sampling_state)
272    }
273    .instrument(info_span!(SpanName::REQUESTS, total))
274    .await;
275
276    let response_stats = compute_response_stats(&all_results, &request.tracked_fields);
277
278    Ok(Some(RunStats {
279        elapsed: started_at.elapsed(),
280        template_duration,
281        response_stats,
282        results: all_results,
283        mode: RunMode::Fixed,
284        curve_duration: None,
285        curve_stages: None,
286        total_requests: sampling_state.total_requests(),
287        total_failures: sampling_state.total_failures(),
288        sample_rate: sampling_state.sample_rate(),
289        min_sample_rate: sampling_state.min_sample_rate(),
290    }))
291}
292
293// ── execute_curve ─────────────────────────────────────────────────────────────
294
295/// Curve-based dynamic VU execution path.
296async fn execute_curve(
297    request_spec: RequestSpec,
298    sampling: SamplingConfig,
299    curve: LoadCurve,
300) -> Result<Option<RunStats>, Box<dyn std::error::Error>> {
301    let RequestSpec {
302        host,
303        method,
304        body,
305        template_path,
306        response_template_path,
307        headers,
308    } = request_spec;
309    let curve_duration = curve.total_duration();
310    let curve_stages = curve.stages.clone();
311
312    // Parse template for on-demand body generation (no pre-generation in curve mode)
313    let template: Option<Arc<Template>> = template_path
314        .map(|path| Template::parse(&path).map(Arc::new))
315        .transpose()
316        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
317
318    let tracked_fields = resolve_tracked_fields(response_template_path)?;
319    let request_config = build_request_config(host, method, body, tracked_fields, headers);
320
321    let cancellation_token = CancellationToken::new();
322    let cancel = cancellation_token.clone();
323    tokio::spawn(async move {
324        tokio::signal::ctrl_c()
325            .await
326            .expect("failed to listen for ctrl_c");
327        eprintln!("\nShutdown signal received — cancelling curve execution...");
328        cancel.cancel();
329    });
330
331    let started_at = Instant::now();
332
333    let executor = CurveExecutor::new(CurveExecutorParams {
334        curve,
335        request_config: Arc::clone(&request_config),
336        template,
337        cancellation_token,
338        sampling: SamplingParams {
339            vu_threshold: sampling.sample_threshold,
340            reservoir_size: sampling.result_buffer,
341        },
342    });
343
344    let curve_result = executor.execute().await;
345
346    let response_stats =
347        compute_response_stats(&curve_result.results, &request_config.tracked_fields);
348
349    Ok(Some(RunStats {
350        elapsed: started_at.elapsed(),
351        template_duration: None,
352        response_stats,
353        results: curve_result.results,
354        mode: RunMode::Curve,
355        curve_duration: Some(curve_duration),
356        curve_stages: Some(curve_stages),
357        total_requests: curve_result.total_requests,
358        total_failures: curve_result.total_failures,
359        sample_rate: curve_result.sample_rate,
360        min_sample_rate: curve_result.min_sample_rate,
361    }))
362}
363
364// ── Tests ─────────────────────────────────────────────────────────────────────
365
366#[cfg(test)]
367mod tests {
368    use std::time::Duration;
369
370    use super::{RunMode, RunStats};
371    use crate::load_curve::{RampType, Stage};
372
373    fn make_stats_fixed() -> RunStats {
374        RunStats {
375            elapsed: Duration::from_secs(1),
376            template_duration: None,
377            response_stats: None,
378            results: vec![],
379            mode: RunMode::Fixed,
380            curve_duration: None,
381            curve_stages: None,
382            total_requests: 10,
383            total_failures: 0,
384            sample_rate: 1.0,
385            min_sample_rate: 1.0,
386        }
387    }
388
389    fn make_stats_curve(stages: Vec<Stage>) -> RunStats {
390        RunStats {
391            elapsed: Duration::from_secs(10),
392            template_duration: None,
393            response_stats: None,
394            results: vec![],
395            mode: RunMode::Curve,
396            curve_duration: Some(Duration::from_secs(10)),
397            curve_stages: Some(stages),
398            total_requests: 100,
399            total_failures: 2,
400            sample_rate: 1.0,
401            min_sample_rate: 1.0,
402        }
403    }
404
405    // ── curve_stages_none_for_fixed_mode ──────────────────────────────────────
406
407    #[test]
408    fn curve_stages_none_for_fixed_mode() {
409        let stats = make_stats_fixed();
410        assert!(
411            stats.curve_stages.is_none(),
412            "fixed-mode RunStats must have curve_stages == None"
413        );
414    }
415
416    // ── curve_stages_some_for_curve_mode ──────────────────────────────────────
417
418    #[test]
419    fn curve_stages_some_for_curve_mode() {
420        let stages = vec![
421            Stage {
422                duration: Duration::from_secs(5),
423                target_vus: 50,
424                ramp: RampType::Linear,
425            },
426            Stage {
427                duration: Duration::from_secs(5),
428                target_vus: 100,
429                ramp: RampType::Step,
430            },
431        ];
432        let stats = make_stats_curve(stages.clone());
433
434        let stored = stats
435            .curve_stages
436            .expect("curve_stages must be Some in curve mode");
437        assert_eq!(stored.len(), 2);
438        assert_eq!(stored[0].target_vus, 50);
439        assert_eq!(stored[0].ramp, RampType::Linear);
440        assert_eq!(stored[1].target_vus, 100);
441        assert_eq!(stored[1].ramp, RampType::Step);
442    }
443
444    // ── curve_stages_count_matches_original ───────────────────────────────────
445
446    #[test]
447    fn curve_stages_count_matches_original() {
448        let stages: Vec<Stage> = (0..5)
449            .map(|i| Stage {
450                duration: Duration::from_secs(10),
451                target_vus: (i + 1) * 20,
452                ramp: RampType::Linear,
453            })
454            .collect();
455        let count = stages.len();
456        let stats = make_stats_curve(stages);
457        assert_eq!(
458            stats.curve_stages.unwrap().len(),
459            count,
460            "stored stage count must equal original stage count"
461        );
462    }
463
464    // ── run_mode_fixed_variant ────────────────────────────────────────────────
465
466    #[test]
467    fn run_mode_fixed_variant() {
468        let stats = make_stats_fixed();
469        assert_eq!(stats.mode, RunMode::Fixed);
470    }
471
472    // ── run_mode_curve_variant ────────────────────────────────────────────────
473
474    #[test]
475    fn run_mode_curve_variant() {
476        let stages = vec![Stage {
477            duration: Duration::from_secs(5),
478            target_vus: 10,
479            ramp: RampType::Linear,
480        }];
481        let stats = make_stats_curve(stages);
482        assert_eq!(stats.mode, RunMode::Curve);
483    }
484}