use crate::memory_system::{DecayPolicy, MemoryStore, TimeBasedDecay};
use crate::types::{Layer3Result, MemoryEntry, MemoryQuery, MemoryTier};
use async_trait::async_trait;
use parking_lot::RwLock;
use std::collections::VecDeque;
use std::sync::Arc;
#[allow(dead_code)]
pub struct WorkingMemory {
buffer: Arc<RwLock<VecDeque<MemoryEntry>>>,
max_size: usize,
#[allow(dead_code)]
decay_policy: Box<dyn DecayPolicy>,
}
impl WorkingMemory {
pub fn new(max_size: usize) -> Self {
Self {
buffer: Arc::new(RwLock::new(VecDeque::with_capacity(max_size))),
max_size,
decay_policy: Box::new(TimeBasedDecay::default()),
}
}
}
impl Default for WorkingMemory {
fn default() -> Self {
Self::new(100)
}
}
#[async_trait]
impl MemoryStore for WorkingMemory {
fn tier(&self) -> MemoryTier {
MemoryTier::Working
}
async fn store(&self, entry: MemoryEntry) -> Layer3Result<String> {
let mut buffer = self.buffer.write();
if buffer.len() >= self.max_size {
buffer.pop_front();
}
let id = entry.id.clone();
buffer.push_back(entry);
Ok(id)
}
async fn get(&self, id: &str) -> Layer3Result<Option<MemoryEntry>> {
let buffer = self.buffer.read();
Ok(buffer.iter().find(|e| e.id == id).cloned())
}
async fn delete(&self, id: &str) -> Layer3Result<bool> {
let mut buffer = self.buffer.write();
let len_before = buffer.len();
buffer.retain(|e| e.id != id);
Ok(buffer.len() < len_before)
}
async fn query(&self, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
let buffer = self.buffer.read();
let results: Vec<MemoryEntry> = buffer
.iter()
.filter(|e| {
if let Some(tier) = query.tier {
if e.tier != tier {
return false;
}
}
e.content.contains(&query.query)
})
.take(query.limit.unwrap_or(10))
.cloned()
.collect();
Ok(results)
}
async fn list(&self, limit: Option<usize>) -> Layer3Result<Vec<MemoryEntry>> {
let buffer = self.buffer.read();
Ok(buffer
.iter()
.take(limit.unwrap_or(usize::MAX))
.cloned()
.collect())
}
async fn clear(&self) -> Layer3Result<usize> {
let mut buffer = self.buffer.write();
let count = buffer.len();
buffer.clear();
Ok(count)
}
async fn count(&self) -> Layer3Result<usize> {
Ok(self.buffer.read().len())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_working_memory_store() {
let memory = WorkingMemory::new(10);
let entry = MemoryEntry {
id: "test-1".to_string(),
tier: MemoryTier::Working,
content: "test content".to_string(),
metadata: Default::default(),
created_at: chrono::Utc::now(),
last_accessed: chrono::Utc::now(),
access_count: 0,
importance: 0.5,
};
memory.store(entry).await.unwrap();
assert_eq!(memory.count().await.unwrap(), 1);
}
}