use async_trait::async_trait;
use std::collections::HashMap;
use serde_json::Value;
use super::base::{BaseMemory, MemoryError, ChatMessageHistory};
use crate::schema::Message;
#[derive(Debug)]
pub struct ConversationBufferWindowMemory {
chat_memory: ChatMessageHistory,
k: usize,
input_key: String,
output_key: String,
memory_key: String,
return_messages: bool,
}
impl ConversationBufferWindowMemory {
pub fn new(k: usize) -> Self {
Self {
chat_memory: ChatMessageHistory::new(),
k,
input_key: "input".to_string(),
output_key: "output".to_string(),
memory_key: "history".to_string(),
return_messages: false,
}
}
pub fn with_input_key(mut self, key: String) -> Self {
self.input_key = key;
self
}
pub fn with_output_key(mut self, key: String) -> Self {
self.output_key = key;
self
}
pub fn with_memory_key(mut self, key: String) -> Self {
self.memory_key = key;
self
}
pub fn with_return_messages(mut self, return_messages: bool) -> Self {
self.return_messages = return_messages;
self
}
pub fn chat_memory(&self) -> &ChatMessageHistory {
&self.chat_memory
}
pub fn k(&self) -> usize {
self.k
}
fn get_window_messages(&self) -> Vec<Message> {
let messages = self.chat_memory.messages();
let total = messages.len();
let max_messages = self.k * 2;
if total <= max_messages {
messages.to_vec()
} else {
messages[total - max_messages..].to_vec()
}
}
fn buffer_as_string(&self) -> String {
self.get_window_messages()
.iter()
.map(|msg| {
let role = match msg.message_type {
crate::schema::MessageType::Human => "Human",
crate::schema::MessageType::AI => "AI",
crate::schema::MessageType::System => "System",
crate::schema::MessageType::Tool { .. } => "Tool",
};
format!("{}: {}", role, msg.content)
})
.collect::<Vec<_>>()
.join("\n")
}
}
impl Default for ConversationBufferWindowMemory {
fn default() -> Self {
Self::new(5)
}
}
#[async_trait]
impl BaseMemory for ConversationBufferWindowMemory {
fn memory_variables(&self) -> Vec<&str> {
vec![&self.memory_key]
}
async fn load_memory_variables(
&self,
_inputs: &HashMap<String, String>,
) -> Result<HashMap<String, Value>, MemoryError> {
let mut result = HashMap::new();
if self.return_messages {
let messages: Vec<Value> = self.get_window_messages()
.into_iter()
.map(|msg| {
serde_json::to_value(&msg)
.unwrap_or(Value::Null)
})
.collect();
result.insert(self.memory_key.clone(), Value::Array(messages));
} else {
result.insert(
self.memory_key.clone(),
Value::String(self.buffer_as_string())
);
}
Ok(result)
}
async fn save_context(
&mut self,
inputs: &HashMap<String, String>,
outputs: &HashMap<String, String>,
) -> Result<(), MemoryError> {
if let Some(input) = inputs.get(&self.input_key) {
self.chat_memory.add_user_message(input);
}
if let Some(output) = outputs.get(&self.output_key) {
self.chat_memory.add_ai_message(output);
}
Ok(())
}
async fn clear(&mut self) -> Result<(), MemoryError> {
self.chat_memory.clear();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_window_memory() {
let mut memory = ConversationBufferWindowMemory::new(2);
for i in 1..=3 {
let inputs = HashMap::from([("input".to_string(), format!("问题{}", i))]);
let outputs = HashMap::from([("output".to_string(), format!("答案{}", i))]);
memory.save_context(&inputs, &outputs).await.unwrap();
}
assert_eq!(memory.chat_memory().len(), 6);
let memory_vars = memory.load_memory_variables(&HashMap::new()).await.unwrap();
let history = memory_vars.get("history").unwrap().as_str().unwrap();
assert!(!history.contains("问题1"));
assert!(!history.contains("答案1"));
assert!(history.contains("问题2"));
assert!(history.contains("答案3"));
}
#[tokio::test]
async fn test_window_memory_smaller_than_k() {
let mut memory = ConversationBufferWindowMemory::new(5);
for i in 1..=2 {
let inputs = HashMap::from([("input".to_string(), format!("问题{}", i))]);
let outputs = HashMap::from([("output".to_string(), format!("答案{}", i))]);
memory.save_context(&inputs, &outputs).await.unwrap();
}
let memory_vars = memory.load_memory_variables(&HashMap::new()).await.unwrap();
let history = memory_vars.get("history").unwrap().as_str().unwrap();
assert!(history.contains("问题1"));
assert!(history.contains("问题2"));
}
#[tokio::test]
async fn test_window_memory_clear() {
let mut memory = ConversationBufferWindowMemory::new(2);
let inputs = HashMap::from([("input".to_string(), "测试".to_string())]);
let outputs = HashMap::from([("output".to_string(), "收到".to_string())]);
memory.save_context(&inputs, &outputs).await.unwrap();
assert_eq!(memory.chat_memory().len(), 2);
memory.clear().await.unwrap();
assert_eq!(memory.chat_memory().len(), 0);
}
}