use agentlib_core::{
MemoryProvider, MemoryReadOptions, MemoryWriteOptions, ModelMessage, ModelProvider,
ModelRequest, Role, async_trait, estimate_messages_tokens,
};
use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct SummarizingMemory {
model: Arc<dyn ModelProvider>,
active_window_tokens: usize,
summary_prompt: String,
sessions: Arc<Mutex<HashMap<String, SessionData>>>,
}
struct SessionData {
summary: Option<String>,
active_messages: Vec<ModelMessage>,
}
impl SummarizingMemory {
pub fn new(model: Arc<dyn ModelProvider>, active_window_tokens: usize) -> Self {
Self {
model,
active_window_tokens,
summary_prompt: "You are a memory compression assistant. Summarize the following conversation concisely.".to_string(),
sessions: Arc::new(Mutex::new(HashMap::new())),
}
}
async fn compress(&self, session: &mut SessionData) -> Result<()> {
let messages = &session.active_messages;
let split_at = messages.len() / 2;
let (to_compress, to_keep) = messages.split_at(split_at);
let mut compress_input = String::new();
if let Some(summary) = &session.summary {
compress_input.push_str(&format!(
"Previous summary:\n{}\n\nNew conversation to add:\n",
summary
));
}
for msg in to_compress {
compress_input.push_str(&format!("{:?}: {}\n", msg.role, msg.content));
}
let response = self
.model
.complete(ModelRequest {
messages: vec![
ModelMessage {
role: Role::System,
content: self.summary_prompt.clone(),
tool_call_id: None,
tool_calls: None,
},
ModelMessage {
role: Role::User,
content: compress_input,
tool_call_id: None,
tool_calls: None,
},
],
tools: None,
})
.await?;
session.summary = Some(response.message.content);
session.active_messages = to_keep.to_vec();
Ok(())
}
}
#[async_trait]
impl MemoryProvider for SummarizingMemory {
async fn read(&self, options: MemoryReadOptions) -> Result<Vec<ModelMessage>> {
let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
let sessions = self.sessions.lock().await;
let session = match sessions.get(&session_id) {
Some(s) => s,
None => return Ok(Vec::new()),
};
let mut messages = Vec::new();
if let Some(summary) = &session.summary {
messages.push(ModelMessage {
role: Role::System,
content: format!("[Conversation summary so far]\n{}", summary),
tool_call_id: None,
tool_calls: None,
});
}
messages.extend(session.active_messages.clone());
Ok(messages)
}
async fn write(&self, messages: Vec<ModelMessage>, options: MemoryWriteOptions) -> Result<()> {
let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
let non_system: Vec<ModelMessage> = messages
.into_iter()
.filter(|m| m.role != Role::System)
.collect();
let mut sessions = self.sessions.lock().await;
let session = sessions.entry(session_id).or_insert(SessionData {
summary: None,
active_messages: Vec::new(),
});
session.active_messages = non_system;
let tokens = estimate_messages_tokens(&session.active_messages);
if tokens > self.active_window_tokens {
self.compress(session).await?;
}
Ok(())
}
}