use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use tokio::sync::Mutex;
use cognis_core::error::Result;
use cognis_core::language_models::chat_model::BaseChatModel;
use cognis_core::messages::{get_buffer_string, Message};
use super::BaseMemory;
pub struct ConversationSummaryMemory {
messages: Arc<Mutex<Vec<Message>>>,
summary: Arc<Mutex<String>>,
model: Arc<dyn BaseChatModel>,
max_messages: usize,
memory_key: String,
}
impl ConversationSummaryMemory {
pub fn new(model: Arc<dyn BaseChatModel>, max_messages: usize) -> Self {
Self {
messages: Arc::new(Mutex::new(Vec::new())),
summary: Arc::new(Mutex::new(String::new())),
model,
max_messages,
memory_key: "history".to_string(),
}
}
pub fn with_memory_key(mut self, key: impl Into<String>) -> Self {
self.memory_key = key.into();
self
}
async fn summarize_messages(
&self,
messages: &[Message],
existing_summary: &str,
) -> Result<String> {
let buffer = get_buffer_string(messages, "Human", "AI");
let prompt = if existing_summary.is_empty() {
format!(
"Summarize the following conversation so far:\n{}\n\nSummary:",
buffer
)
} else {
format!(
"Current summary:\n{}\n\nNew conversation:\n{}\n\nUpdated summary:",
existing_summary, buffer
)
};
let prompt_msg = Message::human(prompt);
let response = self.model.invoke_messages(&[prompt_msg], None).await?;
Ok(response.base.content.text())
}
}
#[async_trait]
impl BaseMemory for ConversationSummaryMemory {
async fn load_memory_variables(&self) -> Result<HashMap<String, Value>> {
let messages = self.messages.lock().await;
let summary = self.summary.lock().await;
let mut vars = HashMap::new();
let mut parts = Vec::new();
if !summary.is_empty() {
parts.push(format!("Summary of earlier conversation:\n{}", *summary));
}
if !messages.is_empty() {
let buffer = get_buffer_string(&messages, "Human", "AI");
parts.push(buffer);
}
vars.insert(self.memory_key.clone(), Value::String(parts.join("\n\n")));
Ok(vars)
}
async fn save_context(&self, input: &Message, output: &Message) -> Result<()> {
{
let mut messages = self.messages.lock().await;
messages.push(input.clone());
messages.push(output.clone());
}
let needs_summarization = {
let messages = self.messages.lock().await;
messages.len() > self.max_messages
};
if needs_summarization {
let (msgs_to_summarize, remaining) = {
let messages = self.messages.lock().await;
let split_at = messages.len().saturating_sub(2);
let to_summarize = messages[..split_at].to_vec();
let remaining = messages[split_at..].to_vec();
(to_summarize, remaining)
};
let existing_summary = {
let summary = self.summary.lock().await;
summary.clone()
};
let new_summary = self
.summarize_messages(&msgs_to_summarize, &existing_summary)
.await?;
{
let mut summary = self.summary.lock().await;
*summary = new_summary;
}
{
let mut messages = self.messages.lock().await;
*messages = remaining;
}
}
Ok(())
}
async fn clear(&self) -> Result<()> {
let mut messages = self.messages.lock().await;
let mut summary = self.summary.lock().await;
messages.clear();
summary.clear();
Ok(())
}
fn memory_key(&self) -> &str {
&self.memory_key
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::language_models::fake::FakeListChatModel;
use cognis_core::messages::Message;
#[tokio::test]
async fn test_summary_under_limit() {
let model = Arc::new(FakeListChatModel::new(vec![
"This should not be called".to_string()
]));
let mem = ConversationSummaryMemory::new(model, 10);
mem.save_context(&Message::human("Hello"), &Message::ai("Hi"))
.await
.unwrap();
let vars = mem.load_memory_variables().await.unwrap();
let history = vars.get("history").unwrap().as_str().unwrap();
assert!(history.contains("Hello"));
assert!(history.contains("Hi"));
}
#[tokio::test]
async fn test_summary_triggers() {
let model = Arc::new(FakeListChatModel::new(vec![
"User greeted and AI responded.".to_string(),
]));
let mem = ConversationSummaryMemory::new(model, 2);
mem.save_context(&Message::human("Hello"), &Message::ai("Hi there"))
.await
.unwrap();
{
let msgs = mem.messages.lock().await;
assert_eq!(msgs.len(), 2);
}
mem.save_context(&Message::human("How are you?"), &Message::ai("Fine"))
.await
.unwrap();
{
let msgs = mem.messages.lock().await;
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0].content().text(), "How are you?");
assert_eq!(msgs[1].content().text(), "Fine");
}
{
let summary = mem.summary.lock().await;
assert_eq!(*summary, "User greeted and AI responded.");
}
let vars = mem.load_memory_variables().await.unwrap();
let history = vars.get("history").unwrap().as_str().unwrap();
assert!(history.contains("User greeted and AI responded."));
assert!(history.contains("How are you?"));
}
}