#![allow(clippy::cast_possible_wrap)]
#![allow(clippy::cast_sign_loss)]
use crate::Database;
use std::sync::Arc;
pub use super::episodic_memory::EpisodicMemory;
pub use super::error::AgentMemoryError;
pub use super::procedural_memory::{ProceduralMemory, ProcedureMatch};
pub use super::semantic_memory::SemanticMemory;
pub use super::snapshot::{MemoryState, SnapshotManager};
pub use super::temporal_index::TemporalIndex;
pub use super::ttl::{EvictionConfig, ExpireResult, MemoryTtl};
pub const DEFAULT_DIMENSION: usize = 384;
pub struct AgentMemory {
db: Arc<Database>,
semantic: SemanticMemory,
episodic: EpisodicMemory,
procedural: ProceduralMemory,
ttl: Arc<MemoryTtl>,
#[allow(dead_code)]
temporal_index: Arc<TemporalIndex>,
eviction_config: EvictionConfig,
snapshot_manager: Option<SnapshotManager>,
}
impl AgentMemory {
pub fn new(db: Arc<Database>) -> Result<Self, AgentMemoryError> {
Self::with_dimension(db, DEFAULT_DIMENSION)
}
pub fn with_dimension(db: Arc<Database>, dimension: usize) -> Result<Self, AgentMemoryError> {
let ttl = Arc::new(MemoryTtl::new());
let temporal_index = Arc::new(TemporalIndex::new());
let semantic = SemanticMemory::new(Arc::clone(&db), dimension, Arc::clone(&ttl))?;
let episodic = EpisodicMemory::new(
Arc::clone(&db),
dimension,
Arc::clone(&ttl),
Arc::clone(&temporal_index),
)?;
let procedural = ProceduralMemory::new(Arc::clone(&db), dimension, Arc::clone(&ttl))?;
Ok(Self {
db,
semantic,
episodic,
procedural,
ttl,
temporal_index,
eviction_config: EvictionConfig::default(),
snapshot_manager: None,
})
}
#[must_use]
pub fn with_eviction_config(mut self, config: EvictionConfig) -> Self {
self.eviction_config = config;
self
}
#[must_use]
pub fn with_snapshots(mut self, snapshot_dir: &str, max_snapshots: usize) -> Self {
self.snapshot_manager = Some(SnapshotManager::new(snapshot_dir, max_snapshots));
self
}
#[must_use]
pub fn semantic(&self) -> &SemanticMemory {
&self.semantic
}
#[must_use]
pub fn episodic(&self) -> &EpisodicMemory {
&self.episodic
}
#[must_use]
pub fn procedural(&self) -> &ProceduralMemory {
&self.procedural
}
pub fn set_semantic_ttl(&self, id: u64, ttl_seconds: u64) {
self.ttl.set_ttl(id, ttl_seconds);
}
pub fn set_episodic_ttl(&self, id: u64, ttl_seconds: u64) {
self.ttl.set_ttl(id, ttl_seconds);
}
pub fn set_procedural_ttl(&self, id: u64, ttl_seconds: u64) {
self.ttl.set_ttl(id, ttl_seconds);
}
pub fn auto_expire(&self) -> Result<ExpireResult, AgentMemoryError> {
let expired_ids = self.ttl.expire();
let mut result = ExpireResult::default();
for id in &expired_ids {
if self.semantic.delete(*id).is_ok() {
result.semantic_expired += 1;
}
if self.episodic.delete(*id).is_ok() {
result.episodic_expired += 1;
}
if self.procedural.delete(*id).is_ok() {
result.procedural_expired += 1;
}
}
if self.eviction_config.consolidation_age_threshold > 0 {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_secs() as i64);
let cutoff = now - self.eviction_config.consolidation_age_threshold as i64;
result.episodic_consolidated = self.consolidate_old_episodes(cutoff)?;
}
Ok(result)
}
pub fn evict_low_confidence_procedures(
&self,
min_confidence: f32,
) -> Result<usize, AgentMemoryError> {
let all_procedures = self.procedural.list_all()?;
let mut evicted = 0;
for proc in all_procedures {
if proc.confidence < min_confidence {
self.procedural.delete(proc.id)?;
evicted += 1;
}
}
Ok(evicted)
}
fn require_snapshot_manager(&self) -> Result<&SnapshotManager, AgentMemoryError> {
self.snapshot_manager.as_ref().ok_or_else(|| {
AgentMemoryError::SnapshotError("Snapshot manager not configured".to_string())
})
}
pub fn snapshot(&self) -> Result<u64, AgentMemoryError> {
let manager = self.require_snapshot_manager()?;
let state = MemoryState {
semantic: self.semantic.serialize()?,
episodic: self.episodic.serialize()?,
procedural: self.procedural.serialize()?,
ttl: self.ttl.serialize(),
};
Ok(manager.create_versioned_snapshot(&state)?)
}
pub fn load_latest_snapshot(&self) -> Result<u64, AgentMemoryError> {
let manager = self.require_snapshot_manager()?;
let (version, state) = manager.load_latest()?;
self.restore_state(&state)?;
Ok(version)
}
pub fn load_snapshot_version(&self, version: u64) -> Result<(), AgentMemoryError> {
let manager = self.require_snapshot_manager()?;
let state = manager.load_version(version)?;
self.restore_state(&state)?;
Ok(())
}
pub fn list_snapshot_versions(&self) -> Result<Vec<u64>, AgentMemoryError> {
let manager = self.require_snapshot_manager()?;
Ok(manager.list_versions()?)
}
pub fn query_semantic(
&self,
sql: &str,
params: &std::collections::HashMap<String, serde_json::Value>,
) -> Result<Vec<crate::SearchResult>, AgentMemoryError> {
super::memory_helpers::execute_velesql(
&self.db,
self.semantic.collection_name(),
sql,
params,
)
}
pub fn query_episodic(
&self,
sql: &str,
params: &std::collections::HashMap<String, serde_json::Value>,
) -> Result<Vec<crate::SearchResult>, AgentMemoryError> {
super::memory_helpers::execute_velesql(
&self.db,
self.episodic.collection_name(),
sql,
params,
)
}
pub fn query_procedural(
&self,
sql: &str,
params: &std::collections::HashMap<String, serde_json::Value>,
) -> Result<Vec<crate::SearchResult>, AgentMemoryError> {
super::memory_helpers::execute_velesql(
&self.db,
self.procedural.collection_name(),
sql,
params,
)
}
fn restore_state(&self, state: &MemoryState) -> Result<(), AgentMemoryError> {
self.semantic.deserialize(&state.semantic)?;
self.episodic.deserialize(&state.episodic)?;
self.procedural.deserialize(&state.procedural)?;
if let Some(ttl) = MemoryTtl::deserialize(&state.ttl) {
self.ttl.replace_from(&ttl);
} else {
self.ttl.clear();
}
Ok(())
}
fn consolidate_old_episodes(&self, cutoff_timestamp: i64) -> Result<usize, AgentMemoryError> {
let old_events = self.episodic.older_than(cutoff_timestamp, 1000)?;
let mut consolidated = 0;
for (id, _description, _timestamp) in old_events {
if let Some((description, _ts, embedding)) = self.episodic.get_with_embedding(id)? {
self.semantic.store(id, &description, &embedding)?;
self.episodic.delete(id)?;
consolidated += 1;
}
}
Ok(consolidated)
}
}