ai_agent/services/compact/
session_memory_compact.rs1use crate::services::compact::microcompact::estimate_message_tokens;
8use crate::services::compact::prompt::get_compact_user_summary_message;
9use crate::types::{Message, MessageRole};
10use crate::utils::env_utils;
11use std::sync::atomic::{AtomicBool, Ordering};
12
13#[derive(Debug, Clone)]
15pub struct SessionMemoryCompactConfig {
16 pub min_tokens: usize,
18 pub min_text_block_messages: usize,
20 pub max_tokens: usize,
22}
23
24impl Default for SessionMemoryCompactConfig {
25 fn default() -> Self {
26 Self {
27 min_tokens: 10_000,
28 min_text_block_messages: 5,
29 max_tokens: 40_000,
30 }
31 }
32}
33
34static SM_COMPACT_CONFIG: std::sync::LazyLock<std::sync::Mutex<SessionMemoryCompactConfig>> =
36 std::sync::LazyLock::new(|| std::sync::Mutex::new(SessionMemoryCompactConfig::default()));
37static CONFIG_INITIALIZED: AtomicBool = AtomicBool::new(false);
38
39pub fn get_session_memory_compact_config() -> SessionMemoryCompactConfig {
41 SM_COMPACT_CONFIG.lock().unwrap().clone()
42}
43
44pub fn should_use_session_memory_compaction() -> bool {
46 if env_utils::is_env_truthy(
48 std::env::var("ENABLE_CLAUDE_CODE_SM_COMPACT")
49 .ok()
50 .as_deref(),
51 ) {
52 return true;
53 }
54 if env_utils::is_env_truthy(
55 std::env::var("DISABLE_CLAUDE_CODE_SM_COMPACT")
56 .ok()
57 .as_deref(),
58 ) {
59 return false;
60 }
61
62 false
64}
65
66pub fn has_text_blocks(message: &Message) -> bool {
68 match &message.role {
69 MessageRole::Assistant => !message.content.is_empty(),
70 MessageRole::User => !message.content.is_empty(),
71 _ => false,
72 }
73}
74
75pub fn is_compact_boundary_message(message: &Message) -> bool {
77 matches!(message.role, MessageRole::System)
78 && (message
79 .content
80 .contains("[Previous conversation summarized]")
81 || message.content.contains("compacted")
82 || message.content.contains("summarized"))
83}
84
85fn get_tool_result_ids(message: &Message) -> Vec<String> {
87 if !matches!(message.role, MessageRole::Tool) {
88 return Vec::new();
89 }
90 message.tool_call_id.clone().into_iter().collect()
91}
92
93fn has_tool_use_with_ids(
95 message: &Message,
96 tool_use_ids: &std::collections::HashSet<String>,
97) -> bool {
98 if !matches!(message.role, MessageRole::Assistant) {
99 return false;
100 }
101 if let Some(tool_calls) = &message.tool_calls {
102 for tc in tool_calls {
103 if tool_use_ids.contains(&tc.id) {
104 return true;
105 }
106 }
107 }
108 false
109}
110
111pub fn adjust_index_to_preserve_api_invariants(messages: &[Message], start_index: usize) -> usize {
114 if start_index <= 0 || start_index >= messages.len() {
115 return start_index;
116 }
117
118 let mut adjusted_index = start_index;
119
120 let all_tool_result_ids: std::collections::HashSet<String> = messages[start_index..]
123 .iter()
124 .flat_map(get_tool_result_ids)
125 .collect();
126
127 if !all_tool_result_ids.is_empty() {
128 let tool_use_ids_in_kept_range: std::collections::HashSet<String> = messages[start_index..]
130 .iter()
131 .filter(|m| matches!(m.role, MessageRole::Assistant))
132 .flat_map(|m| m.tool_calls.iter().flatten().map(|tc| tc.id.clone()))
133 .collect();
134
135 let needed_tool_use_ids: std::collections::HashSet<String> = all_tool_result_ids
137 .difference(&tool_use_ids_in_kept_range)
138 .cloned()
139 .collect();
140
141 for i in (0..adjusted_index).rev() {
143 if has_tool_use_with_ids(&messages[i], &needed_tool_use_ids) {
144 adjusted_index = i;
145 if let Some(tool_calls) = &messages[i].tool_calls {
147 for tc in tool_calls {
148 if needed_tool_use_ids.contains(&tc.id) {
149 }
151 }
152 }
153 }
154 }
155 }
156
157 adjusted_index
162}
163
164pub fn calculate_messages_to_keep_index(
166 messages: &[Message],
167 last_summarized_index: usize,
168) -> usize {
169 if messages.is_empty() {
170 return 0;
171 }
172
173 let config = get_session_memory_compact_config();
174
175 let mut start_index = if last_summarized_index < messages.len() {
177 last_summarized_index + 1
178 } else {
179 messages.len()
180 };
181
182 let mut total_tokens = 0;
184 let mut text_block_message_count = 0;
185
186 for i in start_index..messages.len() {
187 total_tokens += estimate_message_tokens(&[messages[i].clone()]);
188 if has_text_blocks(&messages[i]) {
189 text_block_message_count += 1;
190 }
191 }
192
193 if total_tokens >= config.max_tokens {
195 return adjust_index_to_preserve_api_invariants(messages, start_index);
196 }
197
198 if total_tokens >= config.min_tokens
200 && text_block_message_count >= config.min_text_block_messages
201 {
202 return adjust_index_to_preserve_api_invariants(messages, start_index);
203 }
204
205 let floor = messages
208 .iter()
209 .rposition(|m| is_compact_boundary_message(m))
210 .map(|idx| idx + 1)
211 .unwrap_or(0);
212
213 let mut i = if start_index > 0 { start_index - 1 } else { 0 };
214 loop {
215 if i < floor {
216 break;
217 }
218 let msg = &messages[i];
219 let msg_tokens = estimate_message_tokens(&[msg.clone()]);
220 total_tokens += msg_tokens;
221 if has_text_blocks(msg) {
222 text_block_message_count += 1;
223 }
224 start_index = i;
225
226 if total_tokens >= config.max_tokens {
228 break;
229 }
230
231 if total_tokens >= config.min_tokens
233 && text_block_message_count >= config.min_text_block_messages
234 {
235 break;
236 }
237
238 if i == 0 {
239 break;
240 }
241 i -= 1;
242 }
243
244 adjust_index_to_preserve_api_invariants(messages, start_index)
245}
246
247fn get_session_memory_template() -> &'static str {
249 r#"# Session Notes
250
251This file contains automatically extracted notes about the current conversation.
252
253## Key Points
254
255-
256
257## Decisions Made
258
259-
260
261## Open Items
262
263-
264
265## Context
266
267"#
268}
269
270fn is_session_memory_empty(content: &str) -> bool {
272 let template = get_session_memory_template();
273 content.trim() == template.trim()
274}
275
276const MAX_SECTION_LENGTH: usize = 2000;
278const MAX_CHARS_PER_SECTION: usize = MAX_SECTION_LENGTH * 4;
279
280fn truncate_session_memory_for_compact(content: &str) -> (String, bool) {
284 let mut result = String::new();
285 let mut was_truncated = false;
286 let mut current_section: Vec<String> = Vec::new();
287 let mut lines = content.lines().peekable();
288
289 while let Some(line) = lines.next() {
290 if line.starts_with('#') && !line.starts_with("## ") {
291 if !current_section.is_empty() {
293 flush_section(¤t_section, &mut result, &mut was_truncated);
294 }
295 current_section = vec![line.to_string()];
296 } else {
297 current_section.push(line.to_string());
298 }
299 }
300 if !current_section.is_empty() {
302 flush_section(¤t_section, &mut result, &mut was_truncated);
303 }
304
305 (result, was_truncated)
306}
307
308fn flush_section(lines: &[String], result: &mut String, was_truncated: &mut bool) {
309 let joined = lines.join("\n");
310 if joined.len() <= MAX_CHARS_PER_SECTION {
311 result.push_str(&joined);
312 result.push('\n');
313 } else {
314 result.push_str(&joined[..MAX_CHARS_PER_SECTION]);
315 result.push_str("\n[... section truncated for length ...]\n");
316 *was_truncated = true;
317 }
318}
319
320fn format_compact_summary_text(summary: &str) -> String {
323 let mut text = summary.to_string();
324
325 while let (Some(start), Some(end)) = (
327 text.find("<analysis>"),
328 text.rfind("</analysis>"),
329 ) {
330 text = format!("{}{}", &text[..start], &text[end + 10..]);
331 }
332
333 text = text.replace("<summary>", "Summary:\n").replace("</summary>", "");
335
336 text.trim().to_string()
337}
338
339pub async fn try_session_memory_compaction(
342 messages: &[Message],
343 _agent_id: Option<&str>,
344 auto_compact_threshold: Option<usize>,
345) -> Option<SessionMemoryCompactResult> {
346 if !should_use_session_memory_compaction() {
347 return None;
348 }
349
350 crate::session_memory::wait_for_session_memory_extraction().await;
352
353 let session_memory = match crate::session_memory::get_session_memory_content().await {
355 Ok(Some(content)) => content,
356 _ => return None,
357 };
358
359 if is_session_memory_empty(&session_memory) {
361 return None;
362 }
363
364 let last_summarized_index =
366 crate::session_memory::get_last_summarized_message_id_as_index(messages)
367 .unwrap_or(messages.len().saturating_sub(1));
368
369 let start_index = calculate_messages_to_keep_index(messages, last_summarized_index.min(messages.len().saturating_sub(1)));
371 let messages_to_keep: Vec<Message> = messages[start_index..]
372 .iter()
373 .filter(|m| !is_compact_boundary_message(m))
374 .cloned()
375 .collect();
376
377 let pre_compact_token_count = estimate_message_tokens(messages);
378
379 let (session_memory, _was_truncated) = truncate_session_memory_for_compact(&session_memory);
381
382 let formatted_summary = format_compact_summary_text(&session_memory);
384
385 let boundary_content = format!(
387 "[Previous conversation summarized]\n\n{}",
388 get_compact_user_summary_message(&formatted_summary, Some(true), None, Some(true))
389 );
390
391 let boundary_msg = Message {
393 role: MessageRole::System,
394 content: boundary_content,
395 is_meta: Some(true),
396 uuid: None,
397 ..Default::default()
398 };
399 let post_compact_token_count = estimate_message_tokens(
400 &[boundary_msg]
401 .iter()
402 .chain(messages_to_keep.iter())
403 .cloned()
404 .collect::<Vec<_>>()
405 .as_slice(),
406 );
407
408 if let Some(threshold) = auto_compact_threshold {
410 if post_compact_token_count >= threshold {
411 return None;
412 }
413 }
414
415 Some(SessionMemoryCompactResult {
416 compacted: true,
417 messages_to_keep,
418 session_memory_content: session_memory,
419 pre_compact_token_count,
420 post_compact_token_count,
421 })
422}
423
424#[derive(Debug, Clone)]
426pub struct SessionMemoryCompactResult {
427 pub compacted: bool,
428 pub messages_to_keep: Vec<Message>,
429 pub session_memory_content: String,
430 pub pre_compact_token_count: usize,
431 pub post_compact_token_count: usize,
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_default_config() {
440 let config = get_session_memory_compact_config();
441 assert_eq!(config.min_tokens, 10_000);
442 assert_eq!(config.min_text_block_messages, 5);
443 assert_eq!(config.max_tokens, 40_000);
444 }
445
446 #[test]
447 fn test_has_text_blocks() {
448 let msg = Message {
449 role: MessageRole::User,
450 content: "Hello".to_string(),
451 ..Default::default()
452 };
453 assert!(has_text_blocks(&msg));
454
455 let empty = Message {
456 role: MessageRole::User,
457 content: String::new(),
458 ..Default::default()
459 };
460 assert!(!has_text_blocks(&empty));
461 }
462
463 #[test]
464 fn test_adjust_index_empty_messages() {
465 assert_eq!(adjust_index_to_preserve_api_invariants(&[], 0), 0);
466 }
467
468 #[test]
469 fn test_calculate_messages_to_keep_empty() {
470 assert_eq!(calculate_messages_to_keep_index(&[], 0), 0);
471 }
472
473 #[test]
474 fn test_is_compact_boundary_message() {
475 let boundary = Message {
476 role: MessageRole::System,
477 content: "[Previous conversation summarized]".to_string(),
478 ..Default::default()
479 };
480 assert!(is_compact_boundary_message(&boundary));
481
482 let normal = Message {
483 role: MessageRole::User,
484 content: "Hello".to_string(),
485 ..Default::default()
486 };
487 assert!(!is_compact_boundary_message(&normal));
488 }
489}