Skip to main content

heartbit_core/memory/
mod.rs

1//! Agent memory system — `Memory` trait, in-memory and PostgreSQL stores, BM25 and vector recall, Ebbinghaus decay, reflection, and consolidation.
2
3#![allow(missing_docs)]
4pub mod bm25;
5pub mod consolidation;
6pub mod embedding;
7pub mod hybrid;
8pub mod in_memory;
9pub mod namespaced;
10pub mod pruning;
11pub mod reflection;
12pub mod scoring;
13pub mod shared_tools;
14pub mod tools;
15
16use std::future::Future;
17use std::pin::Pin;
18
19use chrono::{DateTime, Utc};
20use serde::{Deserialize, Serialize};
21
22use crate::auth::TenantScope;
23use crate::error::Error;
24
25/// Classification of a memory entry's origin and purpose.
26#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
27#[serde(rename_all = "snake_case")]
28pub enum MemoryType {
29    /// Direct experience or observation from an agent run.
30    #[default]
31    Episodic,
32    /// Generalized knowledge derived from consolidation or reflection.
33    Semantic,
34    /// Higher-order insight generated by reflecting on episodic memories.
35    Reflection,
36}
37
38/// Access classification for memory entries.
39///
40/// Ordered from least to most sensitive — `PartialOrd`/`Ord` derives use
41/// variant declaration order, so `Public < Internal < Confidential < Restricted`.
42#[derive(
43    Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Default,
44)]
45#[serde(rename_all = "snake_case")]
46pub enum Confidentiality {
47    /// Shareable with anyone (public facts, general knowledge).
48    #[default]
49    Public,
50    /// Internal context (work items, project details). Shareable with Verified+ senders.
51    Internal,
52    /// Personal/sensitive (expenses, health, private conversations). Owner only.
53    Confidential,
54    /// Secrets (API keys, passwords, tokens). Never included in LLM context.
55    Restricted,
56}
57
58/// A single memory entry stored by an agent.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct MemoryEntry {
61    pub id: String,
62    pub agent: String,
63    pub content: String,
64    pub category: String,
65    pub tags: Vec<String>,
66    pub created_at: DateTime<Utc>,
67    pub last_accessed: DateTime<Utc>,
68    pub access_count: u32,
69    /// Importance score (1-10). Default: 5. Set by agent at store time.
70    #[serde(default = "default_importance")]
71    pub importance: u8,
72    /// Classification of memory origin (episodic, semantic, reflection).
73    #[serde(default)]
74    pub memory_type: MemoryType,
75    /// LLM-generated keywords for improved retrieval.
76    #[serde(default)]
77    pub keywords: Vec<String>,
78    /// One-sentence summary providing context for the memory content.
79    #[serde(default)]
80    pub summary: Option<String>,
81    /// Ebbinghaus strength score. Starts at 1.0, decays over time,
82    /// reinforced on access. Entries with low strength may be pruned.
83    ///
84    /// `Memory::store` preserves whatever value the caller supplies — no
85    /// normalisation or clamping at insert time. `Memory::recall` reinforces
86    /// the stored value by `+0.2` per access (capped at 1.0) unless the
87    /// caller opts out via [`MemoryQuery::reinforce`] = `false`. Decay is
88    /// applied lazily at read time via `effective_strength`.
89    #[serde(default = "default_strength")]
90    pub strength: f64,
91    /// Bidirectional links to related memory entries.
92    #[serde(default)]
93    pub related_ids: Vec<String>,
94    /// IDs of source entries that were consolidated into this one.
95    #[serde(default)]
96    pub source_ids: Vec<String>,
97    /// Optional vector embedding for semantic search (hybrid retrieval).
98    #[serde(default, skip_serializing_if = "Option::is_none")]
99    pub embedding: Option<Vec<f32>>,
100    /// Access classification — controls which trust levels may read this entry.
101    #[serde(default)]
102    pub confidentiality: Confidentiality,
103    /// User ID of the agent/user who authored this entry (multi-tenant authorship).
104    #[serde(default, skip_serializing_if = "Option::is_none")]
105    pub author_user_id: Option<String>,
106    /// Tenant ID of the agent/user who authored this entry (multi-tenant authorship).
107    #[serde(default, skip_serializing_if = "Option::is_none")]
108    pub author_tenant_id: Option<String>,
109}
110
111pub(crate) fn default_importance() -> u8 {
112    5
113}
114
115pub(crate) fn default_strength() -> f64 {
116    1.0
117}
118
119pub(crate) fn default_category() -> String {
120    "fact".into()
121}
122
123pub(crate) fn default_recall_limit() -> usize {
124    10
125}
126
127/// Query parameters for recalling memories.
128///
129/// `limit` controls the maximum number of results returned. A value of `0`
130/// means no limit (return all matching entries). This is the default.
131#[derive(Debug, Clone)]
132pub struct MemoryQuery {
133    pub text: Option<String>,
134    pub category: Option<String>,
135    pub tags: Vec<String>,
136    pub agent: Option<String>,
137    /// Filter entries whose `agent` field starts with this prefix.
138    /// Useful for cross-agent recall within a user namespace (e.g. `"tg:123"`
139    /// matches `"tg:123:assistant"`, `"tg:123:researcher"`, etc.).
140    /// Mutually exclusive with `agent` — if both are set, `agent` takes precedence.
141    pub agent_prefix: Option<String>,
142    /// Maximum number of results. `0` means unlimited.
143    pub limit: usize,
144    /// Filter by memory type.
145    pub memory_type: Option<MemoryType>,
146    /// Minimum strength threshold. Entries below this are excluded.
147    pub min_strength: Option<f64>,
148    /// Optional query embedding for hybrid (BM25 + vector) retrieval.
149    /// When present and entries have stored embeddings, cosine similarity
150    /// is computed and fused with BM25 via Reciprocal Rank Fusion.
151    /// Populated automatically by `EmbeddingMemory::recall()`.
152    pub query_embedding: Option<Vec<f32>>,
153    /// When set, recall excludes entries with confidentiality above this level.
154    /// `None` means no restriction (all levels returned).
155    pub max_confidentiality: Option<Confidentiality>,
156    /// Whether to reinforce the `strength` of returned entries on this read
157    /// (Ebbinghaus reinforcement, +0.2 per access, capped at 1.0).
158    ///
159    /// Defaults to `true` to preserve historical recall semantics. Set to
160    /// `false` for a pure read — useful when surfacing strength to a UI,
161    /// driving deterministic decay tests, or letting `prune_weak_entries`
162    /// observe a freshly-stored low-strength entry without first promoting
163    /// it above the prune threshold. `last_accessed` and `access_count`
164    /// are still updated regardless.
165    pub reinforce: bool,
166
167    /// Opt in to **exact-word** text matching instead of the default
168    /// substring (`word.contains(token)`) semantics.
169    ///
170    /// When `true`, `InMemoryStore::recall` short-circuits to entries
171    /// whose lowercased content / keyword tokens **exactly** equal at
172    /// least one query token, looked up via the in-memory inverted
173    /// index built at store time. Estimated gain at N=10k entries:
174    /// 12.69 ms → 1–3 ms text-query recall (Phase 8 in
175    /// `tasks/perf-audit-v2-2026-05-07.md`).
176    ///
177    /// Trade-off: queries whose tokens are *prefixes / substrings*
178    /// of indexed words ("perf" matching "performance") will no
179    /// longer match. Default is `false` — substring semantics
180    /// preserved; opt-in when callers know their queries are full
181    /// words. The Postgres path ignores this flag (it doesn't
182    /// implement substring matching the same way).
183    pub exact_words: bool,
184}
185
186impl Default for MemoryQuery {
187    fn default() -> Self {
188        Self {
189            text: None,
190            category: None,
191            tags: Vec::new(),
192            agent: None,
193            agent_prefix: None,
194            limit: 0,
195            memory_type: None,
196            min_strength: None,
197            query_embedding: None,
198            max_confidentiality: None,
199            reinforce: true,
200            exact_words: false,
201        }
202    }
203}
204
205/// Trait for persistent memory stores.
206///
207/// Every method requires a `&TenantScope` as the first parameter so the
208/// compiler rejects code that accidentally drops tenant context. Single-tenant
209/// deployments pass `&TenantScope::default()` (the empty-string sentinel).
210///
211/// Uses `Pin<Box<dyn Future>>` for dyn-compatibility, matching the `Tool` trait pattern.
212///
213/// # Example
214///
215/// Recalling memory entries with the in-memory backend:
216///
217/// ```rust,no_run
218/// use heartbit_core::auth::TenantScope;
219/// use heartbit_core::{InMemoryStore, Memory, MemoryQuery};
220///
221/// # async fn run() -> Result<(), heartbit_core::Error> {
222/// let store = InMemoryStore::new();
223/// let scope = TenantScope::default();
224/// let hits = store
225///     .recall(&scope, MemoryQuery {
226///         agent: Some("assistant".into()),
227///         text: Some("preferences".into()),
228///         ..MemoryQuery::default()
229///     })
230///     .await?;
231/// for entry in hits {
232///     println!("{}: {}", entry.id, entry.content);
233/// }
234/// # Ok(()) }
235/// ```
236pub trait Memory: Send + Sync {
237    /// Persist `entry` under `scope`.
238    ///
239    /// The caller-supplied [`MemoryEntry::strength`] is preserved verbatim;
240    /// implementations must not normalise or clamp it at insert time.
241    /// Reinforcement happens on read via [`Memory::recall`], not on write.
242    fn store(
243        &self,
244        scope: &TenantScope,
245        entry: MemoryEntry,
246    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>>;
247
248    /// Recall entries matching `query` from `scope`.
249    ///
250    /// By default, recall reinforces the `strength` of returned entries
251    /// (Ebbinghaus reinforcement, +0.2 per access, capped at 1.0). To
252    /// observe strength without modifying it, set
253    /// [`MemoryQuery::reinforce`] to `false`.
254    fn recall(
255        &self,
256        scope: &TenantScope,
257        query: MemoryQuery,
258    ) -> Pin<Box<dyn Future<Output = Result<Vec<MemoryEntry>, Error>> + Send + '_>>;
259
260    fn update(
261        &self,
262        scope: &TenantScope,
263        id: &str,
264        content: String,
265    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>>;
266
267    fn forget(
268        &self,
269        scope: &TenantScope,
270        id: &str,
271    ) -> Pin<Box<dyn Future<Output = Result<bool, Error>> + Send + '_>>;
272
273    /// Add a bidirectional link between two memory entries.
274    /// Default implementation is a no-op for backward compatibility.
275    fn add_link(
276        &self,
277        _scope: &TenantScope,
278        _id: &str,
279        _related_id: &str,
280    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
281        Box::pin(async { Ok(()) })
282    }
283
284    /// Remove entries whose strength has decayed below `min_strength`
285    /// and are older than `min_age`.
286    ///
287    /// When `agent_prefix` is `Some`, only entries whose `agent` field
288    /// starts with the given prefix are candidates for pruning. This
289    /// enables namespace-scoped pruning in multi-tenant setups where
290    /// `NamespacedMemory` must not delete entries from other namespaces.
291    ///
292    /// Returns the number of entries pruned.
293    /// Default implementation is a no-op for backward compatibility.
294    fn prune(
295        &self,
296        _scope: &TenantScope,
297        _min_strength: f64,
298        _min_age: chrono::Duration,
299        _agent_prefix: Option<&str>,
300    ) -> Pin<Box<dyn Future<Output = Result<usize, Error>> + Send + '_>> {
301        Box::pin(async { Ok(0) })
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    fn make_entry(id: &str, content: &str) -> MemoryEntry {
310        MemoryEntry {
311            id: id.into(),
312            agent: "a".into(),
313            content: content.into(),
314            category: "fact".into(),
315            tags: vec![],
316            created_at: Utc::now(),
317            last_accessed: Utc::now(),
318            access_count: 0,
319            importance: 5,
320            memory_type: MemoryType::default(),
321            keywords: vec![],
322            summary: None,
323            strength: 1.0,
324            related_ids: vec![],
325            source_ids: vec![],
326            embedding: None,
327            confidentiality: Confidentiality::default(),
328            author_user_id: None,
329            author_tenant_id: None,
330        }
331    }
332
333    #[test]
334    fn memory_entry_serializes() {
335        let entry = MemoryEntry {
336            id: "m1".into(),
337            agent: "researcher".into(),
338            content: "Rust is fast".into(),
339            category: "fact".into(),
340            tags: vec!["rust".into()],
341            created_at: Utc::now(),
342            last_accessed: Utc::now(),
343            access_count: 0,
344            importance: 7,
345            memory_type: MemoryType::default(),
346            keywords: vec![],
347            summary: None,
348            strength: 1.0,
349            related_ids: vec![],
350            source_ids: vec![],
351            embedding: None,
352            confidentiality: Confidentiality::default(),
353            author_user_id: None,
354            author_tenant_id: None,
355        };
356        let json = serde_json::to_string(&entry).unwrap();
357        let parsed: MemoryEntry = serde_json::from_str(&json).unwrap();
358        assert_eq!(parsed.id, "m1");
359        assert_eq!(parsed.agent, "researcher");
360        assert_eq!(parsed.content, "Rust is fast");
361        assert_eq!(parsed.importance, 7);
362    }
363
364    #[test]
365    fn memory_entry_serializes_new_fields() {
366        let entry = MemoryEntry {
367            id: "m1".into(),
368            agent: "a".into(),
369            content: "test".into(),
370            category: "fact".into(),
371            tags: vec![],
372            created_at: Utc::now(),
373            last_accessed: Utc::now(),
374            access_count: 0,
375            importance: 7,
376            memory_type: MemoryType::Reflection,
377            keywords: vec!["rust".into(), "performance".into()],
378            summary: Some("Rust is fast for systems programming".into()),
379            strength: 0.85,
380            related_ids: vec!["m2".into(), "m3".into()],
381            source_ids: vec!["m0".into()],
382            embedding: None,
383            confidentiality: Confidentiality::default(),
384            author_user_id: None,
385            author_tenant_id: None,
386        };
387        let json = serde_json::to_string(&entry).unwrap();
388        let parsed: MemoryEntry = serde_json::from_str(&json).unwrap();
389        assert_eq!(parsed.memory_type, MemoryType::Reflection);
390        assert_eq!(parsed.keywords, vec!["rust", "performance"]);
391        assert_eq!(
392            parsed.summary.as_deref(),
393            Some("Rust is fast for systems programming")
394        );
395        assert!((parsed.strength - 0.85).abs() < f64::EPSILON);
396        assert_eq!(parsed.related_ids, vec!["m2", "m3"]);
397        assert_eq!(parsed.source_ids, vec!["m0"]);
398    }
399
400    #[test]
401    fn memory_entry_deserialize_without_new_fields() {
402        // Existing JSON without the new fields — backward compat
403        let json = r#"{"id":"m1","agent":"a","content":"test","category":"fact","tags":[],"created_at":"2024-01-01T00:00:00Z","last_accessed":"2024-01-01T00:00:00Z","access_count":0,"importance":9}"#;
404        let entry: MemoryEntry = serde_json::from_str(json).unwrap();
405        assert_eq!(entry.importance, 9);
406        assert_eq!(entry.memory_type, MemoryType::Episodic);
407        assert!(entry.keywords.is_empty());
408        assert!(entry.summary.is_none());
409        assert!((entry.strength - 1.0).abs() < f64::EPSILON);
410        assert!(entry.related_ids.is_empty());
411        assert!(entry.source_ids.is_empty());
412    }
413
414    #[test]
415    fn memory_type_default_is_episodic() {
416        assert_eq!(MemoryType::default(), MemoryType::Episodic);
417    }
418
419    #[test]
420    fn strength_default_is_one() {
421        assert!((default_strength() - 1.0).abs() < f64::EPSILON);
422    }
423
424    #[test]
425    fn memory_type_serialization_roundtrip() {
426        for mt in [
427            MemoryType::Episodic,
428            MemoryType::Semantic,
429            MemoryType::Reflection,
430        ] {
431            let json = serde_json::to_string(&mt).unwrap();
432            let parsed: MemoryType = serde_json::from_str(&json).unwrap();
433            assert_eq!(parsed, mt);
434        }
435    }
436
437    #[test]
438    fn memory_type_serializes_as_snake_case() {
439        assert_eq!(
440            serde_json::to_string(&MemoryType::Episodic).unwrap(),
441            "\"episodic\""
442        );
443        assert_eq!(
444            serde_json::to_string(&MemoryType::Semantic).unwrap(),
445            "\"semantic\""
446        );
447        assert_eq!(
448            serde_json::to_string(&MemoryType::Reflection).unwrap(),
449            "\"reflection\""
450        );
451    }
452
453    #[test]
454    fn memory_entry_default_importance() {
455        let entry = make_entry("m1", "test");
456        assert_eq!(entry.importance, 5);
457    }
458
459    #[test]
460    fn memory_entry_deserialize_without_importance() {
461        let json = r#"{"id":"m1","agent":"a","content":"test","category":"fact","tags":[],"created_at":"2024-01-01T00:00:00Z","last_accessed":"2024-01-01T00:00:00Z","access_count":0}"#;
462        let entry: MemoryEntry = serde_json::from_str(json).unwrap();
463        assert_eq!(entry.importance, 5); // default
464    }
465
466    #[test]
467    fn memory_entry_deserialize_with_importance() {
468        let json = r#"{"id":"m1","agent":"a","content":"test","category":"fact","tags":[],"created_at":"2024-01-01T00:00:00Z","last_accessed":"2024-01-01T00:00:00Z","access_count":0,"importance":9}"#;
469        let entry: MemoryEntry = serde_json::from_str(json).unwrap();
470        assert_eq!(entry.importance, 9);
471    }
472
473    #[test]
474    fn memory_query_default() {
475        let q = MemoryQuery::default();
476        assert!(q.text.is_none());
477        assert!(q.category.is_none());
478        assert!(q.tags.is_empty());
479        assert!(q.agent.is_none());
480        assert_eq!(q.limit, 0);
481        assert!(q.memory_type.is_none());
482        assert!(q.min_strength.is_none());
483        assert!(q.query_embedding.is_none());
484    }
485
486    #[test]
487    fn memory_trait_is_object_safe() {
488        // Verify Memory can be used as dyn trait (including new default methods)
489        fn _accepts_dyn(_m: &dyn Memory) {}
490    }
491
492    #[test]
493    fn memory_entry_embedding_serde_roundtrip() {
494        let entry = MemoryEntry {
495            id: "m1".into(),
496            agent: "a".into(),
497            content: "test".into(),
498            category: "fact".into(),
499            tags: vec![],
500            created_at: Utc::now(),
501            last_accessed: Utc::now(),
502            access_count: 0,
503            importance: 5,
504            memory_type: MemoryType::default(),
505            keywords: vec![],
506            summary: None,
507            strength: 1.0,
508            related_ids: vec![],
509            source_ids: vec![],
510            embedding: Some(vec![0.1, 0.2, 0.3]),
511            confidentiality: Confidentiality::default(),
512            author_user_id: None,
513            author_tenant_id: None,
514        };
515        let json = serde_json::to_string(&entry).unwrap();
516        assert!(json.contains("\"embedding\""));
517        let parsed: MemoryEntry = serde_json::from_str(&json).unwrap();
518        let emb = parsed.embedding.unwrap();
519        assert_eq!(emb.len(), 3);
520        assert!((emb[0] - 0.1).abs() < f32::EPSILON);
521    }
522
523    #[test]
524    fn memory_entry_backward_compat_no_embedding() {
525        // Old JSON without embedding field — should deserialize to None
526        let json = r#"{"id":"m1","agent":"a","content":"test","category":"fact","tags":[],"created_at":"2024-01-01T00:00:00Z","last_accessed":"2024-01-01T00:00:00Z","access_count":0,"importance":5}"#;
527        let entry: MemoryEntry = serde_json::from_str(json).unwrap();
528        assert!(entry.embedding.is_none());
529    }
530
531    #[test]
532    fn memory_entry_none_embedding_not_serialized() {
533        // When embedding is None, field should be omitted from JSON
534        let entry = make_entry("m1", "test");
535        let json = serde_json::to_string(&entry).unwrap();
536        assert!(!json.contains("embedding"));
537    }
538
539    #[test]
540    fn confidentiality_default_is_public() {
541        assert_eq!(Confidentiality::default(), Confidentiality::Public);
542    }
543
544    #[test]
545    fn confidentiality_ordering() {
546        assert!(Confidentiality::Public < Confidentiality::Internal);
547        assert!(Confidentiality::Internal < Confidentiality::Confidential);
548        assert!(Confidentiality::Confidential < Confidentiality::Restricted);
549    }
550
551    #[test]
552    fn confidentiality_serde_roundtrip() {
553        for c in [
554            Confidentiality::Public,
555            Confidentiality::Internal,
556            Confidentiality::Confidential,
557            Confidentiality::Restricted,
558        ] {
559            let json = serde_json::to_string(&c).unwrap();
560            let parsed: Confidentiality = serde_json::from_str(&json).unwrap();
561            assert_eq!(parsed, c);
562        }
563    }
564
565    #[test]
566    fn confidentiality_serializes_as_snake_case() {
567        assert_eq!(
568            serde_json::to_string(&Confidentiality::Public).unwrap(),
569            "\"public\""
570        );
571        assert_eq!(
572            serde_json::to_string(&Confidentiality::Confidential).unwrap(),
573            "\"confidential\""
574        );
575        assert_eq!(
576            serde_json::to_string(&Confidentiality::Restricted).unwrap(),
577            "\"restricted\""
578        );
579    }
580
581    #[test]
582    fn memory_entry_backward_compat_no_confidentiality() {
583        // Old JSON without confidentiality field — should deserialize as Public
584        let json = r#"{"id":"m1","agent":"a","content":"test","category":"fact","tags":[],"created_at":"2024-01-01T00:00:00Z","last_accessed":"2024-01-01T00:00:00Z","access_count":0,"importance":5}"#;
585        let entry: MemoryEntry = serde_json::from_str(json).unwrap();
586        assert_eq!(entry.confidentiality, Confidentiality::Public);
587    }
588
589    #[test]
590    fn memory_query_max_confidentiality_default_is_none() {
591        let q = MemoryQuery::default();
592        assert!(q.max_confidentiality.is_none());
593    }
594}