Skip to main content

lmn_core/execution/curve/
mod.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3
4use tokio::sync::mpsc;
5use tokio::task::JoinHandle;
6use tokio_util::sync::CancellationToken;
7use tracing::debug;
8
9use crate::execution::StageStats;
10use crate::histogram::{LatencyHistogram, StatusCodeHistogram};
11use crate::http::{RequestConfig, RequestRecord};
12use crate::load_curve::{LoadCurve, Stage};
13use crate::request_template::Template;
14use crate::response_template::stats::ResponseStats;
15use crate::vu::Vu;
16
17// ── CurveExecutorParams ───────────────────────────────────────────────────────
18
19/// Parameters for constructing a `CurveExecutor`.
20pub struct CurveExecutorParams {
21    pub curve: LoadCurve,
22    pub request_config: Arc<RequestConfig>,
23    pub template: Option<Arc<Template>>,
24    pub cancellation_token: CancellationToken,
25}
26
27// ── CurveExecutionResult ──────────────────────────────────────────────────────
28
29/// Result returned by `CurveExecutor::execute`.
30pub struct CurveExecutionResult {
31    pub latency: LatencyHistogram,
32    pub status_codes: StatusCodeHistogram,
33    pub total_requests: u64,
34    pub total_failures: u64,
35    pub response_stats: Option<ResponseStats>,
36    pub stage_stats: Vec<StageStats>,
37}
38
39// ── stage_index_at ────────────────────────────────────────────────────────────
40
41/// Returns the 0-based stage index for a given elapsed duration.
42fn stage_index_at(stages: &[Stage], elapsed: Duration) -> usize {
43    let mut offset = Duration::ZERO;
44    for (i, stage) in stages.iter().enumerate() {
45        offset += stage.duration;
46        if elapsed < offset {
47            return i;
48        }
49    }
50    stages.len().saturating_sub(1)
51}
52
53// ── CurveExecutor ─────────────────────────────────────────────────────────────
54
55/// Executes a load test driven by a `LoadCurve`, dynamically scaling VUs.
56pub struct CurveExecutor {
57    params: CurveExecutorParams,
58}
59
60impl CurveExecutor {
61    pub fn new(params: CurveExecutorParams) -> Self {
62        Self { params }
63    }
64
65    /// Runs the load curve, spawning and cancelling VU tasks as the curve
66    /// dictates. Returns a `CurveExecutionResult` when the curve completes or a
67    /// cancellation signal is received.
68    pub async fn execute(self) -> Result<CurveExecutionResult, crate::execution::RunError> {
69        let CurveExecutorParams {
70            curve,
71            request_config,
72            template,
73            cancellation_token,
74        } = self.params;
75
76        let total_duration = curve.total_duration();
77        let run_start = Instant::now();
78
79        // Pre-convert headers once before spawning any VUs to avoid per-VU allocation.
80        let plain_headers: Arc<Vec<(String, String)>> = Arc::new(
81            request_config
82                .headers
83                .iter()
84                .map(|(k, v)| (k.clone(), v.to_string()))
85                .collect(),
86        );
87
88        let has_tracked_fields = request_config.tracked_fields.is_some();
89        let n_stages = curve.stages.len();
90
91        // Clone the stages vec so the drain task can own it without holding onto `curve`.
92        let drain_stages = curve.stages.clone();
93
94        // Unbounded channel; VUs push results as they complete without risk of blocking.
95        let (tx, rx) = mpsc::unbounded_channel::<RequestRecord>();
96
97        // Spawn a dedicated drain task that owns the receiver and all accumulator
98        // state. It attributes each record to the correct stage via `completed_at`.
99        let drain_handle = tokio::spawn(async move {
100            let mut rx = rx;
101            let mut latency = LatencyHistogram::new();
102            let mut status_codes = StatusCodeHistogram::new();
103            let mut total_requests: u64 = 0;
104            let mut total_failures: u64 = 0;
105            let mut response_stats: Option<ResponseStats> = if has_tracked_fields {
106                Some(ResponseStats::new())
107            } else {
108                None
109            };
110
111            // Pre-allocate per-stage accumulators.
112            let mut stage_stats: Vec<StageStats> = (0..n_stages)
113                .map(|_| StageStats {
114                    latency: LatencyHistogram::new(),
115                    status_codes: StatusCodeHistogram::new(),
116                    total_requests: 0,
117                    total_failures: 0,
118                })
119                .collect();
120
121            while let Some(record) = rx.recv().await {
122                total_requests += 1;
123                if !record.success {
124                    total_failures += 1;
125                }
126                latency.record(record.duration);
127                status_codes.record(record.status_code);
128
129                // Determine which stage this record belongs to using its
130                // wall-clock completion time relative to the run start.
131                let elapsed = record
132                    .completed_at
133                    .checked_duration_since(run_start)
134                    .unwrap_or_default();
135                let stage_idx = stage_index_at(&drain_stages, elapsed);
136
137                stage_stats[stage_idx].latency.record(record.duration);
138                stage_stats[stage_idx]
139                    .status_codes
140                    .record(record.status_code);
141                stage_stats[stage_idx].total_requests += 1;
142                if !record.success {
143                    stage_stats[stage_idx].total_failures += 1;
144                }
145
146                if let Some(extraction) = record.extraction
147                    && let Some(ref mut rs) = response_stats
148                {
149                    rs.record(extraction);
150                }
151            }
152
153            CurveExecutionResult {
154                latency,
155                status_codes,
156                total_requests,
157                total_failures,
158                response_stats,
159                stage_stats,
160            }
161        });
162
163        // Track active VU handles and their per-VU cancellation tokens.
164        let mut vu_handles: Vec<(JoinHandle<()>, CancellationToken)> = Vec::new();
165
166        let mut ticker = tokio::time::interval(tokio::time::Duration::from_millis(100));
167
168        loop {
169            tokio::select! {
170                _ = cancellation_token.cancelled() => {
171                    debug!("curve executor: parent cancellation received");
172                    break;
173                }
174                _ = ticker.tick() => {
175                    let elapsed = run_start.elapsed();
176
177                    if elapsed >= total_duration {
178                        debug!("curve executor: total duration elapsed, shutting down");
179                        break;
180                    }
181
182                    let target = curve.target_vus_at(elapsed) as usize;
183                    let current = vu_handles.len();
184
185                    match target.cmp(&current) {
186                        std::cmp::Ordering::Greater => {
187                            // Spawn additional VUs
188                            let to_add = target - current;
189                            for _ in 0..to_add {
190                                let vu_token = CancellationToken::new();
191                                let handle = Vu {
192                                    request_config: Arc::clone(&request_config),
193                                    plain_headers: Arc::clone(&plain_headers),
194                                    template: template.as_ref().map(Arc::clone),
195                                    cancellation_token: vu_token.clone(),
196                                    result_tx: tx.clone(),
197                                    budget: None,
198                                }
199                                .spawn();
200                                vu_handles.push((handle, vu_token));
201                            }
202                        }
203                        std::cmp::Ordering::Less => {
204                            // Cancel excess VUs (cancel from the end of the list)
205                            let to_remove = current - target;
206                            let drain_start = vu_handles.len() - to_remove;
207                            let excess: Vec<_> = vu_handles.drain(drain_start..).collect();
208                            // Cancel all tokens first so all VUs begin exiting simultaneously
209                            for (_, token) in &excess {
210                                token.cancel();
211                            }
212                            // Await sequentially — VUs are already exiting in parallel on the runtime
213                            for (handle, _) in excess {
214                                let _ = handle.await;
215                            }
216                        }
217                        std::cmp::Ordering::Equal => {}
218                    }
219                }
220            }
221        }
222
223        // Cancel all remaining VU tasks — cancel all tokens first, then await.
224        for (_, token) in &vu_handles {
225            token.cancel();
226        }
227        for (handle, _) in vu_handles {
228            let _ = handle.await;
229        }
230
231        // Drop the coordinator's sender so the channel closes once all VU
232        // senders (clones) are also dropped (they are, since tasks ended).
233        drop(tx);
234
235        // Await the drain task to get the fully accumulated result.
236        Ok(drain_handle.await?)
237    }
238}