brainos-hippocampus 0.3.0

Episodic and semantic memory engine with hybrid search for Brain OS
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
//! Recall engine — hybrid search with RRF fusion.
//!
//! Combines episodic BM25 search and semantic vector search
//! using Reciprocal Rank Fusion (RRF), then applies importance
//! and recency reranking with a forgetting curve.

use std::collections::HashMap;

use crate::episodic::{EpisodicStore, FtsResult};
use crate::semantic::{SemanticResult, SemanticStore};

/// A unified memory result from the recall engine.
#[derive(Debug, Clone)]
pub struct Memory {
    pub id: String,
    pub content: String,
    pub source: MemorySource,
    pub score: f64,
    pub importance: f64,
    pub timestamp: String,
    /// Originating agent that stored this memory (if known).
    pub agent: Option<String>,
}

/// Where this memory came from.
#[derive(Debug, Clone, PartialEq)]
pub enum MemorySource {
    Episodic,
    Semantic,
}

/// Configuration for the recall engine.
#[derive(Debug, Clone)]
pub struct RecallConfig {
    /// RRF constant (default: 60).
    pub rrf_k: f64,
    /// How many candidates to fetch from each source before fusion.
    pub pre_fusion_limit: usize,
    /// Weight for importance in final reranking (0.0–1.0).
    pub importance_weight: f64,
    /// Weight for recency in final reranking (0.0–1.0).
    pub recency_weight: f64,
    /// Decay rate for the forgetting curve (higher = faster decay).
    pub decay_rate: f64,
    /// Minimum similarity score for semantic results (0.0–1.0).
    /// ANN results with similarity below this threshold are discarded before fusion.
    pub similarity_threshold: f64,
}

impl RecallConfig {
    /// Build from individual config values (avoids cross-crate dependency on brain_core).
    pub fn from_config(
        rrf_k: u32,
        pre_fusion_limit: u32,
        importance_weight: f64,
        recency_weight: f64,
        decay_rate: f64,
        similarity_threshold: f64,
    ) -> Self {
        Self {
            rrf_k: rrf_k as f64,
            pre_fusion_limit: pre_fusion_limit as usize,
            importance_weight,
            recency_weight,
            decay_rate,
            similarity_threshold,
        }
    }
}

impl Default for RecallConfig {
    fn default() -> Self {
        Self {
            rrf_k: 60.0,
            pre_fusion_limit: 50,
            importance_weight: 0.3,
            recency_weight: 0.2,
            decay_rate: 0.01,
            similarity_threshold: 0.65,
        }
    }
}

/// Reciprocal Rank Fusion (RRF) algorithm.
///
/// Given multiple ranked lists, produces a single fused ranking.
/// Score for item i = Σ (1 / (k + rank_i)) across all lists.
pub fn rrf_fuse(ranked_lists: &[Vec<(String, f64)>], k: f64) -> Vec<(String, f64)> {
    let mut scores: HashMap<String, f64> = HashMap::new();

    for list in ranked_lists {
        for (rank, (id, _original_score)) in list.iter().enumerate() {
            let rrf_score = 1.0 / (k + (rank as f64 + 1.0));
            *scores.entry(id.clone()).or_default() += rrf_score;
        }
    }

    let mut fused: Vec<(String, f64)> = scores.into_iter().collect();
    fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    fused
}

/// Calculate retention using a simplified forgetting curve.
///
/// `retention = importance * e^(-decay_rate * hours_since_access)`
pub fn forgetting_curve(importance: f64, hours_since_access: f64, decay_rate: f64) -> f64 {
    importance * (-decay_rate * hours_since_access).exp()
}

/// The recall engine orchestrates memory retrieval.
///
/// It queries both episodic (BM25) and semantic (vector) stores,
/// fuses results with RRF, and reranks by importance + recency.
pub struct RecallEngine {
    config: RecallConfig,
}

impl RecallEngine {
    pub fn new(config: RecallConfig) -> Self {
        Self { config }
    }

    pub fn with_defaults() -> Self {
        Self::new(RecallConfig::default())
    }

