agentlib-memory 0.1.0

Advanced memory providers and history management for AgentLib
Documentation
use agentlib_core::{
    MemoryProvider, MemoryReadOptions, MemoryWriteOptions, ModelMessage, ModelProvider,
    ModelRequest, Role, async_trait, estimate_messages_tokens,
};
use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;

pub struct SummarizingMemory {
    model: Arc<dyn ModelProvider>,
    active_window_tokens: usize,
    summary_prompt: String,
    sessions: Arc<Mutex<HashMap<String, SessionData>>>,
}

struct SessionData {
    summary: Option<String>,
    active_messages: Vec<ModelMessage>,
}

impl SummarizingMemory {
    pub fn new(model: Arc<dyn ModelProvider>, active_window_tokens: usize) -> Self {
        Self {
            model,
            active_window_tokens,
            summary_prompt: "You are a memory compression assistant. Summarize the following conversation concisely.".to_string(),
            sessions: Arc::new(Mutex::new(HashMap::new())),
        }
    }

    async fn compress(&self, session: &mut SessionData) -> Result<()> {
        let messages = &session.active_messages;
        let split_at = messages.len() / 2;
        let (to_compress, to_keep) = messages.split_at(split_at);

        let mut compress_input = String::new();
        if let Some(summary) = &session.summary {
            compress_input.push_str(&format!(
                "Previous summary:\n{}\n\nNew conversation to add:\n",
                summary
            ));
        }

        for msg in to_compress {
            compress_input.push_str(&format!("{:?}: {}\n", msg.role, msg.content));
        }

        let response = self
            .model
            .complete(ModelRequest {
                messages: vec![
                    ModelMessage {
                        role: Role::System,
                        content: self.summary_prompt.clone(),
                        tool_call_id: None,
                        tool_calls: None,
                    },
                    ModelMessage {
                        role: Role::User,
                        content: compress_input,
                        tool_call_id: None,
                        tool_calls: None,
                    },
                ],
                tools: None,
            })
            .await?;

        session.summary = Some(response.message.content);
        session.active_messages = to_keep.to_vec();
        Ok(())
    }
}

#[async_trait]
impl MemoryProvider for SummarizingMemory {
    async fn read(&self, options: MemoryReadOptions) -> Result<Vec<ModelMessage>> {
        let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
        let sessions = self.sessions.lock().await;
        let session = match sessions.get(&session_id) {
            Some(s) => s,
            None => return Ok(Vec::new()),
        };

        let mut messages = Vec::new();
        if let Some(summary) = &session.summary {
            messages.push(ModelMessage {
                role: Role::System,
                content: format!("[Conversation summary so far]\n{}", summary),
                tool_call_id: None,
                tool_calls: None,
            });
        }
        messages.extend(session.active_messages.clone());
        Ok(messages)
    }

    async fn write(&self, messages: Vec<ModelMessage>, options: MemoryWriteOptions) -> Result<()> {
        let session_id = options.session_id.unwrap_or_else(|| "default".to_string());

        let non_system: Vec<ModelMessage> = messages
            .into_iter()
            .filter(|m| m.role != Role::System)
            .collect();

        let mut sessions = self.sessions.lock().await;
        let session = sessions.entry(session_id).or_insert(SessionData {
            summary: None,
            active_messages: Vec::new(),
        });

        session.active_messages = non_system;

        let tokens = estimate_messages_tokens(&session.active_messages);
        if tokens > self.active_window_tokens {
            self.compress(session).await?;
        }

        Ok(())
    }
}