Skip to main content

lmn_core/load_curve/
executor.rs

1use std::sync::Arc;
2use std::time::Instant;
3
4use tokio::sync::mpsc;
5use tokio::task::JoinHandle;
6use tokio_util::sync::CancellationToken;
7use tracing::debug;
8
9use crate::http::{Request, RequestConfig, RequestResult};
10use crate::request_template::Template;
11use crate::sampling::{ReservoirAction, SamplingParams, SamplingState};
12
13use super::LoadCurve;
14
15// ── CurveExecutorParams ───────────────────────────────────────────────────────
16
17/// Parameters for constructing a `CurveExecutor`.
18pub struct CurveExecutorParams {
19    pub curve: LoadCurve,
20    pub request_config: Arc<RequestConfig>,
21    pub template: Option<Arc<Template>>,
22    pub cancellation_token: CancellationToken,
23    pub sampling: SamplingParams,
24}
25
26// ── CurveExecutionResult ──────────────────────────────────────────────────────
27
28/// Result returned by `CurveExecutor::execute`. Carries the reservoir-bounded
29/// sample of results plus the four sampling counters for `RunStats`.
30pub struct CurveExecutionResult {
31    pub results: Vec<RequestResult>,
32    pub total_requests: usize,
33    pub total_failures: usize,
34    pub sample_rate: f64,
35    pub min_sample_rate: f64,
36}
37
38// ── CurveExecutor ─────────────────────────────────────────────────────────────
39
40/// Executes a load test driven by a `LoadCurve`, dynamically scaling VUs.
41pub struct CurveExecutor {
42    params: CurveExecutorParams,
43}
44
45impl CurveExecutor {
46    pub fn new(params: CurveExecutorParams) -> Self {
47        Self { params }
48    }
49
50    /// Runs the load curve, spawning and cancelling VU tasks as the curve
51    /// dictates. Applies VU-threshold + reservoir sampling to bound memory
52    /// usage. Returns a `CurveExecutionResult` when the curve completes or a
53    /// cancellation signal is received.
54    pub async fn execute(self) -> CurveExecutionResult {
55        let CurveExecutorParams {
56            curve,
57            request_config,
58            template,
59            cancellation_token,
60            sampling,
61        } = self.params;
62
63        let total_duration = curve.total_duration();
64        let started_at = Instant::now();
65
66        // Pre-convert headers once before spawning any VUs to avoid per-VU allocation.
67        let plain_headers: Arc<Vec<(String, String)>> = Arc::new(
68            request_config
69                .headers
70                .iter()
71                .map(|(k, v)| (k.clone(), v.to_string()))
72                .collect(),
73        );
74
75        // Unbounded channel; VUs push results as they complete without risk of blocking.
76        let (tx, mut rx) = mpsc::unbounded_channel::<RequestResult>();
77
78        // Track active VU handles and their per-VU cancellation tokens.
79        let mut vu_handles: Vec<(JoinHandle<()>, CancellationToken)> = Vec::new();
80
81        let mut sampling = SamplingState::new(sampling);
82        let mut results: Vec<RequestResult> = Vec::new();
83
84        let mut ticker = tokio::time::interval(tokio::time::Duration::from_millis(100));
85
86        loop {
87            tokio::select! {
88                _ = cancellation_token.cancelled() => {
89                    debug!("curve executor: parent cancellation received");
90                    break;
91                }
92                _ = ticker.tick() => {
93                    let elapsed = started_at.elapsed();
94
95                    if elapsed >= total_duration {
96                        debug!("curve executor: total duration elapsed, shutting down");
97                        break;
98                    }
99
100                    let target = curve.target_vus_at(elapsed) as usize;
101                    let current = vu_handles.len();
102
103                    match target.cmp(&current) {
104                        std::cmp::Ordering::Greater => {
105                            // Spawn additional VUs
106                            let to_add = target - current;
107                            for _ in 0..to_add {
108                                let vu_token = CancellationToken::new();
109                                let handle = spawn_vu(VuParams {
110                                    request_config: Arc::clone(&request_config),
111                                    plain_headers: Arc::clone(&plain_headers),
112                                    template: template.as_ref().map(Arc::clone),
113                                    cancellation_token: vu_token.clone(),
114                                    result_tx: tx.clone(),
115                                });
116                                vu_handles.push((handle, vu_token));
117                            }
118                        }
119                        std::cmp::Ordering::Less => {
120                            // Cancel excess VUs (cancel from the end of the list)
121                            let to_remove = current - target;
122                            let drain_start = vu_handles.len() - to_remove;
123                            let excess: Vec<_> = vu_handles.drain(drain_start..).collect();
124                            // Cancel all tokens first so all VUs begin exiting simultaneously
125                            for (_, token) in &excess {
126                                token.cancel();
127                            }
128                            // Await sequentially — VUs are already exiting in parallel on the runtime
129                            for (handle, _) in excess {
130                                let _ = handle.await;
131                            }
132                        }
133                        std::cmp::Ordering::Equal => {}
134                        // If target == current: nothing to do
135                    }
136
137                    // Update sampling rate based on the current active VU count.
138                    sampling.set_active_vus(vu_handles.len());
139
140                    // Drain all results currently in the channel without blocking.
141                    // This prevents channel backpressure from inflating latency at
142                    // high throughput — a correctness fix independent of sampling.
143                    while let Ok(result) = rx.try_recv() {
144                        sampling.record_request(result.success);
145                        if sampling.should_collect() {
146                            match sampling.reservoir_slot(results.len()) {
147                                ReservoirAction::Push => results.push(result),
148                                ReservoirAction::Replace(idx) => results[idx] = result,
149                                ReservoirAction::Discard => {}
150                            }
151                        }
152                    }
153                }
154            }
155        }
156
157        // Cancel all remaining VU tasks — cancel all tokens first, then await
158        for (_, token) in &vu_handles {
159            token.cancel();
160        }
161        for (handle, _) in vu_handles {
162            let _ = handle.await;
163        }
164
165        // Drop the coordinator's sender so the channel closes once all VU
166        // senders (clones) are also dropped (they are, since tasks ended).
167        drop(tx);
168
169        // Final drain: collect any results that arrived between the last tick
170        // and the VU tasks completing.
171        while let Some(result) = rx.recv().await {
172            sampling.record_request(result.success);
173            if sampling.should_collect() {
174                match sampling.reservoir_slot(results.len()) {
175                    ReservoirAction::Push => results.push(result),
176                    ReservoirAction::Replace(idx) => results[idx] = result,
177                    ReservoirAction::Discard => {}
178                }
179            }
180        }
181
182        CurveExecutionResult {
183            results,
184            total_requests: sampling.total_requests(),
185            total_failures: sampling.total_failures(),
186            sample_rate: sampling.sample_rate(),
187            min_sample_rate: sampling.min_sample_rate(),
188        }
189    }
190}
191
192// ── VU task ───────────────────────────────────────────────────────────────────
193
194struct VuParams {
195    request_config: Arc<RequestConfig>,
196    /// Pre-converted header pairs shared across all VUs — avoids per-request allocation.
197    plain_headers: Arc<Vec<(String, String)>>,
198    template: Option<Arc<Template>>,
199    cancellation_token: CancellationToken,
200    result_tx: mpsc::UnboundedSender<RequestResult>,
201}
202
203fn spawn_vu(params: VuParams) -> JoinHandle<()> {
204    tokio::spawn(async move {
205        let VuParams {
206            request_config,
207            plain_headers,
208            template,
209            cancellation_token,
210            result_tx,
211        } = params;
212
213        loop {
214            // Generate body on demand for this request
215            let body = template.as_ref().map(|t| t.generate_one());
216
217            let resolved = request_config.resolve_body(body);
218
219            let client = request_config.client.clone();
220            let url = request_config.host.as_str().to_string();
221            let method = request_config.method;
222            let capture_body = request_config.tracked_fields.is_some();
223
224            // Clone the Arc cheaply; dereference to get a Vec clone only when needed.
225            let headers = Arc::clone(&plain_headers);
226
227            let result_fut = async {
228                let mut req = Request::new(client, url, method);
229                if let Some((content, content_type)) = resolved {
230                    req = req.body(content, content_type);
231                }
232                if capture_body {
233                    req = req.read_response();
234                }
235                if !headers.is_empty() {
236                    req = req.headers((*headers).clone());
237                }
238                req.execute().await
239            };
240
241            tokio::select! {
242                _ = cancellation_token.cancelled() => {
243                    break;
244                }
245                result = result_fut => {
246                    // Best-effort send — if receiver is gone, we stop
247                    if result_tx.send(result).is_err() {
248                        break;
249                    }
250                }
251            }
252        }
253    })
254}