use crate::message::Message;
use crate::error::Result;
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
#[async_trait]
pub trait Memory: Send + Sync {
async fn add(&mut self, message: Message) -> Result<()>;
async fn get_all(&self) -> Result<Vec<Message>>;
async fn get_recent(&self, count: usize) -> Result<Vec<Message>>;
async fn clear(&mut self) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct MemoryStore {
messages: Arc<RwLock<Vec<Message>>>,
max_size: Option<usize>,
}
impl MemoryStore {
pub fn new() -> Self {
Self {
messages: Arc::new(RwLock::new(Vec::new())),
max_size: None,
}
}
pub fn with_max_size(max_size: usize) -> Self {
Self {
messages: Arc::new(RwLock::new(Vec::new())),
max_size: Some(max_size),
}
}
}
impl Default for MemoryStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Memory for MemoryStore {
async fn add(&mut self, message: Message) -> Result<()> {
let mut messages = self.messages.write().await;
messages.push(message);
if let Some(max_size) = self.max_size {
if messages.len() > max_size {
let excess = messages.len() - max_size;
messages.drain(0..excess);
}
}
Ok(())
}
async fn get_all(&self) -> Result<Vec<Message>> {
let messages = self.messages.read().await;
Ok(messages.clone())
}
async fn get_recent(&self, count: usize) -> Result<Vec<Message>> {
let messages = self.messages.read().await;
let start = messages.len().saturating_sub(count);
Ok(messages[start..].to_vec())
}
async fn clear(&mut self) -> Result<()> {
let mut messages = self.messages.write().await;
messages.clear();
Ok(())
}
}