1use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::{Arc, Mutex};
12use std::time::{Duration, Instant};
13
14use crate::error::Result;
15
16#[derive(Debug, Clone)]
18pub struct LoadTestConfig {
19 pub concurrent_users: u32,
21 pub ramp_up: Duration,
23 pub duration: Duration,
25 pub requests_per_user: Option<u32>,
27 pub target_rps: Option<f64>,
29}
30
31impl Default for LoadTestConfig {
32 fn default() -> Self {
33 Self {
34 concurrent_users: 10,
35 ramp_up: Duration::from_secs(10),
36 duration: Duration::from_secs(60),
37 requests_per_user: None,
38 target_rps: None,
39 }
40 }
41}
42
43impl LoadTestConfig {
44 #[must_use]
46 pub fn light() -> Self {
47 Self {
48 concurrent_users: 5,
49 ramp_up: Duration::from_secs(5),
50 duration: Duration::from_secs(30),
51 ..Default::default()
52 }
53 }
54
55 #[must_use]
57 pub fn moderate() -> Self {
58 Self {
59 concurrent_users: 50,
60 ramp_up: Duration::from_secs(30),
61 duration: Duration::from_secs(120),
62 ..Default::default()
63 }
64 }
65
66 #[must_use]
68 pub fn heavy() -> Self {
69 Self {
70 concurrent_users: 200,
71 ramp_up: Duration::from_secs(60),
72 duration: Duration::from_secs(300),
73 ..Default::default()
74 }
75 }
76
77 #[must_use]
79 pub fn quick() -> Self {
80 Self {
81 concurrent_users: 4,
82 ramp_up: Duration::from_millis(100),
83 duration: Duration::from_millis(500),
84 requests_per_user: Some(10),
85 target_rps: None,
86 }
87 }
88}
89
90struct LoadMetrics {
92 total_requests: AtomicU64,
93 successful: AtomicU64,
94 failed: AtomicU64,
95 latencies_us: Mutex<Vec<u64>>,
96}
97
98impl Default for LoadMetrics {
99 fn default() -> Self {
100 Self {
101 total_requests: AtomicU64::new(0),
102 successful: AtomicU64::new(0),
103 failed: AtomicU64::new(0),
104 latencies_us: Mutex::new(Vec::new()),
105 }
106 }
107}
108
109impl LoadMetrics {
110 fn record_success(&self, latency_us: u64) {
111 self.total_requests.fetch_add(1, Ordering::Relaxed);
112 self.successful.fetch_add(1, Ordering::Relaxed);
113 if let Ok(mut latencies) = self.latencies_us.lock() {
114 latencies.push(latency_us);
115 }
116 }
117
118 fn record_failure(&self) {
119 self.total_requests.fetch_add(1, Ordering::Relaxed);
120 self.failed.fetch_add(1, Ordering::Relaxed);
121 }
122
123 fn get_latencies(&self) -> Vec<u64> {
124 self.latencies_us.lock().map(|l| l.clone()).unwrap_or_default()
125 }
126}
127
128pub type RequestHandler = Arc<dyn Fn(u32, u64) -> bool + Send + Sync>;
131
132pub struct LoadTester {
134 config: LoadTestConfig,
135 handler: Option<RequestHandler>,
136}
137
138impl LoadTester {
139 #[must_use]
141 pub const fn new(config: LoadTestConfig) -> Self {
142 Self {
143 config,
144 handler: None,
145 }
146 }
147
148 #[must_use]
152 pub fn with_handler(mut self, handler: RequestHandler) -> Self {
153 self.handler = Some(handler);
154 self
155 }
156
157 #[allow(clippy::too_many_lines)]
167 pub async fn run(&self) -> Result<LoadTestReport> {
168 tracing::info!(
169 users = self.config.concurrent_users,
170 duration = ?self.config.duration,
171 ramp_up = ?self.config.ramp_up,
172 "starting load test"
173 );
174
175 let metrics = Arc::new(LoadMetrics::default());
176 let start_time = Instant::now();
177 let test_end = start_time + self.config.ramp_up + self.config.duration;
178
179 let mut handles = Vec::with_capacity(self.config.concurrent_users as usize);
181 let ramp_delay = if self.config.concurrent_users > 1 {
182 self.config.ramp_up.as_millis() as u64 / (u64::from(self.config.concurrent_users) - 1).max(1)
183 } else {
184 0
185 };
186
187 let concurrent_users = self.config.concurrent_users;
189
190 for user_id in 0..self.config.concurrent_users {
191 let metrics = Arc::clone(&metrics);
192 let requests_per_user = self.config.requests_per_user;
193 let target_rps = self.config.target_rps;
194 let request_handler = self.handler.clone();
195
196 let worker_start_delay = Duration::from_millis(ramp_delay * u64::from(user_id));
198
199 handles.push(tokio::spawn(async move {
200 tokio::time::sleep(worker_start_delay).await;
202
203 let mut request_id = 0u64;
204 let interval = target_rps.map(|rps| {
205 Duration::from_secs_f64(1.0 / rps * f64::from(concurrent_users))
206 });
207
208 loop {
209 if Instant::now() >= test_end {
211 break;
212 }
213 if requests_per_user.is_some_and(|max| request_id >= u64::from(max)) {
214 break;
215 }
216
217 let req_start = Instant::now();
219 let success = if let Some(ref h) = request_handler {
220 h(user_id, request_id)
221 } else {
222 tokio::time::sleep(Duration::from_micros(100)).await;
224 !request_id.is_multiple_of(100) };
226 let latency_us = req_start.elapsed().as_micros() as u64;
227
228 if success {
229 metrics.record_success(latency_us);
230 } else {
231 metrics.record_failure();
232 }
233
234 request_id += 1;
235
236 if let Some(delay) = interval {
238 tokio::time::sleep(delay).await;
239 }
240 }
241 }));
242 }
243
244 for handle in handles {
246 let _ = handle.await;
247 }
248
249 let elapsed = start_time.elapsed();
250
251 let total_requests = metrics.total_requests.load(Ordering::Relaxed);
253 let successful = metrics.successful.load(Ordering::Relaxed);
254 let failed = metrics.failed.load(Ordering::Relaxed);
255
256 let mut latencies = metrics.get_latencies();
257 latencies.sort_unstable();
258
259 let (p50, p95, p99) = if latencies.is_empty() {
260 (0, 0, 0)
261 } else {
262 (
263 percentile(&latencies, 50),
264 percentile(&latencies, 95),
265 percentile(&latencies, 99),
266 )
267 };
268
269 let throughput_rps = if elapsed.as_secs_f64() > 0.0 {
270 total_requests as f64 / elapsed.as_secs_f64()
271 } else {
272 0.0
273 };
274
275 let error_rate = if total_requests > 0 {
276 failed as f64 / total_requests as f64
277 } else {
278 0.0
279 };
280
281 tracing::info!(
282 total = total_requests,
283 successful = successful,
284 failed = failed,
285 throughput_rps = format!("{throughput_rps:.2}"),
286 p50_us = p50,
287 p95_us = p95,
288 p99_us = p99,
289 "load test completed"
290 );
291
292 Ok(LoadTestReport {
293 total_requests,
294 successful,
295 failed,
296 latency_p50_us: p50,
297 latency_p95_us: p95,
298 latency_p99_us: p99,
299 throughput_rps,
300 error_rate,
301 })
302 }
303
304 #[must_use]
306 pub const fn config(&self) -> &LoadTestConfig {
307 &self.config
308 }
309}
310
311impl Default for LoadTester {
312 fn default() -> Self {
313 Self::new(LoadTestConfig::default())
314 }
315}
316
317fn percentile(sorted: &[u64], p: usize) -> u64 {
319 if sorted.is_empty() {
320 return 0;
321 }
322 let idx = (sorted.len() * p / 100).min(sorted.len() - 1);
323 sorted[idx]
324}
325
326#[derive(Debug, Clone)]
328pub struct LoadTestReport {
329 pub total_requests: u64,
331 pub successful: u64,
333 pub failed: u64,
335 pub latency_p50_us: u64,
337 pub latency_p95_us: u64,
339 pub latency_p99_us: u64,
341 pub throughput_rps: f64,
343 pub error_rate: f64,
345}
346
347impl LoadTestReport {
348 #[must_use]
350 pub fn passed(&self) -> bool {
351 self.error_rate < 0.01
352 }
353
354 #[must_use]
356 pub fn success_rate(&self) -> f64 {
357 if self.total_requests > 0 {
358 self.successful as f64 / self.total_requests as f64
359 } else {
360 0.0
361 }
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_load_test_config_presets() {
371 let light = LoadTestConfig::light();
372 assert_eq!(light.concurrent_users, 5);
373
374 let moderate = LoadTestConfig::moderate();
375 assert_eq!(moderate.concurrent_users, 50);
376
377 let heavy = LoadTestConfig::heavy();
378 assert_eq!(heavy.concurrent_users, 200);
379
380 let quick = LoadTestConfig::quick();
381 assert_eq!(quick.concurrent_users, 4);
382 assert_eq!(quick.requests_per_user, Some(10));
383 }
384
385 #[tokio::test]
386 async fn test_load_tester_run_quick() {
387 let tester = LoadTester::new(LoadTestConfig::quick());
389 let report = tester.run().await;
390
391 assert!(report.is_ok());
392 let report = report.unwrap();
393
394 assert_eq!(report.total_requests, 40);
396 assert_eq!(report.failed, 4);
399 assert_eq!(report.successful, 36);
400 assert!(report.throughput_rps > 0.0);
401 assert!(report.latency_p50_us > 0);
402 }
403
404 #[tokio::test]
405 async fn test_load_tester_with_custom_handler() {
406 let handler: RequestHandler = Arc::new(|_user_id, request_id| {
407 request_id % 5 != 0
409 });
410
411 let config = LoadTestConfig {
412 concurrent_users: 2,
413 ramp_up: Duration::from_millis(10),
414 duration: Duration::from_millis(100),
415 requests_per_user: Some(10),
416 target_rps: None,
417 };
418
419 let tester = LoadTester::new(config).with_handler(handler);
420 let report = tester.run().await.unwrap();
421
422 assert_eq!(report.total_requests, 20);
424 assert_eq!(report.failed, 4);
426 assert_eq!(report.successful, 16);
427 }
428
429 #[tokio::test]
430 async fn test_load_tester_all_success() {
431 let handler: RequestHandler = Arc::new(|_, _| true);
432
433 let config = LoadTestConfig {
434 concurrent_users: 2,
435 ramp_up: Duration::from_millis(10),
436 duration: Duration::from_millis(100),
437 requests_per_user: Some(5),
438 target_rps: None,
439 };
440
441 let tester = LoadTester::new(config).with_handler(handler);
442 let report = tester.run().await.unwrap();
443
444 assert_eq!(report.total_requests, 10);
445 assert_eq!(report.failed, 0);
446 assert_eq!(report.successful, 10);
447 assert!(report.passed());
448 assert!((report.success_rate() - 1.0).abs() < 0.001);
449 assert!((report.error_rate - 0.0).abs() < 0.001);
450 }
451
452 #[tokio::test]
453 async fn test_load_tester_all_failure() {
454 let handler: RequestHandler = Arc::new(|_, _| false);
455
456 let config = LoadTestConfig {
457 concurrent_users: 2,
458 ramp_up: Duration::from_millis(10),
459 duration: Duration::from_millis(100),
460 requests_per_user: Some(5),
461 target_rps: None,
462 };
463
464 let tester = LoadTester::new(config).with_handler(handler);
465 let report = tester.run().await.unwrap();
466
467 assert_eq!(report.total_requests, 10);
468 assert_eq!(report.failed, 10);
469 assert_eq!(report.successful, 0);
470 assert!(!report.passed());
471 assert!((report.success_rate() - 0.0).abs() < 0.001);
472 assert!((report.error_rate - 1.0).abs() < 0.001);
473 }
474
475 #[test]
476 fn test_load_test_report_passed() {
477 let report = LoadTestReport {
478 total_requests: 1000,
479 successful: 995,
480 failed: 5,
481 latency_p50_us: 1000,
482 latency_p95_us: 5000,
483 latency_p99_us: 10000,
484 throughput_rps: 100.0,
485 error_rate: 0.005,
486 };
487
488 assert!(report.passed());
489 assert!((report.success_rate() - 0.995).abs() < 0.001);
490 }
491
492 #[test]
493 fn test_load_test_report_failed() {
494 let report = LoadTestReport {
495 total_requests: 100,
496 successful: 90,
497 failed: 10,
498 latency_p50_us: 1000,
499 latency_p95_us: 5000,
500 latency_p99_us: 10000,
501 throughput_rps: 100.0,
502 error_rate: 0.10, };
504
505 assert!(!report.passed()); assert!((report.success_rate() - 0.9).abs() < 0.001);
507 }
508
509 #[test]
510 fn test_percentile_calculation() {
511 let sorted = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
512
513 assert_eq!(percentile(&sorted, 0), 1);
519 assert_eq!(percentile(&sorted, 50), 6);
520 assert_eq!(percentile(&sorted, 90), 10);
521 assert_eq!(percentile(&sorted, 100), 10);
522 }
523
524 #[test]
525 fn test_percentile_empty() {
526 let empty: Vec<u64> = vec![];
527 assert_eq!(percentile(&empty, 50), 0);
528 }
529
530 #[test]
531 fn test_load_metrics_thread_safety() {
532 let metrics = Arc::new(LoadMetrics::default());
533 let mut handles = vec![];
534
535 for _ in 0..10 {
536 let m = Arc::clone(&metrics);
537 handles.push(std::thread::spawn(move || {
538 for i in 0..100 {
539 if i % 10 == 0 {
540 m.record_failure();
541 } else {
542 m.record_success(i * 100);
543 }
544 }
545 }));
546 }
547
548 for h in handles {
549 h.join().unwrap();
550 }
551
552 assert_eq!(metrics.total_requests.load(Ordering::Relaxed), 1000);
553 assert_eq!(metrics.failed.load(Ordering::Relaxed), 100);
554 assert_eq!(metrics.successful.load(Ordering::Relaxed), 900);
555
556 let latencies = metrics.get_latencies();
557 assert_eq!(latencies.len(), 900);
558 }
559}