use super::{BoxedFut, ScenarioConfig};
use crate::transaction::{TransactionData, TRANSACTION_HOOK};
#[cfg(feature = "rt")]
use crate::{runtime::BALTER_OUT, scenario::ScenarioKind};
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use humantime::format_duration;
use std::{
future::Future,
num::NonZeroU32,
sync::{atomic::Ordering, Arc},
time::{Duration, Instant},
};
#[cfg(feature = "rt")]
use tokio::runtime::Handle;
use tokio::task::JoinSet;
#[allow(unused_imports)]
use tracing::{debug, error, info, instrument, trace, Instrument};
#[instrument(name="goal_tps", skip_all, fields(name=config.name, goal_tps=goal_tps))]
pub(crate) async fn run_goal_tps(
scenario: fn() -> BoxedFut,
config: ScenarioConfig,
goal_tps: u32,
) {
info!(
"Running {} at {goal_tps}tps for {}",
config.name,
format_duration(config.duration)
);
let start = Instant::now();
let mut timer = Instant::now();
let mut task_learner = GoalTpsTaskLearner::new(goal_tps, &config, &start);
let mut transaction_data = TransactionData {
limiter: Some(task_learner.limiter.clone()),
success: Arc::new(0.into()),
error: Arc::new(0.into()),
};
let mut jobs = JoinSet::new();
jobs.spawn(spawn_scenario(scenario, transaction_data.clone()));
#[allow(clippy::redundant_pattern_matching)]
while let Some(_) = jobs.join_next().await {
let elapsed = timer.elapsed();
if elapsed > Duration::from_millis(1000) {
handle_statistics(&transaction_data, &mut task_learner, elapsed);
transaction_data.limiter = Some(task_learner.limiter.clone());
timer = Instant::now();
}
while jobs.len() < task_learner.task_count() && start.elapsed() < config.duration {
jobs.spawn(spawn_scenario(scenario, transaction_data.clone()));
}
}
debug!("Scenario complete.");
}
fn handle_statistics(
transaction_data: &TransactionData,
task_learner: &mut GoalTpsTaskLearner,
elapsed: Duration,
) {
let success_count = transaction_data.success.fetch_min(0, Ordering::Relaxed);
let error_count = transaction_data.error.fetch_min(0, Ordering::Relaxed);
let total_count = success_count + error_count;
let actual_tps = (success_count + error_count) as f64 / elapsed.as_millis() as f64 * 1000.;
task_learner.push_statistic(actual_tps, total_count);
}
fn spawn_scenario(
scenario: fn() -> BoxedFut,
transaction_data: TransactionData,
) -> impl Future<Output = ()> + Send {
TRANSACTION_HOOK
.scope(transaction_data, async move { scenario().await })
.in_current_span()
}
#[allow(dead_code)]
struct GoalTpsTaskLearner<'a> {
samples: Vec<f64>,
measurements: u64,
task_count: usize,
previous: Vec<f64>,
complete: bool,
limiter: Arc<DefaultDirectRateLimiter>,
goal_tps: f64,
config: &'a ScenarioConfig,
start: &'a Instant,
}
impl<'a> GoalTpsTaskLearner<'a> {
fn new(goal_tps: u32, config: &'a ScenarioConfig, start: &'a Instant) -> Self {
Self {
samples: vec![],
measurements: 0,
task_count: 1,
previous: vec![],
complete: false,
limiter: Arc::new(RateLimiter::direct(Quota::per_second(
NonZeroU32::new(goal_tps).unwrap(),
))),
goal_tps: goal_tps as f64,
config,
start,
}
}
fn push_statistic(&mut self, measured_tps: f64, measurements: u64) {
if self.complete {
return;
}
self.samples.push(measured_tps);
self.measurements += measurements;
trace!(
"Push statistic: sample count={}, measurements={}",
self.samples.len(),
self.measurements
);
if self.measurements > 10 || self.samples.len() > 5 {
let mean_tps = mean(&self.samples);
let error = ((self.goal_tps - mean_tps) / self.goal_tps).max(0.0);
if error > 0.05 {
let better_counts = self
.previous
.iter()
.enumerate()
.map(|(idx, x)| (idx + 1, x))
.filter(|(_, x)| **x > mean_tps)
.collect::<Vec<_>>();
if !better_counts.is_empty() {
let best_count = better_counts
.iter()
.max_by_key(|(_, x)| **x as u64)
.unwrap();
info!("Goal TPS exceeds capability of this server. Found best task count: {best_count:?}.");
#[cfg(feature = "rt")]
{
let mut new_config = self.config.clone();
new_config.duration = self.config.duration - self.start.elapsed();
match &mut new_config.kind {
ScenarioKind::Tps(ref mut goal_tps) => {
*goal_tps = (self.goal_tps - mean_tps).floor() as u32;
}
_ => unreachable!(),
}
let handle = Handle::current();
handle.spawn(async move {
let (ref tx, _) = *BALTER_OUT;
let _ = tx.send(new_config).await;
});
}
#[cfg(not(feature = "rt"))]
{
error!("No distributed runtime available to scale out.");
}
self.goal_tps = mean_tps;
self.limiter = Arc::new(RateLimiter::direct(Quota::per_second(
NonZeroU32::new(self.goal_tps.floor() as u32).unwrap(),
)));
self.task_count = best_count.0;
self.complete = true;
} else {
self.previous.push(mean_tps);
self.task_count += 1;
debug!(
"Measured {mean_tps}, increasing task count to {}",
self.task_count
);
}
} else {
debug!("Found task count: {}", self.task_count);
self.complete = true;
}
self.samples = vec![];
self.measurements = 0;
}
}
fn task_count(&self) -> usize {
self.task_count
}
}
fn mean(samples: &[f64]) -> f64 {
let sum: f64 = samples.iter().sum();
sum / samples.len() as f64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scaling_up() {
let config = ScenarioConfig {
name: "some_task".to_string(),
duration: Duration::from_secs(30),
kind: ScenarioKind::Tps(100),
};
let start = Instant::now();
let mut task_learner = GoalTpsTaskLearner::new(100, &config, &start);
assert_eq!(task_learner.task_count(), 1);
task_learner.push_statistic(50., 700);
assert_eq!(task_learner.task_count(), 2);
}
#[tokio::test]
async fn test_dont_overload() {
let config = ScenarioConfig {
name: "some_task".to_string(),
duration: Duration::from_secs(30),
kind: ScenarioKind::Tps(100),
};
let start = Instant::now();
let mut task_learner = GoalTpsTaskLearner::new(100, &config, &start);
assert_eq!(task_learner.task_count(), 1);
task_learner.push_statistic(50., 700);
assert_eq!(task_learner.task_count(), 2);
task_learner.push_statistic(75., 700);
assert_eq!(task_learner.task_count(), 3);
task_learner.push_statistic(80., 700);
assert_eq!(task_learner.task_count(), 4);
task_learner.push_statistic(85., 700);
assert_eq!(task_learner.task_count(), 5);
task_learner.push_statistic(73., 700);
assert_eq!(task_learner.task_count(), 4);
let (_, ref rx) = *BALTER_OUT;
assert!(rx.recv().await.is_ok());
}
}