traitclaw-core 1.0.0

Core traits, types, and runtime for the TraitClaw AI Agent Framework
Documentation
//! Compressed memory decorator — automatic context window management.
//!
//! `CompressedMemory` wraps any [`Memory`] implementation and automatically
//! summarizes older messages when the conversation exceeds a configured
//! threshold, keeping the agent within its context window budget.
//!
//! # Architecture Decision
//!
//! Uses the **Decorator Pattern** — `CompressedMemory<M>` implements `Memory`
//! for any `M: Memory`. This means it can wrap `InMemoryMemory`, `SqliteMemory`,
//! or even another `CompressedMemory` (stackable).
//!
//! # Example
//!
//! ```rust,no_run
//! use traitclaw_core::memory::compressed::CompressedMemory;
//! use traitclaw_core::memory::in_memory::InMemoryMemory;
//!
//! // Wrap in-memory with compression at 70% threshold
//! let memory = CompressedMemory::new(
//!     InMemoryMemory::new(),
//!     20,  // compress when > 20 messages
//!     5,   // keep last 5 messages uncompressed
//! );
//! ```

use async_trait::async_trait;
use serde_json::Value;
use std::sync::Mutex;

use crate::traits::memory::{Memory, MemoryEntry};
use crate::types::message::{Message, MessageRole};
use crate::Result;

/// A memory decorator that compresses older messages into summaries.
///
/// When the message count exceeds `threshold`, older messages (except
/// the system prompt) are replaced with a single summary message,
/// keeping only the most recent `keep_recent` messages intact.
///
/// The summary is generated by concatenating message content (for
/// simplicity and zero-dependency). For LLM-powered summarization,
/// use a custom strategy by wrapping this decorator with a provider call.
pub struct CompressedMemory<M: Memory> {
    inner: M,
    /// Compress when message count exceeds this.
    threshold: usize,
    /// Number of recent messages to keep uncompressed.
    keep_recent: usize,
    /// Cached summaries per session.
    summaries: Mutex<std::collections::HashMap<String, String>>,
}

impl<M: Memory> CompressedMemory<M> {
    /// Create a new compressed memory wrapping the given inner memory.
    ///
    /// - `threshold`: Trigger compression when message count exceeds this
    /// - `keep_recent`: Number of most recent messages to preserve uncompressed
    pub fn new(inner: M, threshold: usize, keep_recent: usize) -> Self {
        Self {
            inner,
            threshold,
            keep_recent,
            summaries: Mutex::new(std::collections::HashMap::new()),
        }
    }

    /// Get the compression threshold.
    #[must_use]
    pub fn threshold(&self) -> usize {
        self.threshold
    }

    /// Get the number of recent messages kept uncompressed.
    #[must_use]
    pub fn keep_recent(&self) -> usize {
        self.keep_recent
    }

    /// Generate a simple summary of messages.
    ///
    /// For production use, replace this with an LLM-based summarizer.
    fn summarize(messages: &[Message]) -> String {
        let mut summary = String::from("[Compressed context summary]\n");
        for msg in messages {
            let role = match msg.role {
                MessageRole::User => "User",
                MessageRole::Assistant => "Assistant",
                MessageRole::System => "System",
                MessageRole::Tool => "Tool",
            };
            // Truncate long messages in summary (char-boundary safe)
            let content = if msg.content.len() > 100 {
                // Find the last valid char boundary at or before byte 100
                let mut end = 100;
                while end > 0 && !msg.content.is_char_boundary(end) {
                    end -= 1;
                }
                format!("{}...", &msg.content[..end])
            } else {
                msg.content.clone()
            };
            summary.push_str(&format!("- {role}: {content}\n"));
        }
        summary
    }
}

