use crate::config::StmConfig;
use crate::error::{Error, Result};
use crate::types::{MemoryEntry, MemoryId, MemoryQuery, MemoryResult, MemorySource, Timestamp};
use std::collections::HashMap;
pub struct ShortTermMemory {
entries: HashMap<MemoryId, MemoryEntry>,
access_order: Vec<MemoryId>,
config: StmConfig,
last_decay: Timestamp,
memory_usage: usize,
}
impl ShortTermMemory {
pub fn new(config: StmConfig) -> Self {
Self {
entries: HashMap::new(),
access_order: Vec::new(),
config,
last_decay: Timestamp::now(),
memory_usage: 0,
}
}
pub fn store(&mut self, entry: MemoryEntry) -> Result<MemoryId> {
let id = entry.id.clone();
let entry_size = entry.size_bytes();
if self.entries.len() >= self.config.max_entries {
self.prune_one()?;
}
while self.memory_usage + entry_size > self.config.max_memory_bytes {
if self.entries.is_empty() {
return Err(Error::capacity(
"STM memory",
entry_size,
self.config.max_memory_bytes,
));
}
self.prune_one()?;
}
self.memory_usage += entry_size;
self.entries.insert(id.clone(), entry);
self.access_order.push(id.clone());
Ok(id)
}
pub fn get(&self, id: &MemoryId) -> Result<Option<MemoryEntry>> {
Ok(self.entries.get(id).cloned())
}
pub fn get_and_access(&mut self, id: &MemoryId) -> Result<Option<MemoryEntry>> {
if let Some(entry) = self.entries.get_mut(id) {
entry.metadata.record_access();
if let Some(pos) = self.access_order.iter().position(|x| x == id) {
self.access_order.remove(pos);
self.access_order.push(id.clone());
}
Ok(Some(entry.clone()))
} else {
Ok(None)
}
}
pub fn remove(&mut self, id: &MemoryId) -> Result<()> {
if let Some(entry) = self.entries.remove(id) {
self.memory_usage = self.memory_usage.saturating_sub(entry.size_bytes());
self.access_order.retain(|x| x != id);
}
Ok(())
}
pub fn query(&self, query: &MemoryQuery) -> Result<Vec<MemoryResult>> {
let mut results = Vec::new();
for entry in self.entries.values() {
if !self.matches_query(entry, query) {
continue;
}
let relevance = self.calculate_relevance(entry, query);
results.push(MemoryResult {
entry: entry.clone(),
relevance,
source: MemorySource::ShortTerm,
});
}
results.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
if let Some(limit) = query.limit {
results.truncate(limit);
}
Ok(results)
}
pub fn get_recent(&self, count: usize) -> Result<Vec<MemoryResult>> {
let mut results: Vec<_> = self.access_order.iter().rev().take(count).collect();
results.reverse();
let mut memory_results = Vec::new();
for id in results {
if let Some(entry) = self.entries.get(id) {
memory_results.push(MemoryResult {
entry: entry.clone(),
relevance: entry.metadata.attention,
source: MemorySource::ShortTerm,
});
}
}
Ok(memory_results)
}
pub fn decay(&mut self) -> Result<()> {
let now = Timestamp::now();
let elapsed_secs = (now.0.saturating_sub(self.last_decay.0)) / 1_000_000;
if elapsed_secs < self.config.decay_interval.as_secs() {
return Ok(());
}
let decay_factor = self.config.decay_factor;
for entry in self.entries.values_mut() {
entry.metadata.decay(decay_factor);
}
self.last_decay = now;
Ok(())
}
pub fn prune(&mut self) -> Result<usize> {
let threshold = self.config.min_attention_threshold;
let to_remove: Vec<MemoryId> = self
.entries
.iter()
.filter(|(_, e)| e.metadata.attention < threshold && !e.metadata.consolidated)
.map(|(id, _)| id.clone())
.collect();
let count = to_remove.len();
for id in to_remove {
self.remove(&id)?;
}
Ok(count)
}
fn prune_one(&mut self) -> Result<()> {
let to_remove = self
.entries
.iter()
.filter(|(_, e)| !e.metadata.consolidated)
.min_by(|(_, a), (_, b)| {
a.metadata
.attention
.partial_cmp(&b.metadata.attention)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(id, _)| id.clone());
if let Some(id) = to_remove {
self.remove(&id)?;
}
Ok(())
}
pub fn get_consolidation_candidates(&self, importance_threshold: f32) -> Vec<&MemoryEntry> {
self.entries
.values()
.filter(|e| {
e.metadata.importance >= importance_threshold
&& !e.metadata.consolidated
&& e.metadata.access_count >= 2
})
.collect()
}
pub fn mark_consolidated(&mut self, id: &MemoryId) -> Result<()> {
if let Some(entry) = self.entries.get_mut(id) {
entry.metadata.consolidated = true;
}
Ok(())
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn memory_usage(&self) -> usize {
self.memory_usage
}
pub fn clear(&mut self) -> Result<()> {
self.entries.clear();
self.access_order.clear();
self.memory_usage = 0;
Ok(())
}
fn matches_query(&self, entry: &MemoryEntry, query: &MemoryQuery) -> bool {
if let Some(ref entry_type) = query.entry_type {
if &entry.entry_type != entry_type {
return false;
}
}
if let Some(min_importance) = query.min_importance {
if entry.metadata.importance < min_importance {
return false;
}
}
if let Some(after) = query.after {
if entry.metadata.created_at < after {
return false;
}
}
if let Some(before) = query.before {
if entry.metadata.created_at > before {
return false;
}
}
if !query.tags.is_empty() {
let has_tag = query.tags.iter().any(|qt| entry.tags.contains(qt));
if !has_tag {
return false;
}
}
true
}
fn calculate_relevance(&self, entry: &MemoryEntry, query: &MemoryQuery) -> f32 {
let mut score = 0.0;
score += entry.metadata.attention * 0.3;
score += entry.metadata.importance * 0.2;
let age_secs = entry.metadata.created_at.age_secs();
let recency = 1.0 / (1.0 + (age_secs as f32 / 3600.0)); score += recency * 0.2;
if !query.tags.is_empty() {
let matching_tags = query
.tags
.iter()
.filter(|qt| entry.tags.contains(qt))
.count();
let tag_score = matching_tags as f32 / query.tags.len() as f32;
score += tag_score * 0.15;
}
if let (Some(ref query_emb), Some(ref entry_emb)) = (&query.embedding, &entry.embedding) {
let similarity = query_emb.cosine_similarity(entry_emb);
score += similarity * 0.15;
}
if let Some(ref text) = query.text {
let text_lower = text.to_lowercase();
let data_str = entry.data.to_string().to_lowercase();
let type_str = entry.entry_type.to_lowercase();
if data_str.contains(&text_lower) || type_str.contains(&text_lower) {
score += 0.2;
}
}
score.min(1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_entry(name: &str) -> MemoryEntry {
MemoryEntry::new("test", serde_json::json!({"name": name}))
}
#[test]
fn test_store_retrieve() {
let config = StmConfig::default();
let mut stm = ShortTermMemory::new(config);
let entry = make_entry("test1");
let id = stm.store(entry).unwrap();
let retrieved = stm.get(&id).unwrap();
assert!(retrieved.is_some());
}
#[test]
fn test_capacity_limit() {
let config = StmConfig {
max_entries: 2,
..Default::default()
};
let mut stm = ShortTermMemory::new(config);
stm.store(make_entry("test1")).unwrap();
stm.store(make_entry("test2")).unwrap();
stm.store(make_entry("test3")).unwrap();
assert!(stm.len() <= 2);
}
#[test]
fn test_query_by_type() {
let config = StmConfig::default();
let mut stm = ShortTermMemory::new(config);
stm.store(MemoryEntry::new("sensor", serde_json::json!({})))
.unwrap();
stm.store(MemoryEntry::new("event", serde_json::json!({})))
.unwrap();
let query = MemoryQuery::entry_type("sensor");
let results = stm.query(&query).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].entry.entry_type, "sensor");
}
#[test]
fn test_decay() {
let config = StmConfig {
decay_interval: std::time::Duration::from_secs(0), decay_factor: 0.5,
..Default::default()
};
let mut stm = ShortTermMemory::new(config);
let entry = make_entry("test1");
let id = stm.store(entry).unwrap();
let before = stm.get(&id).unwrap().unwrap().metadata.attention;
stm.decay().unwrap();
let after = stm.get(&id).unwrap().unwrap().metadata.attention;
assert!(after < before);
}
#[test]
fn test_get_recent() {
let config = StmConfig::default();
let mut stm = ShortTermMemory::new(config);
stm.store(make_entry("test1")).unwrap();
stm.store(make_entry("test2")).unwrap();
stm.store(make_entry("test3")).unwrap();
let recent = stm.get_recent(2).unwrap();
assert_eq!(recent.len(), 2);
}
}