use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::memory::auto_protect::AutoProtector;
use crate::memory::embedding::EmbeddingVector;
use crate::memory::storage::MemoryStorageExt;
#[cfg(feature = "sqlite-memory")]
use crate::memory::types::MemoryTier;
use crate::memory::types::{content_hash, dedup_by_id, extract_keywords, MemoryEntry, MemoryType};
use super::MemoryManager;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct VectorIndexSnapshot {
created_at: DateTime<Utc>,
entry_count: usize,
entries: HashMap<String, EmbeddingVector>,
}
impl MemoryManager {
pub async fn total_entries(&self) -> usize {
let mut total = 0;
for mt in MemoryType::all() {
if let Ok(entries) = self.list(*mt, 1_000_000).await {
total += entries.len();
}
}
total
}
pub async fn rebuild_index(&self) -> anyhow::Result<()> {
let mut entries_to_index: Vec<(String, EmbeddingVector)> = Vec::new();
for mt in MemoryType::all() {
if let Ok(names) = self.storage.list_category(mt.category()).await {
for name in names {
if let Ok(Some(entry)) = self
.storage
.load_json::<MemoryEntry>(mt.category(), &name)
.await
{
let vector = self.embedding.embed(&entry.content).await?;
entries_to_index.push((entry.id.clone(), vector));
}
}
}
}
{
let mut index = self.vector_index.write();
index.clear();
for (id, vector) in entries_to_index {
index.insert(id, vector);
}
}
tracing::info!(
entries = self.vector_index.read().len(),
"Memory vector index rebuilt"
);
Ok(())
}
pub async fn save_index_snapshot(&self) -> anyhow::Result<()> {
let snapshot = {
let index = self.vector_index.read();
VectorIndexSnapshot {
created_at: chrono::Utc::now(),
entry_count: index.len(),
entries: index.clone(),
}
};
self.storage
.save_json("memory", "vector_index_snapshot", &snapshot)
.await?;
self.git_commit("memory/vector_index_snapshot.json", "memory: snapshot save")
.await;
tracing::debug!(
entries = snapshot.entry_count,
"Vector index snapshot saved"
);
Ok(())
}
pub async fn load_index_snapshot(&self) -> anyhow::Result<usize> {
let snapshot: Option<VectorIndexSnapshot> = self
.storage
.load_json("memory", "vector_index_snapshot")
.await?;
match snapshot {
Some(snap) => {
let count = snap.entry_count;
let mut index = self.vector_index.write();
*index = snap.entries;
tracing::info!(entries = count, "Vector index snapshot loaded");
Ok(count)
}
None => {
tracing::debug!("No vector index snapshot found");
Ok(0)
}
}
}
pub async fn remember(&self, entry: MemoryEntry) -> anyhow::Result<String> {
#[cfg(feature = "sqlite-memory")]
if let Some(ref sqlite) = self.sqlite_store {
return sqlite.remember(&entry).await;
}
let id = entry.id.clone();
let vector = self.embedding.embed(&entry.content).await?;
let category = entry.memory_type.category();
self.storage.save_json(category, &id, &entry).await?;
self.git_commit(
&format!("{category}/{id}.json"),
&format!("memory: store {id}"),
)
.await;
{
let mut index = self.vector_index.write();
index.insert(id.clone(), vector.clone());
}
if let Some(f32_vec) = vector.to_f32_dense() {
let hnsw = self.hnsw_index.read();
if let Some(ref hnsw) = *hnsw {
if let Err(e) = hnsw.add_entry(&id, &f32_vec) {
tracing::warn!(id = %id, error = %e, "Failed to update HNSW index on remember");
}
}
}
tracing::debug!(id = %id, ty = entry.memory_type.label(), "Memory stored");
Ok(id)
}
pub async fn get(
&self,
id: &str,
memory_type: MemoryType,
) -> anyhow::Result<Option<MemoryEntry>> {
#[cfg(feature = "sqlite-memory")]
if let Some(ref sqlite) = self.sqlite_store {
return sqlite.get(id, memory_type);
}
let result: Option<MemoryEntry> =
self.storage.load_json(memory_type.category(), id).await?;
if let Some(mut entry) = result {
AutoProtector::record_access(&mut entry, "");
Ok(Some(entry))
} else {
Ok(None)
}
}
pub async fn forget(&self, id: &str, memory_type: MemoryType) -> anyhow::Result<bool> {
#[cfg(feature = "sqlite-memory")]
if let Some(ref sqlite) = self.sqlite_store {
return sqlite.forget(id, memory_type);
}
let result = self.storage.delete_file(memory_type.category(), id).await?;
{
let hnsw = self.hnsw_index.read();
if let Some(ref hnsw) = *hnsw {
if let Err(e) = hnsw.remove_entry(id) {
tracing::warn!(id = %id, error = %e, "Failed to remove from HNSW index on forget");
}
}
}
Ok(result)
}
pub async fn list(
&self,
memory_type: MemoryType,
limit: usize,
) -> anyhow::Result<Vec<MemoryEntry>> {
#[cfg(feature = "sqlite-memory")]
if let Some(ref sqlite) = self.sqlite_store {
return sqlite.list(memory_type, limit);
}
let category = memory_type.category();
let names = self.storage.list_category(category).await?;
let mut entries = Vec::new();
for name in names.into_iter().take(limit.saturating_mul(2)) {
if let Ok(Some(entry)) = self.storage.load_json::<MemoryEntry>(category, &name).await {
entries.push(entry);
}
}
entries.sort_by_key(|b| std::cmp::Reverse(b.created_at));
entries.truncate(limit);
Ok(entries)
}
pub async fn search(
&self,
query: &str,
memory_type: Option<MemoryType>,
limit: usize,
) -> anyhow::Result<Vec<MemoryEntry>> {
#[cfg(feature = "sqlite-memory")]
if let Some(ref sqlite) = self.sqlite_store {
return sqlite.search(query, memory_type, limit).await;
}
let query_vector = self.embedding.embed(query).await?;
let scored: Vec<(String, f64)> = {
let index = self.vector_index.read();
let mut scored: Vec<(String, f64)> = index
.iter()
.map(|(id, vector)| {
let score = query_vector.cosine_similarity(vector);
(id.clone(), score)
})
.filter(|(_, score)| *score > 0.1)
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
scored
};
if scored.is_empty() {
return self.keyword_search(query, memory_type, limit).await;
}
let types: &[MemoryType] = match memory_type {
Some(ref t) => std::slice::from_ref(t),
None => MemoryType::all(),
};
let mut results = Vec::new();
for (id, score) in scored {
for mt in types {
if let Ok(Some(mut entry)) = self
.storage
.load_json::<MemoryEntry>(mt.category(), &id)
.await
{
AutoProtector::record_access(&mut entry, "");
tracing::debug!(id = %id, score, "Vector search hit");
results.push(entry);
break;
}
}
}
if results.is_empty() {
return self.keyword_search(query, memory_type, limit).await;
}
Ok(results)
}
pub(crate) async fn keyword_search(
&self,
query: &str,
memory_type: Option<MemoryType>,
limit: usize,
) -> anyhow::Result<Vec<MemoryEntry>> {
let keywords = extract_keywords(query);
let types = match memory_type {
Some(t) => vec![t],
None => MemoryType::all().to_vec(),
};
let mut results = Vec::new();
for ty in &types {
let entries = self.list(*ty, limit * 2).await?;
for entry in entries {
let matches = keywords.iter().any(|k| {
let k_lower = k.to_lowercase();
entry.content.to_lowercase().contains(&k_lower)
|| entry
.tags
.iter()
.any(|t| t.to_lowercase().contains(&k_lower))
});
if matches {
results.push(entry);
}
}
}
results.sort_by(|a, b| {
b.importance
.partial_cmp(&a.importance)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
Ok(results)
}
pub async fn recall(&self, query: &str) -> anyhow::Result<Vec<MemoryEntry>> {
#[cfg(feature = "sqlite-memory")]
if let Some(ref sqlite) = self.sqlite_store {
return sqlite.recall(query, self.max_recall).await;
}
let limit = self.max_recall;
let recent = self
.list(MemoryType::Conversation, 3)
.await
.unwrap_or_default();
let sessions = self.list(MemoryType::Session, 2).await.unwrap_or_default();
let relevant = self.search(query, None, limit).await.unwrap_or_default();
let mut combined = recent;
combined.extend(sessions);
combined.extend(relevant);
dedup_by_id(&mut combined);
combined.truncate(limit);
Ok(combined)
}
pub fn blend_into_prompt(&self, memories: &[MemoryEntry], system_prompt: &str) -> String {
#[cfg(feature = "sqlite-memory")]
if let Some(ref sqlite) = self.sqlite_store {
return sqlite.blend_into_prompt(memories, system_prompt);
}
if memories.is_empty() {
return system_prompt.to_string();
}
let memory_block = memories
.iter()
.map(|m| format!("- [{}] {}", m.memory_type.label(), m.content))
.collect::<Vec<_>>()
.join("\n");
format!("{system_prompt}\n\n## Relevant Memory\n\n{memory_block}")
}
#[cfg(feature = "sqlite-memory")]
pub async fn recall_with_rerank(&self, query: &str) -> anyhow::Result<Vec<MemoryEntry>> {
if let Some(ref sqlite) = self.sqlite_store {
return sqlite.recall_with_rerank(query, self.max_recall).await;
}
self.recall(query).await
}
pub async fn is_duplicate(&self, content: &str) -> bool {
let hash = content_hash(content);
let query_vector = match self.embedding.embed(content).await {
Ok(v) => v,
Err(_) => return false,
};
let similar = {
let index = self.vector_index.read();
index
.iter()
.any(|(_, vector)| query_vector.cosine_similarity(vector) > 0.95)
};
if similar {
return true;
}
for mt in MemoryType::all() {
if let Ok(entries) = self.list(*mt, 1000).await {
for entry in entries {
if content_hash(&entry.content) == hash {
return true;
}
}
}
}
false
}
pub async fn remember_unique(&self, entry: MemoryEntry) -> anyhow::Result<Option<String>> {
#[cfg(feature = "sqlite-memory")]
if let Some(ref sqlite) = self.sqlite_store {
return sqlite.remember_unique(&entry).await;
}
if self.is_duplicate(&entry.content).await {
tracing::debug!(id = %entry.id, "Skipping duplicate memory");
return Ok(None);
}
let id = self.remember(entry).await?;
Ok(Some(id))
}
pub async fn recall_with_proactive(
&self,
query: &str,
recall_timing: &mut Option<crate::memory::proactive::RecallTiming>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let mut combined = self.recall(query).await?;
let should_recall = recall_timing
.as_mut()
.map(|t| t.should_recall(query))
.unwrap_or(true);
if should_recall && combined.len() < self.max_recall {
#[cfg(feature = "sqlite-memory")]
if self.sqlite_store.is_some() {
let remaining = self.max_recall - combined.len();
let warm = self.list_by_tier(MemoryTier::Warm, remaining).await?;
let mut seen_ids: std::collections::HashSet<String> =
combined.iter().map(|e| e.id.clone()).collect();
for entry in warm {
if seen_ids.insert(entry.id.clone()) && combined.len() < self.max_recall {
combined.push(entry);
}
}
}
#[cfg(not(feature = "sqlite-memory"))]
{
let proactive = crate::memory::proactive::ProactiveRecall::new(5, 0.6);
let extra = proactive.recall(self, query, &combined).await?;
combined.extend(extra);
dedup_by_id(&mut combined);
combined.truncate(self.max_recall);
}
#[cfg(feature = "sqlite-memory")]
if self.sqlite_store.is_none() {
let proactive = crate::memory::proactive::ProactiveRecall::new(5, 0.6);
let extra = proactive.recall(self, query, &combined).await?;
combined.extend(extra);
dedup_by_id(&mut combined);
combined.truncate(self.max_recall);
}
}
Ok(combined)
}
}