1use std::collections::HashMap;
11use std::time::Instant;
12
13use super::{Benchmark, BenchmarkResult};
14use crate::storage::queries::{create_memory, get_memory};
15use crate::storage::Storage;
16use crate::types::{CreateMemoryInput, MemoryType, StorageConfig, StorageMode};
17
18pub struct MemBenchmark {
20 pub num_memories: usize,
22 pub num_queries: usize,
24}
25
26struct SyntheticTopic {
28 keyword: &'static str,
29 relevant_phrases: &'static [&'static str],
30}
31
32const TOPICS: &[SyntheticTopic] = &[
33 SyntheticTopic {
34 keyword: "machine learning",
35 relevant_phrases: &[
36 "machine learning model architecture",
37 "deep learning neural network training",
38 "gradient descent optimizer convergence",
39 "training loss accuracy metrics",
40 ],
41 },
42 SyntheticTopic {
43 keyword: "database",
44 relevant_phrases: &[
45 "SQL query optimization plan",
46 "database index scan performance",
47 "transaction isolation level committed",
48 "PostgreSQL connection pool management",
49 ],
50 },
51 SyntheticTopic {
52 keyword: "security",
53 relevant_phrases: &[
54 "authentication token JWT verification",
55 "SQL injection vulnerability prevention",
56 "HTTPS TLS certificate renewal",
57 "password hashing bcrypt salt",
58 ],
59 },
60 SyntheticTopic {
61 keyword: "performance",
62 relevant_phrases: &[
63 "latency p99 benchmark test results",
64 "throughput requests per second measurement",
65 "memory allocation profiling heap",
66 "cache hit ratio optimization",
67 ],
68 },
69];
70
71const CORPUS_TEMPLATES: &[&str] = &[
73 "machine learning model trained on {} dataset with Adam optimizer",
74 "deep learning neural network achieved 95% accuracy on image classification",
75 "gradient descent optimizer converged after 1000 epochs of training",
76 "SQL query optimization reduced latency by 40% after index tuning",
77 "database index scan improved search performance significantly",
78 "transaction isolation level set to READ COMMITTED for consistency",
79 "authentication token JWT expires after 1 hour session timeout",
80 "HTTPS TLS certificate renewed for production domain hosting",
81 "latency p99 benchmark shows 12ms for 100k RPS under load",
82 "memory allocation profiling revealed 200MB footprint in production",
83 "unrelated fact about cooking: pasta needs salted boiling water",
84 "weather today is sunny with 25 degrees Celsius temperature",
85 "team meeting scheduled for next Tuesday at 2pm in conference room",
86 "coffee machine on floor 3 needs maintenance and refill",
87 "quarterly report submitted to finance department for review",
88 "new joiner onboarding checklist completed successfully",
89 "vacation request approved for two weeks in August holidays",
90 "parking permit renewed for building B underground garage",
91 "printer on floor 2 is out of paper and toner cartridge",
92 "lunch order: 5 sandwiches and 3 salads for the engineering team",
93];
94
95impl MemBenchmark {
96 pub fn ndcg_at_k(retrieved: &[i64], relevant_ids: &[i64], k: usize) -> f64 {
98 let top_k: Vec<_> = retrieved.iter().take(k).collect();
99
100 let dcg: f64 = top_k
102 .iter()
103 .enumerate()
104 .map(|(i, &&id)| {
105 let rel = if relevant_ids.contains(&id) { 1.0 } else { 0.0 };
106 rel / (i as f64 + 2.0).log2()
107 })
108 .sum();
109
110 let num_relevant = relevant_ids.len().min(k);
112 let idcg: f64 = (0..num_relevant)
113 .map(|i| 1.0 / (i as f64 + 2.0).log2())
114 .sum();
115
116 if idcg == 0.0 {
117 0.0
118 } else {
119 dcg / idcg
120 }
121 }
122
123 pub fn mrr(retrieved: &[i64], relevant_ids: &[i64]) -> f64 {
125 for (i, &id) in retrieved.iter().enumerate() {
126 if relevant_ids.contains(&id) {
127 return 1.0 / (i as f64 + 1.0);
128 }
129 }
130 0.0
131 }
132}
133
134impl Benchmark for MemBenchmark {
135 fn name(&self) -> &str {
136 "membench"
137 }
138
139 fn description(&self) -> &str {
140 "CRUD throughput and search quality benchmark. Measures create_per_sec, get_per_sec, \
141 search_per_sec, NDCG@10, and MRR using synthetic memories."
142 }
143
144 fn run(&self, db_path: &str) -> Result<BenchmarkResult, Box<dyn std::error::Error>> {
145 let storage = if db_path == ":memory:" {
146 Storage::open_in_memory()?
147 } else {
148 let bench_path = format!("{}.membench.db", db_path);
149 Storage::open(StorageConfig {
150 db_path: bench_path,
151 storage_mode: StorageMode::Local,
152 cloud_uri: None,
153 encrypt_cloud: false,
154 confidence_half_life_days: 30.0,
155 auto_sync: false,
156 sync_debounce_ms: 5000,
157 })?
158 };
159
160 let create_start = Instant::now();
162 let mut created_ids: Vec<i64> = Vec::with_capacity(self.num_memories);
163
164 for i in 0..self.num_memories {
165 let template = CORPUS_TEMPLATES[i % CORPUS_TEMPLATES.len()];
166 let content = template.replace("{}", &format!("batch_{}", i));
167 let mem = storage.with_connection(|conn| {
168 create_memory(
169 conn,
170 &CreateMemoryInput {
171 content,
172 memory_type: MemoryType::Note,
173 workspace: Some("membench".to_string()),
174 ..Default::default()
175 },
176 )
177 })?;
178 created_ids.push(mem.id);
179 }
180 let create_elapsed = create_start.elapsed();
181 let create_per_sec = if create_elapsed.as_secs_f64() > 0.0 {
182 self.num_memories as f64 / create_elapsed.as_secs_f64()
183 } else {
184 self.num_memories as f64 * 1_000_000.0
185 };
186
187 let get_start = Instant::now();
189 let mut get_hits = 0usize;
190 for &id in &created_ids {
191 if storage.with_connection(|conn| get_memory(conn, id)).is_ok() {
192 get_hits += 1;
193 }
194 }
195 let get_elapsed = get_start.elapsed();
196 let get_per_sec = if get_elapsed.as_secs_f64() > 0.0 {
197 get_hits as f64 / get_elapsed.as_secs_f64()
198 } else {
199 get_hits as f64 * 1_000_000.0
200 };
201
202 let mut topic_relevant_ids: HashMap<&str, Vec<i64>> = HashMap::new();
205
206 for topic in TOPICS {
207 let mut relevant = Vec::new();
208 for phrase in topic.relevant_phrases {
209 let mem = storage.with_connection(|conn| {
210 create_memory(
211 conn,
212 &CreateMemoryInput {
213 content: phrase.to_string(),
214 memory_type: MemoryType::Note,
215 workspace: Some("membench-quality".to_string()),
216 ..Default::default()
217 },
218 )
219 })?;
220 relevant.push(mem.id);
221 }
222 topic_relevant_ids.insert(topic.keyword, relevant);
223 }
224
225 let search_start = Instant::now();
226 let mut ndcg_sum = 0.0f64;
227 let mut mrr_sum = 0.0f64;
228 let mut search_count = 0usize;
229
230 let queries: Vec<&str> = TOPICS
231 .iter()
232 .map(|t| t.keyword)
233 .cycle()
234 .take(self.num_queries)
235 .collect();
236
237 for query in &queries {
238 let keyword_pattern = format!("%{}%", query);
239 let retrieved_ids: Vec<i64> = storage.with_connection(|conn| {
240 let mut stmt = conn.prepare(
241 "SELECT id FROM memories WHERE content LIKE ?1 \
242 ORDER BY created_at DESC LIMIT 10",
243 )?;
244 let ids: Vec<i64> = stmt
245 .query_map([&keyword_pattern], |row| row.get(0))?
246 .filter_map(|r| r.ok())
247 .collect();
248 Ok(ids)
249 })?;
250
251 if let Some(relevant_ids) = topic_relevant_ids.get(query) {
252 ndcg_sum += Self::ndcg_at_k(&retrieved_ids, relevant_ids, 10);
253 mrr_sum += Self::mrr(&retrieved_ids, relevant_ids);
254 search_count += 1;
255 }
256 }
257
258 let search_elapsed = search_start.elapsed();
259 let search_per_sec = if search_elapsed.as_secs_f64() > 0.0 {
260 self.num_queries as f64 / search_elapsed.as_secs_f64()
261 } else {
262 self.num_queries as f64 * 1_000_000.0
263 };
264
265 let ndcg_at_10 = if search_count > 0 {
266 ndcg_sum / search_count as f64
267 } else {
268 0.0
269 };
270
271 let mrr = if search_count > 0 {
272 mrr_sum / search_count as f64
273 } else {
274 0.0
275 };
276
277 let duration_ms = create_elapsed.as_millis() as u64
278 + get_elapsed.as_millis() as u64
279 + search_elapsed.as_millis() as u64;
280
281 let mut metrics = HashMap::new();
282 metrics.insert(
283 "create_per_sec".to_string(),
284 create_per_sec.min(1_000_000.0),
285 );
286 metrics.insert("get_per_sec".to_string(), get_per_sec.min(1_000_000.0));
287 metrics.insert(
288 "search_per_sec".to_string(),
289 search_per_sec.min(1_000_000.0),
290 );
291 metrics.insert("ndcg_at_10".to_string(), ndcg_at_10);
292 metrics.insert("mrr".to_string(), mrr);
293 metrics.insert("num_memories".to_string(), self.num_memories as f64);
294 metrics.insert("num_queries".to_string(), self.num_queries as f64);
295
296 if db_path != ":memory:" {
298 let bench_path = format!("{}.membench.db", db_path);
299 drop(storage);
300 let _ = std::fs::remove_file(&bench_path);
301 let _ = std::fs::remove_file(format!("{}-wal", bench_path));
302 let _ = std::fs::remove_file(format!("{}-shm", bench_path));
303 }
304
305 Ok(BenchmarkResult {
306 name: self.name().to_string(),
307 metrics,
308 duration_ms,
309 timestamp: chrono::Utc::now().to_rfc3339(),
310 })
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_membench_runs() {
320 let bench = MemBenchmark {
321 num_memories: 20,
322 num_queries: 5,
323 };
324 let result = bench.run(":memory:").expect("benchmark should succeed");
325 assert_eq!(result.name, "membench");
326 }
327
328 #[test]
329 fn test_membench_metrics_present() {
330 let bench = MemBenchmark {
331 num_memories: 10,
332 num_queries: 4,
333 };
334 let result = bench.run(":memory:").expect("benchmark should succeed");
335
336 let expected_keys = [
337 "create_per_sec",
338 "get_per_sec",
339 "search_per_sec",
340 "ndcg_at_10",
341 "mrr",
342 ];
343 for key in &expected_keys {
344 assert!(result.metrics.contains_key(*key), "missing metric: {}", key);
345 }
346 }
347
348 #[test]
349 fn test_throughput_positive() {
350 let bench = MemBenchmark {
351 num_memories: 50,
352 num_queries: 10,
353 };
354 let result = bench.run(":memory:").expect("benchmark should succeed");
355 assert!(
356 result.metrics["create_per_sec"] > 0.0,
357 "create_per_sec should be positive"
358 );
359 assert!(
360 result.metrics["get_per_sec"] > 0.0,
361 "get_per_sec should be positive"
362 );
363 }
364
365 #[test]
366 fn test_ndcg_range() {
367 let bench = MemBenchmark {
368 num_memories: 30,
369 num_queries: 8,
370 };
371 let result = bench.run(":memory:").expect("benchmark should succeed");
372 let ndcg = result.metrics["ndcg_at_10"];
373 assert!(
374 (0.0..=1.0).contains(&ndcg),
375 "NDCG@10 = {} out of range",
376 ndcg
377 );
378 }
379
380 #[test]
381 fn test_ndcg_at_k_computation() {
382 let relevant = vec![1i64, 2, 3];
384 let retrieved = vec![1i64, 4, 2, 5, 3];
385 let ndcg = MemBenchmark::ndcg_at_k(&retrieved, &relevant, 5);
386 assert!(ndcg > 0.0 && ndcg <= 1.0, "ndcg={}", ndcg);
387
388 let ndcg_empty = MemBenchmark::ndcg_at_k(&[], &relevant, 10);
390 assert_eq!(ndcg_empty, 0.0);
391
392 let perfect = vec![1i64, 2, 3, 4, 5];
394 let ndcg_perfect = MemBenchmark::ndcg_at_k(&perfect, &[1, 2, 3], 3);
395 assert!(
396 (ndcg_perfect - 1.0).abs() < 1e-9,
397 "perfect ndcg={}",
398 ndcg_perfect
399 );
400 }
401
402 #[test]
403 fn test_mrr_computation() {
404 let relevant = vec![3i64];
406 let retrieved = vec![1i64, 2, 3, 4, 5];
407 let mrr = MemBenchmark::mrr(&retrieved, &relevant);
408 assert!((mrr - 1.0 / 3.0).abs() < 1e-9, "mrr={}", mrr);
409
410 let mrr_miss = MemBenchmark::mrr(&[10, 11, 12], &[99]);
412 assert_eq!(mrr_miss, 0.0);
413
414 let mrr_first = MemBenchmark::mrr(&[5, 6, 7], &[5]);
416 assert!((mrr_first - 1.0).abs() < 1e-9, "mrr_first={}", mrr_first);
417 }
418}