use agentlib_core::{
MemoryProvider, MemoryReadOptions, MemoryWriteOptions, ModelMessage, Role, async_trait,
};
use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct BufferMemory {
max_messages: usize,
store: Arc<Mutex<HashMap<String, Vec<ModelMessage>>>>,
}
impl BufferMemory {
pub fn new(max_messages: usize) -> Self {
Self {
max_messages,
store: Arc::new(Mutex::new(HashMap::new())),
}
}
}
#[async_trait]
impl MemoryProvider for BufferMemory {
async fn read(&self, options: MemoryReadOptions) -> Result<Vec<ModelMessage>> {
let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
let store = self.store.lock().await;
let messages = store.get(&session_id).cloned().unwrap_or_default();
Ok(messages)
}
async fn write(&self, messages: Vec<ModelMessage>, options: MemoryWriteOptions) -> Result<()> {
let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
let mut to_store: Vec<ModelMessage> = messages
.into_iter()
.filter(|m| m.role != Role::System)
.collect();
if to_store.len() > self.max_messages {
to_store = to_store.split_off(to_store.len() - self.max_messages);
}
let mut store = self.store.lock().await;
store.insert(session_id, to_store);
Ok(())
}
async fn clear(&self, session_id: Option<&str>) -> Result<()> {
let mut store = self.store.lock().await;
if let Some(sid) = session_id {
store.remove(sid);
} else {
store.clear();
}
Ok(())
}
}