lmn_core/load_curve/
executor.rs1use std::sync::Arc;
2use std::time::Instant;
3
4use tokio::sync::mpsc;
5use tokio::task::JoinHandle;
6use tokio_util::sync::CancellationToken;
7use tracing::debug;
8
9use crate::http::{Request, RequestConfig, RequestResult};
10use crate::request_template::Template;
11use crate::sampling::{ReservoirAction, SamplingParams, SamplingState};
12
13use super::LoadCurve;
14
15pub struct CurveExecutorParams {
19 pub curve: LoadCurve,
20 pub request_config: Arc<RequestConfig>,
21 pub template: Option<Arc<Template>>,
22 pub cancellation_token: CancellationToken,
23 pub sampling: SamplingParams,
24}
25
26pub struct CurveExecutionResult {
31 pub results: Vec<RequestResult>,
32 pub total_requests: usize,
33 pub total_failures: usize,
34 pub sample_rate: f64,
35 pub min_sample_rate: f64,
36}
37
38pub struct CurveExecutor {
42 params: CurveExecutorParams,
43}
44
45impl CurveExecutor {
46 pub fn new(params: CurveExecutorParams) -> Self {
47 Self { params }
48 }
49
50 pub async fn execute(self) -> CurveExecutionResult {
55 let CurveExecutorParams {
56 curve,
57 request_config,
58 template,
59 cancellation_token,
60 sampling,
61 } = self.params;
62
63 let total_duration = curve.total_duration();
64 let started_at = Instant::now();
65
66 let plain_headers: Arc<Vec<(String, String)>> = Arc::new(
68 request_config
69 .headers
70 .iter()
71 .map(|(k, v)| (k.clone(), v.to_string()))
72 .collect(),
73 );
74
75 let (tx, mut rx) = mpsc::unbounded_channel::<RequestResult>();
77
78 let mut vu_handles: Vec<(JoinHandle<()>, CancellationToken)> = Vec::new();
80
81 let mut sampling = SamplingState::new(sampling);
82 let mut results: Vec<RequestResult> = Vec::new();
83
84 let mut ticker = tokio::time::interval(tokio::time::Duration::from_millis(100));
85
86 loop {
87 tokio::select! {
88 _ = cancellation_token.cancelled() => {
89 debug!("curve executor: parent cancellation received");
90 break;
91 }
92 _ = ticker.tick() => {
93 let elapsed = started_at.elapsed();
94
95 if elapsed >= total_duration {
96 debug!("curve executor: total duration elapsed, shutting down");
97 break;
98 }
99
100 let target = curve.target_vus_at(elapsed) as usize;
101 let current = vu_handles.len();
102
103 match target.cmp(¤t) {
104 std::cmp::Ordering::Greater => {
105 let to_add = target - current;
107 for _ in 0..to_add {
108 let vu_token = CancellationToken::new();
109 let handle = spawn_vu(VuParams {
110 request_config: Arc::clone(&request_config),
111 plain_headers: Arc::clone(&plain_headers),
112 template: template.as_ref().map(Arc::clone),
113 cancellation_token: vu_token.clone(),
114 result_tx: tx.clone(),
115 });
116 vu_handles.push((handle, vu_token));
117 }
118 }
119 std::cmp::Ordering::Less => {
120 let to_remove = current - target;
122 let drain_start = vu_handles.len() - to_remove;
123 let excess: Vec<_> = vu_handles.drain(drain_start..).collect();
124 for (_, token) in &excess {
126 token.cancel();
127 }
128 for (handle, _) in excess {
130 let _ = handle.await;
131 }
132 }
133 std::cmp::Ordering::Equal => {}
134 }
136
137 sampling.set_active_vus(vu_handles.len());
139
140 while let Ok(result) = rx.try_recv() {
144 sampling.record_request(result.success);
145 if sampling.should_collect() {
146 match sampling.reservoir_slot(results.len()) {
147 ReservoirAction::Push => results.push(result),
148 ReservoirAction::Replace(idx) => results[idx] = result,
149 ReservoirAction::Discard => {}
150 }
151 }
152 }
153 }
154 }
155 }
156
157 for (_, token) in &vu_handles {
159 token.cancel();
160 }
161 for (handle, _) in vu_handles {
162 let _ = handle.await;
163 }
164
165 drop(tx);
168
169 while let Some(result) = rx.recv().await {
172 sampling.record_request(result.success);
173 if sampling.should_collect() {
174 match sampling.reservoir_slot(results.len()) {
175 ReservoirAction::Push => results.push(result),
176 ReservoirAction::Replace(idx) => results[idx] = result,
177 ReservoirAction::Discard => {}
178 }
179 }
180 }
181
182 CurveExecutionResult {
183 results,
184 total_requests: sampling.total_requests(),
185 total_failures: sampling.total_failures(),
186 sample_rate: sampling.sample_rate(),
187 min_sample_rate: sampling.min_sample_rate(),
188 }
189 }
190}
191
192struct VuParams {
195 request_config: Arc<RequestConfig>,
196 plain_headers: Arc<Vec<(String, String)>>,
198 template: Option<Arc<Template>>,
199 cancellation_token: CancellationToken,
200 result_tx: mpsc::UnboundedSender<RequestResult>,
201}
202
203fn spawn_vu(params: VuParams) -> JoinHandle<()> {
204 tokio::spawn(async move {
205 let VuParams {
206 request_config,
207 plain_headers,
208 template,
209 cancellation_token,
210 result_tx,
211 } = params;
212
213 loop {
214 let body = template.as_ref().map(|t| t.generate_one());
216
217 let resolved = request_config.resolve_body(body);
218
219 let client = request_config.client.clone();
220 let url = request_config.host.as_str().to_string();
221 let method = request_config.method;
222 let capture_body = request_config.tracked_fields.is_some();
223
224 let headers = Arc::clone(&plain_headers);
226
227 let result_fut = async {
228 let mut req = Request::new(client, url, method);
229 if let Some((content, content_type)) = resolved {
230 req = req.body(content, content_type);
231 }
232 if capture_body {
233 req = req.read_response();
234 }
235 if !headers.is_empty() {
236 req = req.headers((*headers).clone());
237 }
238 req.execute().await
239 };
240
241 tokio::select! {
242 _ = cancellation_token.cancelled() => {
243 break;
244 }
245 result = result_fut => {
246 if result_tx.send(result).is_err() {
248 break;
249 }
250 }
251 }
252 }
253 })
254}