#[async_trait]
impl<M: Memory> Memory for CompressedMemory<M> {
    async fn messages(&self, session_id: &str) -> Result<Vec<Message>> {
        let all_messages = self.inner.messages(session_id).await?;

        if all_messages.len() <= self.threshold {
            return Ok(all_messages);
        }

        // Separate system prompt (if present)
        let (system_msgs, conversation): (Vec<_>, Vec<_>) = all_messages
            .into_iter()
            .partition(|m| m.role == MessageRole::System);

        if conversation.len() <= self.keep_recent {
            let mut result = system_msgs;
            result.extend(conversation);
            return Ok(result);
        }

        // Split into old (to compress) and recent (to keep)
        let split_point = conversation.len().saturating_sub(self.keep_recent);
        let (old_messages, recent_messages) = conversation.split_at(split_point);

        // Generate or retrieve cached summary
        let summary = {
            let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
            let key = format!("{session_id}:{split_point}");
            cache
                .entry(key)
                .or_insert_with(|| Self::summarize(old_messages))
                .clone()
        };

        // Reconstruct: system + summary + recent
        let mut result = system_msgs;
        result.push(Message::system(summary));
        result.extend(recent_messages.iter().cloned());

        Ok(result)
    }

    async fn append(&self, session_id: &str, message: Message) -> Result<()> {
        // Invalidate summary cache for this session
        {
            let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
            let prefix = format!("{session_id}:");
            cache.retain(|k, _| !k.starts_with(&prefix));
        }
        self.inner.append(session_id, message).await
    }

    async fn get_context(&self, session_id: &str, key: &str) -> Result<Option<Value>> {
        self.inner.get_context(session_id, key).await
    }

    async fn set_context(&self, session_id: &str, key: &str, value: Value) -> Result<()> {
        self.inner.set_context(session_id, key, value).await
    }

