1use std::collections::BTreeMap;
2use std::env;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use anyhow::Error;
9use futures::future::join_all;
10use handlebars::Handlebars;
11use histogram::Histogram;
12use reqwest::header::{HeaderMap, HeaderValue, USER_AGENT};
13use reqwest::Client;
14use serde_json::{json, Value};
15use tokio::sync::{mpsc, oneshot, Mutex};
16use tokio::task::JoinHandle;
17use url::Url;
18
19use crate::core::check_endpoints_names::check_endpoints_names;
20use crate::core::concurrency_controller::ConcurrencyController;
21use crate::core::fixed_size_queue;
22use crate::core::sleep_guard::SleepGuard;
23use crate::core::{listening_assert, setup, share_result, start_task};
24use crate::models::api_endpoint::ApiEndpoint;
25use crate::models::assert_error_stats::AssertErrorStats;
26use crate::models::http_error_stats::HttpErrorStats;
27use crate::models::result::{ApiResult, BatchResult};
28use crate::models::setup::SetupApiEndpoint;
29use crate::models::step_option::{InnerStepOption, StepOption};
30
31pub async fn batch(
32 result_sender: mpsc::Sender<Option<BatchResult>>,
33 test_duration_secs: u64,
34 concurrent_requests: usize,
35 timeout_secs: u64,
36 cookie_store_enable: bool,
37 verbose: bool,
38 should_prevent: bool,
39 api_endpoints: Vec<ApiEndpoint>,
40 step_option: Option<StepOption>,
41 setup_options: Option<Vec<SetupApiEndpoint>>,
42 mut assert_channel_buffer_size: usize,
43 ema_alpha: f64,
44) -> anyhow::Result<BatchResult> {
45 let _guard = SleepGuard::new(should_prevent);
47 if let Some(step_option) = step_option.clone() {
49 let total_steps = test_duration_secs / step_option.increase_interval;
51 let total_concurrency_increase =
53 step_option.increase_step as u64 * total_steps * (total_steps + 1) / 2;
54 if total_concurrency_increase < concurrent_requests as u64 {
55 return Err(Error::msg(
56 "阶梯加压总并发数在设置的时间内无法增加到预设的结束并发数",
57 ));
58 }
59 };
60 if let Err(e) = check_endpoints_names(api_endpoints.clone()) {
62 return Err(Error::msg(e));
63 }
64 let histogram = match Histogram::new(14, 20) {
66 Ok(h) => Arc::new(Mutex::new(h)),
67 Err(e) => {
68 return Err(Error::msg(format!("获取存储桶失败::{:?}", e.to_string())));
69 }
70 };
71 let successful_requests = Arc::new(AtomicUsize::new(0));
73 let total_requests = Arc::new(AtomicUsize::new(0));
75 let max_response_time = Arc::new(Mutex::new(0u64));
77 let min_response_time = Arc::new(Mutex::new(u64::MAX));
79 let err_count = Arc::new(AtomicUsize::new(0));
81 let number_of_last_errors = Arc::new(AtomicUsize::new(0));
83 let dura = Arc::new(Mutex::new(0f64));
84 let number_of_last_requests = Arc::new(AtomicUsize::new(0));
86 let queue_cap = match step_option.clone() {
89 None => 10usize,
91 Some(step_option) => {
93 let steps_to_max_concurrency = concurrent_requests / step_option.increase_step;
95 let time_to_max_concurrency =
96 steps_to_max_concurrency as u64 * step_option.increase_interval;
97 let remaining_time = test_duration_secs.saturating_sub(time_to_max_concurrency);
99 remaining_time as usize
100 }
101 };
102 let rps_queue_arc = Arc::new(Mutex::new(fixed_size_queue::FixedSizeQueue::new(queue_cap)));
103 let api_rps_queue_arc: Arc<Mutex<BTreeMap<String, fixed_size_queue::FixedSizeQueue<f64>>>> =
105 Arc::new(Mutex::new(BTreeMap::new()));
106 let concurrent_number = Arc::new(AtomicUsize::new(0));
108 let mut handles: Vec<JoinHandle<Result<(), Error>>> = Vec::new();
110 let total_response_size = Arc::new(AtomicUsize::new(0));
112 let http_errors = Arc::new(Mutex::new(HttpErrorStats::new()));
114 let assert_errors = Arc::new(Mutex::new(AssertErrorStats::new()));
116 let total_weight: u32 = api_endpoints.iter().map(|e| e.weight).sum();
118 let (should_stop_tx, should_stop_rx) = oneshot::channel();
120 if assert_channel_buffer_size <= 0 {
122 assert_channel_buffer_size = 1024
123 }
124 let (tx_assert, rx_assert) = mpsc::channel(assert_channel_buffer_size);
125 if api_endpoints
127 .clone()
128 .into_iter()
129 .any(|item| item.assert_options.is_some())
130 {
131 if verbose {
132 println!("开启断言消费任务");
133 };
134 tokio::spawn(listening_assert::listening_assert(rx_assert));
135 };
136 let api_endpoints_arc: Vec<Arc<Mutex<ApiEndpoint>>> = api_endpoints
138 .into_iter()
139 .map(|endpoint| Arc::new(Mutex::new(endpoint)))
140 .collect();
141 let test_start = Instant::now();
143 let test_end = test_start + Duration::from_secs(test_duration_secs);
145 let results: Vec<ApiResult> = Vec::new();
147 let results_arc = Arc::new(Mutex::new(results));
148 let info = os_info::get();
150 let os_type = info.os_type();
151 let os_version = info.version().to_string();
152 let app_name = env!("CARGO_PKG_NAME");
153 let app_version = env!("CARGO_PKG_VERSION");
154 let user_agent_value =
155 match format!("{} {} ({}; {})", app_name, app_version, os_type, os_version)
156 .parse::<HeaderValue>()
157 {
158 Ok(v) => v,
159 Err(e) => {
160 return Err(Error::msg(format!(
161 "解析user agent失败::{:?}",
162 e.to_string()
163 )));
164 }
165 };
166 let mut is_need_render_template = false;
167 let mut extract_map: BTreeMap<String, Value> = BTreeMap::new();
169 let builder = Client::builder()
171 .cookie_store(cookie_store_enable)
172 .default_headers({
173 let mut headers = HeaderMap::new();
174 headers.insert(USER_AGENT, user_agent_value);
175 headers
176 });
177 let client = match timeout_secs > 0 {
178 true => match builder.timeout(Duration::from_secs(timeout_secs)).build() {
179 Ok(cli) => cli,
180 Err(e) => return Err(Error::msg(format!("构建含有超时的http客户端失败: {:?}", e))),
181 },
182 false => match builder.build() {
183 Ok(cli) => cli,
184 Err(e) => return Err(Error::msg(format!("构建http客户端失败: {:?}", e))),
185 },
186 };
187 if let Some(setup_options) = setup_options {
189 is_need_render_template = true;
190 match setup::start_setup(setup_options, extract_map.clone(), client.clone()).await {
191 Ok(res) => {
192 if let Some(extract) = res {
193 extract_map.extend(extract);
194 };
195 }
196 Err(e) => return Err(Error::msg(format!("全局初始化失败: {:?}", e))),
197 };
198 };
199 let extract_map_arc = Arc::new(Mutex::new(extract_map));
202 for (index, endpoint_arc) in api_endpoints_arc.clone().into_iter().enumerate() {
204 let endpoint = endpoint_arc.lock().await;
205 let weight = endpoint.weight.clone();
206 let name = endpoint.name.clone();
207 let api_url = match is_need_render_template {
208 true => {
209 let handlebars = Handlebars::new();
211 match handlebars.render_template(
212 &*endpoint.url.clone(),
213 &json!(*extract_map_arc.lock().await),
214 ) {
215 Ok(c) => c,
216 Err(e) => {
217 eprintln!("{:?}", e);
218 endpoint.url.clone()
219 }
220 }
221 }
222 false => endpoint.url.clone(),
223 };
224 drop(endpoint);
225 let weight_ratio = weight as f64 / total_weight as f64;
227 let mut concurrency_for_endpoint =
229 ((concurrent_requests as f64) * weight_ratio).round() as usize;
230 if concurrency_for_endpoint == 0 {
232 concurrency_for_endpoint = 1
233 }
234 let api_histogram = match Histogram::new(14, 20) {
236 Ok(h) => Arc::new(Mutex::new(h)),
237 Err(e) => return Err(Error::msg(format!("获取存储桶失败::{:?}", e.to_string()))),
238 };
239 let api_successful_requests = Arc::new(AtomicUsize::new(0));
241 let api_total_requests = Arc::new(AtomicUsize::new(0));
243 let api_max_response_time = Arc::new(Mutex::new(0u64));
245 let api_min_response_time = Arc::new(Mutex::new(u64::MAX));
247 let api_err_count = Arc::new(AtomicUsize::new(0));
249 let api_concurrent_number = Arc::new(AtomicUsize::new(0));
251 let api_total_response_size = Arc::new(AtomicUsize::new(0));
253 let mut init_api_res = ApiResult::new();
255 init_api_res.name = name.clone();
256 init_api_res.url = api_url.clone();
257 init_api_res.method = endpoint_arc.lock().await.method.clone().to_uppercase();
258 let api_result = Arc::new(Mutex::new(init_api_res.clone()));
260 results_arc.lock().await.push(init_api_res);
262 let controller = match step_option.clone() {
264 None => Arc::new(ConcurrencyController::new(concurrency_for_endpoint, None)),
265 Some(option) => {
266 let step = option.increase_step as f64 * weight_ratio;
268 Arc::new(ConcurrencyController::new(
269 concurrency_for_endpoint,
270 Option::from(InnerStepOption {
271 increase_step: step,
272 increase_interval: option.increase_interval,
273 }),
274 ))
275 }
276 };
277 tokio::spawn({
279 let controller_clone = Arc::clone(&controller);
280 async move {
281 controller_clone.distribute_permits().await;
282 }
283 });
284 endpoint_arc.lock().await.url = api_url.clone();
286 for _ in 0..concurrency_for_endpoint {
287 let handle: JoinHandle<Result<(), Error>> =
289 tokio::spawn(start_task::start_concurrency(
290 client.clone(), Arc::clone(&controller), Arc::clone(&api_concurrent_number), Arc::clone(&concurrent_number), Arc::clone(&extract_map_arc), Arc::clone(&endpoint_arc), Arc::clone(&total_requests), Arc::clone(&api_total_requests), Arc::clone(&histogram), Arc::clone(&api_histogram), Arc::clone(&max_response_time), Arc::clone(&api_max_response_time), Arc::clone(&min_response_time), Arc::clone(&api_min_response_time), Arc::clone(&total_response_size), Arc::clone(&api_total_response_size), Arc::clone(&api_err_count), Arc::clone(&successful_requests), Arc::clone(&err_count), Arc::clone(&http_errors), Arc::clone(&assert_errors), Arc::clone(&api_successful_requests), Arc::clone(&api_result), Arc::clone(&results_arc), tx_assert.clone(), test_start, test_end, is_need_render_template, verbose, index, ));
321 handles.push(handle);
322 }
323 }
325
326 tokio::spawn(share_result::collect_results(
328 result_sender,
329 should_stop_rx,
330 Arc::clone(&total_requests),
331 Arc::clone(&successful_requests),
332 Arc::clone(&histogram),
333 Arc::clone(&total_response_size),
334 Arc::clone(&http_errors),
335 Arc::clone(&err_count),
336 Arc::clone(&max_response_time),
337 Arc::clone(&min_response_time),
338 Arc::clone(&assert_errors),
339 Arc::clone(&results_arc),
340 Arc::clone(&concurrent_number),
341 Arc::clone(&dura),
342 Arc::clone(&number_of_last_requests),
343 Arc::clone(&number_of_last_errors),
344 Arc::clone(&rps_queue_arc),
345 Arc::clone(&api_rps_queue_arc),
346 queue_cap,
347 verbose,
348 test_start,
349 ema_alpha,
350 ));
351
352 let task_results = join_all(handles).await;
354 for task_result in task_results {
355 match task_result {
356 Ok(res) => {
357 match res {
358 Ok(_) => {
359 if verbose {
360 println!("任务完成")
361 }
362 }
363 Err(e) => {
364 eprintln!("异步任务内部错误::{:?}", e)
365 }
366 };
367 }
368 Err(err) => {
369 eprintln!("协程被取消或意外停止::{:?}", err);
370 }
371 };
372 }
373
374 let err_count_clone = Arc::clone(&err_count);
376 let err_count = err_count_clone.load(Ordering::SeqCst);
377 let total_duration = (Instant::now() - test_start).as_secs_f64();
378 let total_requests = total_requests.load(Ordering::SeqCst) as u64;
379 let successful_requests = successful_requests.load(Ordering::SeqCst) as f64;
380 let success_rate = successful_requests / total_requests as f64 * 100.0;
381 let histogram = histogram.lock().await;
382 let total_response_size_kb = total_response_size.load(Ordering::SeqCst) as f64 / 1024.0;
383 let throughput_kb_s = total_response_size_kb / test_duration_secs as f64;
384 let http_errors = http_errors.lock().await.errors.clone();
385 let assert_errors = assert_errors.lock().await.errors.clone();
386 let timestamp = match SystemTime::now().duration_since(UNIX_EPOCH) {
387 Ok(n) => n.as_millis(),
388 Err(_) => 0,
389 };
390 let mut api_results = results_arc.lock().await;
391 for (index, res) in api_results.clone().into_iter().enumerate() {
392 let api_res = match api_rps_queue_arc.lock().await.clone().get(&res.name) {
393 None => 0f64,
394 Some(v) => v.average().await.unwrap_or_else(|| 0f64),
395 };
396 api_results[index].rps = api_res;
397 }
398 for (index, res) in api_results.clone().into_iter().enumerate() {
400 let rps = res.total_requests as f64 / total_duration;
402 api_results[index].rps = rps;
403 if let Ok(url) = Url::parse(&*res.url) {
405 if let Some(host) = url.host() {
406 api_results[index].host = host.to_string();
407 };
408 api_results[index].path = url.path().to_string();
409 };
410 }
411 let error_rate = err_count as f64 / total_requests as f64 * 100.0;
412 let total_concurrent_number_clone = concurrent_number.load(Ordering::SeqCst) as i32;
413 let errors_per_second = err_count - number_of_last_errors.load(Ordering::SeqCst);
415 number_of_last_errors.fetch_add(errors_per_second, Ordering::Relaxed);
417 let rps = rps_queue_arc
418 .lock()
419 .await
420 .average()
421 .await
422 .unwrap_or_else(|| 0f64);
423 number_of_last_requests.fetch_add(rps as usize, Ordering::Relaxed);
425 let result = Ok(BatchResult {
427 total_duration,
428 success_rate,
429 error_rate,
430 median_response_time: match histogram.percentile(50.0) {
431 Ok(b) => *b.range().start(),
432 Err(e) => {
433 return Err(Error::msg(format!("获取50线失败::{:?}", e.to_string())));
434 }
435 },
436 response_time_95: match histogram.percentile(95.0) {
437 Ok(b) => *b.range().start(),
438 Err(e) => {
439 return Err(Error::msg(format!("获取95线失败::{:?}", e.to_string())));
440 }
441 },
442 response_time_99: match histogram.percentile(99.0) {
443 Ok(b) => *b.range().start(),
444 Err(e) => {
445 return Err(Error::msg(format!("获取99线失败::{:?}", e.to_string())));
446 }
447 },
448 total_requests,
449 rps,
450 max_response_time: *max_response_time.lock().await,
451 min_response_time: *min_response_time.lock().await,
452 err_count: err_count_clone.load(Ordering::SeqCst) as i32,
453 total_data_kb: total_response_size_kb,
454 throughput_per_second_kb: throughput_kb_s,
455 http_errors: http_errors.lock().await.clone(),
456 timestamp,
457 assert_errors: assert_errors.lock().await.clone(),
458 total_concurrent_number: total_concurrent_number_clone,
459 api_results: api_results.to_vec().clone(),
460 errors_per_second,
461 });
462 should_stop_tx.send(()).unwrap();
463 eprintln!("测试完成!");
464 result
465}