use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use crate::config::MemoryConfig;
use crate::error::Result;
use crate::llm::Message;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ConversationHistory {
messages: VecDeque<Message>,
max_messages: usize,
}
impl ConversationHistory {
pub fn new(max_messages: usize) -> Self {
Self {
messages: VecDeque::new(),
max_messages,
}
}
pub fn add(&mut self, message: Message) {
self.messages.push_back(message);
while self.messages.len() > self.max_messages {
if let Some(idx) = self
.messages
.iter()
.position(|m| m.role != crate::llm::Role::System)
{
self.messages.remove(idx);
} else {
break;
}
}
}
pub fn messages(&self) -> Vec<Message> {
self.messages.iter().cloned().collect()
}
pub fn clear(&mut self) {
self.messages.clear();
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
}
#[async_trait]
pub trait MemoryAdapter: Send + Sync {
async fn store_short_term(&mut self, message: Message) -> Result<()>;
async fn search_short_term(&self, query: &str, limit: usize) -> Result<Vec<Message>>;
async fn get_short_term(&self) -> Result<Vec<Message>>;
async fn clear_short_term(&mut self) -> Result<()>;
async fn store_long_term(
&mut self,
_text: &str,
_metadata: Option<serde_json::Value>,
) -> Result<()> {
Ok(()) }
async fn search_long_term(&self, _query: &str, _limit: usize) -> Result<Vec<String>> {
Ok(vec![]) }
}
pub struct InMemoryAdapter {
history: ConversationHistory,
}
impl InMemoryAdapter {
pub fn new(max_messages: usize) -> Self {
Self {
history: ConversationHistory::new(max_messages),
}
}
}
impl Default for InMemoryAdapter {
fn default() -> Self {
Self::new(100)
}
}
#[async_trait]
impl MemoryAdapter for InMemoryAdapter {
async fn store_short_term(&mut self, message: Message) -> Result<()> {
self.history.add(message);
Ok(())
}
async fn search_short_term(&self, query: &str, limit: usize) -> Result<Vec<Message>> {
let query_lower = query.to_lowercase();
let results: Vec<_> = self
.history
.messages()
.into_iter()
.filter(|m| m.content.to_lowercase().contains(&query_lower))
.take(limit)
.collect();
Ok(results)
}
async fn get_short_term(&self) -> Result<Vec<Message>> {
Ok(self.history.messages())
}
async fn clear_short_term(&mut self) -> Result<()> {
self.history.clear();
Ok(())
}
}
pub struct Memory {
adapter: Box<dyn MemoryAdapter>,
config: MemoryConfig,
}
impl Memory {
pub fn new(adapter: impl MemoryAdapter + 'static, config: MemoryConfig) -> Self {
Self {
adapter: Box::new(adapter),
config,
}
}
pub fn in_memory(config: MemoryConfig) -> Self {
Self::new(InMemoryAdapter::new(config.max_messages), config)
}
pub fn default_memory() -> Self {
Self::in_memory(MemoryConfig::default())
}
pub async fn store(&mut self, message: Message) -> Result<()> {
if self.config.use_short_term {
self.adapter.store_short_term(message).await?;
}
Ok(())
}
pub async fn history(&self) -> Result<Vec<Message>> {
self.adapter.get_short_term().await
}
pub async fn search(&self, query: &str, limit: usize) -> Result<Vec<Message>> {
self.adapter.search_short_term(query, limit).await
}
pub async fn clear(&mut self) -> Result<()> {
self.adapter.clear_short_term().await
}
pub fn config(&self) -> &MemoryConfig {
&self.config
}
}
impl Default for Memory {
fn default() -> Self {
Self::default_memory()
}
}
impl std::fmt::Debug for Memory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Memory")
.field("config", &self.config)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::Role;
#[test]
fn test_conversation_history() {
let mut history = ConversationHistory::new(5);
history.add(Message::user("Hello"));
history.add(Message::assistant("Hi there!"));
assert_eq!(history.len(), 2);
assert!(!history.is_empty());
}
#[test]
fn test_history_trimming() {
let mut history = ConversationHistory::new(3);
history.add(Message::system("You are helpful"));
history.add(Message::user("1"));
history.add(Message::assistant("1"));
history.add(Message::user("2"));
history.add(Message::assistant("2"));
assert_eq!(history.len(), 3);
assert_eq!(history.messages()[0].role, Role::System);
}
#[tokio::test]
async fn test_in_memory_adapter() {
let mut adapter = InMemoryAdapter::default();
adapter
.store_short_term(Message::user("Hello world"))
.await
.unwrap();
adapter
.store_short_term(Message::assistant("Hi!"))
.await
.unwrap();
let messages = adapter.get_short_term().await.unwrap();
assert_eq!(messages.len(), 2);
let search = adapter.search_short_term("world", 10).await.unwrap();
assert_eq!(search.len(), 1);
}
}