1use std::time::{Duration, Instant};
39use serde::{Serialize, Deserialize};
40
41mod test_node;
42mod metrics;
43pub mod scenarios;
44
45pub use test_node::TestNode;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct LoadTestConfig {
50 pub num_nodes: usize,
52 pub num_connections_per_node: usize,
54 pub message_rate_per_second: usize,
56 pub test_duration: Duration,
58 pub ramp_up_duration: Duration,
60}
61
62impl LoadTestConfig {
63 pub fn validate(&self) -> Result<(), String> {
65 if self.num_nodes == 0 {
66 return Err("num_nodes must be greater than 0".to_string());
67 }
68 if self.test_duration <= self.ramp_up_duration {
69 return Err("test_duration must be greater than ramp_up_duration".to_string());
70 }
71 if self.message_rate_per_second == 0 {
72 return Err("message_rate_per_second must be greater than 0".to_string());
73 }
74 Ok(())
75 }
76}
77
78pub struct LoadTestScenario {
80 config: LoadTestConfig,
81}
82
83impl LoadTestScenario {
84 pub fn new(config: LoadTestConfig) -> Self {
86 Self { config }
87 }
88
89 pub fn config(&self) -> &LoadTestConfig {
91 &self.config
92 }
93
94 pub async fn run(&mut self) -> Result<LoadTestResult, Box<dyn std::error::Error>> {
106 use elara_core::SessionId;
107 use crate::metrics::LoadTestMetrics;
108 use crate::test_node::generate_test_message;
109
110 self.config.validate()?;
112
113 let mut metrics = LoadTestMetrics::new();
114
115 println!("Spawning {} nodes...", self.config.num_nodes);
117 let mut nodes = Vec::new();
118 for i in 0..self.config.num_nodes {
119 match TestNode::spawn_default() {
120 Ok(node) => nodes.push(node),
121 Err(e) => {
122 metrics.record_failure(format!("Failed to spawn node {}: {}", i, e));
123 }
124 }
125 }
126
127 if nodes.is_empty() {
128 return Err("Failed to spawn any nodes".into());
129 }
130
131 let session_id = SessionId::new(1);
133 for node in &mut nodes {
134 node.join_session_unsecured(session_id);
135 }
136
137 println!("Ramping up connections...");
139 let ramp_up_interval = self.config.ramp_up_duration / self.config.num_nodes as u32;
140
141 for i in 0..nodes.len() {
142 tokio::time::sleep(ramp_up_interval).await;
143
144 for j in 1..=self.config.num_connections_per_node.min(nodes.len() - 1) {
146 let peer_idx = (i + j) % nodes.len();
147
148 let peer_node_id = nodes[peer_idx].node_id();
150 let peer_index = nodes[i].peers.len();
151 nodes[i].peers.insert(peer_node_id, peer_index);
152 }
153 }
154
155 println!("Connections established. Starting load generation...");
156
157 let test_end = Instant::now() + self.config.test_duration;
159 let messages_per_tick = self.config.message_rate_per_second / 10; let tick_interval = Duration::from_millis(100);
161
162 let mut tick_count = 0;
163 while Instant::now() < test_end {
164 tick_count += 1;
165
166 for _ in 0..messages_per_tick {
168 let node_idx = tick_count % nodes.len();
169 let payload = generate_test_message(64);
170
171 match nodes[node_idx].send_message(payload) {
172 Ok(start_time) => {
173 let latency = start_time.elapsed();
174 metrics.record_success(latency);
175 }
176 Err(e) => {
177 metrics.record_failure(format!("Send failed: {}", e));
178 }
179 }
180 }
181
182 let mut all_frames: Vec<(usize, Vec<u8>)> = Vec::new();
185
186 for i in 0..nodes.len() {
187 while let Some(_frame) = nodes[i].node_mut().pop_outgoing() {
188 all_frames.push((i, vec![]));
191 }
192 }
193
194 for node in &mut nodes {
196 node.tick();
197 }
198
199 tokio::time::sleep(tick_interval).await;
200 }
201
202 println!("Load generation complete. Collecting final metrics...");
203
204 for node in nodes {
206 node.shutdown();
207 }
208
209 println!("Test complete!");
210
211 Ok(metrics.into_result())
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct LoadTestResult {
218 pub total_messages: u64,
220 pub successful_messages: u64,
222 pub failed_messages: u64,
224 pub avg_latency_ms: f64,
226 pub p50_latency_ms: f64,
228 pub p95_latency_ms: f64,
230 pub p99_latency_ms: f64,
232 pub max_latency_ms: f64,
234 pub throughput_msg_per_sec: f64,
236 pub errors: Vec<LoadTestError>,
238}
239
240impl LoadTestResult {
241 pub fn report(&self) -> String {
243 let success_rate = if self.total_messages > 0 {
244 (self.successful_messages as f64 / self.total_messages as f64) * 100.0
245 } else {
246 0.0
247 };
248
249 format!(
250 r#"Load Test Results
251==================
252Total Messages: {}
253Successful: {} ({:.2}%)
254Failed: {}
255
256Throughput: {:.2} msg/sec
257
258Latency Statistics:
259 Average: {:.2}ms
260 P50 (median): {:.2}ms
261 P95: {:.2}ms
262 P99: {:.2}ms
263 Max: {:.2}ms
264
265Errors: {}
266"#,
267 self.total_messages,
268 self.successful_messages,
269 success_rate,
270 self.failed_messages,
271 self.throughput_msg_per_sec,
272 self.avg_latency_ms,
273 self.p50_latency_ms,
274 self.p95_latency_ms,
275 self.p99_latency_ms,
276 self.max_latency_ms,
277 self.errors.len()
278 )
279 }
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct LoadTestError {
285 pub message: String,
287 pub timestamp: String,
289}
290
291impl LoadTestError {
292 pub fn new(message: String) -> Self {
294 Self {
295 message,
296 timestamp: chrono::Utc::now().to_rfc3339(),
297 }
298 }
299}
300
301impl std::fmt::Display for LoadTestError {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 write!(f, "[{}] {}", self.timestamp, self.message)
304 }
305}
306
307impl std::error::Error for LoadTestError {}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_config_validation() {
315 let valid_config = LoadTestConfig {
316 num_nodes: 10,
317 num_connections_per_node: 5,
318 message_rate_per_second: 100,
319 test_duration: Duration::from_secs(60),
320 ramp_up_duration: Duration::from_secs(10),
321 };
322 assert!(valid_config.validate().is_ok());
323
324 let invalid_nodes = LoadTestConfig {
325 num_nodes: 0,
326 ..valid_config.clone()
327 };
328 assert!(invalid_nodes.validate().is_err());
329
330 let invalid_duration = LoadTestConfig {
331 test_duration: Duration::from_secs(5),
332 ramp_up_duration: Duration::from_secs(10),
333 ..valid_config.clone()
334 };
335 assert!(invalid_duration.validate().is_err());
336
337 let invalid_rate = LoadTestConfig {
338 message_rate_per_second: 0,
339 ..valid_config
340 };
341 assert!(invalid_rate.validate().is_err());
342 }
343
344 #[test]
345 fn test_result_report_generation() {
346 let result = LoadTestResult {
347 total_messages: 1000,
348 successful_messages: 950,
349 failed_messages: 50,
350 avg_latency_ms: 42.5,
351 p50_latency_ms: 38.0,
352 p95_latency_ms: 85.0,
353 p99_latency_ms: 120.0,
354 max_latency_ms: 250.0,
355 throughput_msg_per_sec: 16.67,
356 errors: vec![],
357 };
358
359 let report = result.report();
360 assert!(report.contains("1000"));
361 assert!(report.contains("950"));
362 assert!(report.contains("42.5"));
363 }
364}