use std::cmp::min;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use colored::Colorize;
use indicatif::ProgressBar;
use num_cpus;
use reqwest::Client;
use tokio::{
self, runtime,
sync::{self as tsync, mpsc},
};
use crate::Arg;
use crate::client::build_client;
use crate::dispatcher::DurationDispatcher;
use crate::dispatcher::{CountDispatcher, Dispatcher};
use crate::limiter::Limiter;
use crate::output::{Output, sync_text_output};
use crate::request::build_request;
use crate::statistics::{Message, Statistics};
pub struct Task {
arg: Arg,
client: Client,
statistics: Statistics,
is_canceled: AtomicBool,
progress_bar: Option<ProgressBar>,
is_workers_done: AtomicBool,
dispatcher: Arc<tsync::RwLock<Box<dyn Dispatcher<Limiter = Limiter>>>>,
}
fn create_count_dispatcher(
total: u64,
rate: &Option<u16>,
) -> Box<dyn Dispatcher<Limiter = Limiter>> {
let count_dispatcher = CountDispatcher::new(total, rate);
Box::new(count_dispatcher)
}
fn create_duration_dispatcher(
duration: Duration,
rate: &Option<u16>,
) -> Box<dyn Dispatcher<Limiter = Limiter>> {
let duration_dispatcher = DurationDispatcher::new(duration, rate);
Box::new(duration_dispatcher)
}
fn create_dispatcher(
arg: &Arg,
) -> Arc<tsync::RwLock<Box<dyn Dispatcher<Limiter = Limiter>>>> {
if let Some(requests) = arg.requests {
Arc::new(tsync::RwLock::new(create_count_dispatcher(
requests, &arg.rate,
)))
} else {
Arc::new(tsync::RwLock::new(create_duration_dispatcher(
arg.duration.unwrap(),
&arg.rate,
)))
}
}
impl Task {
pub fn new(
arg: Arg,
progress_bar: Option<ProgressBar>,
) -> anyhow::Result<Self> {
let client = build_client(&arg)?;
let dispatcher = create_dispatcher(&arg);
Ok(Self {
arg,
client,
dispatcher,
progress_bar,
statistics: Statistics::new(),
is_canceled: AtomicBool::new(false),
is_workers_done: AtomicBool::new(false),
})
}
async fn update_progress_bar(self: Arc<Self>) {
if self.progress_bar.is_none() {
return;
}
if self.arg.requests.is_some() {
self.update_count_progress_bar().await;
} else if self.arg.duration.is_some() {
self.update_duration_progress_bar().await;
}
}
async fn update_count_progress_bar(self: Arc<Self>) {
let total = self.arg.requests.unwrap();
loop {
self.progress_bar
.clone()
.unwrap()
.set_position(min(self.statistics.get_total(), total));
if self.is_workers_done.load(Ordering::Acquire) {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
async fn update_duration_progress_bar(self: Arc<Self>) {
let total = self.arg.duration.unwrap().as_secs();
let mut current = 0;
loop {
current += 1;
self.progress_bar
.clone()
.unwrap()
.set_position(min(current, total));
if self.is_workers_done.load(Ordering::Acquire) {
break;
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
fn finish_progress_bar(self: Arc<Self>) {
if let Some(progress_bar) = &self.progress_bar
&& !progress_bar.is_finished()
{
if self.is_canceled.load(Ordering::Acquire) {
progress_bar.abandon_with_message(
"(canceled!!!)".to_uppercase().red().bold().to_string(),
);
} else {
progress_bar.finish();
}
}
}
async fn worker(
self: Arc<Self>,
sender: mpsc::Sender<Message>,
) -> anyhow::Result<()> {
loop {
if !self.dispatcher.read().await.try_apply_job().await {
break;
}
let request = build_request(&self.arg, &self.client).await?;
let req_at = Instant::now();
let response = self.client.execute(request).await;
self.dispatcher.read().await.complete_job();
let message = Message::new(response, req_at, Instant::now());
sender.send(message).await?;
}
Ok(())
}
pub fn text_output(self: Arc<Self>) -> anyhow::Result<String> {
sync_text_output(&self.statistics, &self.arg)
}
pub fn json_output(self: Arc<Self>) -> anyhow::Result<Output> {
Output::sync_from_statistics(&self.statistics)
}
async fn rcv_worker_message(
self: Arc<Self>,
mut receiver: mpsc::Receiver<Message>,
) {
loop {
let result = receiver.try_recv();
if let Ok(message) = result {
self.statistics.handle_message(message).await;
continue;
}
if self.is_workers_done.load(Ordering::Acquire) {
break;
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
async fn handle_ctrl_c_signal(self: Arc<Self>) -> anyhow::Result<()> {
loop {
tokio::signal::ctrl_c().await?;
self.dispatcher.write().await.cancel();
self.is_canceled.store(true, Ordering::SeqCst);
}
}
pub fn run(self: Arc<Self>) -> anyhow::Result<Arc<Self>> {
let rt = runtime::Builder::new_multi_thread()
.worker_threads(num_cpus::get())
.thread_name("rsb-tokio-runtime-worker")
.enable_all()
.build()?;
rt.block_on(async {
let (tx, rx) = mpsc::channel::<Message>(500);
let mut jobs = Vec::with_capacity(self.arg.connections as usize);
let task = self.clone();
#[allow(clippy::redundant_async_block)]
tokio::spawn(
async move { task.statistics.reset_start_time().await },
)
.await?;
tokio::spawn(self.clone().handle_ctrl_c_signal());
let update_pb_job =
tokio::spawn(self.clone().update_progress_bar());
let task = self.clone();
let stat_timer = tokio::spawn(async move {
task.statistics.timer_per_second().await;
});
for _ in 0..self.arg.connections {
jobs.push(tokio::spawn(self.clone().worker(tx.clone())));
}
let statistics_job =
tokio::spawn(self.clone().rcv_worker_message(rx));
for worker in jobs {
worker.await??;
}
self.is_workers_done.store(true, Ordering::SeqCst);
let task = self.clone();
#[allow(clippy::redundant_async_block)]
tokio::spawn(async move { task.statistics.stop_timer().await })
.await?;
statistics_job.await?;
update_pb_job.await?;
stat_timer.await?;
self.clone().finish_progress_bar();
let task = self.clone();
tokio::spawn(async move {
task.statistics
.summary(task.arg.connections, task.arg.percentiles.clone())
.await;
})
.await?;
Ok::<(), anyhow::Error>(())
})?;
Ok(self)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::arg::{Method, OutputFormat};
#[test]
fn test_create_count_dispatcher() {
let dispatcher = create_count_dispatcher(100, &Some(10));
let _ = dispatcher;
}
#[test]
fn test_create_duration_dispatcher() {
let duration = Duration::from_secs(60);
let _dispatcher = create_duration_dispatcher(duration, &Some(10));
}
#[test]
fn test_create_dispatcher_with_requests() {
let arg = Arg {
url: Some("http://example.com".to_string()),
requests: Some(100),
duration: None,
connections: 1,
timeout: Duration::from_secs(30),
latencies: false,
percentiles: vec![],
method: Method::Get,
disable_keep_alive: false,
headers: vec![],
rate: None,
cert: None,
key: None,
insecure: false,
text_file: None,
text_body: None,
json_file: None,
json_body: None,
json_command: None,
form: vec![],
mp: vec![],
mp_file: vec![],
output_format: OutputFormat::Text,
completions: None,
};
let dispatcher = create_dispatcher(&arg);
let _ = dispatcher;
}
#[test]
fn test_create_dispatcher_with_duration() {
let arg = Arg {
url: Some("http://example.com".to_string()),
requests: None,
duration: Some(Duration::from_secs(60)),
connections: 1,
timeout: Duration::from_secs(30),
latencies: false,
percentiles: vec![],
method: Method::Get,
disable_keep_alive: false,
headers: vec![],
rate: None,
cert: None,
key: None,
insecure: false,
text_file: None,
text_body: None,
json_file: None,
json_body: None,
json_command: None,
form: vec![],
mp: vec![],
mp_file: vec![],
output_format: OutputFormat::Text,
completions: None,
};
let dispatcher = create_dispatcher(&arg);
let _ = dispatcher;
}
}