use std::collections::HashMap;
use std::sync::RwLock;
use anyhow::Result;
use uuid::Uuid;
use second_brain_core::schema::{
Conversation, Entity, Memory, MemoryType, Relation, RelationType,
};
use second_brain_core::store::Store;
type RelationsCache = HashMap<(Uuid, Option<RelationType>), Vec<Relation>>;
pub struct CachingStore<'a> {
inner: &'a dyn Store,
relations_cache: RwLock<RelationsCache>,
}
impl<'a> CachingStore<'a> {
pub fn new(inner: &'a dyn Store) -> Self {
Self {
inner,
relations_cache: RwLock::new(HashMap::new()),
}
}
pub fn prewarm(&self) -> Result<()> {
let scored_types = [
RelationType::Reinforces,
RelationType::RelatesTo,
RelationType::DistilledFrom,
RelationType::Mentions,
RelationType::DerivedFrom,
RelationType::Contradicts,
RelationType::Supersedes,
];
let ids = self.inner.all_memory_ids()?;
let mut cache = self.relations_cache.write().unwrap();
cache.reserve(ids.len() * scored_types.len());
for id in &ids {
for rt in &scored_types {
cache.entry((*id, Some(*rt))).or_default();
}
}
for rel in self.inner.all_relations()? {
if let Some(bucket) = cache.get_mut(&(rel.from_id, Some(rel.relation_type))) {
bucket.push(rel);
}
}
Ok(())
}
}
impl Store for CachingStore<'_> {
fn get_relations(
&self,
node_id: Uuid,
relation_type: Option<RelationType>,
) -> Result<Vec<Relation>> {
let key = (node_id, relation_type);
if let Some(hit) = self.relations_cache.read().unwrap().get(&key) {
return Ok(hit.clone());
}
let fetched = self.inner.get_relations(node_id, relation_type)?;
self.relations_cache
.write()
.unwrap()
.insert(key, fetched.clone());
Ok(fetched)
}
fn store_memory(&self, memory: &Memory) -> Result<()> {
self.inner.store_memory(memory)
}
fn get_memory(&self, id: Uuid) -> Result<Option<Memory>> {
self.inner.get_memory(id)
}
fn delete_memory(&self, id: Uuid) -> Result<()> {
self.inner.delete_memory(id)
}
fn store_entity(&self, entity: &Entity) -> Result<()> {
self.inner.store_entity(entity)
}
fn get_entity(&self, id: Uuid) -> Result<Option<Entity>> {
self.inner.get_entity(id)
}
fn find_entity_by_name(&self, name: &str) -> Result<Option<Entity>> {
self.inner.find_entity_by_name(name)
}
fn store_conversation(&self, conversation: &Conversation) -> Result<()> {
self.inner.store_conversation(conversation)
}
fn store_relation(&self, relation: &Relation) -> Result<()> {
self.inner.store_relation(relation)
}
fn vector_search(&self, embedding: &[f32], limit: usize) -> Result<Vec<(Memory, f32)>> {
self.inner.vector_search(embedding, limit)
}
fn traverse(&self, start_id: Uuid, depth: u32) -> Result<Vec<(Memory, Vec<Relation>)>> {
self.inner.traverse(start_id, depth)
}
fn memories_by_source(&self, source: &str) -> Result<Vec<Memory>> {
self.inner.memories_by_source(source)
}
fn memories_by_type(&self, memory_type: MemoryType) -> Result<Vec<Memory>> {
self.inner.memories_by_type(memory_type)
}
fn memories_needing_decay(&self, threshold_days: u32) -> Result<Vec<Memory>> {
self.inner.memories_needing_decay(threshold_days)
}
fn update_memory(&self, memory: &Memory) -> Result<()> {
self.inner.update_memory(memory)
}
fn record_access(&self, memory: &Memory) -> Result<()> {
self.inner.record_access(memory)
}
fn text_search(&self, query: &str, limit: usize) -> Result<Vec<Memory>> {
self.inner.text_search(query, limit)
}
fn memory_count(&self) -> Result<usize> {
self.inner.memory_count()
}
fn all_memory_ids(&self) -> Result<Vec<Uuid>> {
self.inner.all_memory_ids()
}
fn all_relations(&self) -> Result<Vec<Relation>> {
self.inner.all_relations()
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{Duration, Utc};
use second_brain_core::query::{QueryEngine, QueryFilters, QueryRequest};
use second_brain_core::schema::MemoryType;
struct InMemoryStore {
vector_results: Vec<(Memory, f32)>,
relations: Vec<Relation>,
get_relations_calls: std::sync::atomic::AtomicUsize,
}
impl Store for InMemoryStore {
fn vector_search(&self, _embedding: &[f32], _limit: usize) -> Result<Vec<(Memory, f32)>> {
Ok(self.vector_results.clone())
}
fn get_relations(
&self,
node_id: Uuid,
relation_type: Option<RelationType>,
) -> Result<Vec<Relation>> {
self.get_relations_calls
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(self
.relations
.iter()
.filter(|r| r.from_id == node_id)
.filter(|r| relation_type.map(|rt| rt == r.relation_type).unwrap_or(true))
.cloned()
.collect())
}
fn store_memory(&self, _m: &Memory) -> Result<()> {
unimplemented!()
}
fn get_memory(&self, _id: Uuid) -> Result<Option<Memory>> {
unimplemented!()
}
fn delete_memory(&self, _id: Uuid) -> Result<()> {
unimplemented!()
}
fn store_entity(&self, _e: &Entity) -> Result<()> {
unimplemented!()
}
fn get_entity(&self, _id: Uuid) -> Result<Option<Entity>> {
unimplemented!()
}
fn find_entity_by_name(&self, _name: &str) -> Result<Option<Entity>> {
unimplemented!()
}
fn store_conversation(&self, _c: &Conversation) -> Result<()> {
unimplemented!()
}
fn store_relation(&self, _r: &Relation) -> Result<()> {
unimplemented!()
}
fn traverse(&self, _id: Uuid, _depth: u32) -> Result<Vec<(Memory, Vec<Relation>)>> {
unimplemented!()
}
fn memories_by_source(&self, _s: &str) -> Result<Vec<Memory>> {
unimplemented!()
}
fn memories_by_type(&self, _mt: MemoryType) -> Result<Vec<Memory>> {
unimplemented!()
}
fn memories_needing_decay(&self, _days: u32) -> Result<Vec<Memory>> {
unimplemented!()
}
fn update_memory(&self, _m: &Memory) -> Result<()> {
unimplemented!()
}
fn record_access(&self, _memory: &Memory) -> Result<()> {
unimplemented!()
}
fn text_search(&self, _q: &str, _limit: usize) -> Result<Vec<Memory>> {
unimplemented!()
}
fn memory_count(&self) -> Result<usize> {
unimplemented!()
}
fn all_memory_ids(&self) -> Result<Vec<Uuid>> {
Ok(self.vector_results.iter().map(|(m, _)| m.id).collect())
}
fn all_relations(&self) -> Result<Vec<Relation>> {
Ok(self.relations.clone())
}
}
fn memory(content: &str, days_old: i64) -> Memory {
let when = Utc::now() - Duration::days(days_old);
let mut m = Memory::new(
content.to_string(),
MemoryType::Semantic,
"test".to_string(),
String::new(),
);
m.created_at = when;
m.last_accessed = when;
m
}
fn fixture() -> (Vec<(Memory, f32)>, Vec<Relation>) {
let a = memory("kuzu was chosen as the embedded graph store", 10);
let b = memory("sync runs bidirectionally over ssh", 40);
let c = memory("embeddings use the bge model", 5);
let rel = |from: Uuid, rt: RelationType, strength: f32| Relation {
from_id: from,
to_id: Uuid::new_v4(),
relation_type: rt,
strength,
context: None,
};
let relations = vec![
rel(a.id, RelationType::Reinforces, 1.0),
rel(a.id, RelationType::RelatesTo, 0.7),
rel(b.id, RelationType::Mentions, 1.0),
rel(c.id, RelationType::RelatesTo, 0.4),
rel(c.id, RelationType::Supersedes, 1.0),
];
let vector_results = vec![(a, 0.91), (b, 0.78), (c, 0.66)];
(vector_results, relations)
}
fn request() -> QueryRequest {
QueryRequest {
text: "graph store choice".to_string(),
embedding: vec![0.1_f32; 384],
limit: 10,
filters: QueryFilters::default(),
}
}
#[test]
fn caching_store_recall_matches_raw_store() {
let (vector_results, relations) = fixture();
let raw = InMemoryStore {
vector_results,
relations,
get_relations_calls: std::sync::atomic::AtomicUsize::new(0),
};
let baseline = QueryEngine::new(&raw).recall(&request()).unwrap();
let cached = CachingStore::new(&raw);
let first = QueryEngine::new(&cached).recall(&request()).unwrap();
let second = QueryEngine::new(&cached).recall(&request()).unwrap();
assert_eq!(baseline.len(), first.len());
assert_eq!(first.len(), second.len());
for (b, c) in baseline.iter().zip(first.iter()) {
assert_eq!(b.memory.id, c.memory.id, "result order must match");
assert!(
(b.score - c.score).abs() < 1e-6,
"scores must match: {} vs {}",
b.score,
c.score
);
}
for (a, c) in first.iter().zip(second.iter()) {
assert_eq!(a.memory.id, c.memory.id);
assert!((a.score - c.score).abs() < 1e-6);
}
}
#[test]
fn prewarmed_store_recall_matches_raw_and_skips_live_reads() {
use std::sync::atomic::Ordering;
let (vector_results, relations) = fixture();
let raw = InMemoryStore {
vector_results,
relations,
get_relations_calls: std::sync::atomic::AtomicUsize::new(0),
};
let baseline = QueryEngine::new(&raw).recall(&request()).unwrap();
let calls_after_baseline = raw.get_relations_calls.load(Ordering::Relaxed);
assert!(calls_after_baseline > 0, "baseline must hit the live store");
let cached = CachingStore::new(&raw);
cached.prewarm().unwrap();
let prewarm_calls = raw.get_relations_calls.load(Ordering::Relaxed);
let recalled = QueryEngine::new(&cached).recall(&request()).unwrap();
assert_eq!(
raw.get_relations_calls.load(Ordering::Relaxed),
prewarm_calls,
"prewarmed recall must not call inner.get_relations"
);
assert_eq!(prewarm_calls, calls_after_baseline, "prewarm must not read via get_relations");
assert_eq!(baseline.len(), recalled.len());
for (b, c) in baseline.iter().zip(recalled.iter()) {
assert_eq!(b.memory.id, c.memory.id, "result order must match");
assert!(
(b.score - c.score).abs() < 1e-6,
"scores must match: {} vs {}",
b.score,
c.score
);
}
}
}