use std::collections::HashMap;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::config::Config;
use crate::error::{Result, ZeptoError};
fn now_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub key: String,
pub value: String,
pub category: String,
pub created_at: u64,
pub last_accessed: u64,
pub access_count: u64,
pub tags: Vec<String>,
}
#[derive(Debug)]
pub struct LongTermMemory {
entries: HashMap<String, MemoryEntry>,
storage_path: PathBuf,
}
impl LongTermMemory {
pub fn new() -> Result<Self> {
let path = Config::dir().join("memory").join("longterm.json");
Self::with_path(path)
}
pub fn with_path(path: PathBuf) -> Result<Self> {
let entries = Self::load(&path)?;
Ok(Self {
entries,
storage_path: path,
})
}
pub fn set(&mut self, key: &str, value: &str, category: &str, tags: Vec<String>) -> Result<()> {
let now = now_timestamp();
if let Some(existing) = self.entries.get_mut(key) {
existing.value = value.to_string();
existing.category = category.to_string();
existing.tags = tags;
existing.last_accessed = now;
} else {
let entry = MemoryEntry {
key: key.to_string(),
value: value.to_string(),
category: category.to_string(),
created_at: now,
last_accessed: now,
access_count: 0,
tags,
};
self.entries.insert(key.to_string(), entry);
}
self.save()
}
pub fn get(&mut self, key: &str) -> Option<&MemoryEntry> {
let now = now_timestamp();
if let Some(entry) = self.entries.get_mut(key) {
entry.last_accessed = now;
entry.access_count += 1;
}
self.entries.get(key)
}
pub fn get_readonly(&self, key: &str) -> Option<&MemoryEntry> {
self.entries.get(key)
}
pub fn delete(&mut self, key: &str) -> Result<bool> {
let existed = self.entries.remove(key).is_some();
if existed {
self.save()?;
}
Ok(existed)
}
pub fn search(&self, query: &str) -> Vec<&MemoryEntry> {
let query_lower = query.to_lowercase();
let mut results: Vec<&MemoryEntry> = self
.entries
.values()
.filter(|entry| {
entry.key.to_lowercase().contains(&query_lower)
|| entry.value.to_lowercase().contains(&query_lower)
|| entry.category.to_lowercase().contains(&query_lower)
|| entry
.tags
.iter()
.any(|tag| tag.to_lowercase().contains(&query_lower))
})
.collect();
results.sort_by(|a, b| {
let a_exact = a.key.to_lowercase() == query_lower;
let b_exact = b.key.to_lowercase() == query_lower;
match (a_exact, b_exact) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => b.access_count.cmp(&a.access_count),
}
});
results
}
pub fn list_by_category(&self, category: &str) -> Vec<&MemoryEntry> {
let cat_lower = category.to_lowercase();
let mut results: Vec<&MemoryEntry> = self
.entries
.values()
.filter(|entry| entry.category.to_lowercase() == cat_lower)
.collect();
results.sort_by(|a, b| b.last_accessed.cmp(&a.last_accessed));
results
}
pub fn list_all(&self) -> Vec<&MemoryEntry> {
let mut results: Vec<&MemoryEntry> = self.entries.values().collect();
results.sort_by(|a, b| b.last_accessed.cmp(&a.last_accessed));
results
}
pub fn count(&self) -> usize {
self.entries.len()
}
pub fn categories(&self) -> Vec<String> {
let mut cats: Vec<String> = self
.entries
.values()
.map(|e| e.category.clone())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
cats.sort();
cats
}
pub fn cleanup_least_used(&mut self, keep_count: usize) -> Result<usize> {
if self.entries.len() <= keep_count {
return Ok(0);
}
let mut entries_vec: Vec<(String, u64)> = self
.entries
.iter()
.map(|(k, v)| (k.clone(), v.access_count))
.collect();
entries_vec.sort_by(|a, b| a.1.cmp(&b.1));
let to_remove = entries_vec.len() - keep_count;
let keys_to_remove: Vec<String> = entries_vec
.into_iter()
.take(to_remove)
.map(|(k, _)| k)
.collect();
for key in &keys_to_remove {
self.entries.remove(key);
}
self.save()?;
Ok(to_remove)
}
pub fn summary(&self) -> String {
let count = self.count();
let cat_count = self.categories().len();
format!(
"Long-term memory: {} entries ({} categories)",
count, cat_count
)
}
pub fn save(&self) -> Result<()> {
if let Some(parent) = self.storage_path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
ZeptoError::Config(format!(
"Failed to create memory directory {}: {}",
parent.display(),
e
))
})?;
}
let json = serde_json::to_string_pretty(&self.entries).map_err(|e| {
ZeptoError::Config(format!("Failed to serialize long-term memory: {}", e))
})?;
std::fs::write(&self.storage_path, json).map_err(|e| {
ZeptoError::Config(format!(
"Failed to write long-term memory to {}: {}",
self.storage_path.display(),
e
))
})?;
Ok(())
}
fn load(path: &PathBuf) -> Result<HashMap<String, MemoryEntry>> {
if !path.exists() {
return Ok(HashMap::new());
}
let content = std::fs::read_to_string(path).map_err(|e| {
ZeptoError::Config(format!(
"Failed to read long-term memory from {}: {}",
path.display(),
e
))
})?;
if content.trim().is_empty() {
return Ok(HashMap::new());
}
let entries: HashMap<String, MemoryEntry> =
serde_json::from_str(&content).map_err(|e| {
ZeptoError::Config(format!("Failed to parse long-term memory JSON: {}", e))
})?;
Ok(entries)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn temp_memory() -> (LongTermMemory, TempDir) {
let dir = TempDir::new().expect("failed to create temp dir");
let path = dir.path().join("longterm.json");
let mem = LongTermMemory::with_path(path).expect("failed to create memory");
(mem, dir)
}
#[test]
fn test_memory_entry_creation() {
let entry = MemoryEntry {
key: "user:name".to_string(),
value: "Alice".to_string(),
category: "user".to_string(),
created_at: 1000,
last_accessed: 2000,
access_count: 5,
tags: vec!["identity".to_string()],
};
assert_eq!(entry.key, "user:name");
assert_eq!(entry.value, "Alice");
assert_eq!(entry.category, "user");
assert_eq!(entry.created_at, 1000);
assert_eq!(entry.last_accessed, 2000);
assert_eq!(entry.access_count, 5);
assert_eq!(entry.tags, vec!["identity"]);
}
#[test]
fn test_longterm_memory_new_empty() {
let (mem, _dir) = temp_memory();
assert_eq!(mem.count(), 0);
}
#[test]
fn test_set_and_get() {
let (mut mem, _dir) = temp_memory();
mem.set("user:name", "Alice", "user", vec!["identity".to_string()])
.unwrap();
let entry = mem.get("user:name").unwrap();
assert_eq!(entry.value, "Alice");
assert_eq!(entry.category, "user");
}
#[test]
fn test_set_upsert() {
let (mut mem, _dir) = temp_memory();
mem.set("user:name", "Alice", "user", vec![]).unwrap();
mem.set("user:name", "Bob", "user", vec!["updated".to_string()])
.unwrap();
let entry = mem.get("user:name").unwrap();
assert_eq!(entry.value, "Bob");
assert_eq!(entry.tags, vec!["updated"]);
assert_eq!(mem.count(), 1);
}
#[test]
fn test_get_updates_access_stats() {
let (mut mem, _dir) = temp_memory();
mem.set("key1", "value1", "test", vec![]).unwrap();
let before_access = mem.get_readonly("key1").unwrap().last_accessed;
let before_count = mem.get_readonly("key1").unwrap().access_count;
let _ = mem.get("key1");
let _ = mem.get("key1");
let entry = mem.get_readonly("key1").unwrap();
assert_eq!(entry.access_count, before_count + 2);
assert!(entry.last_accessed >= before_access);
}
#[test]
fn test_get_readonly_no_update() {
let (mut mem, _dir) = temp_memory();
mem.set("key1", "value1", "test", vec![]).unwrap();
let before = mem.get_readonly("key1").unwrap().access_count;
let _ = mem.get_readonly("key1");
let _ = mem.get_readonly("key1");
let after = mem.get_readonly("key1").unwrap().access_count;
assert_eq!(before, after);
}
#[test]
fn test_get_nonexistent() {
let (mut mem, _dir) = temp_memory();
assert!(mem.get("nonexistent").is_none());
}
#[test]
fn test_delete_existing() {
let (mut mem, _dir) = temp_memory();
mem.set("key1", "value1", "test", vec![]).unwrap();
assert_eq!(mem.count(), 1);
let existed = mem.delete("key1").unwrap();
assert!(existed);
assert_eq!(mem.count(), 0);
assert!(mem.get("key1").is_none());
}
#[test]
fn test_delete_nonexistent() {
let (mut mem, _dir) = temp_memory();
let existed = mem.delete("nonexistent").unwrap();
assert!(!existed);
}
#[test]
fn test_search_by_key() {
let (mut mem, _dir) = temp_memory();
mem.set("user:name", "Alice", "user", vec![]).unwrap();
mem.set("project:name", "ZeptoClaw", "project", vec![])
.unwrap();
let results = mem.search("user");
assert!(!results.is_empty());
assert!(results.iter().any(|e| e.key == "user:name"));
}
#[test]
fn test_search_by_value() {
let (mut mem, _dir) = temp_memory();
mem.set("key1", "Rust programming language", "fact", vec![])
.unwrap();
mem.set("key2", "Python scripting", "fact", vec![]).unwrap();
let results = mem.search("Rust");
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "key1");
}
#[test]
fn test_search_by_tag() {
let (mut mem, _dir) = temp_memory();
mem.set(
"key1",
"some value",
"test",
vec!["important".to_string(), "work".to_string()],
)
.unwrap();
mem.set("key2", "other value", "test", vec!["personal".to_string()])
.unwrap();
let results = mem.search("important");
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "key1");
}
#[test]
fn test_search_case_insensitive() {
let (mut mem, _dir) = temp_memory();
mem.set("Key1", "Hello World", "Test", vec!["MyTag".to_string()])
.unwrap();
assert!(!mem.search("hello").is_empty());
assert!(!mem.search("HELLO").is_empty());
assert!(!mem.search("key1").is_empty());
assert!(!mem.search("KEY1").is_empty());
assert!(!mem.search("mytag").is_empty());
assert!(!mem.search("test").is_empty());
}
#[test]
fn test_list_by_category() {
let (mut mem, _dir) = temp_memory();
mem.set("k1", "v1", "user", vec![]).unwrap();
mem.set("k2", "v2", "user", vec![]).unwrap();
mem.set("k3", "v3", "project", vec![]).unwrap();
let user_entries = mem.list_by_category("user");
assert_eq!(user_entries.len(), 2);
assert!(user_entries.iter().all(|e| e.category == "user"));
let project_entries = mem.list_by_category("project");
assert_eq!(project_entries.len(), 1);
}
#[test]
fn test_list_all() {
let (mut mem, _dir) = temp_memory();
mem.set("k1", "v1", "a", vec![]).unwrap();
mem.set("k2", "v2", "b", vec![]).unwrap();
mem.set("k3", "v3", "c", vec![]).unwrap();
let all = mem.list_all();
assert_eq!(all.len(), 3);
}
#[test]
fn test_count() {
let (mut mem, _dir) = temp_memory();
assert_eq!(mem.count(), 0);
mem.set("k1", "v1", "test", vec![]).unwrap();
assert_eq!(mem.count(), 1);
mem.set("k2", "v2", "test", vec![]).unwrap();
assert_eq!(mem.count(), 2);
mem.delete("k1").unwrap();
assert_eq!(mem.count(), 1);
}
#[test]
fn test_categories() {
let (mut mem, _dir) = temp_memory();
mem.set("k1", "v1", "user", vec![]).unwrap();
mem.set("k2", "v2", "fact", vec![]).unwrap();
mem.set("k3", "v3", "user", vec![]).unwrap();
mem.set("k4", "v4", "preference", vec![]).unwrap();
let cats = mem.categories();
assert_eq!(cats, vec!["fact", "preference", "user"]);
}
#[test]
fn test_cleanup_least_used() {
let (mut mem, _dir) = temp_memory();
mem.set("k1", "v1", "test", vec![]).unwrap();
mem.set("k2", "v2", "test", vec![]).unwrap();
mem.set("k3", "v3", "test", vec![]).unwrap();
let _ = mem.get("k3");
let _ = mem.get("k3");
let _ = mem.get("k3");
let _ = mem.get("k1");
let removed = mem.cleanup_least_used(1).unwrap();
assert_eq!(removed, 2);
assert_eq!(mem.count(), 1);
assert!(mem.get_readonly("k3").is_some());
}
#[test]
fn test_persistence_roundtrip() {
let dir = TempDir::new().expect("failed to create temp dir");
let path = dir.path().join("longterm.json");
{
let mut mem = LongTermMemory::with_path(path.clone()).unwrap();
mem.set("user:name", "Alice", "user", vec!["identity".to_string()])
.unwrap();
mem.set("fact:lang", "Rust", "fact", vec!["tech".to_string()])
.unwrap();
}
{
let mem = LongTermMemory::with_path(path).unwrap();
assert_eq!(mem.count(), 2);
let entry = mem.get_readonly("user:name").unwrap();
assert_eq!(entry.value, "Alice");
assert_eq!(entry.tags, vec!["identity"]);
let entry2 = mem.get_readonly("fact:lang").unwrap();
assert_eq!(entry2.value, "Rust");
}
}
#[test]
fn test_summary() {
let (mut mem, _dir) = temp_memory();
assert_eq!(mem.summary(), "Long-term memory: 0 entries (0 categories)");
mem.set("k1", "v1", "user", vec![]).unwrap();
mem.set("k2", "v2", "fact", vec![]).unwrap();
mem.set("k3", "v3", "fact", vec![]).unwrap();
assert_eq!(mem.summary(), "Long-term memory: 3 entries (2 categories)");
}
}