use std::collections::VecDeque;
use std::sync::Arc;
use async_trait::async_trait;
use rucora_core::{
error::MemoryError,
memory::{Memory, MemoryItem, MemoryQuery},
};
use tokio::sync::RwLock;
const DEFAULT_MAX_CAPACITY: usize = 1000;
#[derive(Default)]
pub struct InMemoryMemory {
items: Arc<RwLock<VecDeque<MemoryItem>>>,
max_capacity: usize,
}
impl InMemoryMemory {
pub fn new() -> Self {
Self {
items: Arc::new(RwLock::new(VecDeque::with_capacity(64))),
max_capacity: DEFAULT_MAX_CAPACITY,
}
}
pub fn with_capacity(max_capacity: usize) -> Self {
Self {
items: Arc::new(RwLock::new(VecDeque::with_capacity(64))),
max_capacity,
}
}
pub async fn len(&self) -> usize {
self.items.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.items.read().await.is_empty()
}
pub async fn clear(&self) {
self.items.write().await.clear();
}
fn enforce_capacity(items: &mut VecDeque<MemoryItem>, max: usize) {
if max == 0 {
return;
}
while items.len() > max {
items.pop_front();
}
}
}
#[async_trait]
impl Memory for InMemoryMemory {
async fn add(&self, item: MemoryItem) -> Result<(), MemoryError> {
let mut items = self.items.write().await;
if let Some(existing) = items.iter_mut().find(|x| x.id == item.id) {
*existing = item;
return Ok(());
}
items.push_back(item);
Self::enforce_capacity(&mut items, self.max_capacity);
Ok(())
}
async fn query(&self, query: MemoryQuery) -> Result<Vec<MemoryItem>, MemoryError> {
let items = self.items.read().await;
let limit = if query.limit == 0 {
usize::MAX
} else {
query.limit
};
let needle = query.text.to_lowercase();
if needle.is_empty() {
return Ok(items.iter().rev().take(limit).cloned().collect());
}
let mut matches: Vec<MemoryItem> = items
.iter()
.filter(|item| {
if item.id.to_lowercase().contains(&needle) {
return true;
}
if item.content.to_lowercase().contains(&needle) {
return true;
}
if let Some(meta) = &item.metadata {
let meta_str = meta.to_string().to_lowercase();
if meta_str.contains(&needle) {
return true;
}
}
false
})
.cloned()
.collect();
matches.sort_by(|a, b| a.id.cmp(&b.id));
Ok(matches.into_iter().take(limit).collect())
}
}