    async fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
        self.inner.recall(query, limit).await
    }

    async fn store(&self, entry: MemoryEntry) -> Result<()> {
        self.inner.store(entry).await
    }

    async fn create_session(&self) -> Result<String> {
        self.inner.create_session().await
    }

    async fn list_sessions(&self) -> Result<Vec<String>> {
        self.inner.list_sessions().await
    }

    async fn delete_session(&self, session_id: &str) -> Result<()> {
        // Clean up summary cache
        {
            let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
            cache.retain(|k, _| !k.starts_with(session_id));
        }
        self.inner.delete_session(session_id).await
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::memory::in_memory::InMemoryMemory;

    #[tokio::test]
    async fn test_below_threshold_no_compression() {
        let memory = CompressedMemory::new(InMemoryMemory::new(), 10, 3);

        // Add 5 messages (below threshold of 10)
        for i in 0..5 {
            memory
                .append("s1", Message::user(format!("msg {i}")))
                .await
                .unwrap();
        }

        let msgs = memory.messages("s1").await.unwrap();
        assert_eq!(msgs.len(), 5, "should return all messages uncompressed");
    }

    #[tokio::test]
    async fn test_above_threshold_compresses() {
        let memory = CompressedMemory::new(InMemoryMemory::new(), 5, 3);

        // Add 8 messages (above threshold of 5)
        for i in 0..8 {
            memory
                .append("s1", Message::user(format!("msg {i}")))
                .await
                .unwrap();
        }

        let msgs = memory.messages("s1").await.unwrap();
        // Should have: 1 summary + 3 recent = 4
        assert_eq!(msgs.len(), 4, "should compress to summary + 3 recent");

        // First message should be the summary
        assert!(
            msgs[0].content.contains("[Compressed context summary]"),
            "first msg should be summary, got: {}",
            msgs[0].content
        );

        // Last 3 should be the recent messages
        assert_eq!(msgs[1].content, "msg 5");
        assert_eq!(msgs[2].content, "msg 6");
        assert_eq!(msgs[3].content, "msg 7");
    }

    #[tokio::test]
    async fn test_system_prompt_preserved() {
        let inner = InMemoryMemory::new();
        inner
            .append("s1", Message::system("You are helpful"))
            .await
            .unwrap();
        for i in 0..10 {
            inner
                .append("s1", Message::user(format!("msg {i}")))
                .await
                .unwrap();
        }

        let memory = CompressedMemory::new(inner, 5, 3);
        let msgs = memory.messages("s1").await.unwrap();

        // System prompt should be first
        assert_eq!(msgs[0].role, MessageRole::System);
        assert_eq!(msgs[0].content, "You are helpful");
    }

    #[tokio::test]
    async fn test_append_invalidates_cache() {
        let memory = CompressedMemory::new(InMemoryMemory::new(), 3, 2);

        for i in 0..5 {
            memory
                .append("s1", Message::user(format!("msg {i}")))
                .await
                .unwrap();
        }

        // First call populates cache
        let msgs1 = memory.messages("s1").await.unwrap();
        let summary1 = msgs1[0].content.clone();

        // Add new message — should invalidate cache
        memory.append("s1", Message::user("msg 5")).await.unwrap();

        let msgs2 = memory.messages("s1").await.unwrap();
        // Summary should be different (more messages compressed)
        assert_ne!(msgs2[0].content, summary1);
    }

    #[tokio::test]
    async fn test_stackable_decorator() {
        // CompressedMemory wrapping CompressedMemory
        let inner = CompressedMemory::new(InMemoryMemory::new(), 20, 5);
        let outer = CompressedMemory::new(inner, 10, 3);

        for i in 0..15 {
            outer
                .append("s1", Message::user(format!("msg {i}")))
                .await
                .unwrap();
        }

        let msgs = outer.messages("s1").await.unwrap();
        // Outer threshold is 10, keep_recent is 3
        // So we get: 1 summary + 3 recent = 4
        assert_eq!(msgs.len(), 4, "stacked decorators should compress");
    }

    #[tokio::test]
    async fn test_working_memory_delegates() {
        let memory = CompressedMemory::new(InMemoryMemory::new(), 5, 2);

        memory
            .set_context("s1", "key", serde_json::json!("value"))
            .await
            .unwrap();

        let val = memory.get_context("s1", "key").await.unwrap();
        assert_eq!(val, Some(serde_json::json!("value")));
    }

    #[tokio::test]
    async fn test_multibyte_content_no_panic() {
        // Regression: &msg.content[..100] used to panic on multi-byte UTF-8
        let memory = CompressedMemory::new(InMemoryMemory::new(), 3, 1);

        // Vietnamese text > 100 bytes (each Vietnamese char is 2-3 bytes)
        let long_vietnamese = "Xin chào! Đây là một tin nhắn rất dài bằng tiếng Việt để kiểm tra rằng việc cắt ngắn không gây lỗi panic khi gặp ký tự đa byte UTF-8.";
        assert!(
            long_vietnamese.len() > 100,
            "test premise: string must be >100 bytes"
        );

        for _ in 0..5 {
            memory
                .append("s1", Message::user(long_vietnamese))
                .await
                .unwrap();
        }

        // This should NOT panic
        let msgs = memory.messages("s1").await.unwrap();
        assert_eq!(msgs.len(), 2); // 1 summary + 1 recent
        assert!(msgs[0].content.contains("[Compressed"));
    }

    #[tokio::test]
    async fn test_threshold_zero_always_compresses() {
        let memory = CompressedMemory::new(InMemoryMemory::new(), 0, 1);

        memory.append("s1", Message::user("a")).await.unwrap();
        memory.append("s1", Message::user("b")).await.unwrap();

        let msgs = memory.messages("s1").await.unwrap();
        // threshold=0 means even 2 messages triggers compression
        // summary + 1 recent = 2
        assert_eq!(msgs.len(), 2);
        assert!(msgs[0].content.contains("[Compressed"));
        assert_eq!(msgs[1].content, "b");
    }

    #[tokio::test]
    async fn test_keep_recent_exceeds_message_count() {
        let memory = CompressedMemory::new(InMemoryMemory::new(), 0, 100);

        memory
            .append("s1", Message::user("only one"))
            .await
            .unwrap();

        let msgs = memory.messages("s1").await.unwrap();
        // keep_recent (100) > message count (1) → no compression
        assert_eq!(msgs.len(), 1);
        assert_eq!(msgs[0].content, "only one");
    }
}