Skip to main content

engram/bench/
locomo.rs

1//! LOCOMO benchmark — Multi-session conversation memory
2//!
3//! Measures precision, recall, and F1 for retrieving facts that were
4//! discussed across multiple synthetic conversation sessions.
5
6use std::collections::HashMap;
7use std::time::Instant;
8
9use super::{Benchmark, BenchmarkResult};
10use crate::storage::queries::create_memory;
11use crate::storage::Storage;
12use crate::types::{CreateMemoryInput, MemoryType, StorageConfig, StorageMode};
13
14/// LOCOMO benchmark configuration
15pub struct LocomoBenchmark {
16    /// Number of synthetic conversations to generate
17    pub num_conversations: usize,
18    /// Number of recall queries per conversation
19    pub queries_per_conversation: usize,
20}
21
22/// A synthetic conversation with ground-truth answers
23struct SyntheticConversation {
24    session_id: usize,
25    facts: Vec<String>,
26    queries: Vec<ConversationQuery>,
27}
28
29/// A recall query with expected relevant fact indices (into `SyntheticConversation::facts`)
30struct ConversationQuery {
31    expected_fact_indices: Vec<usize>,
32}
33
34impl LocomoBenchmark {
35    fn generate_conversations(&self) -> Vec<SyntheticConversation> {
36        let templates = [
37            (
38                "Alice works at Acme Corp as a software engineer",
39                "Bob is studying machine learning at MIT",
40                "Carol prefers Python over Rust for scripting",
41            ),
42            (
43                "David's favorite color is blue and he lives in London",
44                "Eve is learning Japanese and visits Tokyo each year",
45                "Frank is allergic to peanuts and avoids Thai food",
46            ),
47            (
48                "Grace runs marathons every spring in Boston",
49                "Henry has two cats named Luna and Mochi",
50                "Iris is a vegetarian who loves Italian cuisine",
51            ),
52            (
53                "Jack recently moved from New York to San Francisco",
54                "Karen plays the piano and violin professionally",
55                "Leo is a night owl who does his best work after midnight",
56            ),
57            (
58                "Mia has a PhD in quantum computing from Caltech",
59                "Noah volunteers at the local animal shelter every weekend",
60                "Olivia runs a small bakery specializing in sourdough bread",
61            ),
62        ];
63
64        (0..self.num_conversations)
65            .map(|i| {
66                let tpl = &templates[i % templates.len()];
67
68                let facts = vec![
69                    format!("Session {}: {}", i, tpl.0),
70                    format!("Session {}: {}", i, tpl.1),
71                    format!("Session {}: {}", i, tpl.2),
72                ];
73
74                let num_queries = self.queries_per_conversation.min(facts.len());
75                let queries = (0..num_queries)
76                    .map(|fi| ConversationQuery {
77                        expected_fact_indices: vec![fi],
78                    })
79                    .collect();
80
81                SyntheticConversation {
82                    session_id: i,
83                    facts,
84                    queries,
85                }
86            })
87            .collect()
88    }
89}
90
91impl Benchmark for LocomoBenchmark {
92    fn name(&self) -> &str {
93        "locomo"
94    }
95
96    fn description(&self) -> &str {
97        "Multi-session conversation memory benchmark. Measures precision, recall, and F1 \
98         for retrieving facts stored across multiple synthetic conversation sessions."
99    }
100
101    fn run(&self, db_path: &str) -> Result<BenchmarkResult, Box<dyn std::error::Error>> {
102        let start = Instant::now();
103
104        // Open an isolated Storage
105        let storage = if db_path == ":memory:" {
106            Storage::open_in_memory()?
107        } else {
108            let bench_path = format!("{}.locomo_bench.db", db_path);
109            Storage::open(StorageConfig {
110                db_path: bench_path,
111                storage_mode: StorageMode::Local,
112                cloud_uri: None,
113                encrypt_cloud: false,
114                confidence_half_life_days: 30.0,
115                auto_sync: false,
116                sync_debounce_ms: 5000,
117            })?
118        };
119
120        // Phase 1: Index synthetic conversations
121        let conversations = self.generate_conversations();
122        let mut memory_ids: Vec<Vec<i64>> = Vec::new();
123
124        for conv in &conversations {
125            let mut ids = Vec::new();
126            for fact in &conv.facts {
127                let content = fact.clone();
128                let session_tag = format!("session:{}", conv.session_id);
129                let mem = storage.with_connection(|conn| {
130                    create_memory(
131                        conn,
132                        &CreateMemoryInput {
133                            content,
134                            memory_type: MemoryType::Episodic,
135                            tags: vec![session_tag],
136                            workspace: Some("locomo-bench".to_string()),
137                            ..Default::default()
138                        },
139                    )
140                })?;
141                ids.push(mem.id);
142            }
143            memory_ids.push(ids);
144        }
145
146        // Phase 2: Run recall queries using session-scoped LIKE search
147        let mut true_positives = 0usize;
148        let mut false_positives = 0usize;
149        let mut false_negatives = 0usize;
150
151        for (conv_idx, conv) in conversations.iter().enumerate() {
152            let conv_ids = &memory_ids[conv_idx];
153
154            for query in &conv.queries {
155                // Retrieve all memories for this session
156                let keyword = format!("%Session {}%", conv.session_id);
157                let retrieved_ids: Vec<i64> = storage.with_connection(|conn| {
158                    let mut stmt =
159                        conn.prepare("SELECT id FROM memories WHERE content LIKE ?1 LIMIT 10")?;
160                    let ids: Vec<i64> = stmt
161                        .query_map([&keyword], |row| row.get(0))?
162                        .filter_map(|r| r.ok())
163                        .collect();
164                    Ok(ids)
165                })?;
166
167                // Compute expected IDs
168                let expected_ids: Vec<i64> = query
169                    .expected_fact_indices
170                    .iter()
171                    .filter_map(|&fi| conv_ids.get(fi).copied())
172                    .collect();
173
174                for &rid in &retrieved_ids {
175                    if expected_ids.contains(&rid) {
176                        true_positives += 1;
177                    } else {
178                        false_positives += 1;
179                    }
180                }
181
182                for &eid in &expected_ids {
183                    if !retrieved_ids.contains(&eid) {
184                        false_negatives += 1;
185                    }
186                }
187            }
188        }
189
190        // Compute precision, recall, F1
191        let precision = if true_positives + false_positives > 0 {
192            true_positives as f64 / (true_positives + false_positives) as f64
193        } else {
194            0.0
195        };
196
197        let recall = if true_positives + false_negatives > 0 {
198            true_positives as f64 / (true_positives + false_negatives) as f64
199        } else {
200            0.0
201        };
202
203        let f1 = if precision + recall > 0.0 {
204            2.0 * precision * recall / (precision + recall)
205        } else {
206            0.0
207        };
208
209        let duration_ms = start.elapsed().as_millis() as u64;
210
211        let mut metrics = HashMap::new();
212        metrics.insert("precision".to_string(), precision);
213        metrics.insert("recall".to_string(), recall);
214        metrics.insert("f1".to_string(), f1);
215        metrics.insert(
216            "num_conversations".to_string(),
217            self.num_conversations as f64,
218        );
219        metrics.insert(
220            "queries_per_conversation".to_string(),
221            self.queries_per_conversation as f64,
222        );
223        metrics.insert("true_positives".to_string(), true_positives as f64);
224        metrics.insert("false_positives".to_string(), false_positives as f64);
225        metrics.insert("false_negatives".to_string(), false_negatives as f64);
226
227        // Clean up temporary database file if not in-memory
228        if db_path != ":memory:" {
229            let bench_path = format!("{}.locomo_bench.db", db_path);
230            // Drop storage first to release file handles
231            drop(storage);
232            let _ = std::fs::remove_file(&bench_path);
233            let _ = std::fs::remove_file(format!("{}-wal", bench_path));
234            let _ = std::fs::remove_file(format!("{}-shm", bench_path));
235        }
236
237        Ok(BenchmarkResult {
238            name: self.name().to_string(),
239            metrics,
240            duration_ms,
241            timestamp: chrono::Utc::now().to_rfc3339(),
242        })
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_locomo_runs_in_memory() {
252        let bench = LocomoBenchmark {
253            num_conversations: 3,
254            queries_per_conversation: 2,
255        };
256        let result = bench.run(":memory:").expect("benchmark should succeed");
257        assert_eq!(result.name, "locomo");
258        assert!(result.metrics.contains_key("precision"));
259        assert!(result.metrics.contains_key("recall"));
260        assert!(result.metrics.contains_key("f1"));
261    }
262
263    #[test]
264    fn test_locomo_metrics_range() {
265        let bench = LocomoBenchmark {
266            num_conversations: 2,
267            queries_per_conversation: 1,
268        };
269        let result = bench.run(":memory:").expect("benchmark should succeed");
270        let precision = result.metrics["precision"];
271        let recall = result.metrics["recall"];
272        let f1 = result.metrics["f1"];
273
274        assert!(
275            (0.0..=1.0).contains(&precision),
276            "precision out of range: {}",
277            precision
278        );
279        assert!(
280            (0.0..=1.0).contains(&recall),
281            "recall out of range: {}",
282            recall
283        );
284        assert!((0.0..=1.0).contains(&f1), "f1 out of range: {}", f1);
285    }
286
287    #[test]
288    fn test_locomo_generates_correct_conversation_count() {
289        let bench = LocomoBenchmark {
290            num_conversations: 5,
291            queries_per_conversation: 2,
292        };
293        let conversations = bench.generate_conversations();
294        assert_eq!(conversations.len(), 5);
295        for conv in &conversations {
296            assert_eq!(conv.queries.len(), 2);
297            assert_eq!(conv.facts.len(), 3);
298        }
299    }
300}