use std::collections::VecDeque;
#[derive(Debug, Clone, PartialEq)]
pub struct MemoryEntry {
pub content: String,
pub score: f64,
pub timestamp: usize,
}
impl MemoryEntry {
pub fn new(content: String, score: f64, timestamp: usize) -> Self {
Self {
content,
score,
timestamp,
}
}
}
#[derive(Debug, Clone)]
pub struct HotStore {
pub capacity: usize,
entries: VecDeque<MemoryEntry>,
}
impl HotStore {
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "HotStore capacity must be > 0");
Self {
capacity,
entries: VecDeque::with_capacity(capacity),
}
}
pub fn push(&mut self, entry: MemoryEntry) {
if self.entries.len() >= self.capacity {
self.entries.pop_front();
}
self.entries.push_back(entry);
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn entries(&self) -> &VecDeque<MemoryEntry> {
&self.entries
}
pub fn relevant(&self, query: &str, top_n: usize) -> Vec<&MemoryEntry> {
let query_tokens: std::collections::HashSet<String> = query
.split_whitespace()
.map(|w| w.to_ascii_lowercase())
.collect();
let mut scored: Vec<(&MemoryEntry, usize)> = self
.entries
.iter()
.map(|e| {
let entry_tokens: std::collections::HashSet<String> = e
.content
.split_whitespace()
.map(|w| w.to_ascii_lowercase())
.collect();
let overlap = query_tokens.intersection(&entry_tokens).count();
(e, overlap)
})
.collect();
scored.sort_by_key(|b| std::cmp::Reverse(b.1));
scored.into_iter().take(top_n).map(|(e, _)| e).collect()
}
pub fn drain_to_cold(&mut self, cold: &mut ColdStore, threshold: f64) {
let mut retain = VecDeque::new();
while let Some(entry) = self.entries.pop_front() {
if entry.score >= threshold {
cold.promote(entry);
} else {
retain.push_back(entry);
}
}
self.entries = retain;
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn snapshot(&self) -> Vec<String> {
self.entries
.iter()
.rev()
.map(|e| e.content.clone())
.collect()
}
}
#[derive(Debug, Clone, Default)]
pub struct ColdStore {
entries: Vec<MemoryEntry>,
}
impl ColdStore {
pub fn promote(&mut self, entry: MemoryEntry) {
self.entries.push(entry);
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn all(&self) -> &[MemoryEntry] {
&self.entries
}
pub fn recall(&self, query: &str, top_n: usize) -> Vec<&MemoryEntry> {
let query_tokens: std::collections::HashSet<String> = query
.split_whitespace()
.map(|w| w.to_ascii_lowercase())
.collect();
let total = self.entries.len();
let mut scored: Vec<(&MemoryEntry, f64)> = self
.entries
.iter()
.enumerate()
.map(|(i, e)| {
let entry_tokens: std::collections::HashSet<String> = e
.content
.split_whitespace()
.map(|w| w.to_ascii_lowercase())
.collect();
let overlap = query_tokens.intersection(&entry_tokens).count() as f64;
let recency = (i + 1) as f64 / total as f64;
let blended = (e.score + overlap * 0.1) * (0.7 + 0.3 * recency);
(e, blended)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().take(top_n).map(|(e, _)| e).collect()
}
pub fn snapshot(&self) -> Vec<String> {
self.entries
.iter()
.rev()
.map(|e| e.content.clone())
.collect()
}
}