    /// Recall memories relevant to a query.
    ///
    /// Pipeline:
    /// 1. Query episodic store (BM25 full-text search)
    /// 2. Query semantic store (ANN vector search, optionally scoped to namespace)
    /// 3. Fuse with Reciprocal Rank Fusion (k=60)
    /// 4. Rerank by importance × recency (forgetting curve)
    /// 5. Return top_k results
    #[allow(clippy::too_many_arguments)]
    pub async fn recall(
        &self,
        query: &str,
        query_vector: Vec<f32>,
        episodic: &EpisodicStore,
        semantic: &SemanticStore,
        top_k: usize,
        namespace: Option<&str>,
        agent: Option<&str>,
    ) -> Result<Vec<Memory>, RecallError> {
        let limit = self.config.pre_fusion_limit;

        // 1. BM25 search on episodic store
        let bm25_results = episodic
            .search_bm25(query, limit, namespace, agent)
            .map_err(RecallError::Episodic)?;

        let bm25_ranked: Vec<(String, f64)> = bm25_results
            .iter()
            .map(|r| (r.episode_id.clone(), r.rank))
            .collect();

        // 2. ANN search on semantic store (filtered by namespace and/or agent if provided)
        let ann_results = semantic
            .search_similar(query_vector, limit, namespace, agent)
            .await
            .map_err(RecallError::Semantic)?;

        // Convert distance to similarity and filter by threshold.
        // distance is L2; similarity = 1/(1+d). Higher = more similar.
        let threshold = self.config.similarity_threshold;
        let ann_ranked: Vec<(String, f64)> = ann_results
            .iter()
            .map(|r| (r.fact.id.clone(), 1.0 / (1.0 + r.distance as f64)))
            .filter(|(_, sim)| *sim >= threshold)
            .collect();

        // 3. RRF fusion
        let fused = rrf_fuse(&[bm25_ranked, ann_ranked], self.config.rrf_k);

        // 4. Build lookup maps to avoid O(n*m) linear scans during reranking
        let bm25_map: HashMap<&str, &FtsResult> = bm25_results
            .iter()
            .map(|r| (r.episode_id.as_str(), r))
            .collect();
        let ann_map: HashMap<&str, &SemanticResult> = ann_results
            .iter()
            .map(|r| (r.fact.id.as_str(), r))
            .collect();

        // 5. Build Memory objects and rerank
        let now = chrono::Utc::now();
        let mut memories: Vec<Memory> = Vec::new();

        for (id, rrf_score) in &fused {
            // Try episodic first
            if let Some(fts) = bm25_map.get(id.as_str()) {
                let importance = fts.importance;
                let hours = parse_elapsed_hours(&fts.timestamp, &now);
                let retention = forgetting_curve(importance, hours, self.config.decay_rate);
                let final_score = rrf_score
                    + self.config.importance_weight * importance
                    + self.config.recency_weight * retention;

                memories.push(Memory {
                    id: id.clone(),
                    content: fts.content.clone(),
                    source: MemorySource::Episodic,
                    score: final_score,
                    importance,
                    timestamp: fts.timestamp.clone(),
                    agent: fts.agent.clone(),
                });
                continue;
            }

            // Try semantic
            if let Some(sr) = ann_map.get(id.as_str()) {
                let importance = sr.fact.confidence;
                let hours = parse_elapsed_hours(&sr.created_at, &now);
                let retention = forgetting_curve(importance, hours, self.config.decay_rate);
                let final_score = rrf_score
                    + self.config.importance_weight * importance
                    + self.config.recency_weight * retention;

                let content = format!(
                    "{} {} {}",
                    sr.fact.subject, sr.fact.predicate, sr.fact.object
                );

                memories.push(Memory {
                    id: id.clone(),
                    content,
                    source: MemorySource::Semantic,
                    score: final_score,
                    importance,
                    timestamp: sr.created_at.clone(),
                    agent: sr.fact.agent.clone(),
                });
            }
        }

        // Sort by final score descending
        memories.sort_by(|a, b| {
            b.score
                .partial_cmp(&a.score)
                .unwrap_or(std::cmp::Ordering::Equal)
        });
        memories.truncate(top_k);

        Ok(memories)
    }
}

/// Parse an ISO 8601 or SQLite datetime string and return hours elapsed since `now`.
///
/// Falls back to 1.0 hour if parsing fails (e.g., empty or malformed timestamp).
/// A fallback of 1.0h applies mild decay without artificially boosting (0.0) or
/// aggressively penalizing the memory. Logs a warning so serialization bugs are visible.
fn parse_elapsed_hours(timestamp: &str, now: &chrono::DateTime<chrono::Utc>) -> f64 {
    if timestamp.is_empty() {
        tracing::warn!("Empty timestamp in recall — using 1.0h fallback");
        return 1.0;
    }
    // Try RFC 3339 first (e.g. "2025-03-01T12:00:00+00:00")
    if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(timestamp) {
        let elapsed = *now - dt.with_timezone(&chrono::Utc);
        return (elapsed.num_seconds() as f64 / 3600.0).max(0.01);
    }
    // Try SQLite datetime format (e.g. "2025-03-01 12:00:00")
    if let Ok(naive) = chrono::NaiveDateTime::parse_from_str(timestamp, "%Y-%m-%d %H:%M:%S") {
        let dt = naive.and_utc();
        let elapsed = *now - dt;
        return (elapsed.num_seconds() as f64 / 3600.0).max(0.01);
    }
    tracing::warn!(
        timestamp,
        "Unparseable timestamp in recall — using 1.0h fallback"
    );
    1.0 // fallback
}

