1use std::time::{Duration, Instant};
7use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
8use tokio::time::sleep;
9use crate::{TestResult, client::TestClient};
10
11#[derive(Debug, Clone)]
13pub struct LoadTestConfig {
14 pub concurrent_users: usize,
16 pub duration: Duration,
18 pub target_rps: usize,
20 pub ramp_up: Duration,
22 pub timeout: Duration,
24}
25
26impl LoadTestConfig {
27 pub fn basic() -> Self {
29 Self {
30 concurrent_users: 10,
31 duration: Duration::from_secs(30),
32 target_rps: 0, ramp_up: Duration::from_secs(5),
34 timeout: Duration::from_secs(30),
35 }
36 }
37
38 pub fn light() -> Self {
40 Self {
41 concurrent_users: 5,
42 duration: Duration::from_secs(10),
43 target_rps: 50,
44 ramp_up: Duration::from_secs(2),
45 timeout: Duration::from_secs(10),
46 }
47 }
48
49 pub fn heavy() -> Self {
51 Self {
52 concurrent_users: 100,
53 duration: Duration::from_secs(120),
54 target_rps: 1000,
55 ramp_up: Duration::from_secs(30),
56 timeout: Duration::from_secs(30),
57 }
58 }
59
60 pub fn with_concurrent_users(mut self, users: usize) -> Self {
62 self.concurrent_users = users;
63 self
64 }
65
66 pub fn with_duration(mut self, duration: Duration) -> Self {
68 self.duration = duration;
69 self
70 }
71
72 pub fn with_target_rps(mut self, rps: usize) -> Self {
74 self.target_rps = rps;
75 self
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct LoadTestResults {
82 pub total_requests: usize,
84 pub successful_requests: usize,
86 pub failed_requests: usize,
88 pub duration: Duration,
90 pub avg_rps: f64,
92 pub response_times: ResponseTimeStats,
94 pub errors: Vec<(String, usize)>,
96}
97
98impl LoadTestResults {
99 pub fn success_rate(&self) -> f64 {
101 if self.total_requests == 0 {
102 0.0
103 } else {
104 (self.successful_requests as f64 / self.total_requests as f64) * 100.0
105 }
106 }
107
108 pub fn passes_criteria(&self, min_success_rate: f64, max_avg_response_time: Duration) -> bool {
110 self.success_rate() >= min_success_rate &&
111 self.response_times.avg <= max_avg_response_time
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct ResponseTimeStats {
118 pub min: Duration,
119 pub max: Duration,
120 pub avg: Duration,
121 pub p50: Duration,
122 pub p95: Duration,
123 pub p99: Duration,
124}
125
126impl ResponseTimeStats {
127 fn new() -> Self {
128 Self {
129 min: Duration::ZERO,
130 max: Duration::ZERO,
131 avg: Duration::ZERO,
132 p50: Duration::ZERO,
133 p95: Duration::ZERO,
134 p99: Duration::ZERO,
135 }
136 }
137
138 fn from_times(mut times: Vec<Duration>) -> Self {
139 if times.is_empty() {
140 return Self::new();
141 }
142
143 times.sort();
144 let len = times.len();
145
146 let min = times[0];
147 let max = times[len - 1];
148 let avg = Duration::from_nanos(
149 times.iter().map(|d| d.as_nanos() as u64).sum::<u64>() / len as u64
150 );
151
152 let p50 = times[len * 50 / 100];
153 let p95 = times[len * 95 / 100];
154 let p99 = times[len * 99 / 100];
155
156 Self { min, max, avg, p50, p95, p99 }
157 }
158}
159
160pub struct LoadTestRunner {
162 config: LoadTestConfig,
163 base_client: TestClient,
164}
165
166impl LoadTestRunner {
167 pub fn new(config: LoadTestConfig) -> Self {
169 Self {
170 config,
171 base_client: TestClient::new(),
172 }
173 }
174
175 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
177 self.base_client = TestClient::with_base_url(url);
178 self
179 }
180
181 pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
183 self.base_client = self.base_client.authenticated_with_token(token);
184 self
185 }
186
187 pub async fn run_scenario<F, Fut>(&self, scenario: F) -> TestResult<LoadTestResults>
189 where
190 F: Fn(TestClient) -> Fut + Send + Sync + Clone + 'static,
191 Fut: std::future::Future<Output = TestResult<Duration>> + Send,
192 {
193 let start_time = Instant::now();
194 let end_time = start_time + self.config.duration;
195
196 let total_requests = Arc::new(AtomicUsize::new(0));
198 let successful_requests = Arc::new(AtomicUsize::new(0));
199 let failed_requests = Arc::new(AtomicUsize::new(0));
200 let response_times = Arc::new(tokio::sync::Mutex::new(Vec::<Duration>::new()));
201 let errors = Arc::new(tokio::sync::Mutex::new(Vec::<String>::new()));
202
203 let ramp_up_delay = if self.config.concurrent_users > 0 {
205 self.config.ramp_up.as_millis() / self.config.concurrent_users as u128
206 } else {
207 0
208 };
209
210 let target_rps = self.config.target_rps;
211
212 let mut handles = Vec::new();
214 for i in 0..self.config.concurrent_users {
215 let scenario = scenario.clone();
216 let client = self.base_client.clone();
217 let total_requests = total_requests.clone();
218 let successful_requests = successful_requests.clone();
219 let failed_requests = failed_requests.clone();
220 let response_times = response_times.clone();
221 let errors = errors.clone();
222 let end_time = end_time;
223
224 let handle = tokio::spawn(async move {
225 if ramp_up_delay > 0 {
227 sleep(Duration::from_millis((i as u128 * ramp_up_delay) as u64)).await;
228 }
229
230 while Instant::now() < end_time {
231 total_requests.fetch_add(1, Ordering::Relaxed);
232
233 match scenario(client.clone()).await {
234 Ok(duration) => {
235 successful_requests.fetch_add(1, Ordering::Relaxed);
236 response_times.lock().await.push(duration);
237 },
238 Err(e) => {
239 failed_requests.fetch_add(1, Ordering::Relaxed);
240 errors.lock().await.push(e.to_string());
241 }
242 }
243
244 if target_rps > 0 {
246 let delay = Duration::from_millis(1000 / target_rps as u64);
247 sleep(delay).await;
248 }
249 }
250 });
251
252 handles.push(handle);
253 }
254
255 for handle in handles {
257 let _ = handle.await;
258 }
259
260 let actual_duration = start_time.elapsed();
261 let total = total_requests.load(Ordering::Relaxed);
262 let successful = successful_requests.load(Ordering::Relaxed);
263 let failed = failed_requests.load(Ordering::Relaxed);
264
265 let avg_rps = if actual_duration.as_secs() > 0 {
266 total as f64 / actual_duration.as_secs_f64()
267 } else {
268 0.0
269 };
270
271 let times = response_times.lock().await.clone();
273 let response_time_stats = ResponseTimeStats::from_times(times);
274
275 let error_list = errors.lock().await.clone();
277 let mut error_counts = std::collections::HashMap::new();
278 for error in error_list {
279 *error_counts.entry(error).or_insert(0) += 1;
280 }
281 let error_distribution: Vec<(String, usize)> = error_counts.into_iter().collect();
282
283 Ok(LoadTestResults {
284 total_requests: total,
285 successful_requests: successful,
286 failed_requests: failed,
287 duration: actual_duration,
288 avg_rps,
289 response_times: response_time_stats,
290 errors: error_distribution,
291 })
292 }
293
294 pub async fn run_get_test(&self, path: impl Into<String>) -> TestResult<LoadTestResults> {
296 let path = path.into();
297 self.run_scenario(move |client| {
298 let path = path.clone();
299 async move {
300 let start = Instant::now();
301 client.get(path).send().await?;
302 Ok(start.elapsed())
303 }
304 }).await
305 }
306
307 pub async fn run_post_test<T: serde::Serialize + Clone + Send + Sync + 'static>(
309 &self,
310 path: impl Into<String>,
311 data: T
312 ) -> TestResult<LoadTestResults> {
313 let path = path.into();
314 self.run_scenario(move |client| {
315 let path = path.clone();
316 let data = data.clone();
317 async move {
318 let start = Instant::now();
319 client.post(path).json(&data).send().await?;
320 Ok(start.elapsed())
321 }
322 }).await
323 }
324}
325
326pub struct Benchmark {
328 name: String,
329 iterations: usize,
330}
331
332impl Benchmark {
333 pub fn new(name: impl Into<String>, iterations: usize) -> Self {
335 Self {
336 name: name.into(),
337 iterations,
338 }
339 }
340
341 pub fn run_sync<F>(&self, mut operation: F) -> BenchmarkResult
343 where
344 F: FnMut() -> (),
345 {
346 let mut times = Vec::with_capacity(self.iterations);
347
348 for _ in 0..self.iterations {
349 let start = Instant::now();
350 operation();
351 times.push(start.elapsed());
352 }
353
354 BenchmarkResult {
355 name: self.name.clone(),
356 iterations: self.iterations,
357 stats: ResponseTimeStats::from_times(times),
358 }
359 }
360
361 pub async fn run_async<F, Fut>(&self, operation: F) -> BenchmarkResult
363 where
364 F: Fn() -> Fut,
365 Fut: std::future::Future<Output = ()>,
366 {
367 let mut times = Vec::with_capacity(self.iterations);
368
369 for _ in 0..self.iterations {
370 let start = Instant::now();
371 operation().await;
372 times.push(start.elapsed());
373 }
374
375 BenchmarkResult {
376 name: self.name.clone(),
377 iterations: self.iterations,
378 stats: ResponseTimeStats::from_times(times),
379 }
380 }
381}
382
383#[derive(Debug, Clone)]
385pub struct BenchmarkResult {
386 pub name: String,
387 pub iterations: usize,
388 pub stats: ResponseTimeStats,
389}
390
391impl BenchmarkResult {
392 pub fn ops_per_second(&self) -> f64 {
394 if self.stats.avg.as_nanos() > 0 {
395 1_000_000_000.0 / self.stats.avg.as_nanos() as f64
396 } else {
397 0.0
398 }
399 }
400
401 pub fn print(&self) {
403 println!("Benchmark: {}", self.name);
404 println!("Iterations: {}", self.iterations);
405 println!("Average time: {:?}", self.stats.avg);
406 println!("Min time: {:?}", self.stats.min);
407 println!("Max time: {:?}", self.stats.max);
408 println!("95th percentile: {:?}", self.stats.p95);
409 println!("99th percentile: {:?}", self.stats.p99);
410 println!("Ops/sec: {:.2}", self.ops_per_second());
411 println!("---");
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
420 fn test_load_test_config() {
421 let config = LoadTestConfig::basic()
422 .with_concurrent_users(20)
423 .with_duration(Duration::from_secs(60));
424
425 assert_eq!(config.concurrent_users, 20);
426 assert_eq!(config.duration, Duration::from_secs(60));
427 }
428
429 #[test]
430 fn test_response_time_stats() {
431 let times = vec![
432 Duration::from_millis(10),
433 Duration::from_millis(20),
434 Duration::from_millis(30),
435 Duration::from_millis(40),
436 Duration::from_millis(50),
437 ];
438
439 let stats = ResponseTimeStats::from_times(times);
440 assert_eq!(stats.min, Duration::from_millis(10));
441 assert_eq!(stats.max, Duration::from_millis(50));
442 assert_eq!(stats.avg, Duration::from_millis(30));
443 }
444
445 #[test]
446 fn test_load_test_results() {
447 let results = LoadTestResults {
448 total_requests: 100,
449 successful_requests: 95,
450 failed_requests: 5,
451 duration: Duration::from_secs(10),
452 avg_rps: 10.0,
453 response_times: ResponseTimeStats::new(),
454 errors: vec![],
455 };
456
457 assert_eq!(results.success_rate(), 95.0);
458 }
459
460 #[tokio::test]
461 async fn test_benchmark() {
462 let benchmark = Benchmark::new("test_operation", 100);
463
464 let result = benchmark.run_sync(|| {
465 std::thread::sleep(Duration::from_micros(1));
467 });
468
469 assert_eq!(result.name, "test_operation");
470 assert_eq!(result.iterations, 100);
471 assert!(result.ops_per_second() > 0.0);
472 }
473
474 #[tokio::test]
475 async fn test_async_benchmark() {
476 let benchmark = Benchmark::new("async_test_operation", 10);
477
478 let result = benchmark.run_async(|| async {
479 tokio::time::sleep(Duration::from_millis(1)).await;
480 }).await;
481
482 assert_eq!(result.iterations, 10);
483 }
484
485 #[tokio::test]
486 async fn test_load_test_runner_creation() {
487 let config = LoadTestConfig::light();
488 let runner = LoadTestRunner::new(config.clone());
489
490 assert_eq!(runner.config.concurrent_users, config.concurrent_users);
491 }
492}