Skip to main content

engram/bench/
membench.rs

1//! MemBench — CRUD throughput and search quality benchmark
2//!
3//! Measures:
4//! - create_per_sec: How many memories can be created per second
5//! - get_per_sec: How many memories can be retrieved by ID per second
6//! - search_per_sec: How many keyword searches can run per second
7//! - ndcg_at_10: Normalized Discounted Cumulative Gain@10 for search quality
8//! - mrr: Mean Reciprocal Rank for search quality
9
10use std::collections::HashMap;
11use std::time::Instant;
12
13use super::{Benchmark, BenchmarkResult};
14use crate::storage::queries::{create_memory, get_memory};
15use crate::storage::Storage;
16use crate::types::{CreateMemoryInput, MemoryType, StorageConfig, StorageMode};
17
18/// MemBench configuration
19pub struct MemBenchmark {
20    /// Number of memories to create during throughput test
21    pub num_memories: usize,
22    /// Number of search queries to run during quality test
23    pub num_queries: usize,
24}
25
26/// A synthetic topic with keywords for relevance judgments
27struct SyntheticTopic {
28    keyword: &'static str,
29    relevant_phrases: &'static [&'static str],
30}
31
32const TOPICS: &[SyntheticTopic] = &[
33    SyntheticTopic {
34        keyword: "machine learning",
35        relevant_phrases: &[
36            "machine learning model architecture",
37            "deep learning neural network training",
38            "gradient descent optimizer convergence",
39            "training loss accuracy metrics",
40        ],
41    },
42    SyntheticTopic {
43        keyword: "database",
44        relevant_phrases: &[
45            "SQL query optimization plan",
46            "database index scan performance",
47            "transaction isolation level committed",
48            "PostgreSQL connection pool management",
49        ],
50    },
51    SyntheticTopic {
52        keyword: "security",
53        relevant_phrases: &[
54            "authentication token JWT verification",
55            "SQL injection vulnerability prevention",
56            "HTTPS TLS certificate renewal",
57            "password hashing bcrypt salt",
58        ],
59    },
60    SyntheticTopic {
61        keyword: "performance",
62        relevant_phrases: &[
63            "latency p99 benchmark test results",
64            "throughput requests per second measurement",
65            "memory allocation profiling heap",
66            "cache hit ratio optimization",
67        ],
68    },
69];
70
71/// Corpus of memories to create, mixing relevant and irrelevant content
72const CORPUS_TEMPLATES: &[&str] = &[
73    "machine learning model trained on {} dataset with Adam optimizer",
74    "deep learning neural network achieved 95% accuracy on image classification",
75    "gradient descent optimizer converged after 1000 epochs of training",
76    "SQL query optimization reduced latency by 40% after index tuning",
77    "database index scan improved search performance significantly",
78    "transaction isolation level set to READ COMMITTED for consistency",
79    "authentication token JWT expires after 1 hour session timeout",
80    "HTTPS TLS certificate renewed for production domain hosting",
81    "latency p99 benchmark shows 12ms for 100k RPS under load",
82    "memory allocation profiling revealed 200MB footprint in production",
83    "unrelated fact about cooking: pasta needs salted boiling water",
84    "weather today is sunny with 25 degrees Celsius temperature",
85    "team meeting scheduled for next Tuesday at 2pm in conference room",
86    "coffee machine on floor 3 needs maintenance and refill",
87    "quarterly report submitted to finance department for review",
88    "new joiner onboarding checklist completed successfully",
89    "vacation request approved for two weeks in August holidays",
90    "parking permit renewed for building B underground garage",
91    "printer on floor 2 is out of paper and toner cartridge",
92    "lunch order: 5 sandwiches and 3 salads for the engineering team",
93];
94
95impl MemBenchmark {
96    /// Compute NDCG@k given ordered retrieved IDs and a set of relevant IDs
97    pub fn ndcg_at_k(retrieved: &[i64], relevant_ids: &[i64], k: usize) -> f64 {
98        let top_k: Vec<_> = retrieved.iter().take(k).collect();
99
100        // DCG
101        let dcg: f64 = top_k
102            .iter()
103            .enumerate()
104            .map(|(i, &&id)| {
105                let rel = if relevant_ids.contains(&id) { 1.0 } else { 0.0 };
106                rel / (i as f64 + 2.0).log2()
107            })
108            .sum();
109
110        // Ideal DCG: assume all relevant docs at the top
111        let num_relevant = relevant_ids.len().min(k);
112        let idcg: f64 = (0..num_relevant)
113            .map(|i| 1.0 / (i as f64 + 2.0).log2())
114            .sum();
115
116        if idcg == 0.0 {
117            0.0
118        } else {
119            dcg / idcg
120        }
121    }
122
123    /// Compute MRR given ordered retrieved IDs and a set of relevant IDs
124    pub fn mrr(retrieved: &[i64], relevant_ids: &[i64]) -> f64 {
125        for (i, &id) in retrieved.iter().enumerate() {
126            if relevant_ids.contains(&id) {
127                return 1.0 / (i as f64 + 1.0);
128            }
129        }
130        0.0
131    }
132}
133
134impl Benchmark for MemBenchmark {
135    fn name(&self) -> &str {
136        "membench"
137    }
138
139    fn description(&self) -> &str {
140        "CRUD throughput and search quality benchmark. Measures create_per_sec, get_per_sec, \
141         search_per_sec, NDCG@10, and MRR using synthetic memories."
142    }
143
144    fn run(&self, db_path: &str) -> Result<BenchmarkResult, Box<dyn std::error::Error>> {
145        let storage = if db_path == ":memory:" {
146            Storage::open_in_memory()?
147        } else {
148            let bench_path = format!("{}.membench.db", db_path);
149            Storage::open(StorageConfig {
150                db_path: bench_path,
151                storage_mode: StorageMode::Local,
152                cloud_uri: None,
153                encrypt_cloud: false,
154                confidence_half_life_days: 30.0,
155                auto_sync: false,
156                sync_debounce_ms: 5000,
157            })?
158        };
159
160        // ===== Phase 1: CREATE throughput =====
161        let create_start = Instant::now();
162        let mut created_ids: Vec<i64> = Vec::with_capacity(self.num_memories);
163
164        for i in 0..self.num_memories {
165            let template = CORPUS_TEMPLATES[i % CORPUS_TEMPLATES.len()];
166            let content = template.replace("{}", &format!("batch_{}", i));
167            let mem = storage.with_connection(|conn| {
168                create_memory(
169                    conn,
170                    &CreateMemoryInput {
171                        content,
172                        memory_type: MemoryType::Note,
173                        workspace: Some("membench".to_string()),
174                        ..Default::default()
175                    },
176                )
177            })?;
178            created_ids.push(mem.id);
179        }
180        let create_elapsed = create_start.elapsed();
181        let create_per_sec = if create_elapsed.as_secs_f64() > 0.0 {
182            self.num_memories as f64 / create_elapsed.as_secs_f64()
183        } else {
184            self.num_memories as f64 * 1_000_000.0
185        };
186
187        // ===== Phase 2: GET throughput =====
188        let get_start = Instant::now();
189        let mut get_hits = 0usize;
190        for &id in &created_ids {
191            if storage.with_connection(|conn| get_memory(conn, id)).is_ok() {
192                get_hits += 1;
193            }
194        }
195        let get_elapsed = get_start.elapsed();
196        let get_per_sec = if get_elapsed.as_secs_f64() > 0.0 {
197            get_hits as f64 / get_elapsed.as_secs_f64()
198        } else {
199            get_hits as f64 * 1_000_000.0
200        };
201
202        // ===== Phase 3: SEARCH throughput + quality =====
203        // Create topic-specific memories and track which IDs are relevant
204        let mut topic_relevant_ids: HashMap<&str, Vec<i64>> = HashMap::new();
205
206        for topic in TOPICS {
207            let mut relevant = Vec::new();
208            for phrase in topic.relevant_phrases {
209                let mem = storage.with_connection(|conn| {
210                    create_memory(
211                        conn,
212                        &CreateMemoryInput {
213                            content: phrase.to_string(),
214                            memory_type: MemoryType::Note,
215                            workspace: Some("membench-quality".to_string()),
216                            ..Default::default()
217                        },
218                    )
219                })?;
220                relevant.push(mem.id);
221            }
222            topic_relevant_ids.insert(topic.keyword, relevant);
223        }
224
225        let search_start = Instant::now();
226        let mut ndcg_sum = 0.0f64;
227        let mut mrr_sum = 0.0f64;
228        let mut search_count = 0usize;
229
230        let queries: Vec<&str> = TOPICS
231            .iter()
232            .map(|t| t.keyword)
233            .cycle()
234            .take(self.num_queries)
235            .collect();
236
237        for query in &queries {
238            let keyword_pattern = format!("%{}%", query);
239            let retrieved_ids: Vec<i64> = storage.with_connection(|conn| {
240                let mut stmt = conn.prepare(
241                    "SELECT id FROM memories WHERE content LIKE ?1 \
242                     ORDER BY created_at DESC LIMIT 10",
243                )?;
244                let ids: Vec<i64> = stmt
245                    .query_map([&keyword_pattern], |row| row.get(0))?
246                    .filter_map(|r| r.ok())
247                    .collect();
248                Ok(ids)
249            })?;
250
251            if let Some(relevant_ids) = topic_relevant_ids.get(query) {
252                ndcg_sum += Self::ndcg_at_k(&retrieved_ids, relevant_ids, 10);
253                mrr_sum += Self::mrr(&retrieved_ids, relevant_ids);
254                search_count += 1;
255            }
256        }
257
258        let search_elapsed = search_start.elapsed();
259        let search_per_sec = if search_elapsed.as_secs_f64() > 0.0 {
260            self.num_queries as f64 / search_elapsed.as_secs_f64()
261        } else {
262            self.num_queries as f64 * 1_000_000.0
263        };
264
265        let ndcg_at_10 = if search_count > 0 {
266            ndcg_sum / search_count as f64
267        } else {
268            0.0
269        };
270
271        let mrr = if search_count > 0 {
272            mrr_sum / search_count as f64
273        } else {
274            0.0
275        };
276
277        let duration_ms = create_elapsed.as_millis() as u64
278            + get_elapsed.as_millis() as u64
279            + search_elapsed.as_millis() as u64;
280
281        let mut metrics = HashMap::new();
282        metrics.insert(
283            "create_per_sec".to_string(),
284            create_per_sec.min(1_000_000.0),
285        );
286        metrics.insert("get_per_sec".to_string(), get_per_sec.min(1_000_000.0));
287        metrics.insert(
288            "search_per_sec".to_string(),
289            search_per_sec.min(1_000_000.0),
290        );
291        metrics.insert("ndcg_at_10".to_string(), ndcg_at_10);
292        metrics.insert("mrr".to_string(), mrr);
293        metrics.insert("num_memories".to_string(), self.num_memories as f64);
294        metrics.insert("num_queries".to_string(), self.num_queries as f64);
295
296        // Clean up temporary file
297        if db_path != ":memory:" {
298            let bench_path = format!("{}.membench.db", db_path);
299            drop(storage);
300            let _ = std::fs::remove_file(&bench_path);
301            let _ = std::fs::remove_file(format!("{}-wal", bench_path));
302            let _ = std::fs::remove_file(format!("{}-shm", bench_path));
303        }
304
305        Ok(BenchmarkResult {
306            name: self.name().to_string(),
307            metrics,
308            duration_ms,
309            timestamp: chrono::Utc::now().to_rfc3339(),
310        })
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_membench_runs() {
320        let bench = MemBenchmark {
321            num_memories: 20,
322            num_queries: 5,
323        };
324        let result = bench.run(":memory:").expect("benchmark should succeed");
325        assert_eq!(result.name, "membench");
326    }
327
328    #[test]
329    fn test_membench_metrics_present() {
330        let bench = MemBenchmark {
331            num_memories: 10,
332            num_queries: 4,
333        };
334        let result = bench.run(":memory:").expect("benchmark should succeed");
335
336        let expected_keys = [
337            "create_per_sec",
338            "get_per_sec",
339            "search_per_sec",
340            "ndcg_at_10",
341            "mrr",
342        ];
343        for key in &expected_keys {
344            assert!(result.metrics.contains_key(*key), "missing metric: {}", key);
345        }
346    }
347
348    #[test]
349    fn test_throughput_positive() {
350        let bench = MemBenchmark {
351            num_memories: 50,
352            num_queries: 10,
353        };
354        let result = bench.run(":memory:").expect("benchmark should succeed");
355        assert!(
356            result.metrics["create_per_sec"] > 0.0,
357            "create_per_sec should be positive"
358        );
359        assert!(
360            result.metrics["get_per_sec"] > 0.0,
361            "get_per_sec should be positive"
362        );
363    }
364
365    #[test]
366    fn test_ndcg_range() {
367        let bench = MemBenchmark {
368            num_memories: 30,
369            num_queries: 8,
370        };
371        let result = bench.run(":memory:").expect("benchmark should succeed");
372        let ndcg = result.metrics["ndcg_at_10"];
373        assert!(
374            (0.0..=1.0).contains(&ndcg),
375            "NDCG@10 = {} out of range",
376            ndcg
377        );
378    }
379
380    #[test]
381    fn test_ndcg_at_k_computation() {
382        // Relevant docs at positions 0, 2, 4 → DCG should be positive
383        let relevant = vec![1i64, 2, 3];
384        let retrieved = vec![1i64, 4, 2, 5, 3];
385        let ndcg = MemBenchmark::ndcg_at_k(&retrieved, &relevant, 5);
386        assert!(ndcg > 0.0 && ndcg <= 1.0, "ndcg={}", ndcg);
387
388        // Empty retrieval → NDCG = 0
389        let ndcg_empty = MemBenchmark::ndcg_at_k(&[], &relevant, 10);
390        assert_eq!(ndcg_empty, 0.0);
391
392        // Perfect ranking: relevant docs at the top → NDCG = 1.0
393        let perfect = vec![1i64, 2, 3, 4, 5];
394        let ndcg_perfect = MemBenchmark::ndcg_at_k(&perfect, &[1, 2, 3], 3);
395        assert!(
396            (ndcg_perfect - 1.0).abs() < 1e-9,
397            "perfect ndcg={}",
398            ndcg_perfect
399        );
400    }
401
402    #[test]
403    fn test_mrr_computation() {
404        // First hit at position 2 (0-indexed) → MRR = 1/3
405        let relevant = vec![3i64];
406        let retrieved = vec![1i64, 2, 3, 4, 5];
407        let mrr = MemBenchmark::mrr(&retrieved, &relevant);
408        assert!((mrr - 1.0 / 3.0).abs() < 1e-9, "mrr={}", mrr);
409
410        // No hit → MRR = 0
411        let mrr_miss = MemBenchmark::mrr(&[10, 11, 12], &[99]);
412        assert_eq!(mrr_miss, 0.0);
413
414        // First position hit → MRR = 1.0
415        let mrr_first = MemBenchmark::mrr(&[5, 6, 7], &[5]);
416        assert!((mrr_first - 1.0).abs() < 1e-9, "mrr_first={}", mrr_first);
417    }
418}