1use std::collections::HashMap;
7use std::time::Instant;
8
9use super::{Benchmark, BenchmarkResult};
10use crate::storage::queries::create_memory;
11use crate::storage::Storage;
12use crate::types::{CreateMemoryInput, MemoryType, StorageConfig, StorageMode};
13
14pub struct LocomoBenchmark {
16 pub num_conversations: usize,
18 pub queries_per_conversation: usize,
20}
21
22struct SyntheticConversation {
24 session_id: usize,
25 facts: Vec<String>,
26 queries: Vec<ConversationQuery>,
27}
28
29struct ConversationQuery {
31 expected_fact_indices: Vec<usize>,
32}
33
34impl LocomoBenchmark {
35 fn generate_conversations(&self) -> Vec<SyntheticConversation> {
36 let templates = [
37 (
38 "Alice works at Acme Corp as a software engineer",
39 "Bob is studying machine learning at MIT",
40 "Carol prefers Python over Rust for scripting",
41 ),
42 (
43 "David's favorite color is blue and he lives in London",
44 "Eve is learning Japanese and visits Tokyo each year",
45 "Frank is allergic to peanuts and avoids Thai food",
46 ),
47 (
48 "Grace runs marathons every spring in Boston",
49 "Henry has two cats named Luna and Mochi",
50 "Iris is a vegetarian who loves Italian cuisine",
51 ),
52 (
53 "Jack recently moved from New York to San Francisco",
54 "Karen plays the piano and violin professionally",
55 "Leo is a night owl who does his best work after midnight",
56 ),
57 (
58 "Mia has a PhD in quantum computing from Caltech",
59 "Noah volunteers at the local animal shelter every weekend",
60 "Olivia runs a small bakery specializing in sourdough bread",
61 ),
62 ];
63
64 (0..self.num_conversations)
65 .map(|i| {
66 let tpl = &templates[i % templates.len()];
67
68 let facts = vec![
69 format!("Session {}: {}", i, tpl.0),
70 format!("Session {}: {}", i, tpl.1),
71 format!("Session {}: {}", i, tpl.2),
72 ];
73
74 let num_queries = self.queries_per_conversation.min(facts.len());
75 let queries = (0..num_queries)
76 .map(|fi| ConversationQuery {
77 expected_fact_indices: vec![fi],
78 })
79 .collect();
80
81 SyntheticConversation {
82 session_id: i,
83 facts,
84 queries,
85 }
86 })
87 .collect()
88 }
89}
90
91impl Benchmark for LocomoBenchmark {
92 fn name(&self) -> &str {
93 "locomo"
94 }
95
96 fn description(&self) -> &str {
97 "Multi-session conversation memory benchmark. Measures precision, recall, and F1 \
98 for retrieving facts stored across multiple synthetic conversation sessions."
99 }
100
101 fn run(&self, db_path: &str) -> Result<BenchmarkResult, Box<dyn std::error::Error>> {
102 let start = Instant::now();
103
104 let storage = if db_path == ":memory:" {
106 Storage::open_in_memory()?
107 } else {
108 let bench_path = format!("{}.locomo_bench.db", db_path);
109 Storage::open(StorageConfig {
110 db_path: bench_path,
111 storage_mode: StorageMode::Local,
112 cloud_uri: None,
113 encrypt_cloud: false,
114 confidence_half_life_days: 30.0,
115 auto_sync: false,
116 sync_debounce_ms: 5000,
117 })?
118 };
119
120 let conversations = self.generate_conversations();
122 let mut memory_ids: Vec<Vec<i64>> = Vec::new();
123
124 for conv in &conversations {
125 let mut ids = Vec::new();
126 for fact in &conv.facts {
127 let content = fact.clone();
128 let session_tag = format!("session:{}", conv.session_id);
129 let mem = storage.with_connection(|conn| {
130 create_memory(
131 conn,
132 &CreateMemoryInput {
133 content,
134 memory_type: MemoryType::Episodic,
135 tags: vec![session_tag],
136 workspace: Some("locomo-bench".to_string()),
137 ..Default::default()
138 },
139 )
140 })?;
141 ids.push(mem.id);
142 }
143 memory_ids.push(ids);
144 }
145
146 let mut true_positives = 0usize;
148 let mut false_positives = 0usize;
149 let mut false_negatives = 0usize;
150
151 for (conv_idx, conv) in conversations.iter().enumerate() {
152 let conv_ids = &memory_ids[conv_idx];
153
154 for query in &conv.queries {
155 let keyword = format!("%Session {}%", conv.session_id);
157 let retrieved_ids: Vec<i64> = storage.with_connection(|conn| {
158 let mut stmt =
159 conn.prepare("SELECT id FROM memories WHERE content LIKE ?1 LIMIT 10")?;
160 let ids: Vec<i64> = stmt
161 .query_map([&keyword], |row| row.get(0))?
162 .filter_map(|r| r.ok())
163 .collect();
164 Ok(ids)
165 })?;
166
167 let expected_ids: Vec<i64> = query
169 .expected_fact_indices
170 .iter()
171 .filter_map(|&fi| conv_ids.get(fi).copied())
172 .collect();
173
174 for &rid in &retrieved_ids {
175 if expected_ids.contains(&rid) {
176 true_positives += 1;
177 } else {
178 false_positives += 1;
179 }
180 }
181
182 for &eid in &expected_ids {
183 if !retrieved_ids.contains(&eid) {
184 false_negatives += 1;
185 }
186 }
187 }
188 }
189
190 let precision = if true_positives + false_positives > 0 {
192 true_positives as f64 / (true_positives + false_positives) as f64
193 } else {
194 0.0
195 };
196
197 let recall = if true_positives + false_negatives > 0 {
198 true_positives as f64 / (true_positives + false_negatives) as f64
199 } else {
200 0.0
201 };
202
203 let f1 = if precision + recall > 0.0 {
204 2.0 * precision * recall / (precision + recall)
205 } else {
206 0.0
207 };
208
209 let duration_ms = start.elapsed().as_millis() as u64;
210
211 let mut metrics = HashMap::new();
212 metrics.insert("precision".to_string(), precision);
213 metrics.insert("recall".to_string(), recall);
214 metrics.insert("f1".to_string(), f1);
215 metrics.insert(
216 "num_conversations".to_string(),
217 self.num_conversations as f64,
218 );
219 metrics.insert(
220 "queries_per_conversation".to_string(),
221 self.queries_per_conversation as f64,
222 );
223 metrics.insert("true_positives".to_string(), true_positives as f64);
224 metrics.insert("false_positives".to_string(), false_positives as f64);
225 metrics.insert("false_negatives".to_string(), false_negatives as f64);
226
227 if db_path != ":memory:" {
229 let bench_path = format!("{}.locomo_bench.db", db_path);
230 drop(storage);
232 let _ = std::fs::remove_file(&bench_path);
233 let _ = std::fs::remove_file(format!("{}-wal", bench_path));
234 let _ = std::fs::remove_file(format!("{}-shm", bench_path));
235 }
236
237 Ok(BenchmarkResult {
238 name: self.name().to_string(),
239 metrics,
240 duration_ms,
241 timestamp: chrono::Utc::now().to_rfc3339(),
242 })
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_locomo_runs_in_memory() {
252 let bench = LocomoBenchmark {
253 num_conversations: 3,
254 queries_per_conversation: 2,
255 };
256 let result = bench.run(":memory:").expect("benchmark should succeed");
257 assert_eq!(result.name, "locomo");
258 assert!(result.metrics.contains_key("precision"));
259 assert!(result.metrics.contains_key("recall"));
260 assert!(result.metrics.contains_key("f1"));
261 }
262
263 #[test]
264 fn test_locomo_metrics_range() {
265 let bench = LocomoBenchmark {
266 num_conversations: 2,
267 queries_per_conversation: 1,
268 };
269 let result = bench.run(":memory:").expect("benchmark should succeed");
270 let precision = result.metrics["precision"];
271 let recall = result.metrics["recall"];
272 let f1 = result.metrics["f1"];
273
274 assert!(
275 (0.0..=1.0).contains(&precision),
276 "precision out of range: {}",
277 precision
278 );
279 assert!(
280 (0.0..=1.0).contains(&recall),
281 "recall out of range: {}",
282 recall
283 );
284 assert!((0.0..=1.0).contains(&f1), "f1 out of range: {}", f1);
285 }
286
287 #[test]
288 fn test_locomo_generates_correct_conversation_count() {
289 let bench = LocomoBenchmark {
290 num_conversations: 5,
291 queries_per_conversation: 2,
292 };
293 let conversations = bench.generate_conversations();
294 assert_eq!(conversations.len(), 5);
295 for conv in &conversations {
296 assert_eq!(conv.queries.len(), 2);
297 assert_eq!(conv.facts.len(), 3);
298 }
299 }
300}