1use std::sync::Arc;
2use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
25use std::time::{Duration, Instant};
26
27use crate::error::WorkerResult;
28use crate::message::{AckHandle, Message, MessageMetadata, ReceivedMessage};
29use crate::worker::Worker;
30use async_trait::async_trait;
31
32#[derive(Debug, Clone)]
34pub struct StressTestConfig {
35 pub message_count: usize,
37 pub concurrency: usize,
39 pub message_size_bytes: usize,
41 pub simulate_processing_delay: bool,
43 pub processing_delay_range_ms: (u64, u64),
45 pub test_timeout_secs: u64,
47}
48
49impl Default for StressTestConfig {
50 fn default() -> Self {
51 Self {
52 message_count: 10_000,
53 concurrency: 50,
54 message_size_bytes: 256,
55 simulate_processing_delay: false,
56 processing_delay_range_ms: (1, 10),
57 test_timeout_secs: 300, }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct StressTestResults {
65 pub total_messages: usize,
67 pub total_duration: Duration,
69 pub throughput: f64,
71 pub avg_processing_time_ms: f64,
73 pub p95_processing_time_ms: f64,
75 pub p99_processing_time_ms: f64,
77 pub peak_memory_mb: f64,
79 pub error_count: usize,
81 pub success_rate: f64,
83}
84
85impl StressTestResults {
86 pub fn print_summary(&self) {
88 println!("\n=== Stress Test Results ===");
89 println!("Total Messages: {}", self.total_messages);
90 println!("Total Duration: {:?}", self.total_duration);
91 println!("Throughput: {:.2} msg/sec", self.throughput);
92 println!("Avg Processing Time: {:.2} ms", self.avg_processing_time_ms);
93 println!("P95 Processing Time: {:.2} ms", self.p95_processing_time_ms);
94 println!("P99 Processing Time: {:.2} ms", self.p99_processing_time_ms);
95 println!("Peak Memory: {:.2} MB", self.peak_memory_mb);
96 println!("Errors: {}", self.error_count);
97 println!("Success Rate: {:.2}%", self.success_rate);
98 println!("==========================\n");
99 }
100}
101
102#[derive(Debug)]
104struct StressTestAckHandle;
105
106#[async_trait]
107impl AckHandle for StressTestAckHandle {
108 async fn ack(&self) -> WorkerResult<()> {
109 Ok(())
110 }
111
112 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
113 Ok(())
114 }
115}
116
117struct StressTestWorker {
119 id: String,
120 processing_times: Arc<Vec<AtomicU64>>, config: StressTestConfig,
122}
123
124#[async_trait]
125impl Worker for StressTestWorker {
126 fn id(&self) -> &str {
127 &self.id
128 }
129
130 async fn process(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
131 let start = Instant::now();
132
133 if self.config.simulate_processing_delay {
135 let (min_ms, max_ms) = self.config.processing_delay_range_ms;
136 let delay_ms = rand::random::<u64>() % (max_ms - min_ms + 1) + min_ms;
137 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
138 }
139
140 let _hash = calculate_hash(&message.message.payload);
142
143 let elapsed = start.elapsed();
144
145 let idx = self
147 .id
148 .split('-')
149 .next_back()
150 .unwrap_or("0")
151 .parse::<usize>()
152 .unwrap_or(0);
153 if let Some(counter) = self.processing_times.get(idx) {
154 counter.store(elapsed.as_micros() as u64, Ordering::Relaxed);
155 }
156
157 message.ack().await?;
159
160 Ok(())
161 }
162}
163
164fn calculate_hash(value: &serde_json::Value) -> u64 {
166 let serialized = serde_json::to_string(value).unwrap_or_default();
167 serialized
168 .bytes()
169 .fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64))
170}
171
172pub async fn run_stress_test(config: StressTestConfig) -> StressTestResults {
177 use crate::metrics::NoOpMetrics;
178 use crate::pool::WorkerPool;
179 use crate::strategies::LoadBalancingStrategy;
180
181 println!("Starting stress test...");
182 println!(" Messages: {}", config.message_count);
183 println!(" Concurrency: {}", config.concurrency);
184 println!(" Message Size: {} bytes", config.message_size_bytes);
185 println!();
186
187 let start_time = Instant::now();
188 let processing_times = Arc::new(
189 (0..config.concurrency)
190 .map(|_| AtomicU64::new(0))
191 .collect::<Vec<_>>(),
192 );
193 let error_count = Arc::new(AtomicUsize::new(0));
194
195 let mut pool = WorkerPool::with_concurrency(
197 "stress-test-pool",
198 LoadBalancingStrategy::RoundRobin,
199 config.concurrency,
200 Arc::new(NoOpMetrics),
201 );
202
203 for i in 0..config.concurrency {
205 let worker = StressTestWorker {
206 id: format!("worker-{}", i),
207 processing_times: processing_times.clone(),
208 config: config.clone(),
209 };
210 pool.add_worker(Arc::new(worker));
211 }
212
213 let test_payload = generate_test_payload(config.message_size_bytes);
215
216 println!("Dispatching {} messages...", config.message_count);
217 let dispatch_start = Instant::now();
218
219 for i in 0..config.message_count {
221 let message = create_stress_test_message(&format!("msg-{}", i), test_payload.clone());
222
223 if let Err(e) = pool.dispatch(message).await {
224 eprintln!("Failed to dispatch message {}: {}", i, e);
225 error_count.fetch_add(1, Ordering::Relaxed);
226 }
227
228 if (i + 1) % 1000 == 0 {
230 println!(" Dispatched {} / {} messages", i + 1, config.message_count);
231 }
232 }
233
234 let dispatch_duration = dispatch_start.elapsed();
235 println!("Dispatch completed in {:?}", dispatch_duration);
236
237 println!("Waiting for processing to complete...");
239 let timeout = Duration::from_secs(config.test_timeout_secs);
240
241 loop {
242 if start_time.elapsed() > timeout {
243 eprintln!("WARNING: Test timeout reached!");
244 break;
245 }
246
247 if pool.in_flight_count() == 0 {
249 tokio::time::sleep(Duration::from_millis(100)).await;
251 break;
252 }
253
254 tokio::time::sleep(Duration::from_millis(50)).await;
255 }
256
257 let total_duration = start_time.elapsed();
258 let errors = error_count.load(Ordering::Relaxed);
259 let successful = config.message_count.saturating_sub(errors);
260
261 let throughput = successful as f64 / total_duration.as_secs_f64();
263
264 let times: Vec<u64> = processing_times
265 .iter()
266 .map(|t| t.load(Ordering::Relaxed))
267 .filter(|&t| t > 0)
268 .collect();
269
270 let avg_time = if times.is_empty() {
271 0.0
272 } else {
273 times.iter().sum::<u64>() as f64 / times.len() as f64 / 1000.0 };
275
276 let mut sorted_times = times.clone();
277 sorted_times.sort();
278
279 let p95_idx =
280 ((sorted_times.len() as f64 * 0.95) as usize).min(sorted_times.len().saturating_sub(1));
281 let p99_idx =
282 ((sorted_times.len() as f64 * 0.99) as usize).min(sorted_times.len().saturating_sub(1));
283
284 let p95_time = sorted_times.get(p95_idx).copied().unwrap_or(0) as f64 / 1000.0;
285 let p99_time = sorted_times.get(p99_idx).copied().unwrap_or(0) as f64 / 1000.0;
286
287 let peak_memory_mb = estimate_memory_usage(
289 config.message_count,
290 config.message_size_bytes,
291 config.concurrency,
292 );
293
294 let success_rate = if config.message_count > 0 {
295 (successful as f64 / config.message_count as f64) * 100.0
296 } else {
297 0.0
298 };
299
300 let results = StressTestResults {
301 total_messages: successful,
302 total_duration,
303 throughput,
304 avg_processing_time_ms: avg_time,
305 p95_processing_time_ms: p95_time,
306 p99_processing_time_ms: p99_time,
307 peak_memory_mb,
308 error_count: errors,
309 success_rate,
310 };
311
312 results.print_summary();
313
314 if let Err(e) = pool.shutdown().await {
316 eprintln!("Warning: Pool shutdown failed: {}", e);
317 }
318
319 results
320}
321
322fn generate_test_payload(size_bytes: usize) -> serde_json::Value {
324 let data = "x".repeat(size_bytes);
325 serde_json::json!({
326 "data": data,
327 "timestamp": chrono::Utc::now().to_rfc3339(),
328 "id": uuid::Uuid::new_v4().to_string()
329 })
330}
331
332fn create_stress_test_message(
334 id: &str,
335 payload: serde_json::Value,
336) -> ReceivedMessage<serde_json::Value> {
337 let message = Message {
338 id: id.to_string(),
339 payload,
340 metadata: MessageMetadata::new("stress-test-queue"),
341 };
342 ReceivedMessage::new(message, Arc::new(StressTestAckHandle))
343}
344
345fn estimate_memory_usage(message_count: usize, message_size: usize, concurrency: usize) -> f64 {
347 let in_flight_bytes = concurrency * message_size;
352 let queue_overhead = message_count * message_size / 10;
353 let worker_state = concurrency * 1_048_576; let total_bytes = in_flight_bytes + queue_overhead + worker_state;
356 total_bytes as f64 / 1_048_576.0 }
358
359pub async fn run_stability_test(duration_secs: u64, config: StressTestConfig) -> StressTestResults {
364 println!("Starting {}-second stability test...", duration_secs);
365
366 let start = Instant::now();
367 let target_duration = Duration::from_secs(duration_secs);
368 let mut iteration = 0;
369 let mut total_processed = 0;
370 let mut total_errors = 0;
371
372 while start.elapsed() < target_duration {
373 iteration += 1;
374 println!("\n--- Iteration {} ---", iteration);
375
376 let mut iter_config = config.clone();
377 iter_config.message_count = config.message_count / 10; let results = run_stress_test(iter_config).await;
380 total_processed += results.total_messages;
381 total_errors += results.error_count;
382
383 tokio::time::sleep(Duration::from_secs(1)).await;
385 }
386
387 let total_duration = start.elapsed();
388
389 StressTestResults {
390 total_messages: total_processed,
391 total_duration,
392 throughput: total_processed as f64 / total_duration.as_secs_f64(),
393 avg_processing_time_ms: 0.0, p95_processing_time_ms: 0.0,
395 p99_processing_time_ms: 0.0,
396 peak_memory_mb: 0.0,
397 error_count: total_errors,
398 success_rate: if total_processed + total_errors > 0 {
399 (total_processed as f64 / (total_processed + total_errors) as f64) * 100.0
400 } else {
401 0.0
402 },
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[tokio::test]
411 async fn test_small_stress_test() {
412 let config = StressTestConfig {
413 message_count: 100,
414 concurrency: 5,
415 message_size_bytes: 64,
416 ..Default::default()
417 };
418
419 let results = run_stress_test(config).await;
420
421 assert!(results.total_messages > 0);
422 assert!(results.throughput > 0.0);
423 assert!(results.success_rate >= 99.0); }
425
426 #[test]
427 fn test_payload_generation() {
428 let payload = generate_test_payload(1024);
429 let serialized = serde_json::to_string(&payload).unwrap();
430
431 assert!(serialized.len() >= 1024);
433 }
434
435 #[test]
436 fn test_memory_estimation() {
437 let mem = estimate_memory_usage(1000, 256, 10);
438 assert!(mem > 0.0);
439 assert!(mem < 100.0); }
441}