use async_trait::async_trait;
use serde_json::Value;
use std::sync::Mutex;
use crate::traits::memory::{Memory, MemoryEntry};
use crate::types::message::{Message, MessageRole};
use crate::Result;
pub struct CompressedMemory<M: Memory> {
inner: M,
threshold: usize,
keep_recent: usize,
summaries: Mutex<std::collections::HashMap<String, String>>,
}
impl<M: Memory> CompressedMemory<M> {
pub fn new(inner: M, threshold: usize, keep_recent: usize) -> Self {
Self {
inner,
threshold,
keep_recent,
summaries: Mutex::new(std::collections::HashMap::new()),
}
}
#[must_use]
pub fn threshold(&self) -> usize {
self.threshold
}
#[must_use]
pub fn keep_recent(&self) -> usize {
self.keep_recent
}
fn summarize(messages: &[Message]) -> String {
let mut summary = String::from("[Compressed context summary]\n");
for msg in messages {
let role = match msg.role {
MessageRole::User => "User",
MessageRole::Assistant => "Assistant",
MessageRole::System => "System",
MessageRole::Tool => "Tool",
};
let content = if msg.content.len() > 100 {
let mut end = 100;
while end > 0 && !msg.content.is_char_boundary(end) {
end -= 1;
}
format!("{}...", &msg.content[..end])
} else {
msg.content.clone()
};
summary.push_str(&format!("- {role}: {content}\n"));
}
summary
}
}
#[async_trait]
impl<M: Memory> Memory for CompressedMemory<M> {
async fn messages(&self, session_id: &str) -> Result<Vec<Message>> {
let all_messages = self.inner.messages(session_id).await?;
if all_messages.len() <= self.threshold {
return Ok(all_messages);
}
let (system_msgs, conversation): (Vec<_>, Vec<_>) = all_messages
.into_iter()
.partition(|m| m.role == MessageRole::System);
if conversation.len() <= self.keep_recent {
let mut result = system_msgs;
result.extend(conversation);
return Ok(result);
}
let split_point = conversation.len().saturating_sub(self.keep_recent);
let (old_messages, recent_messages) = conversation.split_at(split_point);
let summary = {
let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
let key = format!("{session_id}:{split_point}");
cache
.entry(key)
.or_insert_with(|| Self::summarize(old_messages))
.clone()
};
let mut result = system_msgs;
result.push(Message::system(summary));
result.extend(recent_messages.iter().cloned());
Ok(result)
}
async fn append(&self, session_id: &str, message: Message) -> Result<()> {
{
let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
let prefix = format!("{session_id}:");
cache.retain(|k, _| !k.starts_with(&prefix));
}
self.inner.append(session_id, message).await
}
async fn get_context(&self, session_id: &str, key: &str) -> Result<Option<Value>> {
self.inner.get_context(session_id, key).await
}
async fn set_context(&self, session_id: &str, key: &str, value: Value) -> Result<()> {
self.inner.set_context(session_id, key, value).await
}
async fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
self.inner.recall(query, limit).await
}
async fn store(&self, entry: MemoryEntry) -> Result<()> {
self.inner.store(entry).await
}
async fn create_session(&self) -> Result<String> {
self.inner.create_session().await
}
async fn list_sessions(&self) -> Result<Vec<String>> {
self.inner.list_sessions().await
}
async fn delete_session(&self, session_id: &str) -> Result<()> {
{
let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
cache.retain(|k, _| !k.starts_with(session_id));
}
self.inner.delete_session(session_id).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::in_memory::InMemoryMemory;
#[tokio::test]
async fn test_below_threshold_no_compression() {
let memory = CompressedMemory::new(InMemoryMemory::new(), 10, 3);
for i in 0..5 {
memory
.append("s1", Message::user(format!("msg {i}")))
.await
.unwrap();
}
let msgs = memory.messages("s1").await.unwrap();
assert_eq!(msgs.len(), 5, "should return all messages uncompressed");
}
#[tokio::test]
async fn test_above_threshold_compresses() {
let memory = CompressedMemory::new(InMemoryMemory::new(), 5, 3);
for i in 0..8 {
memory
.append("s1", Message::user(format!("msg {i}")))
.await
.unwrap();
}
let msgs = memory.messages("s1").await.unwrap();
assert_eq!(msgs.len(), 4, "should compress to summary + 3 recent");
assert!(
msgs[0].content.contains("[Compressed context summary]"),
"first msg should be summary, got: {}",
msgs[0].content
);
assert_eq!(msgs[1].content, "msg 5");
assert_eq!(msgs[2].content, "msg 6");
assert_eq!(msgs[3].content, "msg 7");
}
#[tokio::test]
async fn test_system_prompt_preserved() {
let inner = InMemoryMemory::new();
inner
.append("s1", Message::system("You are helpful"))
.await
.unwrap();
for i in 0..10 {
inner
.append("s1", Message::user(format!("msg {i}")))
.await
.unwrap();
}
let memory = CompressedMemory::new(inner, 5, 3);
let msgs = memory.messages("s1").await.unwrap();
assert_eq!(msgs[0].role, MessageRole::System);
assert_eq!(msgs[0].content, "You are helpful");
}
#[tokio::test]
async fn test_append_invalidates_cache() {
let memory = CompressedMemory::new(InMemoryMemory::new(), 3, 2);
for i in 0..5 {
memory
.append("s1", Message::user(format!("msg {i}")))
.await
.unwrap();
}
let msgs1 = memory.messages("s1").await.unwrap();
let summary1 = msgs1[0].content.clone();
memory.append("s1", Message::user("msg 5")).await.unwrap();
let msgs2 = memory.messages("s1").await.unwrap();
assert_ne!(msgs2[0].content, summary1);
}
#[tokio::test]
async fn test_stackable_decorator() {
let inner = CompressedMemory::new(InMemoryMemory::new(), 20, 5);
let outer = CompressedMemory::new(inner, 10, 3);
for i in 0..15 {
outer
.append("s1", Message::user(format!("msg {i}")))
.await
.unwrap();
}
let msgs = outer.messages("s1").await.unwrap();
assert_eq!(msgs.len(), 4, "stacked decorators should compress");
}
#[tokio::test]
async fn test_working_memory_delegates() {
let memory = CompressedMemory::new(InMemoryMemory::new(), 5, 2);
memory
.set_context("s1", "key", serde_json::json!("value"))
.await
.unwrap();
let val = memory.get_context("s1", "key").await.unwrap();
assert_eq!(val, Some(serde_json::json!("value")));
}
#[tokio::test]
async fn test_multibyte_content_no_panic() {
let memory = CompressedMemory::new(InMemoryMemory::new(), 3, 1);
let long_vietnamese = "Xin chào! Đây là một tin nhắn rất dài bằng tiếng Việt để kiểm tra rằng việc cắt ngắn không gây lỗi panic khi gặp ký tự đa byte UTF-8.";
assert!(
long_vietnamese.len() > 100,
"test premise: string must be >100 bytes"
);
for _ in 0..5 {
memory
.append("s1", Message::user(long_vietnamese))
.await
.unwrap();
}
let msgs = memory.messages("s1").await.unwrap();
assert_eq!(msgs.len(), 2); assert!(msgs[0].content.contains("[Compressed"));
}
#[tokio::test]
async fn test_threshold_zero_always_compresses() {
let memory = CompressedMemory::new(InMemoryMemory::new(), 0, 1);
memory.append("s1", Message::user("a")).await.unwrap();
memory.append("s1", Message::user("b")).await.unwrap();
let msgs = memory.messages("s1").await.unwrap();
assert_eq!(msgs.len(), 2);
assert!(msgs[0].content.contains("[Compressed"));
assert_eq!(msgs[1].content, "b");
}
#[tokio::test]
async fn test_keep_recent_exceeds_message_count() {
let memory = CompressedMemory::new(InMemoryMemory::new(), 0, 100);
memory
.append("s1", Message::user("only one"))
.await
.unwrap();
let msgs = memory.messages("s1").await.unwrap();
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].content, "only one");
}
}