use anyhow::Result;
use crate::memory::auto_protect::AutoProtector;
use crate::memory::hnsw_memory_index::{HnswMemoryIndex, SemanticHit};
use crate::memory::storage::MemoryStorageExt;
use crate::memory::types::{MemoryEntry, MemoryTier, MemoryType};
use super::MemoryManager;
impl MemoryManager {
pub async fn semantic_search(
&self,
query: &str,
memory_type: Option<MemoryType>,
limit: usize,
hnsw_index: &HnswMemoryIndex,
) -> Result<Vec<SemanticHit>> {
if hnsw_index.is_empty() {
tracing::debug!("HNSW index empty, falling back to keyword search");
return self
.keyword_search(query, memory_type, limit)
.await
.map(|entries| {
entries
.into_iter()
.map(|entry| SemanticHit {
entry,
distance: 0.0,
similarity: 0.0,
})
.collect()
});
}
let query_vector = self.embedding.embed(query).await?;
let query_f32 = match query_vector.to_f32_dense() {
Some(v) => v,
None => {
tracing::debug!("Query embedding is sparse, falling back to keyword search");
return self
.keyword_search(query, memory_type, limit)
.await
.map(|entries| {
entries
.into_iter()
.map(|entry| SemanticHit {
entry,
distance: 0.0,
similarity: 0.0,
})
.collect()
});
}
};
let raw_hits = hnsw_index.search(&query_f32, limit * 2)?;
let types: &[MemoryType] = match memory_type {
Some(ref t) => std::slice::from_ref(t),
None => MemoryType::all(),
};
let mut results = Vec::new();
for (id, distance) in raw_hits {
for mt in types {
if let Ok(Some(mut entry)) = self
.storage
.load_json::<MemoryEntry>(mt.category(), &id)
.await
{
AutoProtector::record_access(&mut entry, "");
let similarity = 1.0 - distance;
results.push(SemanticHit {
entry,
distance,
similarity,
});
break;
}
}
if results.len() >= limit {
break;
}
}
results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
tracing::debug!(
query = %query,
hits = results.len(),
"Semantic search complete"
);
if results.is_empty() {
return self
.keyword_search(query, memory_type, limit)
.await
.map(|entries| {
entries
.into_iter()
.map(|entry| SemanticHit {
entry,
distance: 0.0,
similarity: 0.0,
})
.collect()
});
}
Ok(results)
}
pub async fn rebuild_hnsw_index(&self, hnsw_index: &HnswMemoryIndex) -> Result<usize> {
let mut count = 0;
for mt in MemoryType::all() {
if let Ok(names) = self.storage.list_category(mt.category()).await {
for name in names {
if let Ok(Some(entry)) = self
.storage
.load_json::<MemoryEntry>(mt.category(), &name)
.await
{
let vector = self.embedding.embed(&entry.content).await?;
if let Some(f32_vec) = vector.to_f32_dense() {
if let Err(e) = hnsw_index.add_entry(&entry.id, &f32_vec) {
tracing::warn!(
id = %entry.id,
error = %e,
"Failed to add entry to HNSW index"
);
continue;
}
count += 1;
}
}
}
}
}
tracing::info!(entries = count, "HNSW index rebuilt");
Ok(count)
}
pub async fn list_by_tier(&self, tier: MemoryTier, limit: usize) -> Result<Vec<MemoryEntry>> {
#[cfg(feature = "sqlite-memory")]
if let Some(ref sqlite) = self.sqlite_store {
return sqlite.list_by_tier(tier, limit);
}
let mut results = Vec::new();
for mt in MemoryType::all() {
if let Ok(entries) = self.list(*mt, limit).await {
for entry in entries {
if entry.tier == tier {
results.push(entry);
}
}
}
if results.len() >= limit {
break;
}
}
results.truncate(limit);
Ok(results)
}
pub async fn get_by_id(&self, id: &str) -> Result<Option<MemoryEntry>> {
for mt in MemoryType::all() {
if let Ok(Some(entry)) = self.get(id, *mt).await {
return Ok(Some(entry));
}
}
Ok(None)
}
pub async fn load_by_reference(&self, reference: &str) -> Result<Option<MemoryEntry>> {
if let Ok(Some(entry)) = self.get_by_id(reference).await {
return Ok(Some(entry));
}
if let Some((cat, name)) = reference.split_once('/') {
if let Ok(Some(entry)) = self.storage.load_json::<MemoryEntry>(cat, name).await {
return Ok(Some(entry));
}
}
Ok(None)
}
pub async fn select_by_manifest(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
self.keyword_search(query, None, limit).await
}
pub async fn build_hot_context(&self, token_budget: usize) -> Result<String> {
let hot_entries = self.list_by_tier(MemoryTier::Hot, 50).await?;
let mut context_parts = Vec::new();
let mut char_budget = token_budget * 4;
for entry in &hot_entries {
let line = format!("- [{}] {}", entry.memory_type.label(), entry.content);
if line.len() > char_budget {
break;
}
char_budget -= line.len();
context_parts.push(line);
}
if context_parts.is_empty() {
Ok(String::new())
} else {
Ok(format!("## Active Context\n\n{}", context_parts.join("\n")))
}
}
pub async fn build_full_context(
&self,
_query: &str,
system_prompt: &str,
token_budget: usize,
) -> Result<String> {
let hot_ctx = self.build_hot_context(token_budget).await?;
if hot_ctx.is_empty() {
return Ok(system_prompt.to_string());
}
Ok(format!("{system_prompt}\n\n{hot_ctx}"))
}
pub async fn shift_tier(&self, id: &str, from: MemoryTier, to: MemoryTier) -> Result<()> {
if let Ok(Some(mut entry)) = self.get_by_id(id).await {
if entry.tier == from {
entry.tier = to;
self.remember(entry).await?;
}
}
Ok(())
}
pub async fn pin(&self, id: &str) -> Result<()> {
if let Ok(Some(mut entry)) = self.get_by_id(id).await {
entry.pinned = true;
entry.protection = crate::memory::types::ProtectionLevel::Permanent;
self.remember(entry).await?;
}
Ok(())
}
pub async fn unpin(&self, id: &str) -> Result<()> {
if let Ok(Some(mut entry)) = self.get_by_id(id).await {
entry.pinned = false;
let protector = crate::memory::auto_protect::AutoProtector::default_protector();
entry.protection = protector.compute_protection(&entry);
self.remember(entry).await?;
}
Ok(())
}
pub async fn set_importance(&self, id: &str, importance: f32) -> Result<()> {
if let Ok(Some(mut entry)) = self.get_by_id(id).await {
entry.importance = importance.clamp(0.0, 1.0);
self.remember(entry).await?;
}
Ok(())
}
pub async fn recompute_all_decay(&self, multiplier: f32) -> Result<usize> {
let engine = crate::memory::decay::DecayEngine::new(multiplier);
let now = chrono::Utc::now();
let mut count = 0;
for mt in MemoryType::all() {
if let Ok(entries) = self.list(*mt, 1_000_000).await {
for mut entry in entries {
let new_decay = engine.compute_decay(&entry, now);
if (entry.decay_score - new_decay).abs() > 0.001 {
entry.decay_score = new_decay;
self.remember(entry).await?;
count += 1;
}
}
}
}
Ok(count)
}
pub async fn immediate_hot_overflow(&self, hot_max: usize) -> Result<usize> {
let hot_entries = self.list_by_tier(MemoryTier::Hot, hot_max * 2).await?;
if hot_entries.len() <= hot_max {
return Ok(0);
}
let overflow = hot_entries.len() - hot_max;
let mut candidates: Vec<MemoryEntry> = hot_entries
.into_iter()
.filter(|e| e.protection < crate::memory::types::ProtectionLevel::High && !e.pinned)
.collect();
candidates.sort_by(|a, b| {
a.protection.cmp(&b.protection).then(
a.decay_score
.partial_cmp(&b.decay_score)
.unwrap_or(std::cmp::Ordering::Equal),
)
});
let mut demoted = 0;
for entry in candidates.into_iter().take(overflow) {
self.shift_tier(&entry.id, MemoryTier::Hot, MemoryTier::Warm)
.await?;
demoted += 1;
}
Ok(demoted)
}
}