Skip to main content

ai_agent/services/compact/
session_memory_compact.rs

1// Source: ~/claudecode/openclaudecode/src/services/compact/sessionMemoryCompact.ts
2//! Session memory compaction.
3//!
4//! Uses pre-extracted session memory as the summary instead of making an API call.
5//! Keeps recent messages above minimum thresholds and preserves API invariants.
6
7use crate::services::compact::microcompact::estimate_message_tokens;
8use crate::services::compact::prompt::get_compact_user_summary_message;
9use crate::types::{Message, MessageRole};
10use crate::utils::env_utils;
11use std::sync::atomic::{AtomicBool, Ordering};
12
13/// Configuration for session memory compaction thresholds
14#[derive(Debug, Clone)]
15pub struct SessionMemoryCompactConfig {
16    /// Minimum tokens to preserve after compaction
17    pub min_tokens: usize,
18    /// Minimum number of messages with text blocks to keep
19    pub min_text_block_messages: usize,
20    /// Maximum tokens to preserve after compaction (hard cap)
21    pub max_tokens: usize,
22}
23
24impl Default for SessionMemoryCompactConfig {
25    fn default() -> Self {
26        Self {
27            min_tokens: 10_000,
28            min_text_block_messages: 5,
29            max_tokens: 40_000,
30        }
31    }
32}
33
34// Current configuration
35static SM_COMPACT_CONFIG: std::sync::LazyLock<std::sync::Mutex<SessionMemoryCompactConfig>> =
36    std::sync::LazyLock::new(|| std::sync::Mutex::new(SessionMemoryCompactConfig::default()));
37static CONFIG_INITIALIZED: AtomicBool = AtomicBool::new(false);
38
39/// Get the current session memory compact configuration
40pub fn get_session_memory_compact_config() -> SessionMemoryCompactConfig {
41    SM_COMPACT_CONFIG.lock().unwrap().clone()
42}
43
44/// Check if session memory compaction should be used
45pub fn should_use_session_memory_compaction() -> bool {
46    // Allow env override for testing
47    if env_utils::is_env_truthy(
48        std::env::var("ENABLE_CLAUDE_CODE_SM_COMPACT")
49            .ok()
50            .as_deref(),
51    ) {
52        return true;
53    }
54    if env_utils::is_env_truthy(
55        std::env::var("DISABLE_CLAUDE_CODE_SM_COMPACT")
56            .ok()
57            .as_deref(),
58    ) {
59        return false;
60    }
61
62    // For now, default to false (feature-gated in TypeScript)
63    false
64}
65
66/// Check if a message contains text blocks
67pub fn has_text_blocks(message: &Message) -> bool {
68    match &message.role {
69        MessageRole::Assistant => !message.content.is_empty(),
70        MessageRole::User => !message.content.is_empty(),
71        _ => false,
72    }
73}
74
75/// Check if a message is a compact boundary message
76pub fn is_compact_boundary_message(message: &Message) -> bool {
77    matches!(message.role, MessageRole::System)
78        && (message
79            .content
80            .contains("[Previous conversation summarized]")
81            || message.content.contains("compacted")
82            || message.content.contains("summarized"))
83}
84
85/// Collect tool_result IDs from a user message
86fn get_tool_result_ids(message: &Message) -> Vec<String> {
87    if !matches!(message.role, MessageRole::Tool) {
88        return Vec::new();
89    }
90    message.tool_call_id.clone().into_iter().collect()
91}
92
93/// Check if an assistant message contains tool_use blocks with any of the given ids
94fn has_tool_use_with_ids(
95    message: &Message,
96    tool_use_ids: &std::collections::HashSet<String>,
97) -> bool {
98    if !matches!(message.role, MessageRole::Assistant) {
99        return false;
100    }
101    if let Some(tool_calls) = &message.tool_calls {
102        for tc in tool_calls {
103            if tool_use_ids.contains(&tc.id) {
104                return true;
105            }
106        }
107    }
108    false
109}
110
111/// Adjust the start index to ensure we don't split tool_use/tool_result pairs
112/// or thinking blocks that share the same message.id with kept assistant messages.
113pub fn adjust_index_to_preserve_api_invariants(messages: &[Message], start_index: usize) -> usize {
114    if start_index <= 0 || start_index >= messages.len() {
115        return start_index;
116    }
117
118    let mut adjusted_index = start_index;
119
120    // Step 1: Handle tool_use/tool_result pairs
121    // Collect tool_result IDs from ALL messages in the kept range
122    let all_tool_result_ids: std::collections::HashSet<String> = messages[start_index..]
123        .iter()
124        .flat_map(get_tool_result_ids)
125        .collect();
126
127    if !all_tool_result_ids.is_empty() {
128        // Collect tool_use IDs already in the kept range
129        let tool_use_ids_in_kept_range: std::collections::HashSet<String> = messages[start_index..]
130            .iter()
131            .filter(|m| matches!(m.role, MessageRole::Assistant))
132            .flat_map(|m| m.tool_calls.iter().flatten().map(|tc| tc.id.clone()))
133            .collect();
134
135        // Only look for tool_uses that are NOT already in the kept range
136        let needed_tool_use_ids: std::collections::HashSet<String> = all_tool_result_ids
137            .difference(&tool_use_ids_in_kept_range)
138            .cloned()
139            .collect();
140
141        // Find the assistant message(s) with matching tool_use blocks
142        for i in (0..adjusted_index).rev() {
143            if has_tool_use_with_ids(&messages[i], &needed_tool_use_ids) {
144                adjusted_index = i;
145                // Remove found tool_use_ids from the set
146                if let Some(tool_calls) = &messages[i].tool_calls {
147                    for tc in tool_calls {
148                        if needed_tool_use_ids.contains(&tc.id) {
149                            // Can't remove from HashSet in this loop, just continue
150                        }
151                    }
152                }
153            }
154        }
155    }
156
157    // Step 2: Handle thinking blocks that share message.id with kept assistant messages
158    // Note: api_types::Message doesn't have message_id field, so skip this logic
159    // In the original TypeScript, this handled thinking blocks that share IDs with assistant messages
160
161    adjusted_index
162}
163
164/// Calculate the starting index for messages to keep after compaction.
165pub fn calculate_messages_to_keep_index(
166    messages: &[Message],
167    last_summarized_index: usize,
168) -> usize {
169    if messages.is_empty() {
170        return 0;
171    }
172
173    let config = get_session_memory_compact_config();
174
175    // Start from the message after last_summarized_index
176    let mut start_index = if last_summarized_index < messages.len() {
177        last_summarized_index + 1
178    } else {
179        messages.len()
180    };
181
182    // Calculate current tokens and text-block message count from start_index to end
183    let mut total_tokens = 0;
184    let mut text_block_message_count = 0;
185
186    for i in start_index..messages.len() {
187        total_tokens += estimate_message_tokens(&[messages[i].clone()]);
188        if has_text_blocks(&messages[i]) {
189            text_block_message_count += 1;
190        }
191    }
192
193    // Check if we already hit the max cap
194    if total_tokens >= config.max_tokens {
195        return adjust_index_to_preserve_api_invariants(messages, start_index);
196    }
197
198    // Check if we already meet both minimums
199    if total_tokens >= config.min_tokens
200        && text_block_message_count >= config.min_text_block_messages
201    {
202        return adjust_index_to_preserve_api_invariants(messages, start_index);
203    }
204
205    // Expand backwards until we meet both minimums or hit max cap
206    // Floor at the last compact boundary
207    let floor = messages
208        .iter()
209        .rposition(|m| is_compact_boundary_message(m))
210        .map(|idx| idx + 1)
211        .unwrap_or(0);
212
213    let mut i = if start_index > 0 { start_index - 1 } else { 0 };
214    loop {
215        if i < floor {
216            break;
217        }
218        let msg = &messages[i];
219        let msg_tokens = estimate_message_tokens(&[msg.clone()]);
220        total_tokens += msg_tokens;
221        if has_text_blocks(msg) {
222            text_block_message_count += 1;
223        }
224        start_index = i;
225
226        // Stop if we hit the max cap
227        if total_tokens >= config.max_tokens {
228            break;
229        }
230
231        // Stop if we meet both minimums
232        if total_tokens >= config.min_tokens
233            && text_block_message_count >= config.min_text_block_messages
234        {
235            break;
236        }
237
238        if i == 0 {
239            break;
240        }
241        i -= 1;
242    }
243
244    adjust_index_to_preserve_api_invariants(messages, start_index)
245}
246
247/// Default session memory template content (matches the template created by session_memory.rs)
248fn get_session_memory_template() -> &'static str {
249    r#"# Session Notes
250
251This file contains automatically extracted notes about the current conversation.
252
253## Key Points
254
255-
256
257## Decisions Made
258
259-
260
261## Open Items
262
263-
264
265## Context
266
267"#
268}
269
270/// Check if session memory content is just the default template (i.e., no real content yet)
271fn is_session_memory_empty(content: &str) -> bool {
272    let template = get_session_memory_template();
273    content.trim() == template.trim()
274}
275
276/// Maximum characters per section in session memory before truncation
277const MAX_SECTION_LENGTH: usize = 2000;
278const MAX_CHARS_PER_SECTION: usize = MAX_SECTION_LENGTH * 4;
279
280/// Truncate oversized session memory sections for compact.
281/// Walks markdown sections (lines starting with `# `) and truncates each
282/// at MAX_CHARS_PER_SECTION characters, appending a truncation notice.
283fn truncate_session_memory_for_compact(content: &str) -> (String, bool) {
284    let mut result = String::new();
285    let mut was_truncated = false;
286    let mut current_section: Vec<String> = Vec::new();
287    let mut lines = content.lines().peekable();
288
289    while let Some(line) = lines.next() {
290        if line.starts_with('#') && !line.starts_with("## ") {
291            // Flush previous section
292            if !current_section.is_empty() {
293                flush_section(&current_section, &mut result, &mut was_truncated);
294            }
295            current_section = vec![line.to_string()];
296        } else {
297            current_section.push(line.to_string());
298        }
299    }
300    // Flush last section
301    if !current_section.is_empty() {
302        flush_section(&current_section, &mut result, &mut was_truncated);
303    }
304
305    (result, was_truncated)
306}
307
308fn flush_section(lines: &[String], result: &mut String, was_truncated: &mut bool) {
309    let joined = lines.join("\n");
310    if joined.len() <= MAX_CHARS_PER_SECTION {
311        result.push_str(&joined);
312        result.push('\n');
313    } else {
314        result.push_str(&joined[..MAX_CHARS_PER_SECTION]);
315        result.push_str("\n[... section truncated for length ...]\n");
316        *was_truncated = true;
317    }
318}
319
320/// Format session memory content for compact summary.
321/// Strips <analysis> tags and reformats <summary> tags.
322fn format_compact_summary_text(summary: &str) -> String {
323    let mut text = summary.to_string();
324
325    // Strip <analysis>...</analysis> blocks
326    while let (Some(start), Some(end)) = (
327        text.find("<analysis>"),
328        text.rfind("</analysis>"),
329    ) {
330        text = format!("{}{}", &text[..start], &text[end + 10..]);
331    }
332
333    // Replace <summary> and </summary> tags
334    text = text.replace("<summary>", "Summary:\n").replace("</summary>", "");
335
336    text.trim().to_string()
337}
338
339/// Try to use session memory for compaction instead of traditional compaction.
340/// Returns None if session memory compaction cannot be used.
341pub async fn try_session_memory_compaction(
342    messages: &[Message],
343    _agent_id: Option<&str>,
344    auto_compact_threshold: Option<usize>,
345) -> Option<SessionMemoryCompactResult> {
346    if !should_use_session_memory_compaction() {
347        return None;
348    }
349
350    // Wait for any in-progress extraction to complete
351    crate::session_memory::wait_for_session_memory_extraction().await;
352
353    // Get session memory content from file
354    let session_memory = match crate::session_memory::get_session_memory_content().await {
355        Ok(Some(content)) => content,
356        _ => return None,
357    };
358
359    // Check if session memory has real content (not just the template)
360    if is_session_memory_empty(&session_memory) {
361        return None;
362    }
363
364    // Determine last summarized index
365    let last_summarized_index =
366        crate::session_memory::get_last_summarized_message_id_as_index(messages)
367            .unwrap_or(messages.len().saturating_sub(1));
368
369    // Calculate which messages to keep
370    let start_index = calculate_messages_to_keep_index(messages, last_summarized_index.min(messages.len().saturating_sub(1)));
371    let messages_to_keep: Vec<Message> = messages[start_index..]
372        .iter()
373        .filter(|m| !is_compact_boundary_message(m))
374        .cloned()
375        .collect();
376
377    let pre_compact_token_count = estimate_message_tokens(messages);
378
379    // Truncate session memory if needed for compact
380    let (session_memory, _was_truncated) = truncate_session_memory_for_compact(&session_memory);
381
382    // Format summary (strip analysis tags, format for display)
383    let formatted_summary = format_compact_summary_text(&session_memory);
384
385    // Build the boundary content
386    let boundary_content = format!(
387        "[Previous conversation summarized]\n\n{}",
388        get_compact_user_summary_message(&formatted_summary, Some(true), None, Some(true))
389    );
390
391    // Count tokens of boundary + kept messages
392    let boundary_msg = Message {
393        role: MessageRole::System,
394        content: boundary_content,
395        is_meta: Some(true),
396            uuid: None,
397        ..Default::default()
398    };
399    let post_compact_token_count = estimate_message_tokens(
400        &[boundary_msg]
401            .iter()
402            .chain(messages_to_keep.iter())
403            .cloned()
404            .collect::<Vec<_>>()
405            .as_slice(),
406    );
407
408    // Check if compaction would re-trigger (post >= threshold)
409    if let Some(threshold) = auto_compact_threshold {
410        if post_compact_token_count >= threshold {
411            return None;
412        }
413    }
414
415    Some(SessionMemoryCompactResult {
416        compacted: true,
417        messages_to_keep,
418        session_memory_content: session_memory,
419        pre_compact_token_count,
420        post_compact_token_count,
421    })
422}
423
424/// Result from session memory compaction
425#[derive(Debug, Clone)]
426pub struct SessionMemoryCompactResult {
427    pub compacted: bool,
428    pub messages_to_keep: Vec<Message>,
429    pub session_memory_content: String,
430    pub pre_compact_token_count: usize,
431    pub post_compact_token_count: usize,
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_default_config() {
440        let config = get_session_memory_compact_config();
441        assert_eq!(config.min_tokens, 10_000);
442        assert_eq!(config.min_text_block_messages, 5);
443        assert_eq!(config.max_tokens, 40_000);
444    }
445
446    #[test]
447    fn test_has_text_blocks() {
448        let msg = Message {
449            role: MessageRole::User,
450            content: "Hello".to_string(),
451            ..Default::default()
452        };
453        assert!(has_text_blocks(&msg));
454
455        let empty = Message {
456            role: MessageRole::User,
457            content: String::new(),
458            ..Default::default()
459        };
460        assert!(!has_text_blocks(&empty));
461    }
462
463    #[test]
464    fn test_adjust_index_empty_messages() {
465        assert_eq!(adjust_index_to_preserve_api_invariants(&[], 0), 0);
466    }
467
468    #[test]
469    fn test_calculate_messages_to_keep_empty() {
470        assert_eq!(calculate_messages_to_keep_index(&[], 0), 0);
471    }
472
473    #[test]
474    fn test_is_compact_boundary_message() {
475        let boundary = Message {
476            role: MessageRole::System,
477            content: "[Previous conversation summarized]".to_string(),
478            ..Default::default()
479        };
480        assert!(is_compact_boundary_message(&boundary));
481
482        let normal = Message {
483            role: MessageRole::User,
484            content: "Hello".to_string(),
485            ..Default::default()
486        };
487        assert!(!is_compact_boundary_message(&normal));
488    }
489}