use tokio::sync::Mutex;
use crate::error::Result;
use crate::memory::traits::Memory;
use crate::model::SharedModel;
use crate::model::types::{ChatRequest, Message};
const DEFAULT_MAX_MESSAGES: usize = 20;
const DEFAULT_RETAIN_RECENT: usize = 10;
const DEFAULT_SUMMARY_PROMPT: &str = "\
You are a conversation summarizer. Summarize the following conversation \
into a concise paragraph that preserves all important facts, decisions, \
tool results, and context. Be specific — include names, numbers, and \
outcomes. Do not include any preamble, just the summary.";
pub struct SummaryMemory {
messages: Mutex<Vec<Message>>,
summary: Mutex<Option<String>>,
model: SharedModel,
max_messages: usize,
retain_recent: usize,
summary_prompt: String,
}
impl SummaryMemory {
pub fn new(model: SharedModel) -> Self {
Self {
messages: Mutex::new(Vec::new()),
summary: Mutex::new(None),
model,
max_messages: DEFAULT_MAX_MESSAGES,
retain_recent: DEFAULT_RETAIN_RECENT,
summary_prompt: DEFAULT_SUMMARY_PROMPT.to_string(),
}
}
pub fn with_max_messages(mut self, max: usize) -> Self {
self.max_messages = max;
self
}
pub fn with_retain_recent(mut self, count: usize) -> Self {
self.retain_recent = count;
self
}
pub fn with_summary_prompt(mut self, prompt: impl Into<String>) -> Self {
self.summary_prompt = prompt.into();
self
}
pub async fn current_summary(&self) -> Option<String> {
self.summary.lock().await.clone()
}
async fn maybe_summarize(&self) -> Result<()> {
let mut messages = self.messages.lock().await;
if messages.len() <= self.max_messages {
return Ok(());
}
let split_at = messages.len().saturating_sub(self.retain_recent);
if split_at == 0 {
return Ok(());
}
let to_summarize: Vec<Message> = messages.drain(..split_at).collect();
let mut summary_input = Vec::new();
summary_input.push(Message::system(&self.summary_prompt));
let existing_summary = self.summary.lock().await;
if let Some(ref prev) = *existing_summary {
summary_input.push(Message::user(format!(
"Previous conversation summary:\n{prev}\n\nNew messages to incorporate:"
)));
} else {
summary_input.push(Message::user(
"Conversation to summarize:".to_string(),
));
}
let mut conversation_text = String::new();
for msg in &to_summarize {
let role = match msg.role {
crate::model::types::Role::System => "System",
crate::model::types::Role::User => "User",
crate::model::types::Role::Assistant => "Assistant",
crate::model::types::Role::Tool => "Tool",
};
if let Some(ref content) = msg.content {
conversation_text.push_str(&format!("{role}: {content}\n"));
}
if !msg.tool_calls.is_empty() {
for tc in &msg.tool_calls {
conversation_text.push_str(&format!(
"Assistant called tool '{}' with args: {}\n",
tc.name, tc.arguments
));
}
}
}
summary_input.push(Message::user(conversation_text));
let request = ChatRequest {
messages: summary_input,
tools: Vec::new(),
temperature: Some(0.0),
max_tokens: Some(512),
};
drop(existing_summary);
drop(messages);
tracing::debug!(
messages_summarized = to_summarize.len(),
"generating conversation summary"
);
let response = self.model.generate_erased(&request).await?;
let summary_text = response.text().to_string();
let mut summary = self.summary.lock().await;
*summary = Some(summary_text);
Ok(())
}
}
impl Memory for SummaryMemory {
async fn add_message(&self, message: Message) -> Result<()> {
{
let mut messages = self.messages.lock().await;
messages.push(message);
}
self.maybe_summarize().await?;
Ok(())
}
async fn get_messages(&self) -> Result<Vec<Message>> {
let messages = self.messages.lock().await;
let summary = self.summary.lock().await;
let mut result = Vec::new();
if let Some(ref s) = *summary {
result.push(Message::system(format!(
"Previous conversation summary: {s}"
)));
}
result.extend(messages.clone());
Ok(result)
}
async fn clear(&self) -> Result<()> {
let mut messages = self.messages.lock().await;
let mut summary = self.summary.lock().await;
messages.clear();
*summary = None;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::Model;
use crate::model::types::{ChatResponse, Role, StopReason, Usage};
use crate::stream::ResponseStream;
use std::sync::Arc;
struct FakeSummarizerModel;
impl Model for FakeSummarizerModel {
async fn generate(&self, request: &ChatRequest) -> Result<ChatResponse> {
let has_previous = request
.messages
.iter()
.any(|m| {
m.content
.as_ref()
.is_some_and(|c| c.contains("Previous conversation summary"))
});
let msg_count = request
.messages
.last()
.and_then(|m| m.content.as_ref())
.map(|c| c.lines().count())
.unwrap_or(0);
let text = if has_previous {
format!("Extended summary incorporating {msg_count} new lines.")
} else {
format!("Summary of {msg_count} conversation lines.")
};
Ok(ChatResponse {
message: Message::assistant(text),
stop_reason: StopReason::EndTurn,
usage: Some(Usage::default()),
})
}
async fn generate_stream(&self, _request: &ChatRequest) -> Result<ResponseStream> {
Ok(Box::pin(futures::stream::empty()))
}
}
fn make_memory(max_messages: usize, retain_recent: usize) -> SummaryMemory {
let model: SharedModel = Arc::new(FakeSummarizerModel);
SummaryMemory::new(model)
.with_max_messages(max_messages)
.with_retain_recent(retain_recent)
}
#[tokio::test]
async fn test_add_and_get_below_threshold() {
let memory = make_memory(10, 5);
memory.add_message(Message::user("hello")).await.unwrap();
memory
.add_message(Message::assistant("hi"))
.await
.unwrap();
let msgs = memory.get_messages().await.unwrap();
assert_eq!(msgs.len(), 2);
assert!(memory.current_summary().await.is_none());
}
#[tokio::test]
async fn test_summarizes_when_exceeding_threshold() {
let memory = make_memory(5, 2);
for i in 0..6 {
memory
.add_message(Message::user(format!("message {i}")))
.await
.unwrap();
}
let summary = memory.current_summary().await;
assert!(summary.is_some());
let msgs = memory.get_messages().await.unwrap();
assert_eq!(msgs[0].role, Role::System);
assert!(msgs[0]
.content
.as_ref()
.unwrap()
.contains("Summary"));
}
#[tokio::test]
async fn test_retains_recent_messages() {
let memory = make_memory(4, 2);
for i in 0..5 {
memory
.add_message(Message::user(format!("msg {i}")))
.await
.unwrap();
}
let msgs = memory.get_messages().await.unwrap();
assert_eq!(msgs.len(), 3);
assert_eq!(msgs[0].role, Role::System);
assert_eq!(msgs[1].content.as_deref(), Some("msg 3"));
assert_eq!(msgs[2].content.as_deref(), Some("msg 4"));
}
#[tokio::test]
async fn test_clear_resets_summary() {
let memory = make_memory(3, 1);
for i in 0..5 {
memory
.add_message(Message::user(format!("msg {i}")))
.await
.unwrap();
}
assert!(memory.current_summary().await.is_some());
memory.clear().await.unwrap();
assert!(memory.current_summary().await.is_none());
assert!(memory.get_messages().await.unwrap().is_empty());
}
#[tokio::test]
async fn test_incremental_summarization() {
let memory = make_memory(4, 2);
for i in 0..5 {
memory
.add_message(Message::user(format!("batch1 msg {i}")))
.await
.unwrap();
}
let first_summary = memory.current_summary().await.unwrap();
for i in 0..5 {
memory
.add_message(Message::user(format!("batch2 msg {i}")))
.await
.unwrap();
}
let second_summary = memory.current_summary().await.unwrap();
assert_ne!(first_summary, second_summary);
}
#[tokio::test]
async fn test_custom_summary_prompt() {
let model: SharedModel = Arc::new(FakeSummarizerModel);
let memory = SummaryMemory::new(model)
.with_max_messages(3)
.with_retain_recent(1)
.with_summary_prompt("Custom summarization instructions");
assert_eq!(memory.summary_prompt, "Custom summarization instructions");
}
#[tokio::test]
async fn test_empty_memory() {
let memory = make_memory(10, 5);
assert!(memory.current_summary().await.is_none());
assert!(memory.get_messages().await.unwrap().is_empty());
}
}