mod ops;
mod store;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
use parking_lot::RwLock;
use crate::memory::embedding::{EmbeddingProvider, EmbeddingVector, TfIdfEmbeddingProvider};
use crate::memory::hnsw_memory_index::HnswMemoryIndex;
use crate::memory::sona::SonaEngine;
use crate::memory::storage::{MemoryGit, MemoryStorage};
use crate::memory::types::{MemoryEntry, MemoryType};
use super::{CurationCandidate, CurationReport, MemoryBudget};
pub struct MemoryManager {
pub(crate) storage: Arc<dyn MemoryStorage>,
pub(crate) max_recall: usize,
pub(crate) vector_index: RwLock<HashMap<String, EmbeddingVector>>,
pub(crate) embedding: Arc<dyn EmbeddingProvider>,
pub(crate) git: Option<Arc<dyn MemoryGit>>,
pub(crate) hnsw_index: RwLock<Option<Arc<HnswMemoryIndex>>>,
pub(crate) sona_engine: Option<Arc<SonaEngine>>,
#[cfg(feature = "sqlite-memory")]
pub(crate) sqlite_store: Option<Arc<crate::memory::sqlite::SqliteMemoryStore>>,
}
impl std::fmt::Debug for MemoryManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryManager")
.field("max_recall", &self.max_recall)
.field("index_size", &self.vector_index.read().len())
.field("sona_enabled", &self.sona_engine.is_some())
.finish()
}
}
impl MemoryManager {
pub fn new(storage: Arc<dyn MemoryStorage>) -> Self {
Self {
storage,
max_recall: 10,
vector_index: RwLock::new(HashMap::new()),
embedding: Arc::new(TfIdfEmbeddingProvider),
git: None,
hnsw_index: RwLock::new(None),
sona_engine: None,
#[cfg(feature = "sqlite-memory")]
sqlite_store: None,
}
}
pub fn set_git_layer(&mut self, gl: Arc<dyn MemoryGit>) {
self.git = Some(gl);
}
#[cfg(feature = "sqlite-memory")]
pub fn set_sqlite_store(&mut self, store: Arc<crate::memory::sqlite::SqliteMemoryStore>) {
self.sqlite_store = Some(store);
}
#[cfg(feature = "sqlite-memory")]
pub fn sqlite_store(&self) -> &Option<Arc<crate::memory::sqlite::SqliteMemoryStore>> {
&self.sqlite_store
}
pub fn set_sona_engine(&mut self, engine: Arc<SonaEngine>) {
self.sona_engine = Some(engine);
}
pub fn sona_engine(&self) -> Option<&Arc<SonaEngine>> {
self.sona_engine.as_ref()
}
pub fn set_hnsw_index(&self, index: Arc<HnswMemoryIndex>) {
*self.hnsw_index.write() = Some(index);
}
pub fn with_max_recall(mut self, n: usize) -> Self {
self.max_recall = n;
self
}
pub fn set_max_recall(&mut self, n: usize) {
self.max_recall = n;
}
pub fn vector_index_size(&self) -> usize {
self.vector_index.read().len()
}
pub(crate) async fn git_commit(&self, rel_path: &str, message: &str) {
if let Some(ref gl) = self.git {
if gl.is_enabled() {
let _ = gl.commit_file(rel_path, message).await;
}
}
}
pub fn effective_importance(entry: &MemoryEntry) -> f32 {
let access_boost = (1.0_f32 + entry.access_count as f32).ln();
entry.importance * (1.0 + access_boost)
}
pub async fn curate(&self, budget: &MemoryBudget) -> Result<CurationReport> {
let mut report = CurationReport::default();
for mt in &[
MemoryType::Conversation,
MemoryType::Session,
MemoryType::Fact,
MemoryType::Episode,
MemoryType::Knowledge,
] {
let entries = self.list(*mt, budget.max_per_type * 2).await?;
if entries.len() <= budget.max_per_type {
continue;
}
let total_count = entries.len();
let mut scored: Vec<_> = entries
.into_iter()
.map(|e| (e.id.clone(), e.memory_type, Self::effective_importance(&e)))
.collect();
scored.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
let to_remove = scored.len() - budget.max_per_type;
for (id, memory_type, score) in scored.into_iter().take(to_remove) {
report.candidates_for_removal.push(CurationCandidate {
id,
memory_type,
effective_importance: score,
});
}
report.total_before += total_count;
}
for candidate in &report.candidates_for_removal {
if self
.forget(&candidate.id, candidate.memory_type)
.await
.is_ok()
{
report.removed += 1;
}
}
report.total_after = report.total_before - report.removed;
Ok(report)
}
pub fn spawn_curation_task(self: &Arc<Self>, budget: MemoryBudget) {
let mgr = Arc::clone(self);
tokio::spawn(async move {
match mgr.curate(&budget).await {
Ok(report) => {
if report.removed > 0 {
tracing::info!(
removed = report.removed,
candidates = report.candidates_for_removal.len(),
"Memory curation complete"
);
}
}
Err(e) => {
tracing::warn!(error = %e, "Memory curation failed");
}
}
});
}
}