use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use crate::http::{Request, RequestConfig, RequestResult};
use crate::request_template::Template;
use crate::sampling::{ReservoirAction, SamplingParams, SamplingState};
use super::LoadCurve;
pub struct CurveExecutorParams {
pub curve: LoadCurve,
pub request_config: Arc<RequestConfig>,
pub template: Option<Arc<Template>>,
pub cancellation_token: CancellationToken,
pub sampling: SamplingParams,
}
pub struct CurveExecutionResult {
pub results: Vec<RequestResult>,
pub total_requests: usize,
pub total_failures: usize,
pub sample_rate: f64,
pub min_sample_rate: f64,
}
pub struct CurveExecutor {
params: CurveExecutorParams,
}
impl CurveExecutor {
pub fn new(params: CurveExecutorParams) -> Self {
Self { params }
}
pub async fn execute(self) -> CurveExecutionResult {
let CurveExecutorParams {
curve,
request_config,
template,
cancellation_token,
sampling,
} = self.params;
let total_duration = curve.total_duration();
let started_at = Instant::now();
let plain_headers: Arc<Vec<(String, String)>> = Arc::new(
request_config
.headers
.iter()
.map(|(k, v)| (k.clone(), v.to_string()))
.collect(),
);
let (tx, mut rx) = mpsc::unbounded_channel::<RequestResult>();
let mut vu_handles: Vec<(JoinHandle<()>, CancellationToken)> = Vec::new();
let mut sampling = SamplingState::new(sampling);
let mut results: Vec<RequestResult> = Vec::new();
let mut ticker = tokio::time::interval(tokio::time::Duration::from_millis(100));
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
debug!("curve executor: parent cancellation received");
break;
}
_ = ticker.tick() => {
let elapsed = started_at.elapsed();
if elapsed >= total_duration {
debug!("curve executor: total duration elapsed, shutting down");
break;
}
let target = curve.target_vus_at(elapsed) as usize;
let current = vu_handles.len();
match target.cmp(¤t) {
std::cmp::Ordering::Greater => {
let to_add = target - current;
for _ in 0..to_add {
let vu_token = CancellationToken::new();
let handle = spawn_vu(VuParams {
request_config: Arc::clone(&request_config),
plain_headers: Arc::clone(&plain_headers),
template: template.as_ref().map(Arc::clone),
cancellation_token: vu_token.clone(),
result_tx: tx.clone(),
});
vu_handles.push((handle, vu_token));
}
}
std::cmp::Ordering::Less => {
let to_remove = current - target;
let drain_start = vu_handles.len() - to_remove;
let excess: Vec<_> = vu_handles.drain(drain_start..).collect();
for (_, token) in &excess {
token.cancel();
}
for (handle, _) in excess {
let _ = handle.await;
}
}
std::cmp::Ordering::Equal => {}
}
sampling.set_active_vus(vu_handles.len());
while let Ok(result) = rx.try_recv() {
sampling.record_request(result.success);
if sampling.should_collect() {
match sampling.reservoir_slot(results.len()) {
ReservoirAction::Push => results.push(result),
ReservoirAction::Replace(idx) => results[idx] = result,
ReservoirAction::Discard => {}
}
}
}
}
}
}
for (_, token) in &vu_handles {
token.cancel();
}
for (handle, _) in vu_handles {
let _ = handle.await;
}
drop(tx);
while let Some(result) = rx.recv().await {
sampling.record_request(result.success);
if sampling.should_collect() {
match sampling.reservoir_slot(results.len()) {
ReservoirAction::Push => results.push(result),
ReservoirAction::Replace(idx) => results[idx] = result,
ReservoirAction::Discard => {}
}
}
}
CurveExecutionResult {
results,
total_requests: sampling.total_requests(),
total_failures: sampling.total_failures(),
sample_rate: sampling.sample_rate(),
min_sample_rate: sampling.min_sample_rate(),
}
}
}
struct VuParams {
request_config: Arc<RequestConfig>,
plain_headers: Arc<Vec<(String, String)>>,
template: Option<Arc<Template>>,
cancellation_token: CancellationToken,
result_tx: mpsc::UnboundedSender<RequestResult>,
}
fn spawn_vu(params: VuParams) -> JoinHandle<()> {
tokio::spawn(async move {
let VuParams {
request_config,
plain_headers,
template,
cancellation_token,
result_tx,
} = params;
loop {
let body = template.as_ref().map(|t| t.generate_one());
let resolved = request_config.resolve_body(body);
let client = request_config.client.clone();
let url = request_config.host.as_str().to_string();
let method = request_config.method;
let capture_body = request_config.tracked_fields.is_some();
let headers = Arc::clone(&plain_headers);
let result_fut = async {
let mut req = Request::new(client, url, method);
if let Some((content, content_type)) = resolved {
req = req.body(content, content_type);
}
if capture_body {
req = req.read_response();
}
if !headers.is_empty() {
req = req.headers((*headers).clone());
}
req.execute().await
};
tokio::select! {
_ = cancellation_token.cancelled() => {
break;
}
result = result_fut => {
if result_tx.send(result).is_err() {
break;
}
}
}
}
})
}