codex_memory/performance/
load_testing.rs

1//! Load testing implementation for performance validation
2
3use super::{LoadTestConfig, PerformanceMetrics, PerformanceTestResult, TestType};
4use crate::memory::models::{CreateMemoryRequest, SearchRequest, UpdateMemoryRequest};
5use crate::memory::{MemoryRepository, MemoryTier};
6use anyhow::Result;
7use chrono::Utc;
8use std::collections::VecDeque;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::RwLock;
13use tokio::time;
14use tracing::{debug, info};
15
16/// Load testing orchestrator
17pub struct LoadTester {
18    config: LoadTestConfig,
19    repository: Arc<MemoryRepository>,
20    metrics: Arc<LoadTestMetrics>,
21}
22
23/// Metrics collected during load testing
24struct LoadTestMetrics {
25    total_requests: AtomicU64,
26    successful_requests: AtomicU64,
27    failed_requests: AtomicU64,
28    latencies: RwLock<VecDeque<u64>>,
29    errors: RwLock<Vec<String>>,
30    start_time: Instant,
31}
32
33impl LoadTester {
34    pub fn new(config: LoadTestConfig, repository: Arc<MemoryRepository>) -> Self {
35        Self {
36            config,
37            repository,
38            metrics: Arc::new(LoadTestMetrics {
39                total_requests: AtomicU64::new(0),
40                successful_requests: AtomicU64::new(0),
41                failed_requests: AtomicU64::new(0),
42                latencies: RwLock::new(VecDeque::new()),
43                errors: RwLock::new(Vec::new()),
44                start_time: Instant::now(),
45            }),
46        }
47    }
48
49    /// Run load test with specified configuration
50    pub async fn run_load_test(&self) -> Result<PerformanceTestResult> {
51        info!(
52            "Starting load test with {} concurrent users",
53            self.config.concurrent_users
54        );
55
56        let test_start = Utc::now();
57        let start_time = Instant::now();
58
59        // Create a pool of virtual users
60        let mut handles = Vec::new();
61
62        // Calculate requests per user
63        let requests_per_user = self.config.target_rps as usize / self.config.concurrent_users;
64        let request_interval = Duration::from_secs(1) / requests_per_user as u32;
65
66        // Ramp up users gradually
67        let ramp_up_interval = self.config.ramp_up_time / self.config.concurrent_users as u32;
68
69        for user_id in 0..self.config.concurrent_users {
70            let repository = Arc::clone(&self.repository);
71            let metrics = Arc::clone(&self.metrics);
72            let test_duration = self.config.test_duration;
73            let interval = request_interval;
74
75            let handle = tokio::spawn(async move {
76                // Wait for ramp-up
77                time::sleep(ramp_up_interval * user_id as u32).await;
78
79                let user_start = Instant::now();
80
81                while user_start.elapsed() < test_duration {
82                    let request_start = Instant::now();
83
84                    // Simulate user operations
85                    let result = Self::simulate_user_operation(&repository, user_id).await;
86
87                    let latency_ms = request_start.elapsed().as_millis() as u64;
88
89                    // Record metrics
90                    metrics.total_requests.fetch_add(1, Ordering::Relaxed);
91
92                    match result {
93                        Ok(_) => {
94                            metrics.successful_requests.fetch_add(1, Ordering::Relaxed);
95                            let mut latencies = metrics.latencies.write().await;
96                            latencies.push_back(latency_ms);
97
98                            // Keep only last 10000 samples for percentile calculation
99                            if latencies.len() > 10000 {
100                                latencies.pop_front();
101                            }
102                        }
103                        Err(e) => {
104                            metrics.failed_requests.fetch_add(1, Ordering::Relaxed);
105                            let mut errors = metrics.errors.write().await;
106                            errors.push(e.to_string());
107                        }
108                    }
109
110                    // Wait for next request interval
111                    if request_start.elapsed() < interval {
112                        time::sleep(interval - request_start.elapsed()).await;
113                    }
114                }
115
116                debug!("User {} completed load test", user_id);
117            });
118
119            handles.push(handle);
120        }
121
122        // Wait for all users to complete
123        for handle in handles {
124            handle.await?;
125        }
126
127        let test_end = Utc::now();
128        let duration = start_time.elapsed();
129
130        // Calculate final metrics
131        let metrics = self.calculate_metrics().await?;
132
133        // Check for SLA violations
134        let sla_violations = self.check_sla_violations(&metrics);
135        let passed = sla_violations.is_empty();
136
137        let result = PerformanceTestResult {
138            test_name: "Load Test".to_string(),
139            test_type: TestType::Load,
140            start_time: test_start,
141            end_time: test_end,
142            duration,
143            metrics,
144            sla_violations,
145            passed,
146        };
147
148        info!("Load test completed. Result: {:?}", result.passed);
149
150        Ok(result)
151    }
152
153    /// Simulate a user operation (mix of reads, writes, searches)
154    async fn simulate_user_operation(
155        repository: &Arc<MemoryRepository>,
156        user_id: usize,
157    ) -> Result<()> {
158        use rand::Rng;
159
160        // Generate all random values before any await
161        let (operation, importance_score1, query_num, importance_score2) = {
162            let mut rng = rand::thread_rng();
163            (
164                rng.gen_range(0..100),
165                rng.gen_range(0.0..1.0),
166                rng.gen_range(0..100),
167                rng.gen_range(0.0..1.0),
168            )
169        };
170
171        // Create a realistic workload mix
172        // 60% reads, 20% writes, 15% searches, 5% updates
173        match operation {
174            0..=59 => {
175                // Read operation - try to read a random UUID (may not exist)
176                let memory_id = uuid::Uuid::new_v4();
177                let _ = repository.get_memory(memory_id).await;
178            }
179            60..=79 => {
180                // Write operation
181                let content = format!("Test content from user {} at {}", user_id, Utc::now());
182                let request = CreateMemoryRequest {
183                    content,
184                    embedding: None,
185                    tier: Some(MemoryTier::Working),
186                    importance_score: Some(importance_score1),
187                    metadata: Some(serde_json::json!({
188                        "user_id": user_id,
189                        "test": true
190                    })),
191                    parent_id: None,
192                    expires_at: None,
193                };
194                repository.create_memory(request).await?;
195            }
196            80..=94 => {
197                // Search operation
198                let query = format!("test query {query_num}");
199                let search_request = SearchRequest {
200                    query_text: Some(query),
201                    query_embedding: None,
202                    search_type: None,
203                    hybrid_weights: None,
204                    tier: None,
205                    date_range: None,
206                    importance_range: None,
207                    metadata_filters: None,
208                    tags: None,
209                    limit: Some(10),
210                    offset: None,
211                    cursor: None,
212                    similarity_threshold: None,
213                    include_metadata: None,
214                    include_facets: None,
215                    ranking_boost: None,
216                    explain_score: None,
217                };
218                repository.search_memories_simple(search_request).await?;
219            }
220            95..=99 => {
221                // Update operation - try to update a random UUID (may not exist)
222                let memory_id = uuid::Uuid::new_v4();
223                if let Ok(_memory) = repository.get_memory(memory_id).await {
224                    let update_request = UpdateMemoryRequest {
225                        content: Some(format!("Updated content at {}", Utc::now())),
226                        embedding: None,
227                        tier: None,
228                        importance_score: Some(importance_score2),
229                        metadata: None,
230                        expires_at: None,
231                    };
232                    repository.update_memory(memory_id, update_request).await?;
233                }
234            }
235            _ => unreachable!(),
236        }
237
238        Ok(())
239    }
240
241    /// Calculate performance metrics from collected data
242    async fn calculate_metrics(&self) -> Result<PerformanceMetrics> {
243        let total_requests = self.metrics.total_requests.load(Ordering::Relaxed);
244        let successful_requests = self.metrics.successful_requests.load(Ordering::Relaxed);
245        let failed_requests = self.metrics.failed_requests.load(Ordering::Relaxed);
246
247        let duration_secs = self.metrics.start_time.elapsed().as_secs_f64();
248        let throughput_rps = total_requests as f64 / duration_secs;
249
250        let error_rate = if total_requests > 0 {
251            (failed_requests as f64 / total_requests as f64) * 100.0
252        } else {
253            0.0
254        };
255
256        // Calculate latency percentiles
257        let latencies = self.metrics.latencies.read().await;
258        let mut sorted_latencies: Vec<u64> = latencies.iter().cloned().collect();
259        sorted_latencies.sort();
260
261        let latency_p50_ms = Self::calculate_percentile(&sorted_latencies, 50.0);
262        let latency_p95_ms = Self::calculate_percentile(&sorted_latencies, 95.0);
263        let latency_p99_ms = Self::calculate_percentile(&sorted_latencies, 99.0);
264        let latency_max_ms = sorted_latencies.last().cloned().unwrap_or(0);
265
266        // TODO: Get actual CPU/memory metrics from system
267        let cpu_usage_avg = 0.0;
268        let memory_usage_avg = 0.0;
269
270        // TODO: Get actual cache hit ratio from cache implementation
271        let cache_hit_ratio = 0.0;
272
273        // TODO: Get actual DB connection metrics
274        let db_connections_used = 0;
275
276        Ok(PerformanceMetrics {
277            total_requests,
278            successful_requests,
279            failed_requests,
280            throughput_rps,
281            latency_p50_ms,
282            latency_p95_ms,
283            latency_p99_ms,
284            latency_max_ms,
285            error_rate,
286            cpu_usage_avg,
287            memory_usage_avg,
288            cache_hit_ratio,
289            db_connections_used,
290            network_bytes_sent: 0,
291            network_bytes_received: 0,
292        })
293    }
294
295    /// Calculate percentile from sorted latencies
296    fn calculate_percentile(sorted_latencies: &[u64], percentile: f64) -> u64 {
297        if sorted_latencies.is_empty() {
298            return 0;
299        }
300
301        let index = ((percentile / 100.0) * sorted_latencies.len() as f64) as usize;
302        let index = index.min(sorted_latencies.len() - 1);
303
304        sorted_latencies[index]
305    }
306
307    /// Check for SLA violations
308    fn check_sla_violations(&self, metrics: &PerformanceMetrics) -> Vec<super::SlaViolation> {
309        let mut violations = Vec::new();
310
311        // Check latency SLAs
312        if metrics.latency_p50_ms > 10 {
313            violations.push(super::SlaViolation {
314                metric: "P50 Latency".to_string(),
315                threshold: 10.0,
316                actual_value: metrics.latency_p50_ms as f64,
317                severity: super::ViolationSeverity::Warning,
318                timestamp: Utc::now(),
319            });
320        }
321
322        if metrics.latency_p95_ms > 100 {
323            violations.push(super::SlaViolation {
324                metric: "P95 Latency".to_string(),
325                threshold: 100.0,
326                actual_value: metrics.latency_p95_ms as f64,
327                severity: super::ViolationSeverity::Critical,
328                timestamp: Utc::now(),
329            });
330        }
331
332        if metrics.latency_p99_ms > 500 {
333            violations.push(super::SlaViolation {
334                metric: "P99 Latency".to_string(),
335                threshold: 500.0,
336                actual_value: metrics.latency_p99_ms as f64,
337                severity: super::ViolationSeverity::Critical,
338                timestamp: Utc::now(),
339            });
340        }
341
342        // Check throughput SLA
343        if metrics.throughput_rps < 100.0 {
344            violations.push(super::SlaViolation {
345                metric: "Throughput".to_string(),
346                threshold: 100.0,
347                actual_value: metrics.throughput_rps,
348                severity: super::ViolationSeverity::Critical,
349                timestamp: Utc::now(),
350            });
351        }
352
353        // Check error rate SLA
354        if metrics.error_rate > 1.0 {
355            violations.push(super::SlaViolation {
356                metric: "Error Rate".to_string(),
357                threshold: 1.0,
358                actual_value: metrics.error_rate,
359                severity: super::ViolationSeverity::Critical,
360                timestamp: Utc::now(),
361            });
362        }
363
364        violations
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_calculate_percentile() {
374        let latencies = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
375
376        // 50th percentile of 10 elements = index 5 (6th element) = 6
377        assert_eq!(LoadTester::calculate_percentile(&latencies, 50.0), 6);
378        // 90th percentile of 10 elements = index 9 (10th element) = 10
379        assert_eq!(LoadTester::calculate_percentile(&latencies, 90.0), 10);
380        // 99th percentile of 10 elements = index 9 (clamped to last) = 10
381        assert_eq!(LoadTester::calculate_percentile(&latencies, 99.0), 10);
382    }
383
384    #[test]
385    fn test_calculate_percentile_empty() {
386        let latencies = vec![];
387        assert_eq!(LoadTester::calculate_percentile(&latencies, 50.0), 0);
388    }
389}