Skip to main content

engram/bench/
mod.rs

1//! Standardized benchmark suite for Engram
2//!
3//! Provides implementations of standard AI memory benchmarks:
4//! - LOCOMO: Multi-session conversation memory
5//! - LongMemEval: 5-dimension memory evaluation
6//! - MemBench: CRUD throughput and search quality
7
8pub mod locomo;
9pub mod longmemeval;
10pub mod membench;
11
12use std::collections::HashMap;
13
14use chrono::Utc;
15use serde::{Deserialize, Serialize};
16
17/// Result of running a single benchmark
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct BenchmarkResult {
20    pub name: String,
21    pub metrics: HashMap<String, f64>,
22    pub duration_ms: u64,
23    pub timestamp: String,
24}
25
26/// A benchmark that can be run against an Engram database
27pub trait Benchmark: Send + Sync {
28    fn name(&self) -> &str;
29    fn description(&self) -> &str;
30    fn run(&self, db_path: &str) -> Result<BenchmarkResult, Box<dyn std::error::Error>>;
31}
32
33/// Suite that manages and runs multiple benchmarks
34pub struct BenchmarkSuite {
35    benchmarks: Vec<Box<dyn Benchmark>>,
36}
37
38impl BenchmarkSuite {
39    /// Create an empty benchmark suite
40    pub fn new() -> Self {
41        Self {
42            benchmarks: Vec::new(),
43        }
44    }
45
46    /// Add a benchmark to the suite
47    pub fn add(&mut self, benchmark: Box<dyn Benchmark>) {
48        self.benchmarks.push(benchmark);
49    }
50
51    /// Run all benchmarks and return results
52    pub fn run_all(&self, db_path: &str) -> Vec<BenchmarkResult> {
53        self.benchmarks
54            .iter()
55            .map(|b| {
56                b.run(db_path).unwrap_or_else(|e| BenchmarkResult {
57                    name: b.name().to_string(),
58                    metrics: {
59                        let mut m = HashMap::new();
60                        m.insert("error".to_string(), 0.0);
61                        m.insert("error_message".to_string(), 0.0);
62                        let _ = e;
63                        m
64                    },
65                    duration_ms: 0,
66                    timestamp: Utc::now().to_rfc3339(),
67                })
68            })
69            .collect()
70    }
71
72    /// Format results as JSON
73    pub fn report_json(results: &[BenchmarkResult]) -> String {
74        serde_json::to_string_pretty(results).unwrap_or_else(|_| "[]".to_string())
75    }
76
77    /// Format results as Markdown table
78    pub fn report_markdown(results: &[BenchmarkResult]) -> String {
79        if results.is_empty() {
80            return "No benchmark results.\n".to_string();
81        }
82
83        let mut out = String::new();
84        out.push_str("# Engram Benchmark Results\n\n");
85        out.push_str(&format!("*Run at: {}*\n\n", Utc::now().to_rfc3339()));
86
87        for result in results {
88            out.push_str(&format!("## {}\n\n", result.name));
89            out.push_str(&format!("Duration: {}ms\n\n", result.duration_ms));
90            out.push_str("| Metric | Value |\n");
91            out.push_str("|--------|-------|\n");
92
93            let mut metrics: Vec<_> = result.metrics.iter().collect();
94            metrics.sort_by_key(|(k, _)| k.as_str());
95            for (key, value) in metrics {
96                out.push_str(&format!("| {} | {:.4} |\n", key, value));
97            }
98            out.push('\n');
99        }
100
101        out
102    }
103
104    /// Format results as CSV
105    pub fn report_csv(results: &[BenchmarkResult]) -> String {
106        if results.is_empty() {
107            return "benchmark,metric,value,duration_ms,timestamp\n".to_string();
108        }
109
110        let mut out = String::from("benchmark,metric,value,duration_ms,timestamp\n");
111
112        for result in results {
113            let mut metrics: Vec<_> = result.metrics.iter().collect();
114            metrics.sort_by_key(|(k, _)| k.as_str());
115            for (key, value) in metrics {
116                out.push_str(&format!(
117                    "{},{},{:.6},{},{}\n",
118                    result.name, key, value, result.duration_ms, result.timestamp
119                ));
120            }
121        }
122
123        out
124    }
125}
126
127impl Default for BenchmarkSuite {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133/// Build the default suite with all benchmarks
134pub fn default_suite() -> BenchmarkSuite {
135    let mut suite = BenchmarkSuite::new();
136    suite.add(Box::new(locomo::LocomoBenchmark {
137        num_conversations: 10,
138        queries_per_conversation: 3,
139    }));
140    suite.add(Box::new(longmemeval::LongMemEvalBenchmark::default()));
141    suite.add(Box::new(membench::MemBenchmark {
142        num_memories: 100,
143        num_queries: 20,
144    }));
145    suite
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    struct DummyBenchmark {
153        name: String,
154    }
155
156    impl Benchmark for DummyBenchmark {
157        fn name(&self) -> &str {
158            &self.name
159        }
160
161        fn description(&self) -> &str {
162            "A dummy benchmark for testing"
163        }
164
165        fn run(&self, _db_path: &str) -> Result<BenchmarkResult, Box<dyn std::error::Error>> {
166            let mut metrics = HashMap::new();
167            metrics.insert("score".to_string(), 0.95);
168            metrics.insert("latency_ms".to_string(), 12.5);
169
170            Ok(BenchmarkResult {
171                name: self.name.clone(),
172                metrics,
173                duration_ms: 42,
174                timestamp: Utc::now().to_rfc3339(),
175            })
176        }
177    }
178
179    #[test]
180    fn test_suite_creation() {
181        let suite = BenchmarkSuite::new();
182        assert_eq!(suite.benchmarks.len(), 0);
183    }
184
185    #[test]
186    fn test_suite_add_and_run() {
187        let mut suite = BenchmarkSuite::new();
188        suite.add(Box::new(DummyBenchmark {
189            name: "test-bench".to_string(),
190        }));
191        assert_eq!(suite.benchmarks.len(), 1);
192
193        let results = suite.run_all(":memory:");
194        assert_eq!(results.len(), 1);
195        assert_eq!(results[0].name, "test-bench");
196        assert!(results[0].metrics.contains_key("score"));
197        assert_eq!(results[0].duration_ms, 42);
198    }
199
200    #[test]
201    fn test_report_json() {
202        let results = vec![BenchmarkResult {
203            name: "test".to_string(),
204            metrics: {
205                let mut m = HashMap::new();
206                m.insert("precision".to_string(), 0.85);
207                m
208            },
209            duration_ms: 100,
210            timestamp: "2026-01-01T00:00:00Z".to_string(),
211        }];
212
213        let json = BenchmarkSuite::report_json(&results);
214        assert!(json.contains("\"test\""));
215        assert!(json.contains("precision"));
216        assert!(json.contains("0.85"));
217    }
218
219    #[test]
220    fn test_report_markdown() {
221        let results = vec![BenchmarkResult {
222            name: "locomo".to_string(),
223            metrics: {
224                let mut m = HashMap::new();
225                m.insert("f1".to_string(), 0.72);
226                m
227            },
228            duration_ms: 200,
229            timestamp: "2026-01-01T00:00:00Z".to_string(),
230        }];
231
232        let md = BenchmarkSuite::report_markdown(&results);
233        assert!(md.contains("## locomo"));
234        assert!(md.contains("200ms"));
235        assert!(md.contains("f1"));
236    }
237
238    #[test]
239    fn test_report_csv() {
240        let results = vec![BenchmarkResult {
241            name: "membench".to_string(),
242            metrics: {
243                let mut m = HashMap::new();
244                m.insert("create_per_sec".to_string(), 500.0);
245                m
246            },
247            duration_ms: 150,
248            timestamp: "2026-01-01T00:00:00Z".to_string(),
249        }];
250
251        let csv = BenchmarkSuite::report_csv(&results);
252        assert!(csv.starts_with("benchmark,metric,value"));
253        assert!(csv.contains("membench"));
254        assert!(csv.contains("create_per_sec"));
255    }
256
257    #[test]
258    fn test_result_serialization() {
259        let result = BenchmarkResult {
260            name: "roundtrip".to_string(),
261            metrics: {
262                let mut m = HashMap::new();
263                m.insert("recall".to_string(), 0.9);
264                m
265            },
266            duration_ms: 77,
267            timestamp: "2026-01-01T00:00:00Z".to_string(),
268        };
269
270        let serialized = serde_json::to_string(&result).expect("should serialize");
271        let deserialized: BenchmarkResult =
272            serde_json::from_str(&serialized).expect("should deserialize");
273        assert_eq!(deserialized.name, result.name);
274        assert_eq!(deserialized.duration_ms, result.duration_ms);
275        assert!((deserialized.metrics["recall"] - 0.9).abs() < 1e-9);
276    }
277}