use std::collections::HashMap;
use std::sync::Mutex;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
pub trait Memory: Send + Sync {
fn store(&self, entry: MemoryEntry) -> anyhow::Result<()>;
fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>;
fn forget(&self, id: &str) -> anyhow::Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub id: String,
pub content: String,
pub memory_type: MemoryType,
pub relevance: f32,
pub metadata: HashMap<String, String>,
pub created_at: DateTime<Utc>,
pub last_accessed: DateTime<Utc>,
}
impl MemoryEntry {
pub fn new(content: impl Into<String>, memory_type: MemoryType) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
content: content.into(),
memory_type,
relevance: 0.0,
metadata: HashMap::new(),
created_at: now,
last_accessed: now,
}
}
#[must_use]
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryType {
ShortTerm,
LongTerm,
Entity,
}
pub struct InMemoryMemory {
entries: Mutex<Vec<MemoryEntry>>,
}
impl InMemoryMemory {
pub fn new() -> Self {
Self {
entries: Mutex::new(Vec::new()),
}
}
}
impl Default for InMemoryMemory {
fn default() -> Self {
Self::new()
}
}
impl Memory for InMemoryMemory {
fn store(&self, entry: MemoryEntry) -> anyhow::Result<()> {
let mut entries = self.entries.lock().map_err(|e| anyhow::anyhow!("lock poisoned: {e}"))?;
entries.push(entry);
Ok(())
}
fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
let entries = self.entries.lock().map_err(|e| anyhow::anyhow!("lock poisoned: {e}"))?;
let query_words: Vec<String> = query.split_whitespace().map(str::to_lowercase).collect();
if query_words.is_empty() {
return Ok(Vec::new());
}
let mut scored: Vec<(f32, MemoryEntry)> = entries
.iter()
.filter_map(|entry| {
let content_lower = entry.content.to_lowercase();
let matching = query_words.iter().filter(|w| content_lower.contains(w.as_str())).count();
if matching > 0 {
#[allow(clippy::cast_precision_loss)]
let score = matching as f32 / query_words.len() as f32;
let mut recalled = entry.clone();
recalled.relevance = score;
Some((score, recalled))
} else {
None
}
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
Ok(scored.into_iter().map(|(_, entry)| entry).collect())
}
fn forget(&self, id: &str) -> anyhow::Result<()> {
let mut entries = self.entries.lock().map_err(|e| anyhow::anyhow!("lock poisoned: {e}"))?;
entries.retain(|e| e.id != id);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn memory_entry_creation_and_serialization() {
let entry = MemoryEntry::new("test content", MemoryType::ShortTerm).with_metadata("key", "value");
assert_eq!(entry.content, "test content");
assert_eq!(entry.memory_type, MemoryType::ShortTerm);
assert_eq!(entry.metadata.get("key"), Some(&"value".to_string()));
assert_eq!(entry.relevance, 0.0);
let json = serde_json::to_string(&entry).expect("serialize");
let parsed: MemoryEntry = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.content, "test content");
assert_eq!(parsed.memory_type, MemoryType::ShortTerm);
assert_eq!(parsed.metadata.get("key"), Some(&"value".to_string()));
}
#[test]
fn in_memory_store_and_recall() {
let mem = InMemoryMemory::new();
mem.store(MemoryEntry::new("rust programming language", MemoryType::LongTerm)).expect("store");
mem.store(MemoryEntry::new("python data science", MemoryType::LongTerm)).expect("store");
let results = mem.recall("rust", 10).expect("recall");
assert_eq!(results.len(), 1);
assert!(results[0].content.contains("rust"));
}
#[test]
fn recall_keyword_matching_returns_relevant() {
let mem = InMemoryMemory::new();
mem.store(MemoryEntry::new("the quick brown fox jumps over the lazy dog", MemoryType::ShortTerm))
.expect("store");
mem.store(MemoryEntry::new("hello world program in rust", MemoryType::ShortTerm))
.expect("store");
mem.store(MemoryEntry::new("the fox is quick and clever", MemoryType::ShortTerm))
.expect("store");
let results = mem.recall("quick fox", 5).expect("recall");
assert_eq!(results.len(), 2);
assert!(results[0].relevance >= results[1].relevance);
assert!(results[0].content.contains("quick"));
}
#[test]
fn recall_no_matches_returns_empty() {
let mem = InMemoryMemory::new();
mem.store(MemoryEntry::new("rust programming", MemoryType::ShortTerm)).expect("store");
let results = mem.recall("javascript", 10).expect("recall");
assert!(results.is_empty());
}
#[test]
fn forget_removes_entry() {
let mem = InMemoryMemory::new();
let entry = MemoryEntry::new("to be forgotten", MemoryType::ShortTerm);
let id = entry.id.clone();
mem.store(entry).expect("store");
assert_eq!(mem.recall("forgotten", 10).expect("recall").len(), 1);
mem.forget(&id).expect("forget");
assert!(mem.recall("forgotten", 10).expect("recall").is_empty());
}
#[test]
fn memory_type_variants_serialize_correctly() {
let types = [MemoryType::ShortTerm, MemoryType::LongTerm, MemoryType::Entity];
for mt in &types {
let json = serde_json::to_string(mt).expect("serialize");
let parsed: MemoryType = serde_json::from_str(&json).expect("deserialize");
assert_eq!(*mt, parsed);
}
let json = serde_json::to_string(&MemoryType::ShortTerm).expect("serialize");
assert!(json.contains("ShortTerm"));
let json = serde_json::to_string(&MemoryType::LongTerm).expect("serialize");
assert!(json.contains("LongTerm"));
let json = serde_json::to_string(&MemoryType::Entity).expect("serialize");
assert!(json.contains("Entity"));
}
}