atomic_bomb_engine/core/
batch.rs

1use std::collections::BTreeMap;
2use std::env;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use anyhow::Error;
9use futures::future::join_all;
10use handlebars::Handlebars;
11use histogram::Histogram;
12use reqwest::header::{HeaderMap, HeaderValue, USER_AGENT};
13use reqwest::Client;
14use serde_json::{json, Value};
15use tokio::sync::{mpsc, oneshot, Mutex};
16use tokio::task::JoinHandle;
17use url::Url;
18
19use crate::core::check_endpoints_names::check_endpoints_names;
20use crate::core::concurrency_controller::ConcurrencyController;
21use crate::core::fixed_size_queue;
22use crate::core::sleep_guard::SleepGuard;
23use crate::core::{listening_assert, setup, share_result, start_task};
24use crate::models::api_endpoint::ApiEndpoint;
25use crate::models::assert_error_stats::AssertErrorStats;
26use crate::models::http_error_stats::HttpErrorStats;
27use crate::models::result::{ApiResult, BatchResult};
28use crate::models::setup::SetupApiEndpoint;
29use crate::models::step_option::{InnerStepOption, StepOption};
30
31pub async fn batch(
32    result_sender: mpsc::Sender<Option<BatchResult>>,
33    test_duration_secs: u64,
34    concurrent_requests: usize,
35    timeout_secs: u64,
36    cookie_store_enable: bool,
37    verbose: bool,
38    should_prevent: bool,
39    api_endpoints: Vec<ApiEndpoint>,
40    step_option: Option<StepOption>,
41    setup_options: Option<Vec<SetupApiEndpoint>>,
42    mut assert_channel_buffer_size: usize,
43    ema_alpha: f64,
44) -> anyhow::Result<BatchResult> {
45    // 阻止电脑休眠
46    let _guard = SleepGuard::new(should_prevent);
47    // 检查阶梯并发量
48    if let Some(step_option) = step_option.clone() {
49        // 计算总共增加次数
50        let total_steps = test_duration_secs / step_option.increase_interval;
51        // 计算总增加并发数
52        let total_concurrency_increase =
53            step_option.increase_step as u64 * total_steps * (total_steps + 1) / 2;
54        if total_concurrency_increase < concurrent_requests as u64 {
55            return Err(Error::msg(
56                "阶梯加压总并发数在设置的时间内无法增加到预设的结束并发数",
57            ));
58        }
59    };
60    // 检查每个接口的名称
61    if let Err(e) = check_endpoints_names(api_endpoints.clone()) {
62        return Err(Error::msg(e));
63    }
64    // 总响应时间统计
65    let histogram = match Histogram::new(14, 20) {
66        Ok(h) => Arc::new(Mutex::new(h)),
67        Err(e) => {
68            return Err(Error::msg(format!("获取存储桶失败::{:?}", e.to_string())));
69        }
70    };
71    // 成功数据统计
72    let successful_requests = Arc::new(AtomicUsize::new(0));
73    // 请求总数统计
74    let total_requests = Arc::new(AtomicUsize::new(0));
75    // 统计最大响应时间
76    let max_response_time = Arc::new(Mutex::new(0u64));
77    // 统计最小响应时间
78    let min_response_time = Arc::new(Mutex::new(u64::MAX));
79    // 统计错误数量
80    let err_count = Arc::new(AtomicUsize::new(0));
81    // 统计每秒错误数
82    let number_of_last_errors = Arc::new(AtomicUsize::new(0));
83    let dura = Arc::new(Mutex::new(0f64));
84    // 统计rps
85    let number_of_last_requests = Arc::new(AtomicUsize::new(0));
86    // rps队列
87    // 计算队列长度
88    let queue_cap = match step_option.clone() {
89        // 没有阶梯加压,队列长度为10
90        None => 10usize,
91        // 有阶梯加压,计算出最大并发持续时间,变为队列长度
92        Some(step_option) => {
93            // 计算最大并发量所需时间
94            let steps_to_max_concurrency = concurrent_requests / step_option.increase_step;
95            let time_to_max_concurrency =
96                steps_to_max_concurrency as u64 * step_option.increase_interval;
97            // 计算最大并发量剩余时间
98            let remaining_time = test_duration_secs.saturating_sub(time_to_max_concurrency);
99            remaining_time as usize
100        }
101    };
102    let rps_queue_arc = Arc::new(Mutex::new(fixed_size_queue::FixedSizeQueue::new(queue_cap)));
103    // api rps队列
104    let api_rps_queue_arc: Arc<Mutex<BTreeMap<String, fixed_size_queue::FixedSizeQueue<f64>>>> =
105        Arc::new(Mutex::new(BTreeMap::new()));
106    // 已开始并发数
107    let concurrent_number = Arc::new(AtomicUsize::new(0));
108    // 接口线程池
109    let mut handles: Vec<JoinHandle<Result<(), Error>>> = Vec::new();
110    // 统计响应大小
111    let total_response_size = Arc::new(AtomicUsize::new(0));
112    // 统计http错误
113    let http_errors = Arc::new(Mutex::new(HttpErrorStats::new()));
114    // 统计断言错误
115    let assert_errors = Arc::new(Mutex::new(AssertErrorStats::new()));
116    // 总权重
117    let total_weight: u32 = api_endpoints.iter().map(|e| e.weight).sum();
118    // 是否停止通道
119    let (should_stop_tx, should_stop_rx) = oneshot::channel();
120    // 断言队列
121    if assert_channel_buffer_size <= 0 {
122        assert_channel_buffer_size = 1024
123    }
124    let (tx_assert, rx_assert) = mpsc::channel(assert_channel_buffer_size);
125    // 开启一个任务,做断言的生产消费
126    if api_endpoints
127        .clone()
128        .into_iter()
129        .any(|item| item.assert_options.is_some())
130    {
131        if verbose {
132            println!("开启断言消费任务");
133        };
134        tokio::spawn(listening_assert::listening_assert(rx_assert));
135    };
136    // 用arc包装每一个endpoint
137    let api_endpoints_arc: Vec<Arc<Mutex<ApiEndpoint>>> = api_endpoints
138        .into_iter()
139        .map(|endpoint| Arc::new(Mutex::new(endpoint)))
140        .collect();
141    // 开始测试时间
142    let test_start = Instant::now();
143    // 测试结束时间
144    let test_end = test_start + Duration::from_secs(test_duration_secs);
145    // 每个接口的测试结果
146    let results: Vec<ApiResult> = Vec::new();
147    let results_arc = Arc::new(Mutex::new(results));
148    // user_agent
149    let info = os_info::get();
150    let os_type = info.os_type();
151    let os_version = info.version().to_string();
152    let app_name = env!("CARGO_PKG_NAME");
153    let app_version = env!("CARGO_PKG_VERSION");
154    let user_agent_value =
155        match format!("{} {} ({}; {})", app_name, app_version, os_type, os_version)
156            .parse::<HeaderValue>()
157        {
158            Ok(v) => v,
159            Err(e) => {
160                return Err(Error::msg(format!(
161                    "解析user agent失败::{:?}",
162                    e.to_string()
163                )));
164            }
165        };
166    let mut is_need_render_template = false;
167    // 全局提取字典
168    let mut extract_map: BTreeMap<String, Value> = BTreeMap::new();
169    // 创建http客户端
170    let builder = Client::builder()
171        .cookie_store(cookie_store_enable)
172        .default_headers({
173            let mut headers = HeaderMap::new();
174            headers.insert(USER_AGENT, user_agent_value);
175            headers
176        });
177    let client = match timeout_secs > 0 {
178        true => match builder.timeout(Duration::from_secs(timeout_secs)).build() {
179            Ok(cli) => cli,
180            Err(e) => return Err(Error::msg(format!("构建含有超时的http客户端失败: {:?}", e))),
181        },
182        false => match builder.build() {
183            Ok(cli) => cli,
184            Err(e) => return Err(Error::msg(format!("构建http客户端失败: {:?}", e))),
185        },
186    };
187    // 开始初始化
188    if let Some(setup_options) = setup_options {
189        is_need_render_template = true;
190        match setup::start_setup(setup_options, extract_map.clone(), client.clone()).await {
191            Ok(res) => {
192                if let Some(extract) = res {
193                    extract_map.extend(extract);
194                };
195            }
196            Err(e) => return Err(Error::msg(format!("全局初始化失败: {:?}", e))),
197        };
198    };
199    // println!("extract_map:{:?}", extract_map);
200    // 并发安全的提取字典
201    let extract_map_arc = Arc::new(Mutex::new(extract_map));
202    // 针对每一个接口开始配置
203    for (index, endpoint_arc) in api_endpoints_arc.clone().into_iter().enumerate() {
204        let endpoint = endpoint_arc.lock().await;
205        let weight = endpoint.weight.clone();
206        let name = endpoint.name.clone();
207        let api_url = match is_need_render_template {
208            true => {
209                // 使用模版替换cookies
210                let handlebars = Handlebars::new();
211                match handlebars.render_template(
212                    &*endpoint.url.clone(),
213                    &json!(*extract_map_arc.lock().await),
214                ) {
215                    Ok(c) => c,
216                    Err(e) => {
217                        eprintln!("{:?}", e);
218                        endpoint.url.clone()
219                    }
220                }
221            }
222            false => endpoint.url.clone(),
223        };
224        drop(endpoint);
225        // 计算权重比例
226        let weight_ratio = weight as f64 / total_weight as f64;
227        // 计算每个接口的并发量
228        let mut concurrency_for_endpoint =
229            ((concurrent_requests as f64) * weight_ratio).round() as usize;
230        // 如果这个接口的并发量四舍五入成0了, 就把他定为1
231        if concurrency_for_endpoint == 0 {
232            concurrency_for_endpoint = 1
233        }
234        // 接口数据的统计
235        let api_histogram = match Histogram::new(14, 20) {
236            Ok(h) => Arc::new(Mutex::new(h)),
237            Err(e) => return Err(Error::msg(format!("获取存储桶失败::{:?}", e.to_string()))),
238        };
239        // 接口成功数据统计
240        let api_successful_requests = Arc::new(AtomicUsize::new(0));
241        // 接口请求总数统计
242        let api_total_requests = Arc::new(AtomicUsize::new(0));
243        // 接口统计最大响应时间
244        let api_max_response_time = Arc::new(Mutex::new(0u64));
245        // 接口统计最小响应时间
246        let api_min_response_time = Arc::new(Mutex::new(u64::MAX));
247        // 接口统计错误数量
248        let api_err_count = Arc::new(AtomicUsize::new(0));
249        // 接口并发数统计
250        let api_concurrent_number = Arc::new(AtomicUsize::new(0));
251        // 接口响应大小
252        let api_total_response_size = Arc::new(AtomicUsize::new(0));
253        // 初始化api结果
254        let mut init_api_res = ApiResult::new();
255        init_api_res.name = name.clone();
256        init_api_res.url = api_url.clone();
257        init_api_res.method = endpoint_arc.lock().await.method.clone().to_uppercase();
258        // 包装初始化好的接口信息
259        let api_result = Arc::new(Mutex::new(init_api_res.clone()));
260        // 将初始化好的接口信息添加到list中
261        results_arc.lock().await.push(init_api_res);
262        // 根据step初始化并发控制器
263        let controller = match step_option.clone() {
264            None => Arc::new(ConcurrencyController::new(concurrency_for_endpoint, None)),
265            Some(option) => {
266                // 计算每个接口的步长
267                let step = option.increase_step as f64 * weight_ratio;
268                Arc::new(ConcurrencyController::new(
269                    concurrency_for_endpoint,
270                    Option::from(InnerStepOption {
271                        increase_step: step,
272                        increase_interval: option.increase_interval,
273                    }),
274                ))
275            }
276        };
277        // 后台启动并发控制器
278        tokio::spawn({
279            let controller_clone = Arc::clone(&controller);
280            async move {
281                controller_clone.distribute_permits().await;
282            }
283        });
284        // 将新url替换到每个接口中
285        endpoint_arc.lock().await.url = api_url.clone();
286        for _ in 0..concurrency_for_endpoint {
287            // 开启并发
288            let handle: JoinHandle<Result<(), Error>> =
289                tokio::spawn(start_task::start_concurrency(
290                    client.clone(),                       // http客户端
291                    Arc::clone(&controller),              // 并发控制器
292                    Arc::clone(&api_concurrent_number),   // api并发数
293                    Arc::clone(&concurrent_number),       // 总并发数
294                    Arc::clone(&extract_map_arc),         // 断言替换字典
295                    Arc::clone(&endpoint_arc),            // 接口数据
296                    Arc::clone(&total_requests),          // 总请求数
297                    Arc::clone(&api_total_requests),      // api请求数
298                    Arc::clone(&histogram),               // 总统计桶
299                    Arc::clone(&api_histogram),           // api统计桶
300                    Arc::clone(&max_response_time),       // 最大响应时间
301                    Arc::clone(&api_max_response_time),   // 接口最大响应时间
302                    Arc::clone(&min_response_time),       // 最小响应时间
303                    Arc::clone(&api_min_response_time),   // api最小响应时间
304                    Arc::clone(&total_response_size),     // 总响应数据
305                    Arc::clone(&api_total_response_size), // api响应数据
306                    Arc::clone(&api_err_count),           // api错误数
307                    Arc::clone(&successful_requests),     // 成功数量
308                    Arc::clone(&err_count),               // 错误数量
309                    Arc::clone(&http_errors),             // http错误统计
310                    Arc::clone(&assert_errors),           // 断言错误统计
311                    Arc::clone(&api_successful_requests), // api成功数量
312                    Arc::clone(&api_result),              // 接口详细统计
313                    Arc::clone(&results_arc),             // 最终响应结果
314                    tx_assert.clone(),                    // 断言通道
315                    test_start,                           // 测试开始时间
316                    test_end,                             // 测试结束时间
317                    is_need_render_template,              // 是否需要读取模板
318                    verbose,                              // 是否打印详情
319                    index,                                // 索引
320                ));
321            handles.push(handle);
322        }
323        // println!("err count:{:?}",api_err_count.lock().await);
324    }
325
326    // 共享任务状态
327    tokio::spawn(share_result::collect_results(
328        result_sender,
329        should_stop_rx,
330        Arc::clone(&total_requests),
331        Arc::clone(&successful_requests),
332        Arc::clone(&histogram),
333        Arc::clone(&total_response_size),
334        Arc::clone(&http_errors),
335        Arc::clone(&err_count),
336        Arc::clone(&max_response_time),
337        Arc::clone(&min_response_time),
338        Arc::clone(&assert_errors),
339        Arc::clone(&results_arc),
340        Arc::clone(&concurrent_number),
341        Arc::clone(&dura),
342        Arc::clone(&number_of_last_requests),
343        Arc::clone(&number_of_last_errors),
344        Arc::clone(&rps_queue_arc),
345        Arc::clone(&api_rps_queue_arc),
346        queue_cap,
347        verbose,
348        test_start,
349        ema_alpha,
350    ));
351
352    // 等待任务完成
353    let task_results = join_all(handles).await;
354    for task_result in task_results {
355        match task_result {
356            Ok(res) => {
357                match res {
358                    Ok(_) => {
359                        if verbose {
360                            println!("任务完成")
361                        }
362                    }
363                    Err(e) => {
364                        eprintln!("异步任务内部错误::{:?}", e)
365                    }
366                };
367            }
368            Err(err) => {
369                eprintln!("协程被取消或意外停止::{:?}", err);
370            }
371        };
372    }
373
374    // 对结果进行赋值
375    let err_count_clone = Arc::clone(&err_count);
376    let err_count = err_count_clone.load(Ordering::SeqCst);
377    let total_duration = (Instant::now() - test_start).as_secs_f64();
378    let total_requests = total_requests.load(Ordering::SeqCst) as u64;
379    let successful_requests = successful_requests.load(Ordering::SeqCst) as f64;
380    let success_rate = successful_requests / total_requests as f64 * 100.0;
381    let histogram = histogram.lock().await;
382    let total_response_size_kb = total_response_size.load(Ordering::SeqCst) as f64 / 1024.0;
383    let throughput_kb_s = total_response_size_kb / test_duration_secs as f64;
384    let http_errors = http_errors.lock().await.errors.clone();
385    let assert_errors = assert_errors.lock().await.errors.clone();
386    let timestamp = match SystemTime::now().duration_since(UNIX_EPOCH) {
387        Ok(n) => n.as_millis(),
388        Err(_) => 0,
389    };
390    let mut api_results = results_arc.lock().await;
391    for (index, res) in api_results.clone().into_iter().enumerate() {
392        let api_res = match api_rps_queue_arc.lock().await.clone().get(&res.name) {
393            None => 0f64,
394            Some(v) => v.average().await.unwrap_or_else(|| 0f64),
395        };
396        api_results[index].rps = api_res;
397    }
398    // 计算每个接口的rps,host, path
399    for (index, res) in api_results.clone().into_iter().enumerate() {
400        // 计算每个接口的rps
401        let rps = res.total_requests as f64 / total_duration;
402        api_results[index].rps = rps;
403        // 计算每个接口的HOST,PATH
404        if let Ok(url) = Url::parse(&*res.url) {
405            if let Some(host) = url.host() {
406                api_results[index].host = host.to_string();
407            };
408            api_results[index].path = url.path().to_string();
409        };
410    }
411    let error_rate = err_count as f64 / total_requests as f64 * 100.0;
412    let total_concurrent_number_clone = concurrent_number.load(Ordering::SeqCst) as i32;
413    // 总错误数量减去上一次错误数量得出增量
414    let errors_per_second = err_count - number_of_last_errors.load(Ordering::SeqCst);
415    // 将增量累加到上一次错误数量
416    number_of_last_errors.fetch_add(errors_per_second, Ordering::Relaxed);
417    let rps = rps_queue_arc
418        .lock()
419        .await
420        .average()
421        .await
422        .unwrap_or_else(|| 0f64);
423    // 将增量累加
424    number_of_last_requests.fetch_add(rps as usize, Ordering::Relaxed);
425    // 最终结果
426    let result = Ok(BatchResult {
427        total_duration,
428        success_rate,
429        error_rate,
430        median_response_time: match histogram.percentile(50.0) {
431            Ok(b) => *b.range().start(),
432            Err(e) => {
433                return Err(Error::msg(format!("获取50线失败::{:?}", e.to_string())));
434            }
435        },
436        response_time_95: match histogram.percentile(95.0) {
437            Ok(b) => *b.range().start(),
438            Err(e) => {
439                return Err(Error::msg(format!("获取95线失败::{:?}", e.to_string())));
440            }
441        },
442        response_time_99: match histogram.percentile(99.0) {
443            Ok(b) => *b.range().start(),
444            Err(e) => {
445                return Err(Error::msg(format!("获取99线失败::{:?}", e.to_string())));
446            }
447        },
448        total_requests,
449        rps,
450        max_response_time: *max_response_time.lock().await,
451        min_response_time: *min_response_time.lock().await,
452        err_count: err_count_clone.load(Ordering::SeqCst) as i32,
453        total_data_kb: total_response_size_kb,
454        throughput_per_second_kb: throughput_kb_s,
455        http_errors: http_errors.lock().await.clone(),
456        timestamp,
457        assert_errors: assert_errors.lock().await.clone(),
458        total_concurrent_number: total_concurrent_number_clone,
459        api_results: api_results.to_vec().clone(),
460        errors_per_second,
461    });
462    should_stop_tx.send(()).unwrap();
463    eprintln!("测试完成!");
464    result
465}