Skip to main content

mabi_modbus/testing/
load_generator.rs

1//! Load generation utilities for stress testing.
2//!
3//! Provides configurable load generators for simulating various
4//! traffic patterns and connection behaviors.
5
6use std::net::SocketAddr;
7use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use parking_lot::RwLock;
12use tokio::io::AsyncWriteExt;
13use tokio::net::TcpStream;
14use tokio::sync::Semaphore;
15
16/// Load generation configuration.
17#[derive(Debug, Clone)]
18pub struct LoadConfig {
19    /// Target server address.
20    pub server_addr: SocketAddr,
21    /// Number of concurrent connections.
22    pub connections: usize,
23    /// Requests per second per connection.
24    pub requests_per_second: f64,
25    /// Total test duration.
26    pub duration: Duration,
27    /// Connection ramp-up time.
28    pub ramp_up: Duration,
29    /// Load pattern.
30    pub pattern: LoadPattern,
31    /// Connection timeout.
32    pub connect_timeout: Duration,
33    /// Request timeout.
34    pub request_timeout: Duration,
35    /// Whether to keep connections alive.
36    pub keep_alive: bool,
37}
38
39impl Default for LoadConfig {
40    fn default() -> Self {
41        Self {
42            server_addr: "127.0.0.1:502".parse().unwrap(),
43            connections: 100,
44            requests_per_second: 100.0,
45            duration: Duration::from_secs(60),
46            ramp_up: Duration::from_secs(5),
47            pattern: LoadPattern::Constant,
48            connect_timeout: Duration::from_secs(10),
49            request_timeout: Duration::from_secs(5),
50            keep_alive: true,
51        }
52    }
53}
54
55impl LoadConfig {
56    /// Create config for steady load test.
57    pub fn steady(connections: usize, rps: f64) -> Self {
58        Self {
59            connections,
60            requests_per_second: rps,
61            pattern: LoadPattern::Constant,
62            ..Default::default()
63        }
64    }
65
66    /// Create config for spike test.
67    pub fn spike(base_connections: usize, spike_connections: usize) -> Self {
68        Self {
69            connections: spike_connections,
70            pattern: LoadPattern::Spike {
71                base_load: base_connections,
72                spike_load: spike_connections,
73                spike_duration: Duration::from_secs(30),
74                recovery_duration: Duration::from_secs(30),
75            },
76            ..Default::default()
77        }
78    }
79
80    /// Create config for ramp test.
81    pub fn ramp(start: usize, end: usize, step_duration: Duration) -> Self {
82        Self {
83            connections: end,
84            pattern: LoadPattern::Ramp {
85                start_connections: start,
86                end_connections: end,
87                step_duration,
88            },
89            ..Default::default()
90        }
91    }
92
93    /// Set server address.
94    pub fn with_server(mut self, addr: SocketAddr) -> Self {
95        self.server_addr = addr;
96        self
97    }
98
99    /// Set duration.
100    pub fn with_duration(mut self, duration: Duration) -> Self {
101        self.duration = duration;
102        self
103    }
104}
105
106/// Load pattern for traffic simulation.
107#[derive(Debug, Clone)]
108pub enum LoadPattern {
109    /// Constant load throughout the test.
110    Constant,
111    /// Linearly increasing/decreasing load.
112    Ramp {
113        start_connections: usize,
114        end_connections: usize,
115        step_duration: Duration,
116    },
117    /// Sudden spike in load.
118    Spike {
119        base_load: usize,
120        spike_load: usize,
121        spike_duration: Duration,
122        recovery_duration: Duration,
123    },
124    /// Wave pattern (sinusoidal).
125    Wave {
126        min_connections: usize,
127        max_connections: usize,
128        period: Duration,
129    },
130    /// Random fluctuations.
131    Random {
132        min_connections: usize,
133        max_connections: usize,
134    },
135}
136
137/// Load generator for creating traffic.
138pub struct LoadGenerator {
139    config: LoadConfig,
140    running: Arc<AtomicBool>,
141    stats: Arc<LoadStats>,
142}
143
144/// Statistics collected during load generation.
145pub struct LoadStats {
146    pub requests_sent: AtomicU64,
147    pub requests_success: AtomicU64,
148    pub requests_failed: AtomicU64,
149    pub connections_opened: AtomicU64,
150    pub connections_failed: AtomicU64,
151    pub bytes_sent: AtomicU64,
152    pub bytes_received: AtomicU64,
153    latencies: RwLock<Vec<Duration>>,
154}
155
156impl LoadStats {
157    fn new() -> Self {
158        Self {
159            requests_sent: AtomicU64::new(0),
160            requests_success: AtomicU64::new(0),
161            requests_failed: AtomicU64::new(0),
162            connections_opened: AtomicU64::new(0),
163            connections_failed: AtomicU64::new(0),
164            bytes_sent: AtomicU64::new(0),
165            bytes_received: AtomicU64::new(0),
166            latencies: RwLock::new(Vec::with_capacity(10000)),
167        }
168    }
169
170    fn record_latency(&self, latency: Duration) {
171        let count = self.requests_sent.load(Ordering::Relaxed);
172        // Sample to avoid memory issues
173        if count % 10 == 0 {
174            self.latencies.write().push(latency);
175        }
176    }
177
178    /// Get percentile latency.
179    pub fn percentile(&self, p: f64) -> Option<Duration> {
180        let mut latencies = self.latencies.read().clone();
181        if latencies.is_empty() {
182            return None;
183        }
184        latencies.sort();
185        let idx = ((latencies.len() as f64 * p / 100.0) as usize).min(latencies.len() - 1);
186        Some(latencies[idx])
187    }
188
189    /// Get success rate.
190    pub fn success_rate(&self) -> f64 {
191        let total = self.requests_sent.load(Ordering::Relaxed);
192        let success = self.requests_success.load(Ordering::Relaxed);
193        if total > 0 {
194            success as f64 / total as f64
195        } else {
196            0.0
197        }
198    }
199}
200
201impl LoadGenerator {
202    /// Create a new load generator.
203    pub fn new(config: LoadConfig) -> Self {
204        Self {
205            config,
206            running: Arc::new(AtomicBool::new(false)),
207            stats: Arc::new(LoadStats::new()),
208        }
209    }
210
211    /// Run the load generator.
212    pub async fn run(&self) -> LoadGeneratorResult {
213        self.running.store(true, Ordering::SeqCst);
214        let start = Instant::now();
215
216        let semaphore = Arc::new(Semaphore::new(self.config.connections));
217        let mut handles = Vec::new();
218
219        // Calculate interval between requests per connection
220        let request_interval = Duration::from_secs_f64(
221            self.config.connections as f64 / self.config.requests_per_second
222        );
223
224        // Ramp up connections gradually
225        let ramp_interval = self.config.ramp_up / self.config.connections as u32;
226
227        for i in 0..self.config.connections {
228            if !self.running.load(Ordering::Relaxed) {
229                break;
230            }
231
232            // Ramp-up delay
233            if i > 0 {
234                tokio::time::sleep(ramp_interval).await;
235            }
236
237            let permit = semaphore.clone().acquire_owned().await.unwrap();
238            let config = self.config.clone();
239            let stats = self.stats.clone();
240            let running = self.running.clone();
241            let test_duration = self.config.duration;
242
243            let handle = tokio::spawn(async move {
244                let _permit = permit;
245
246                // Connect
247                let stream = match tokio::time::timeout(
248                    config.connect_timeout,
249                    TcpStream::connect(config.server_addr),
250                )
251                .await
252                {
253                    Ok(Ok(s)) => {
254                        stats.connections_opened.fetch_add(1, Ordering::Relaxed);
255                        s
256                    }
257                    _ => {
258                        stats.connections_failed.fetch_add(1, Ordering::Relaxed);
259                        return;
260                    }
261                };
262
263                let conn_start = Instant::now();
264
265                // Send requests
266                while running.load(Ordering::Relaxed) && conn_start.elapsed() < test_duration {
267                    let req_start = Instant::now();
268
269                    let result = Self::send_request(&stream, &config, &stats).await;
270
271                    let latency = req_start.elapsed();
272                    stats.record_latency(latency);
273                    stats.requests_sent.fetch_add(1, Ordering::Relaxed);
274
275                    if result {
276                        stats.requests_success.fetch_add(1, Ordering::Relaxed);
277                    } else {
278                        stats.requests_failed.fetch_add(1, Ordering::Relaxed);
279                    }
280
281                    // Throttle to maintain target rate
282                    if latency < request_interval {
283                        tokio::time::sleep(request_interval - latency).await;
284                    }
285                }
286            });
287
288            handles.push(handle);
289        }
290
291        // Wait for test duration
292        tokio::time::sleep(self.config.duration).await;
293        self.running.store(false, Ordering::SeqCst);
294
295        // Wait for all workers to finish
296        for handle in handles {
297            let _ = handle.await;
298        }
299
300        let duration = start.elapsed();
301
302        LoadGeneratorResult {
303            duration,
304            requests_sent: self.stats.requests_sent.load(Ordering::Relaxed),
305            requests_success: self.stats.requests_success.load(Ordering::Relaxed),
306            requests_failed: self.stats.requests_failed.load(Ordering::Relaxed),
307            connections_opened: self.stats.connections_opened.load(Ordering::Relaxed),
308            connections_failed: self.stats.connections_failed.load(Ordering::Relaxed),
309            bytes_sent: self.stats.bytes_sent.load(Ordering::Relaxed),
310            bytes_received: self.stats.bytes_received.load(Ordering::Relaxed),
311            p50_latency: self.stats.percentile(50.0),
312            p95_latency: self.stats.percentile(95.0),
313            p99_latency: self.stats.percentile(99.0),
314            success_rate: self.stats.success_rate(),
315        }
316    }
317
318    async fn send_request(stream: &TcpStream, config: &LoadConfig, stats: &LoadStats) -> bool {
319        // Modbus TCP read holding registers request
320        let request: [u8; 12] = [
321            0x00, 0x01,             // Transaction ID
322            0x00, 0x00,             // Protocol ID
323            0x00, 0x06,             // Length
324            0x01,                   // Unit ID
325            0x03,                   // Function code: Read Holding Registers
326            0x00, 0x00,             // Starting address
327            0x00, 0x0A,             // Quantity (10 registers)
328        ];
329
330        // Try to send request
331        if stream.try_write(&request).is_err() {
332            return false;
333        }
334        stats.bytes_sent.fetch_add(request.len() as u64, Ordering::Relaxed);
335
336        // Small delay for response
337        tokio::time::sleep(Duration::from_micros(500)).await;
338
339        // Try to read response
340        let mut response = [0u8; 256];
341        match stream.try_read(&mut response) {
342            Ok(n) if n > 0 => {
343                stats.bytes_received.fetch_add(n as u64, Ordering::Relaxed);
344                // Check for valid Modbus response
345                n >= 9 && response[7] == 0x03
346            }
347            _ => false,
348        }
349    }
350
351    /// Stop the load generator.
352    pub fn stop(&self) {
353        self.running.store(false, Ordering::SeqCst);
354    }
355
356    /// Get current statistics.
357    pub fn stats(&self) -> &LoadStats {
358        &self.stats
359    }
360}
361
362/// Result of load generation.
363#[derive(Debug, Clone)]
364pub struct LoadGeneratorResult {
365    pub duration: Duration,
366    pub requests_sent: u64,
367    pub requests_success: u64,
368    pub requests_failed: u64,
369    pub connections_opened: u64,
370    pub connections_failed: u64,
371    pub bytes_sent: u64,
372    pub bytes_received: u64,
373    pub p50_latency: Option<Duration>,
374    pub p95_latency: Option<Duration>,
375    pub p99_latency: Option<Duration>,
376    pub success_rate: f64,
377}
378
379impl LoadGeneratorResult {
380    /// Calculate requests per second.
381    pub fn rps(&self) -> f64 {
382        if self.duration.as_secs_f64() > 0.0 {
383            self.requests_sent as f64 / self.duration.as_secs_f64()
384        } else {
385            0.0
386        }
387    }
388
389    /// Format as human-readable string.
390    pub fn format(&self) -> String {
391        format!(
392            "Load Generation Results:\n\
393             Duration: {:?}\n\
394             Requests: {} sent, {} success, {} failed\n\
395             RPS: {:.2}\n\
396             Success Rate: {:.2}%\n\
397             Connections: {} opened, {} failed\n\
398             P50 Latency: {:?}\n\
399             P95 Latency: {:?}\n\
400             P99 Latency: {:?}\n\
401             Bytes: {} sent, {} received",
402            self.duration,
403            self.requests_sent,
404            self.requests_success,
405            self.requests_failed,
406            self.rps(),
407            self.success_rate * 100.0,
408            self.connections_opened,
409            self.connections_failed,
410            self.p50_latency,
411            self.p95_latency,
412            self.p99_latency,
413            self.bytes_sent,
414            self.bytes_received,
415        )
416    }
417}
418
419/// Connection simulator for testing connection handling.
420pub struct ConnectionSimulator {
421    server_addr: SocketAddr,
422    connections: Vec<TcpStream>,
423}
424
425impl ConnectionSimulator {
426    /// Create a new connection simulator.
427    pub fn new(server_addr: SocketAddr) -> Self {
428        Self {
429            server_addr,
430            connections: Vec::new(),
431        }
432    }
433
434    /// Open N connections.
435    pub async fn open_connections(&mut self, count: usize) -> Result<usize, String> {
436        let mut opened = 0;
437
438        for _ in 0..count {
439            match tokio::time::timeout(
440                Duration::from_secs(5),
441                TcpStream::connect(self.server_addr),
442            )
443            .await
444            {
445                Ok(Ok(stream)) => {
446                    self.connections.push(stream);
447                    opened += 1;
448                }
449                Ok(Err(e)) => {
450                    tracing::debug!("Connection failed: {}", e);
451                }
452                Err(_) => {
453                    tracing::debug!("Connection timeout");
454                }
455            }
456        }
457
458        Ok(opened)
459    }
460
461    /// Get current connection count.
462    pub fn connection_count(&self) -> usize {
463        self.connections.len()
464    }
465
466    /// Close all connections.
467    pub fn close_all(&mut self) {
468        self.connections.clear();
469    }
470
471    /// Test that all connections are still alive.
472    pub async fn verify_connections(&self) -> (usize, usize) {
473        let mut alive = 0;
474        let mut dead = 0;
475
476        for stream in &self.connections {
477            // Try a simple peek to check if connection is alive
478            let mut buf = [0u8; 1];
479            let result = stream.try_read(&mut buf);
480            match &result {
481                Ok(_) => {
482                    alive += 1;
483                }
484                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
485                    // WouldBlock means connection is alive but no data
486                    alive += 1;
487                }
488                Err(_) => {
489                    dead += 1;
490                }
491            }
492        }
493
494        (alive, dead)
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn test_load_config_default() {
504        let config = LoadConfig::default();
505        assert_eq!(config.connections, 100);
506        assert!((config.requests_per_second - 100.0).abs() < 0.01);
507    }
508
509    #[test]
510    fn test_load_config_presets() {
511        let steady = LoadConfig::steady(500, 1000.0);
512        assert_eq!(steady.connections, 500);
513
514        let spike = LoadConfig::spike(100, 1000);
515        assert_eq!(spike.connections, 1000);
516    }
517
518    #[test]
519    fn test_load_stats() {
520        let stats = LoadStats::new();
521
522        stats.requests_sent.fetch_add(100, Ordering::Relaxed);
523        stats.requests_success.fetch_add(95, Ordering::Relaxed);
524
525        let rate = stats.success_rate();
526        assert!((rate - 0.95).abs() < 0.01);
527    }
528
529    #[test]
530    fn test_load_generator_result_rps() {
531        let result = LoadGeneratorResult {
532            duration: Duration::from_secs(10),
533            requests_sent: 10000,
534            requests_success: 9900,
535            requests_failed: 100,
536            connections_opened: 100,
537            connections_failed: 0,
538            bytes_sent: 120000,
539            bytes_received: 240000,
540            p50_latency: Some(Duration::from_millis(5)),
541            p95_latency: Some(Duration::from_millis(10)),
542            p99_latency: Some(Duration::from_millis(20)),
543            success_rate: 0.99,
544        };
545
546        assert!((result.rps() - 1000.0).abs() < 0.01);
547    }
548}