Skip to main content

agentlib_memory/
summarizing.rs

1use agentlib_core::{
2    MemoryProvider, MemoryReadOptions, MemoryWriteOptions, ModelMessage, ModelProvider,
3    ModelRequest, Role, async_trait, estimate_messages_tokens,
4};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10pub struct SummarizingMemory {
11    model: Arc<dyn ModelProvider>,
12    active_window_tokens: usize,
13    summary_prompt: String,
14    sessions: Arc<Mutex<HashMap<String, SessionData>>>,
15}
16
17struct SessionData {
18    summary: Option<String>,
19    active_messages: Vec<ModelMessage>,
20}
21
22impl SummarizingMemory {
23    pub fn new(model: Arc<dyn ModelProvider>, active_window_tokens: usize) -> Self {
24        Self {
25            model,
26            active_window_tokens,
27            summary_prompt: "You are a memory compression assistant. Summarize the following conversation concisely.".to_string(),
28            sessions: Arc::new(Mutex::new(HashMap::new())),
29        }
30    }
31
32    async fn compress(&self, session: &mut SessionData) -> Result<()> {
33        let messages = &session.active_messages;
34        let split_at = messages.len() / 2;
35        let (to_compress, to_keep) = messages.split_at(split_at);
36
37        let mut compress_input = String::new();
38        if let Some(summary) = &session.summary {
39            compress_input.push_str(&format!(
40                "Previous summary:\n{}\n\nNew conversation to add:\n",
41                summary
42            ));
43        }
44
45        for msg in to_compress {
46            compress_input.push_str(&format!("{:?}: {}\n", msg.role, msg.content));
47        }
48
49        let response = self
50            .model
51            .complete(ModelRequest {
52                messages: vec![
53                    ModelMessage {
54                        role: Role::System,
55                        content: self.summary_prompt.clone(),
56                        tool_call_id: None,
57                        tool_calls: None,
58                    },
59                    ModelMessage {
60                        role: Role::User,
61                        content: compress_input,
62                        tool_call_id: None,
63                        tool_calls: None,
64                    },
65                ],
66                tools: None,
67            })
68            .await?;
69
70        session.summary = Some(response.message.content);
71        session.active_messages = to_keep.to_vec();
72        Ok(())
73    }
74}
75
76#[async_trait]
77impl MemoryProvider for SummarizingMemory {
78    async fn read(&self, options: MemoryReadOptions) -> Result<Vec<ModelMessage>> {
79        let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
80        let sessions = self.sessions.lock().await;
81        let session = match sessions.get(&session_id) {
82            Some(s) => s,
83            None => return Ok(Vec::new()),
84        };
85
86        let mut messages = Vec::new();
87        if let Some(summary) = &session.summary {
88            messages.push(ModelMessage {
89                role: Role::System,
90                content: format!("[Conversation summary so far]\n{}", summary),
91                tool_call_id: None,
92                tool_calls: None,
93            });
94        }
95        messages.extend(session.active_messages.clone());
96        Ok(messages)
97    }
98
99    async fn write(&self, messages: Vec<ModelMessage>, options: MemoryWriteOptions) -> Result<()> {
100        let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
101
102        let non_system: Vec<ModelMessage> = messages
103            .into_iter()
104            .filter(|m| m.role != Role::System)
105            .collect();
106
107        let mut sessions = self.sessions.lock().await;
108        let session = sessions.entry(session_id).or_insert(SessionData {
109            summary: None,
110            active_messages: Vec::new(),
111        });
112
113        session.active_messages = non_system;
114
115        let tokens = estimate_messages_tokens(&session.active_messages);
116        if tokens > self.active_window_tokens {
117            self.compress(session).await?;
118        }
119
120        Ok(())
121    }
122}