use std::collections::BTreeMap;
use std::env;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::Error;
use futures::future::join_all;
use handlebars::Handlebars;
use histogram::Histogram;
use reqwest::header::{HeaderMap, HeaderValue, USER_AGENT};
use reqwest::Client;
use serde_json::{json, Value};
use tokio::sync::{mpsc, oneshot, Mutex};
use tokio::task::JoinHandle;
use url::Url;
use crate::core::check_endpoints_names::check_endpoints_names;
use crate::core::concurrency_controller::ConcurrencyController;
use crate::core::fixed_size_queue;
use crate::core::sleep_guard::SleepGuard;
use crate::core::{listening_assert, setup, share_result, start_task};
use crate::models::api_endpoint::ApiEndpoint;
use crate::models::assert_error_stats::AssertErrorStats;
use crate::models::http_error_stats::HttpErrorStats;
use crate::models::result::{ApiResult, BatchResult};
use crate::models::setup::SetupApiEndpoint;
use crate::models::step_option::{InnerStepOption, StepOption};
pub async fn batch(
result_sender: mpsc::Sender<Option<BatchResult>>,
test_duration_secs: u64,
concurrent_requests: usize,
timeout_secs: u64,
cookie_store_enable: bool,
verbose: bool,
should_prevent: bool,
api_endpoints: Vec<ApiEndpoint>,
step_option: Option<StepOption>,
setup_options: Option<Vec<SetupApiEndpoint>>,
mut assert_channel_buffer_size: usize,
ema_alpha: f64,
) -> anyhow::Result<BatchResult> {
let _guard = SleepGuard::new(should_prevent);
if let Some(step_option) = step_option.clone() {
let total_steps = test_duration_secs / step_option.increase_interval;
let total_concurrency_increase =
step_option.increase_step as u64 * total_steps * (total_steps + 1) / 2;
if total_concurrency_increase < concurrent_requests as u64 {
return Err(Error::msg(
"阶梯加压总并发数在设置的时间内无法增加到预设的结束并发数",
));
}
};
if let Err(e) = check_endpoints_names(api_endpoints.clone()) {
return Err(Error::msg(e));
}
let histogram = match Histogram::new(14, 20) {
Ok(h) => Arc::new(Mutex::new(h)),
Err(e) => {
return Err(Error::msg(format!("获取存储桶失败::{:?}", e.to_string())));
}
};
let successful_requests = Arc::new(AtomicUsize::new(0));
let total_requests = Arc::new(AtomicUsize::new(0));
let max_response_time = Arc::new(Mutex::new(0u64));
let min_response_time = Arc::new(Mutex::new(u64::MAX));
let err_count = Arc::new(AtomicUsize::new(0));
let number_of_last_errors = Arc::new(AtomicUsize::new(0));
let dura = Arc::new(Mutex::new(0f64));
let number_of_last_requests = Arc::new(AtomicUsize::new(0));
let queue_cap = match step_option.clone() {
None => 10usize,
Some(step_option) => {
let steps_to_max_concurrency = concurrent_requests / step_option.increase_step;
let time_to_max_concurrency =
steps_to_max_concurrency as u64 * step_option.increase_interval;
let remaining_time = test_duration_secs.saturating_sub(time_to_max_concurrency);
remaining_time as usize
}
};
let rps_queue_arc = Arc::new(Mutex::new(fixed_size_queue::FixedSizeQueue::new(queue_cap)));
let api_rps_queue_arc: Arc<Mutex<BTreeMap<String, fixed_size_queue::FixedSizeQueue<f64>>>> =
Arc::new(Mutex::new(BTreeMap::new()));
let concurrent_number = Arc::new(AtomicUsize::new(0));
let mut handles: Vec<JoinHandle<Result<(), Error>>> = Vec::new();
let total_response_size = Arc::new(AtomicUsize::new(0));
let http_errors = Arc::new(Mutex::new(HttpErrorStats::new()));
let assert_errors = Arc::new(Mutex::new(AssertErrorStats::new()));
let total_weight: u32 = api_endpoints.iter().map(|e| e.weight).sum();
let (should_stop_tx, should_stop_rx) = oneshot::channel();
if assert_channel_buffer_size <= 0 {
assert_channel_buffer_size = 1024
}
let (tx_assert, rx_assert) = mpsc::channel(assert_channel_buffer_size);
if api_endpoints
.clone()
.into_iter()
.any(|item| item.assert_options.is_some())
{
if verbose {
println!("开启断言消费任务");
};
tokio::spawn(listening_assert::listening_assert(rx_assert));
};
let api_endpoints_arc: Vec<Arc<Mutex<ApiEndpoint>>> = api_endpoints
.into_iter()
.map(|endpoint| Arc::new(Mutex::new(endpoint)))
.collect();
let test_start = Instant::now();
let test_end = test_start + Duration::from_secs(test_duration_secs);
let results: Vec<ApiResult> = Vec::new();
let results_arc = Arc::new(Mutex::new(results));
let info = os_info::get();
let os_type = info.os_type();
let os_version = info.version().to_string();
let app_name = env!("CARGO_PKG_NAME");
let app_version = env!("CARGO_PKG_VERSION");
let user_agent_value =
match format!("{} {} ({}; {})", app_name, app_version, os_type, os_version)
.parse::<HeaderValue>()
{
Ok(v) => v,
Err(e) => {
return Err(Error::msg(format!(
"解析user agent失败::{:?}",
e.to_string()
)));
}
};
let mut is_need_render_template = false;
let mut extract_map: BTreeMap<String, Value> = BTreeMap::new();
let builder = Client::builder()
.cookie_store(cookie_store_enable)
.default_headers({
let mut headers = HeaderMap::new();
headers.insert(USER_AGENT, user_agent_value);
headers
});
let client = match timeout_secs > 0 {
true => match builder.timeout(Duration::from_secs(timeout_secs)).build() {
Ok(cli) => cli,
Err(e) => return Err(Error::msg(format!("构建含有超时的http客户端失败: {:?}", e))),
},
false => match builder.build() {
Ok(cli) => cli,
Err(e) => return Err(Error::msg(format!("构建http客户端失败: {:?}", e))),
},
};
if let Some(setup_options) = setup_options {
is_need_render_template = true;
match setup::start_setup(setup_options, extract_map.clone(), client.clone()).await {
Ok(res) => {
if let Some(extract) = res {
extract_map.extend(extract);
};
}
Err(e) => return Err(Error::msg(format!("全局初始化失败: {:?}", e))),
};
};
let extract_map_arc = Arc::new(Mutex::new(extract_map));
for (index, endpoint_arc) in api_endpoints_arc.clone().into_iter().enumerate() {
let endpoint = endpoint_arc.lock().await;
let weight = endpoint.weight.clone();
let name = endpoint.name.clone();
let api_url = match is_need_render_template {
true => {
let handlebars = Handlebars::new();
match handlebars.render_template(
&*endpoint.url.clone(),
&json!(*extract_map_arc.lock().await),
) {
Ok(c) => c,
Err(e) => {
eprintln!("{:?}", e);
endpoint.url.clone()
}
}
}
false => endpoint.url.clone(),
};
drop(endpoint);
let weight_ratio = weight as f64 / total_weight as f64;
let mut concurrency_for_endpoint =
((concurrent_requests as f64) * weight_ratio).round() as usize;
if concurrency_for_endpoint == 0 {
concurrency_for_endpoint = 1
}
let api_histogram = match Histogram::new(14, 20) {
Ok(h) => Arc::new(Mutex::new(h)),
Err(e) => return Err(Error::msg(format!("获取存储桶失败::{:?}", e.to_string()))),
};
let api_successful_requests = Arc::new(AtomicUsize::new(0));
let api_total_requests = Arc::new(AtomicUsize::new(0));
let api_max_response_time = Arc::new(Mutex::new(0u64));
let api_min_response_time = Arc::new(Mutex::new(u64::MAX));
let api_err_count = Arc::new(AtomicUsize::new(0));
let api_concurrent_number = Arc::new(AtomicUsize::new(0));
let api_total_response_size = Arc::new(AtomicUsize::new(0));
let mut init_api_res = ApiResult::new();
init_api_res.name = name.clone();
init_api_res.url = api_url.clone();
init_api_res.method = endpoint_arc.lock().await.method.clone().to_uppercase();
let api_result = Arc::new(Mutex::new(init_api_res.clone()));
results_arc.lock().await.push(init_api_res);
let controller = match step_option.clone() {
None => Arc::new(ConcurrencyController::new(concurrency_for_endpoint, None)),
Some(option) => {
let step = option.increase_step as f64 * weight_ratio;
Arc::new(ConcurrencyController::new(
concurrency_for_endpoint,
Option::from(InnerStepOption {
increase_step: step,
increase_interval: option.increase_interval,
}),
))
}
};
tokio::spawn({
let controller_clone = Arc::clone(&controller);
async move {
controller_clone.distribute_permits().await;
}
});
endpoint_arc.lock().await.url = api_url.clone();
for _ in 0..concurrency_for_endpoint {
let handle: JoinHandle<Result<(), Error>> =
tokio::spawn(start_task::start_concurrency(
client.clone(), Arc::clone(&controller), Arc::clone(&api_concurrent_number), Arc::clone(&concurrent_number), Arc::clone(&extract_map_arc), Arc::clone(&endpoint_arc), Arc::clone(&total_requests), Arc::clone(&api_total_requests), Arc::clone(&histogram), Arc::clone(&api_histogram), Arc::clone(&max_response_time), Arc::clone(&api_max_response_time), Arc::clone(&min_response_time), Arc::clone(&api_min_response_time), Arc::clone(&total_response_size), Arc::clone(&api_total_response_size), Arc::clone(&api_err_count), Arc::clone(&successful_requests), Arc::clone(&err_count), Arc::clone(&http_errors), Arc::clone(&assert_errors), Arc::clone(&api_successful_requests), Arc::clone(&api_result), Arc::clone(&results_arc), tx_assert.clone(), test_start, test_end, is_need_render_template, verbose, index, ));
handles.push(handle);
}
}
tokio::spawn(share_result::collect_results(
result_sender,
should_stop_rx,
Arc::clone(&total_requests),
Arc::clone(&successful_requests),
Arc::clone(&histogram),
Arc::clone(&total_response_size),
Arc::clone(&http_errors),
Arc::clone(&err_count),
Arc::clone(&max_response_time),
Arc::clone(&min_response_time),
Arc::clone(&assert_errors),
Arc::clone(&results_arc),
Arc::clone(&concurrent_number),
Arc::clone(&dura),
Arc::clone(&number_of_last_requests),
Arc::clone(&number_of_last_errors),
Arc::clone(&rps_queue_arc),
Arc::clone(&api_rps_queue_arc),
queue_cap,
verbose,
test_start,
ema_alpha,
));
let task_results = join_all(handles).await;
for task_result in task_results {
match task_result {
Ok(res) => {
match res {
Ok(_) => {
if verbose {
println!("任务完成")
}
}
Err(e) => {
eprintln!("异步任务内部错误::{:?}", e)
}
};
}
Err(err) => {
eprintln!("协程被取消或意外停止::{:?}", err);
}
};
}
let err_count_clone = Arc::clone(&err_count);
let err_count = err_count_clone.load(Ordering::SeqCst);
let total_duration = (Instant::now() - test_start).as_secs_f64();
let total_requests = total_requests.load(Ordering::SeqCst) as u64;
let successful_requests = successful_requests.load(Ordering::SeqCst) as f64;
let success_rate = successful_requests / total_requests as f64 * 100.0;
let histogram = histogram.lock().await;
let total_response_size_kb = total_response_size.load(Ordering::SeqCst) as f64 / 1024.0;
let throughput_kb_s = total_response_size_kb / test_duration_secs as f64;
let http_errors = http_errors.lock().await.errors.clone();
let assert_errors = assert_errors.lock().await.errors.clone();
let timestamp = match SystemTime::now().duration_since(UNIX_EPOCH) {
Ok(n) => n.as_millis(),
Err(_) => 0,
};
let mut api_results = results_arc.lock().await;
for (index, res) in api_results.clone().into_iter().enumerate() {
let api_res = match api_rps_queue_arc.lock().await.clone().get(&res.name) {
None => 0f64,
Some(v) => v.average().await.unwrap_or_else(|| 0f64),
};
api_results[index].rps = api_res;
}
for (index, res) in api_results.clone().into_iter().enumerate() {
let rps = res.total_requests as f64 / total_duration;
api_results[index].rps = rps;
if let Ok(url) = Url::parse(&*res.url) {
if let Some(host) = url.host() {
api_results[index].host = host.to_string();
};
api_results[index].path = url.path().to_string();
};
}
let error_rate = err_count as f64 / total_requests as f64 * 100.0;
let total_concurrent_number_clone = concurrent_number.load(Ordering::SeqCst) as i32;
let errors_per_second = err_count - number_of_last_errors.load(Ordering::SeqCst);
number_of_last_errors.fetch_add(errors_per_second, Ordering::Relaxed);
let rps = rps_queue_arc
.lock()
.await
.average()
.await
.unwrap_or_else(|| 0f64);
number_of_last_requests.fetch_add(rps as usize, Ordering::Relaxed);
let result = Ok(BatchResult {
total_duration,
success_rate,
error_rate,
median_response_time: match histogram.percentile(50.0) {
Ok(b) => *b.range().start(),
Err(e) => {
return Err(Error::msg(format!("获取50线失败::{:?}", e.to_string())));
}
},
response_time_95: match histogram.percentile(95.0) {
Ok(b) => *b.range().start(),
Err(e) => {
return Err(Error::msg(format!("获取95线失败::{:?}", e.to_string())));
}
},
response_time_99: match histogram.percentile(99.0) {
Ok(b) => *b.range().start(),
Err(e) => {
return Err(Error::msg(format!("获取99线失败::{:?}", e.to_string())));
}
},
total_requests,
rps,
max_response_time: *max_response_time.lock().await,
min_response_time: *min_response_time.lock().await,
err_count: err_count_clone.load(Ordering::SeqCst) as i32,
total_data_kb: total_response_size_kb,
throughput_per_second_kb: throughput_kb_s,
http_errors: http_errors.lock().await.clone(),
timestamp,
assert_errors: assert_errors.lock().await.clone(),
total_concurrent_number: total_concurrent_number_clone,
api_results: api_results.to_vec().clone(),
errors_per_second,
});
should_stop_tx.send(()).unwrap();
eprintln!("测试完成!");
result
}