1pub mod locomo;
9pub mod longmemeval;
10pub mod membench;
11
12use std::collections::HashMap;
13
14use chrono::Utc;
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct BenchmarkResult {
20 pub name: String,
21 pub metrics: HashMap<String, f64>,
22 pub duration_ms: u64,
23 pub timestamp: String,
24}
25
26pub trait Benchmark: Send + Sync {
28 fn name(&self) -> &str;
29 fn description(&self) -> &str;
30 fn run(&self, db_path: &str) -> Result<BenchmarkResult, Box<dyn std::error::Error>>;
31}
32
33pub struct BenchmarkSuite {
35 benchmarks: Vec<Box<dyn Benchmark>>,
36}
37
38impl BenchmarkSuite {
39 pub fn new() -> Self {
41 Self {
42 benchmarks: Vec::new(),
43 }
44 }
45
46 pub fn add(&mut self, benchmark: Box<dyn Benchmark>) {
48 self.benchmarks.push(benchmark);
49 }
50
51 pub fn run_all(&self, db_path: &str) -> Vec<BenchmarkResult> {
53 self.benchmarks
54 .iter()
55 .map(|b| {
56 b.run(db_path).unwrap_or_else(|e| BenchmarkResult {
57 name: b.name().to_string(),
58 metrics: {
59 let mut m = HashMap::new();
60 m.insert("error".to_string(), 0.0);
61 m.insert("error_message".to_string(), 0.0);
62 let _ = e;
63 m
64 },
65 duration_ms: 0,
66 timestamp: Utc::now().to_rfc3339(),
67 })
68 })
69 .collect()
70 }
71
72 pub fn report_json(results: &[BenchmarkResult]) -> String {
74 serde_json::to_string_pretty(results).unwrap_or_else(|_| "[]".to_string())
75 }
76
77 pub fn report_markdown(results: &[BenchmarkResult]) -> String {
79 if results.is_empty() {
80 return "No benchmark results.\n".to_string();
81 }
82
83 let mut out = String::new();
84 out.push_str("# Engram Benchmark Results\n\n");
85 out.push_str(&format!("*Run at: {}*\n\n", Utc::now().to_rfc3339()));
86
87 for result in results {
88 out.push_str(&format!("## {}\n\n", result.name));
89 out.push_str(&format!("Duration: {}ms\n\n", result.duration_ms));
90 out.push_str("| Metric | Value |\n");
91 out.push_str("|--------|-------|\n");
92
93 let mut metrics: Vec<_> = result.metrics.iter().collect();
94 metrics.sort_by_key(|(k, _)| k.as_str());
95 for (key, value) in metrics {
96 out.push_str(&format!("| {} | {:.4} |\n", key, value));
97 }
98 out.push('\n');
99 }
100
101 out
102 }
103
104 pub fn report_csv(results: &[BenchmarkResult]) -> String {
106 if results.is_empty() {
107 return "benchmark,metric,value,duration_ms,timestamp\n".to_string();
108 }
109
110 let mut out = String::from("benchmark,metric,value,duration_ms,timestamp\n");
111
112 for result in results {
113 let mut metrics: Vec<_> = result.metrics.iter().collect();
114 metrics.sort_by_key(|(k, _)| k.as_str());
115 for (key, value) in metrics {
116 out.push_str(&format!(
117 "{},{},{:.6},{},{}\n",
118 result.name, key, value, result.duration_ms, result.timestamp
119 ));
120 }
121 }
122
123 out
124 }
125}
126
127impl Default for BenchmarkSuite {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133pub fn default_suite() -> BenchmarkSuite {
135 let mut suite = BenchmarkSuite::new();
136 suite.add(Box::new(locomo::LocomoBenchmark {
137 num_conversations: 10,
138 queries_per_conversation: 3,
139 }));
140 suite.add(Box::new(longmemeval::LongMemEvalBenchmark::default()));
141 suite.add(Box::new(membench::MemBenchmark {
142 num_memories: 100,
143 num_queries: 20,
144 }));
145 suite
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 struct DummyBenchmark {
153 name: String,
154 }
155
156 impl Benchmark for DummyBenchmark {
157 fn name(&self) -> &str {
158 &self.name
159 }
160
161 fn description(&self) -> &str {
162 "A dummy benchmark for testing"
163 }
164
165 fn run(&self, _db_path: &str) -> Result<BenchmarkResult, Box<dyn std::error::Error>> {
166 let mut metrics = HashMap::new();
167 metrics.insert("score".to_string(), 0.95);
168 metrics.insert("latency_ms".to_string(), 12.5);
169
170 Ok(BenchmarkResult {
171 name: self.name.clone(),
172 metrics,
173 duration_ms: 42,
174 timestamp: Utc::now().to_rfc3339(),
175 })
176 }
177 }
178
179 #[test]
180 fn test_suite_creation() {
181 let suite = BenchmarkSuite::new();
182 assert_eq!(suite.benchmarks.len(), 0);
183 }
184
185 #[test]
186 fn test_suite_add_and_run() {
187 let mut suite = BenchmarkSuite::new();
188 suite.add(Box::new(DummyBenchmark {
189 name: "test-bench".to_string(),
190 }));
191 assert_eq!(suite.benchmarks.len(), 1);
192
193 let results = suite.run_all(":memory:");
194 assert_eq!(results.len(), 1);
195 assert_eq!(results[0].name, "test-bench");
196 assert!(results[0].metrics.contains_key("score"));
197 assert_eq!(results[0].duration_ms, 42);
198 }
199
200 #[test]
201 fn test_report_json() {
202 let results = vec![BenchmarkResult {
203 name: "test".to_string(),
204 metrics: {
205 let mut m = HashMap::new();
206 m.insert("precision".to_string(), 0.85);
207 m
208 },
209 duration_ms: 100,
210 timestamp: "2026-01-01T00:00:00Z".to_string(),
211 }];
212
213 let json = BenchmarkSuite::report_json(&results);
214 assert!(json.contains("\"test\""));
215 assert!(json.contains("precision"));
216 assert!(json.contains("0.85"));
217 }
218
219 #[test]
220 fn test_report_markdown() {
221 let results = vec![BenchmarkResult {
222 name: "locomo".to_string(),
223 metrics: {
224 let mut m = HashMap::new();
225 m.insert("f1".to_string(), 0.72);
226 m
227 },
228 duration_ms: 200,
229 timestamp: "2026-01-01T00:00:00Z".to_string(),
230 }];
231
232 let md = BenchmarkSuite::report_markdown(&results);
233 assert!(md.contains("## locomo"));
234 assert!(md.contains("200ms"));
235 assert!(md.contains("f1"));
236 }
237
238 #[test]
239 fn test_report_csv() {
240 let results = vec![BenchmarkResult {
241 name: "membench".to_string(),
242 metrics: {
243 let mut m = HashMap::new();
244 m.insert("create_per_sec".to_string(), 500.0);
245 m
246 },
247 duration_ms: 150,
248 timestamp: "2026-01-01T00:00:00Z".to_string(),
249 }];
250
251 let csv = BenchmarkSuite::report_csv(&results);
252 assert!(csv.starts_with("benchmark,metric,value"));
253 assert!(csv.contains("membench"));
254 assert!(csv.contains("create_per_sec"));
255 }
256
257 #[test]
258 fn test_result_serialization() {
259 let result = BenchmarkResult {
260 name: "roundtrip".to_string(),
261 metrics: {
262 let mut m = HashMap::new();
263 m.insert("recall".to_string(), 0.9);
264 m
265 },
266 duration_ms: 77,
267 timestamp: "2026-01-01T00:00:00Z".to_string(),
268 };
269
270 let serialized = serde_json::to_string(&result).expect("should serialize");
271 let deserialized: BenchmarkResult =
272 serde_json::from_str(&serialized).expect("should deserialize");
273 assert_eq!(deserialized.name, result.name);
274 assert_eq!(deserialized.duration_ms, result.duration_ms);
275 assert!((deserialized.metrics["recall"] - 0.9).abs() < 1e-9);
276 }
277}