1use crate::router::{RouterConfig, SemanticRouter};
40use ipfrs_core::{Cid, Result};
41use rand::Rng;
42use std::sync::Arc;
43use std::time::{Duration, Instant};
44use tokio::task;
45
46#[derive(Debug, Clone)]
48pub struct StressTestConfig {
49 pub num_threads: usize,
51 pub operations_per_thread: usize,
53 pub index_size: usize,
55 pub dimension: usize,
57 pub insert_ratio: f64,
59 pub query_ratio: f64,
61 pub k: usize,
63}
64
65impl Default for StressTestConfig {
66 fn default() -> Self {
67 Self {
68 num_threads: 10,
69 operations_per_thread: 100,
70 index_size: 1000,
71 dimension: 768,
72 insert_ratio: 0.3,
73 query_ratio: 0.7,
74 k: 10,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct StressTestResults {
82 pub total_ops: usize,
84 pub successful_ops: usize,
86 pub failed_ops: usize,
88 pub total_duration: Duration,
90 pub ops_per_second: f64,
92 pub avg_latency: Duration,
94 pub p50_latency: Duration,
96 pub p90_latency: Duration,
98 pub p99_latency: Duration,
100 pub success_rate: f64,
102 pub max_concurrent: usize,
104}
105
106pub struct StressTest {
108 config: StressTestConfig,
109 router: Arc<SemanticRouter>,
110}
111
112impl StressTest {
113 pub fn new(config: StressTestConfig) -> Result<Self> {
115 let router_config =
116 RouterConfig::balanced(config.dimension).with_cache_size(config.index_size * 2);
117
118 let router = SemanticRouter::new(router_config)?;
119
120 if config.index_size > 0 {
122 for i in 0..config.index_size {
123 let cid = generate_test_cid(i);
124 let embedding = generate_random_embedding(config.dimension);
125 router.add(&cid, &embedding)?;
126 }
127 }
128
129 Ok(Self {
130 config,
131 router: Arc::new(router),
132 })
133 }
134
135 pub async fn run(&mut self) -> Result<StressTestResults> {
137 let start = Instant::now();
138 let mut handles = Vec::new();
139 let mut all_latencies = Vec::new();
140
141 let total_ops = self.config.num_threads * self.config.operations_per_thread;
142 let successful_ops = Arc::new(std::sync::atomic::AtomicUsize::new(0));
143 let failed_ops = Arc::new(std::sync::atomic::AtomicUsize::new(0));
144
145 for thread_id in 0..self.config.num_threads {
147 let router = Arc::clone(&self.router);
148 let config = self.config.clone();
149 let successful = Arc::clone(&successful_ops);
150 let failed = Arc::clone(&failed_ops);
151
152 let handle = task::spawn(async move {
153 let mut latencies = Vec::new();
154
155 for op_id in 0..config.operations_per_thread {
156 let op_start = Instant::now();
157
158 let should_insert =
160 ((thread_id + op_id) % 10) as f64 / 10.0 < config.insert_ratio;
161
162 let result = if should_insert {
163 let cid = generate_test_cid(thread_id * 1000000 + op_id);
165 let embedding = generate_random_embedding(config.dimension);
166 router.add(&cid, &embedding)
167 } else {
168 let query = generate_random_embedding(config.dimension);
170 match router.query(&query, config.k).await {
171 Ok(_) => Ok(()),
172 Err(e) => Err(e),
173 }
174 };
175
176 let latency = op_start.elapsed();
177 latencies.push(latency);
178
179 match result {
180 Ok(_) => {
181 successful.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
182 }
183 Err(_) => {
184 failed.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
185 }
186 }
187 }
188
189 latencies
190 });
191
192 handles.push(handle);
193 }
194
195 for handle in handles {
197 let latencies = handle
198 .await
199 .map_err(|e| ipfrs_core::Error::InvalidInput(format!("Task join error: {}", e)))?;
200 all_latencies.extend(latencies);
201 }
202
203 let total_duration = start.elapsed();
204
205 all_latencies.sort();
207 let avg_latency = if !all_latencies.is_empty() {
208 all_latencies.iter().sum::<Duration>() / all_latencies.len() as u32
209 } else {
210 Duration::from_secs(0)
211 };
212
213 let p50_latency = percentile(&all_latencies, 0.50);
214 let p90_latency = percentile(&all_latencies, 0.90);
215 let p99_latency = percentile(&all_latencies, 0.99);
216
217 let successful = successful_ops.load(std::sync::atomic::Ordering::Relaxed);
218 let failed = failed_ops.load(std::sync::atomic::Ordering::Relaxed);
219
220 Ok(StressTestResults {
221 total_ops,
222 successful_ops: successful,
223 failed_ops: failed,
224 total_duration,
225 ops_per_second: total_ops as f64 / total_duration.as_secs_f64(),
226 avg_latency,
227 p50_latency,
228 p90_latency,
229 p99_latency,
230 success_rate: successful as f64 / total_ops as f64,
231 max_concurrent: self.config.num_threads,
232 })
233 }
234}
235
236#[derive(Debug, Clone)]
238pub struct EnduranceTestConfig {
239 pub duration: Duration,
241 pub target_ops_per_second: f64,
243 pub dimension: usize,
245 pub memory_check_interval: Duration,
247}
248
249impl Default for EnduranceTestConfig {
250 fn default() -> Self {
251 Self {
252 duration: Duration::from_secs(300), target_ops_per_second: 100.0,
254 dimension: 768,
255 memory_check_interval: Duration::from_secs(10),
256 }
257 }
258}
259
260#[derive(Debug, Clone)]
262pub struct EnduranceTestResults {
263 pub total_ops: usize,
265 pub actual_duration: Duration,
267 pub avg_ops_per_second: f64,
269 pub peak_memory_bytes: usize,
271 pub initial_memory_bytes: usize,
273 pub memory_growth_bytes: isize,
275 pub error_count: usize,
277}
278
279pub struct EnduranceTest {
281 config: EnduranceTestConfig,
282 router: Arc<SemanticRouter>,
283}
284
285impl EnduranceTest {
286 pub fn new(config: EnduranceTestConfig) -> Result<Self> {
288 let router = SemanticRouter::with_defaults()?;
289
290 Ok(Self {
291 config,
292 router: Arc::new(router),
293 })
294 }
295
296 pub async fn run(&mut self) -> Result<EnduranceTestResults> {
298 let start = Instant::now();
299 let target_interval = Duration::from_secs_f64(1.0 / self.config.target_ops_per_second);
300
301 let initial_memory = estimate_process_memory();
302 let mut peak_memory = initial_memory;
303 let mut last_memory_check = Instant::now();
304
305 let mut total_ops = 0;
306 let mut error_count = 0;
307 let mut op_counter = 0;
308
309 while start.elapsed() < self.config.duration {
310 let op_start = Instant::now();
311
312 let cid = generate_test_cid(op_counter);
314 let embedding = generate_random_embedding(self.config.dimension);
315
316 match self.router.add(&cid, &embedding) {
317 Ok(_) => total_ops += 1,
318 Err(_) => error_count += 1,
319 }
320
321 if op_counter % 5 == 0 {
323 let query = generate_random_embedding(self.config.dimension);
324 match self.router.query(&query, 10).await {
325 Ok(_) => total_ops += 1,
326 Err(_) => error_count += 1,
327 }
328 }
329
330 op_counter += 1;
331
332 if last_memory_check.elapsed() >= self.config.memory_check_interval {
334 let current_memory = estimate_process_memory();
335 if current_memory > peak_memory {
336 peak_memory = current_memory;
337 }
338 last_memory_check = Instant::now();
339 }
340
341 let elapsed = op_start.elapsed();
343 if elapsed < target_interval {
344 tokio::time::sleep(target_interval - elapsed).await;
345 }
346 }
347
348 let actual_duration = start.elapsed();
349
350 Ok(EnduranceTestResults {
351 total_ops,
352 actual_duration,
353 avg_ops_per_second: total_ops as f64 / actual_duration.as_secs_f64(),
354 peak_memory_bytes: peak_memory,
355 initial_memory_bytes: initial_memory,
356 memory_growth_bytes: peak_memory as isize - initial_memory as isize,
357 error_count,
358 })
359 }
360}
361
362fn generate_test_cid(index: usize) -> Cid {
365 use multihash::Multihash;
368 use std::collections::hash_map::DefaultHasher;
369 use std::hash::{Hash, Hasher};
370
371 let mut hasher = DefaultHasher::new();
372 index.hash(&mut hasher);
373 let hash_value = hasher.finish();
374
375 let mut hash_bytes = [0u8; 32];
377 hash_bytes[..8].copy_from_slice(&hash_value.to_le_bytes());
378 for i in 1..4 {
380 let val = (hash_value.wrapping_mul(i as u64)).to_le_bytes();
381 hash_bytes[i * 8..(i + 1) * 8].copy_from_slice(&val);
382 }
383
384 let mh = Multihash::wrap(0x12, &hash_bytes).unwrap(); Cid::new_v1(0x55, mh) }
387
388fn generate_random_embedding(dim: usize) -> Vec<f32> {
389 let mut rng = rand::rng();
390 (0..dim).map(|_| rng.random_range(0.0..1.0)).collect()
391}
392
393fn percentile(sorted_data: &[Duration], p: f64) -> Duration {
394 if sorted_data.is_empty() {
395 return Duration::from_secs(0);
396 }
397 let index = ((p * sorted_data.len() as f64) as usize).min(sorted_data.len() - 1);
398 sorted_data[index]
399}
400
401#[allow(dead_code)]
402fn estimate_process_memory() -> usize {
403 #[cfg(target_os = "linux")]
406 {
407 use std::fs;
408 if let Ok(status) = fs::read_to_string("/proc/self/status") {
409 for line in status.lines() {
410 if line.starts_with("VmRSS:") {
411 if let Some(kb_str) = line.split_whitespace().nth(1) {
412 if let Ok(kb) = kb_str.parse::<usize>() {
413 return kb * 1024; }
415 }
416 }
417 }
418 }
419 }
420
421 0
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[tokio::test]
430 async fn test_stress_test_creation() {
431 let config = StressTestConfig {
433 num_threads: 2,
434 operations_per_thread: 5,
435 index_size: 20,
436 dimension: 64,
437 insert_ratio: 0.5,
438 query_ratio: 0.5,
439 k: 3,
440 };
441
442 let stress_test = StressTest::new(config.clone());
443 if let Err(e) = &stress_test {
444 eprintln!("Error creating stress test: {:?}", e);
445 }
446 assert!(stress_test.is_ok());
447
448 let test = stress_test.unwrap();
450 assert_eq!(test.config.num_threads, 2);
451 }
452
453 #[tokio::test]
454 async fn test_endurance_test_creation() {
455 let config = EnduranceTestConfig {
457 duration: Duration::from_millis(100),
458 target_ops_per_second: 10.0,
459 dimension: 64,
460 memory_check_interval: Duration::from_millis(50),
461 };
462
463 let endurance_test = EnduranceTest::new(config.clone());
464 assert!(endurance_test.is_ok());
465
466 assert_eq!(endurance_test.unwrap().config.dimension, 64);
468 }
469
470 #[test]
471 fn test_generate_test_cid() {
472 let cid1 = generate_test_cid(0);
473 let cid2 = generate_test_cid(1);
474 let cid3 = generate_test_cid(5);
475
476 assert_ne!(cid1, cid2);
478 assert_ne!(cid1, cid3);
479 assert_ne!(cid2, cid3);
480
481 let cid1_again = generate_test_cid(0);
483 assert_eq!(cid1, cid1_again);
484 }
485
486 #[test]
487 fn test_percentile_calculation() {
488 let data = vec![
489 Duration::from_millis(1),
490 Duration::from_millis(2),
491 Duration::from_millis(3),
492 Duration::from_millis(4),
493 Duration::from_millis(5),
494 ];
495
496 let p50 = percentile(&data, 0.5);
497 let p90 = percentile(&data, 0.9);
498
499 assert_eq!(p50, Duration::from_millis(3));
500 assert_eq!(p90, Duration::from_millis(5));
501 }
502
503 #[test]
504 fn test_percentile_empty() {
505 let data: Vec<Duration> = vec![];
506 let p50 = percentile(&data, 0.5);
507 assert_eq!(p50, Duration::from_secs(0));
508 }
509}