use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use tokio::sync::Mutex;
use cognis_core::error::Result;
use cognis_core::messages::{get_buffer_string, Message};
use super::BaseMemory;
pub struct ConversationWindowMemory {
messages: Arc<Mutex<Vec<Message>>>,
k: usize,
memory_key: String,
return_messages: bool,
}
impl ConversationWindowMemory {
pub fn new(k: usize) -> Self {
Self {
messages: Arc::new(Mutex::new(Vec::new())),
k,
memory_key: "history".to_string(),
return_messages: true,
}
}
pub fn with_memory_key(mut self, key: impl Into<String>) -> Self {
self.memory_key = key.into();
self
}
pub fn with_return_messages(mut self, return_messages: bool) -> Self {
self.return_messages = return_messages;
self
}
fn trim(messages: &mut Vec<Message>, k: usize) {
let max_messages = k * 2;
if messages.len() > max_messages {
let drain_count = messages.len() - max_messages;
messages.drain(..drain_count);
}
}
}
#[async_trait]
impl BaseMemory for ConversationWindowMemory {
async fn load_memory_variables(&self) -> Result<HashMap<String, Value>> {
let messages = self.messages.lock().await;
let mut vars = HashMap::new();
if self.return_messages {
let serialized: Vec<Value> = messages
.iter()
.map(|m| serde_json::to_value(m).unwrap_or(Value::Null))
.collect();
vars.insert(self.memory_key.clone(), Value::Array(serialized));
} else {
let buffer = get_buffer_string(&messages, "Human", "AI");
vars.insert(self.memory_key.clone(), Value::String(buffer));
}
Ok(vars)
}
async fn save_context(&self, input: &Message, output: &Message) -> Result<()> {
let mut messages = self.messages.lock().await;
messages.push(input.clone());
messages.push(output.clone());
Self::trim(&mut messages, self.k);
Ok(())
}
async fn clear(&self) -> Result<()> {
let mut messages = self.messages.lock().await;
messages.clear();
Ok(())
}
fn memory_key(&self) -> &str {
&self.memory_key
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::messages::Message;
#[tokio::test]
async fn test_window_keeps_k_turns() {
let mem = ConversationWindowMemory::new(2);
mem.save_context(&Message::human("Turn 1"), &Message::ai("Response 1"))
.await
.unwrap();
mem.save_context(&Message::human("Turn 2"), &Message::ai("Response 2"))
.await
.unwrap();
mem.save_context(&Message::human("Turn 3"), &Message::ai("Response 3"))
.await
.unwrap();
let vars = mem.load_memory_variables().await.unwrap();
let history = vars.get("history").unwrap().as_array().unwrap();
assert_eq!(history.len(), 4);
let messages = mem.messages.lock().await;
assert_eq!(messages[0].content().text(), "Turn 2");
assert_eq!(messages[1].content().text(), "Response 2");
assert_eq!(messages[2].content().text(), "Turn 3");
assert_eq!(messages[3].content().text(), "Response 3");
}
#[tokio::test]
async fn test_window_under_limit() {
let mem = ConversationWindowMemory::new(5);
mem.save_context(&Message::human("Hi"), &Message::ai("Hello"))
.await
.unwrap();
mem.save_context(&Message::human("How?"), &Message::ai("Fine"))
.await
.unwrap();
let vars = mem.load_memory_variables().await.unwrap();
let history = vars.get("history").unwrap().as_array().unwrap();
assert_eq!(history.len(), 4);
}
}