agentlib_memory/
summarizing.rs1use agentlib_core::{
2 MemoryProvider, MemoryReadOptions, MemoryWriteOptions, ModelMessage, ModelProvider,
3 ModelRequest, Role, async_trait, estimate_messages_tokens,
4};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10pub struct SummarizingMemory {
11 model: Arc<dyn ModelProvider>,
12 active_window_tokens: usize,
13 summary_prompt: String,
14 sessions: Arc<Mutex<HashMap<String, SessionData>>>,
15}
16
17struct SessionData {
18 summary: Option<String>,
19 active_messages: Vec<ModelMessage>,
20}
21
22impl SummarizingMemory {
23 pub fn new(model: Arc<dyn ModelProvider>, active_window_tokens: usize) -> Self {
24 Self {
25 model,
26 active_window_tokens,
27 summary_prompt: "You are a memory compression assistant. Summarize the following conversation concisely.".to_string(),
28 sessions: Arc::new(Mutex::new(HashMap::new())),
29 }
30 }
31
32 async fn compress(&self, session: &mut SessionData) -> Result<()> {
33 let messages = &session.active_messages;
34 let split_at = messages.len() / 2;
35 let (to_compress, to_keep) = messages.split_at(split_at);
36
37 let mut compress_input = String::new();
38 if let Some(summary) = &session.summary {
39 compress_input.push_str(&format!(
40 "Previous summary:\n{}\n\nNew conversation to add:\n",
41 summary
42 ));
43 }
44
45 for msg in to_compress {
46 compress_input.push_str(&format!("{:?}: {}\n", msg.role, msg.content));
47 }
48
49 let response = self
50 .model
51 .complete(ModelRequest {
52 messages: vec![
53 ModelMessage {
54 role: Role::System,
55 content: self.summary_prompt.clone(),
56 tool_call_id: None,
57 tool_calls: None,
58 },
59 ModelMessage {
60 role: Role::User,
61 content: compress_input,
62 tool_call_id: None,
63 tool_calls: None,
64 },
65 ],
66 tools: None,
67 })
68 .await?;
69
70 session.summary = Some(response.message.content);
71 session.active_messages = to_keep.to_vec();
72 Ok(())
73 }
74}
75
76#[async_trait]
77impl MemoryProvider for SummarizingMemory {
78 async fn read(&self, options: MemoryReadOptions) -> Result<Vec<ModelMessage>> {
79 let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
80 let sessions = self.sessions.lock().await;
81 let session = match sessions.get(&session_id) {
82 Some(s) => s,
83 None => return Ok(Vec::new()),
84 };
85
86 let mut messages = Vec::new();
87 if let Some(summary) = &session.summary {
88 messages.push(ModelMessage {
89 role: Role::System,
90 content: format!("[Conversation summary so far]\n{}", summary),
91 tool_call_id: None,
92 tool_calls: None,
93 });
94 }
95 messages.extend(session.active_messages.clone());
96 Ok(messages)
97 }
98
99 async fn write(&self, messages: Vec<ModelMessage>, options: MemoryWriteOptions) -> Result<()> {
100 let session_id = options.session_id.unwrap_or_else(|| "default".to_string());
101
102 let non_system: Vec<ModelMessage> = messages
103 .into_iter()
104 .filter(|m| m.role != Role::System)
105 .collect();
106
107 let mut sessions = self.sessions.lock().await;
108 let session = sessions.entry(session_id).or_insert(SessionData {
109 summary: None,
110 active_messages: Vec::new(),
111 });
112
113 session.active_messages = non_system;
114
115 let tokens = estimate_messages_tokens(&session.active_messages);
116 if tokens > self.active_window_tokens {
117 self.compress(session).await?;
118 }
119
120 Ok(())
121 }
122}