/// Errors from the recall engine.
#[derive(Debug, thiserror::Error)]
pub enum RecallError {
    #[error("Episodic search failed: {0}")]
    Episodic(crate::episodic::EpisodicError),

    #[error("Semantic search failed: {0}")]
    Semantic(crate::semantic::SemanticError),
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_rrf_single_list() {
        let lists = vec![vec![
            ("a".to_string(), 10.0),
            ("b".to_string(), 5.0),
            ("c".to_string(), 1.0),
        ]];

        let fused = rrf_fuse(&lists, 60.0);
        assert_eq!(fused[0].0, "a");
        assert_eq!(fused[1].0, "b");
        assert_eq!(fused[2].0, "c");

        // rank 1: 1/(60+1) ≈ 0.01639
        assert!((fused[0].1 - 1.0 / 61.0).abs() < 1e-6);
    }

    #[test]
    fn test_rrf_two_lists() {
        let lists = vec![
            vec![("a".to_string(), 10.0), ("b".to_string(), 5.0)],
            vec![("b".to_string(), 10.0), ("a".to_string(), 5.0)],
        ];

        let fused = rrf_fuse(&lists, 60.0);

        // Both a and b appear at rank 1 and rank 2 in different lists
        // Both should have score 1/61 + 1/62
        assert_eq!(fused.len(), 2);
        let score_a = fused.iter().find(|(id, _)| id == "a").unwrap().1;
        let score_b = fused.iter().find(|(id, _)| id == "b").unwrap().1;
        assert!((score_a - score_b).abs() < 1e-10);
    }

    #[test]
    fn test_rrf_disjoint_lists() {
        let lists = vec![vec![("a".to_string(), 10.0)], vec![("b".to_string(), 10.0)]];

        let fused = rrf_fuse(&lists, 60.0);
        assert_eq!(fused.len(), 2);
        // Both at rank 1 in their respective lists
        let score_a = fused.iter().find(|(id, _)| id == "a").unwrap().1;
        let score_b = fused.iter().find(|(id, _)| id == "b").unwrap().1;
        assert!((score_a - score_b).abs() < 1e-10);
    }

    #[test]
    fn test_rrf_overlap_boost() {
        let lists = vec![
            vec![
                ("a".to_string(), 10.0),
                ("b".to_string(), 5.0),
                ("c".to_string(), 1.0),
            ],
            vec![("a".to_string(), 10.0), ("c".to_string(), 5.0)],
        ];

        let fused = rrf_fuse(&lists, 60.0);

        // 'a' appears at rank 1 in both lists → highest score
        assert_eq!(fused[0].0, "a");

        // 'c' appears in both lists (rank 3 + rank 2) → higher than 'b' (rank 2 only)
        let score_b = fused.iter().find(|(id, _)| id == "b").unwrap().1;
        let score_c = fused.iter().find(|(id, _)| id == "c").unwrap().1;
        assert!(score_c > score_b, "c (in both) should rank > b (in one)");
    }

    #[test]
    fn test_forgetting_curve_no_decay() {
        let retention = forgetting_curve(1.0, 0.0, 0.01);
        assert!((retention - 1.0).abs() < 1e-6);
    }

    #[test]
    fn test_forgetting_curve_decay() {
        let retention_1h = forgetting_curve(1.0, 1.0, 0.01);
        let retention_24h = forgetting_curve(1.0, 24.0, 0.01);
        let retention_168h = forgetting_curve(1.0, 168.0, 0.01); // 1 week

        // Retention should decrease over time
        assert!(retention_1h > retention_24h);
        assert!(retention_24h > retention_168h);

        // High importance slows decay
        let retention_high = forgetting_curve(1.0, 24.0, 0.01);
        let retention_low = forgetting_curve(0.5, 24.0, 0.01);
        assert!(retention_high > retention_low);
    }

    #[test]
    fn test_forgetting_curve_importance_scaling() {
        let ret_a = forgetting_curve(1.0, 10.0, 0.01);
        let ret_b = forgetting_curve(0.5, 10.0, 0.01);
        // ret_a should be exactly 2x ret_b (linear in importance)
        assert!((ret_a / ret_b - 2.0).abs() < 1e-6);
    }

    #[test]
    fn test_rrf_empty_lists() {
        let fused = rrf_fuse(&[], 60.0);
        assert!(fused.is_empty());

        let fused2 = rrf_fuse(&[vec![]], 60.0);
        assert!(fused2.is_empty());
    }

    #[test]
    fn test_recall_config_defaults() {
        let config = RecallConfig::default();
        assert_eq!(config.rrf_k, 60.0);
        assert_eq!(config.pre_fusion_limit, 50);
        assert!((config.importance_weight - 0.3).abs() < 1e-6);
        assert!((config.recency_weight - 0.2).abs() < 1e-6);
    }
}