use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::Result;
use chrono::{DateTime, Utc};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use crate::embedding::{EmbeddingProvider, EmbeddingVector, TfIdfEmbeddingProvider};
use crate::git_layer::GitLayer;
use crate::state_store::StateStore;
pub use budget::{CurationCandidate, CurationReport, MemoryBudget};
pub use store::HnswMemoryIndex;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
pub fn content_hash(content: &str) -> u64 {
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
hasher.finish()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextVector {
tf: HashMap<String, f64>,
}
impl TextVector {
pub fn from_text(text: &str) -> Self {
let mut tf: HashMap<String, f64> = HashMap::new();
let terms = Self::tokenize(text);
let total = terms.len() as f64;
for term in terms {
*tf.entry(term).or_insert(0.0) += 1.0;
}
if total > 0.0 {
for v in tf.values_mut() {
*v /= total;
}
}
Self { tf }
}
pub fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric() && !('\u{AC00}'..='\u{D7A3}').contains(&c))
.filter(|s| !s.is_empty() && s.len() > 1)
.map(|s| s.to_string())
.collect()
}
pub fn tf_map(&self) -> &HashMap<String, f64> {
&self.tf
}
pub fn cosine_similarity(&self, other: &TextVector) -> f64 {
let mut dot = 0.0;
let mut norm_a = 0.0;
let mut norm_b = 0.0;
for (term, &a) in &self.tf {
norm_a += a * a;
if let Some(&b) = other.tf.get(term) {
dot += a * b;
}
}
for &b in other.tf.values() {
norm_b += b * b;
}
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a.sqrt() * norm_b.sqrt())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryType {
Conversation,
Session,
Fact,
Episode,
Knowledge,
Skill,
Preference,
Decision,
UserProfile,
}
impl MemoryType {
pub fn category(&self) -> &'static str {
match self {
MemoryType::Conversation => "memory/conversations",
MemoryType::Session => "memory/sessions",
MemoryType::Fact => "memory/facts",
MemoryType::Episode => "memory/episodes",
MemoryType::Knowledge => "memory/knowledge",
MemoryType::Skill => "memory/skills",
MemoryType::Preference => "memory/preferences",
MemoryType::Decision => "memory/decisions",
MemoryType::UserProfile => "memory/profiles",
}
}
pub fn label(&self) -> &'static str {
match self {
MemoryType::Conversation => "conversation",
MemoryType::Session => "session",
MemoryType::Fact => "fact",
MemoryType::Episode => "episode",
MemoryType::Knowledge => "knowledge",
MemoryType::Skill => "skill",
MemoryType::Preference => "preference",
MemoryType::Decision => "decision",
MemoryType::UserProfile => "user_profile",
}
}
pub fn base_importance(&self) -> f32 {
match self {
MemoryType::UserProfile => 0.95,
MemoryType::Preference => 0.90,
MemoryType::Decision => 0.80,
MemoryType::Knowledge => 0.75,
MemoryType::Skill => 0.75,
MemoryType::Fact => 0.60,
MemoryType::Episode => 0.50,
MemoryType::Session => 0.40,
MemoryType::Conversation => 0.35,
}
}
pub fn base_decay_rate(&self) -> f32 {
match self {
MemoryType::UserProfile => 0.001,
MemoryType::Preference => 0.002,
MemoryType::Decision => 0.005,
MemoryType::Knowledge => 0.006,
MemoryType::Skill => 0.008,
MemoryType::Fact => 0.015,
MemoryType::Episode => 0.025,
MemoryType::Session => 0.040,
MemoryType::Conversation => 0.060,
}
}
pub fn initial_tier(&self) -> MemoryTier {
match self {
MemoryType::UserProfile
| MemoryType::Preference
| MemoryType::Decision
| MemoryType::Fact => MemoryTier::Hot,
MemoryType::Knowledge
| MemoryType::Skill
| MemoryType::Episode
| MemoryType::Session
| MemoryType::Conversation => MemoryTier::Warm,
}
}
pub fn is_auto_protected(&self) -> bool {
matches!(self, MemoryType::UserProfile | MemoryType::Preference)
}
pub fn is_global(&self) -> bool {
matches!(self, MemoryType::UserProfile | MemoryType::Preference)
}
pub fn all() -> &'static [MemoryType] {
&[
MemoryType::Conversation,
MemoryType::Session,
MemoryType::Fact,
MemoryType::Episode,
MemoryType::Knowledge,
MemoryType::Skill,
MemoryType::Preference,
MemoryType::Decision,
MemoryType::UserProfile,
]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryTier {
Hot,
Warm,
Cold,
}
impl MemoryTier {
pub fn default_max_entries(&self) -> usize {
match self {
MemoryTier::Hot => 50,
MemoryTier::Warm => 500,
MemoryTier::Cold => 10_000,
}
}
pub fn default_token_budget(&self) -> usize {
match self {
MemoryTier::Hot => 3_000,
MemoryTier::Warm => 50_000,
MemoryTier::Cold => usize::MAX,
}
}
}
fn default_tier() -> MemoryTier {
MemoryTier::Warm
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum ProtectionLevel {
#[default]
None = 0,
Low = 1,
Medium = 2,
High = 3,
Permanent = 4,
}
impl ProtectionLevel {
pub fn decay_multiplier(&self) -> f32 {
match self {
ProtectionLevel::None => 1.0,
ProtectionLevel::Low => 0.5,
ProtectionLevel::Medium => 0.2,
ProtectionLevel::High => 0.05,
ProtectionLevel::Permanent => 0.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub id: String,
pub memory_type: MemoryType,
#[serde(default = "default_tier")]
pub tier: MemoryTier,
pub content: String,
#[serde(default)]
pub content_hash: u64,
#[serde(default)]
pub tags: Vec<String>,
pub source: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(default = "default_importance")]
pub importance: f32,
#[serde(default)]
pub pinned: bool,
#[serde(default)]
pub protection: ProtectionLevel,
#[serde(default)]
pub auto_classified: bool,
#[serde(default)]
pub session_appearances: u32,
#[serde(default)]
pub user_corrected: bool,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub seen_in_sessions: Vec<String>,
pub created_at: DateTime<Utc>,
pub accessed_at: DateTime<Utc>,
#[serde(default = "default_now")]
pub modified_at: DateTime<Utc>,
#[serde(default)]
pub access_count: u32,
#[serde(default = "default_importance")]
pub decay_score: f32,
#[serde(default)]
pub compaction_level: u8,
#[serde(default)]
pub compacted_from: Vec<String>,
#[serde(default)]
pub related_ids: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub contradicts: Option<String>,
}
fn default_importance() -> f32 {
0.5
}
fn default_now() -> DateTime<Utc> {
Utc::now()
}
pub struct MemoryManager {
state_store: Arc<StateStore>,
max_recall: usize,
vector_index: RwLock<HashMap<String, EmbeddingVector>>,
embedding: Arc<dyn EmbeddingProvider>,
git_layer: Option<Arc<GitLayer>>,
hnsw_index: RwLock<Option<Arc<HnswMemoryIndex>>>,
sona_engine: Option<Arc<sona::SonaEngine>>,
#[cfg(feature = "sqlite-memory")]
sqlite_store: Option<Arc<crate::memory::sqlite_store::SqliteMemoryStore>>,
}
impl std::fmt::Debug for MemoryManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryManager")
.field("max_recall", &self.max_recall)
.field("index_size", &self.vector_index.read().len())
.field("sona_enabled", &self.sona_engine.is_some())
.finish()
}
}
impl MemoryManager {
pub fn new(state_store: Arc<StateStore>) -> Self {
Self {
state_store,
max_recall: 10,
vector_index: RwLock::new(HashMap::new()),
embedding: Arc::new(TfIdfEmbeddingProvider),
git_layer: None,
hnsw_index: RwLock::new(None),
sona_engine: None,
#[cfg(feature = "sqlite-memory")]
sqlite_store: None,
}
}
pub fn set_git_layer(&mut self, gl: Arc<GitLayer>) {
self.git_layer = Some(gl);
}
#[cfg(feature = "sqlite-memory")]
pub fn set_sqlite_store(&mut self, store: Arc<crate::memory::sqlite_store::SqliteMemoryStore>) {
self.sqlite_store = Some(store);
}
#[cfg(feature = "sqlite-memory")]
pub fn sqlite_store(&self) -> &Option<Arc<crate::memory::sqlite_store::SqliteMemoryStore>> {
&self.sqlite_store
}
pub fn set_sona_engine(&mut self, engine: Arc<sona::SonaEngine>) {
self.sona_engine = Some(engine);
}
pub fn sona_engine(&self) -> Option<&Arc<sona::SonaEngine>> {
self.sona_engine.as_ref()
}
pub fn for_space(space_dir: PathBuf) -> Self {
let memory_dir = space_dir.join("memory");
let state_store = Arc::new(StateStore::new(memory_dir).unwrap_or_else(|_| {
StateStore::new(std::env::temp_dir().join("oxios-memory")).unwrap()
}));
Self::new(state_store)
}
pub fn set_hnsw_index(&self, index: Arc<HnswMemoryIndex>) {
*self.hnsw_index.write() = Some(index);
}
fn git_commit(&self, rel_path: &str, message: &str) {
if let Some(ref gl) = self.git_layer {
if gl.is_enabled() {
let _ = gl.commit_file(rel_path, message);
}
}
}
pub fn with_max_recall(mut self, n: usize) -> Self {
self.max_recall = n;
self
}
pub fn with_config(mut self, config: &crate::config::MemoryConfig) -> Self {
self.max_recall = config.max_recall;
self
}
pub fn vector_index_size(&self) -> usize {
self.vector_index.read().len()
}
pub fn effective_importance(entry: &MemoryEntry) -> f32 {
let access_boost = (1.0_f32 + entry.access_count as f32).ln();
entry.importance * (1.0 + access_boost)
}
pub async fn curate(&self, budget: &MemoryBudget) -> Result<CurationReport> {
let mut report = CurationReport::default();
for mt in &[
MemoryType::Conversation,
MemoryType::Session,
MemoryType::Fact,
MemoryType::Episode,
MemoryType::Knowledge,
] {
let entries = self.list(*mt, budget.max_per_type * 2).await?;
if entries.len() <= budget.max_per_type {
continue;
}
let total_count = entries.len();
let mut scored: Vec<_> = entries
.into_iter()
.map(|e| (e.id.clone(), e.memory_type, Self::effective_importance(&e)))
.collect();
scored.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
let to_remove = scored.len() - budget.max_per_type;
for (id, memory_type, score) in scored.into_iter().take(to_remove) {
report.candidates_for_removal.push(CurationCandidate {
id,
memory_type,
effective_importance: score,
});
}
report.total_before += total_count;
}
for candidate in &report.candidates_for_removal {
if self
.forget(&candidate.id, candidate.memory_type)
.await
.is_ok()
{
report.removed += 1;
}
}
report.total_after = report.total_before - report.removed;
Ok(report)
}
pub fn spawn_curation_task(self: &Arc<Self>, budget: MemoryBudget) {
let mgr = Arc::clone(self);
tokio::spawn(async move {
match mgr.curate(&budget).await {
Ok(report) => {
if report.removed > 0 {
tracing::info!(
removed = report.removed,
candidates = report.candidates_for_removal.len(),
"Memory curation complete"
);
}
}
Err(e) => {
tracing::warn!(error = %e, "Memory curation failed");
}
}
});
}
}
pub(crate) fn extract_keywords(query: &str) -> Vec<String> {
const STOP_WORDS: &[&str] = &[
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
"do", "does", "did", "will", "would", "could", "should", "may", "might", "can", "shall",
"to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "through",
"during", "before", "after", "above", "below", "between", "out", "off", "over", "under",
"again", "further", "then", "once", "and", "but", "or", "nor", "not", "so", "yet", "both",
"either", "neither", "each", "every", "all", "any", "few", "more", "most", "other", "some",
"such", "no", "only", "own", "same", "than", "too", "very", "just", "because", "if",
"when", "where", "how", "what", "which", "who", "whom", "this", "that", "these", "those",
"i", "me", "my", "we", "our", "you", "your", "he", "him", "his", "she", "her", "it", "its",
"they", "them", "their",
];
query
.split_whitespace()
.map(|w| {
let w = w.trim_end_matches(|c: char| c.is_ascii_punctuation());
w.to_lowercase()
})
.filter(|w| w.len() > 2 && !STOP_WORDS.contains(&w.as_str()))
.collect()
}
pub(crate) fn dedup_by_id(entries: &mut Vec<MemoryEntry>) {
let mut seen = std::collections::HashSet::new();
entries.retain(|e| seen.insert(e.id.clone()));
}
pub mod auto_classify;
pub mod auto_memory_bridge;
mod auto_protect;
mod budget;
#[cfg(feature = "sqlite-memory")]
pub mod cache;
mod chunking;
mod compaction;
#[cfg(feature = "sqlite-memory")]
pub mod database;
mod decay;
pub mod dream;
pub mod embedding_cache;
pub mod flash_attention;
mod graph;
mod hnsw;
pub mod hyperbolic;
#[cfg(feature = "sqlite-memory")]
pub mod migration;
pub mod normalizer;
mod proactive;
mod root_index;
#[cfg(feature = "sqlite-memory")]
pub mod search;
pub mod sona;
#[cfg(feature = "sqlite-memory")]
pub mod sqlite_store;
pub(crate) mod store;
pub use auto_classify::AutoClassifier;
pub use compaction::CompactionTree;
pub use decay::DecayEngine;
pub use dream::{DreamCheckpoint, DreamProcess, DreamReport};
pub use proactive::ProactiveRecall;
pub use proactive::RecallTiming;
pub use root_index::{HistoricalPeriod, RootEntry, RootIndex, TopicEntry};
pub use embedding_cache::{CacheStats, EmbeddingCache};
pub use store::SemanticHit;
pub use chunking::{chunk_fixed, chunk_paragraphs, ChunkConfig, TextChunk};
pub use graph::MemoryGraph;
pub use hnsw::HnswIndex;
pub use normalizer::{cosine_similarity_f32, l2_normalize_f32, l2_normalize_f64};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_type_category() {
assert_eq!(MemoryType::Conversation.category(), "memory/conversations");
assert_eq!(MemoryType::Fact.category(), "memory/facts");
assert_eq!(MemoryType::Knowledge.category(), "memory/knowledge");
}
#[test]
fn test_extract_keywords() {
let kw = extract_keywords("How do I implement a Rust agent system?");
assert!(kw.contains(&"implement".to_string()));
assert!(kw.contains(&"rust".to_string()));
assert!(kw.contains(&"agent".to_string()));
assert!(kw.contains(&"system".to_string()));
assert!(!kw.contains(&"how".to_string()));
assert!(!kw.contains(&"do".to_string()));
}
#[test]
fn test_dedup_by_id() {
let mut entries = vec![
make_entry("a", MemoryType::Fact),
make_entry("b", MemoryType::Fact),
make_entry("a", MemoryType::Episode), ];
dedup_by_id(&mut entries);
assert_eq!(entries.len(), 2);
}
#[test]
fn test_blend_into_prompt_empty() {
let mgr = MemoryManager::new(Arc::new(
StateStore::new(std::env::temp_dir().join("test")).unwrap(),
));
let result = mgr.blend_into_prompt(&[], "You are an agent.");
assert_eq!(result, "You are an agent.");
}
#[test]
fn test_blend_into_prompt_with_memories() {
let mgr = MemoryManager::new(Arc::new(
StateStore::new(std::env::temp_dir().join("test")).unwrap(),
));
let memories = vec![make_entry("test", MemoryType::Fact)];
let result = mgr.blend_into_prompt(&memories, "You are an agent.");
assert!(result.contains("## Relevant Memory"));
assert!(result.contains("[fact]"));
}
#[test]
fn test_text_vector_cosine_similarity() {
let v1 = TextVector::from_text("fix the null pointer error in main.rs");
let v2 = TextVector::from_text("null pointer error found in rust code");
let v3 = TextVector::from_text("update the documentation for deployment");
assert!(
v1.cosine_similarity(&v2) > 0.3,
"Similar texts should have > 0.3 similarity"
);
assert!(
v1.cosine_similarity(&v3) < 0.2,
"Different texts should have < 0.2 similarity"
);
}
#[test]
fn test_text_vector_multilingual() {
let v1 = TextVector::from_text("main.rs 파일의 null pointer 에러 수정");
let v2 = TextVector::from_text("null pointer 오류를 수정했습니다");
let v3 = TextVector::from_text("문서 업데이트 배포 가이드");
assert!(v1.cosine_similarity(&v2) > 0.1, "Mixed script similarity");
assert!(v1.cosine_similarity(&v3) < 0.1, "Different topics");
}
#[test]
fn test_text_vector_empty() {
let v1 = TextVector::from_text("");
let v2 = TextVector::from_text("hello");
assert_eq!(v1.cosine_similarity(&v2), 0.0);
}
#[test]
fn test_text_vector_identical() {
let v1 = TextVector::from_text("rust programming language");
let v2 = TextVector::from_text("rust programming language");
let sim = v1.cosine_similarity(&v2);
assert!(
(sim - 1.0).abs() < 1e-9,
"Identical texts should have similarity ~1.0, got {}",
sim
);
}
#[test]
fn test_tokenize_multilingual() {
let terms = TextVector::tokenize("main.rs 파일의 버그를 수정");
assert!(!terms.is_empty(), "Non-ASCII text should produce tokens");
}
#[tokio::test]
async fn test_vector_search_over_keyword_fallback() {
let temp_dir = tempfile::tempdir().unwrap();
let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
let mgr = MemoryManager::new(store.clone());
let entry1 = make_entry_with_content(
"vec-test-1",
MemoryType::Fact,
"Rust is a systems programming language focused on safety",
);
let entry2 = make_entry_with_content(
"vec-test-2",
MemoryType::Fact,
"Python is great for machine learning and data science",
);
mgr.remember(entry1).await.unwrap();
mgr.remember(entry2).await.unwrap();
let results = mgr
.search("systems programming with rust", None, 5)
.await
.unwrap();
assert!(!results.is_empty(), "Vector search should find results");
assert_eq!(
results[0].id, "vec-test-1",
"Should find the Rust entry first"
);
}
#[tokio::test]
async fn test_rebuild_index() {
let temp_dir = tempfile::tempdir().unwrap();
let store = Arc::new(StateStore::new(temp_dir.path().to_path_buf()).unwrap());
let mgr = MemoryManager::new(store.clone());
let entry = make_entry_with_content(
"rebuild-test-1",
MemoryType::Fact,
"memory for rebuild test",
);
store
.save_json("memory/facts", "rebuild-test-1", &entry)
.await
.unwrap();
assert_eq!(mgr.vector_index.read().len(), 0);
mgr.rebuild_index().await.unwrap();
assert_eq!(mgr.vector_index.read().len(), 1);
assert!(mgr.vector_index.read().contains_key("rebuild-test-1"));
}
fn make_entry(id: &str, ty: MemoryType) -> MemoryEntry {
make_entry_with_content(id, ty, &format!("Test content for {}", id))
}
fn make_entry_with_content(id: &str, ty: MemoryType, content: &str) -> MemoryEntry {
MemoryEntry {
id: id.to_string(),
memory_type: ty,
tier: MemoryTier::Warm,
content: content.to_string(),
content_hash: 0,
source: "test".to_string(),
session_id: None,
tags: vec![],
importance: 0.5,
pinned: false,
protection: ProtectionLevel::None,
auto_classified: false,
session_appearances: 0,
user_corrected: false,
seen_in_sessions: vec![],
created_at: Utc::now(),
accessed_at: Utc::now(),
modified_at: Utc::now(),
access_count: 0,
decay_score: 1.0,
compaction_level: 0,
compacted_from: vec![],
related_ids: vec![],
contradicts: None,
}
}
}