Skip to main content

cersei_agent/
compact.rs

1//! Auto-compact: context window management for long conversations.
2//!
3//! When the conversation approaches the context window limit, older messages
4//! are summarized to free space while preserving essential context.
5
6use cersei_provider::Provider;
7use cersei_types::*;
8
9// ─── Constants ───────────────────────────────────────────────────────────────
10
11/// Fraction of context window that triggers auto-compact.
12pub const AUTOCOMPACT_TRIGGER_FRACTION: f64 = 0.90;
13/// Number of recent messages to always preserve (never compacted).
14pub const KEEP_RECENT_MESSAGES: usize = 10;
15/// Max consecutive failures before disabling auto-compact.
16pub const MAX_CONSECUTIVE_FAILURES: u32 = 3;
17/// Warning threshold (80% of context window).
18pub const WARNING_PCT: f64 = 0.80;
19/// Critical threshold (95% of context window).
20pub const CRITICAL_PCT: f64 = 0.95;
21
22// ─── Types ───────────────────────────────────────────────────────────────────
23
24/// Session-level compaction tracking.
25#[derive(Debug, Clone, Default)]
26pub struct AutoCompactState {
27    pub compaction_count: u32,
28    pub consecutive_failures: u32,
29    pub disabled: bool,
30}
31
32impl AutoCompactState {
33    pub fn on_success(&mut self) {
34        self.compaction_count += 1;
35        self.consecutive_failures = 0;
36    }
37
38    pub fn on_failure(&mut self) {
39        self.consecutive_failures += 1;
40        if self.consecutive_failures >= MAX_CONSECUTIVE_FAILURES {
41            self.disabled = true;
42        }
43    }
44}
45
46/// Context window fullness level.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum TokenWarningState {
49    /// Below 80% — no action needed.
50    Ok,
51    /// 80-95% — warn user, consider compacting.
52    Warning,
53    /// Above 95% — critical, must compact or will fail.
54    Critical,
55}
56
57/// A semantically coherent group of messages for summarization.
58#[derive(Debug, Clone)]
59pub struct MessageGroup {
60    pub messages: Vec<Message>,
61    pub topic_hint: Option<String>,
62    pub token_estimate: usize,
63}
64
65/// Result of a compaction operation.
66#[derive(Debug, Clone)]
67pub struct CompactResult {
68    pub messages_before: usize,
69    pub messages_after: usize,
70    pub tokens_freed_estimate: u64,
71    pub summary: String,
72}
73
74/// What triggered the compaction.
75#[derive(Debug, Clone, Copy)]
76pub enum CompactTrigger {
77    AutoThreshold,
78    Manual,
79    ContextOverflow,
80}
81
82// ─── Token estimation ────────────────────────────────────────────────────────
83
84/// Rough token estimate for a message (~4 chars per token).
85pub fn estimate_tokens(text: &str) -> u64 {
86    (text.len() as u64) / 4
87}
88
89/// Estimate tokens for a list of messages.
90pub fn estimate_messages_tokens(messages: &[Message]) -> u64 {
91    messages
92        .iter()
93        .map(|m| estimate_tokens(&m.get_all_text()))
94        .sum()
95}
96
97/// Get context window size for a model.
98pub fn context_window_for_model(model: &str) -> u64 {
99    match model {
100        m if m.contains("gpt-5") => 1_000_000,
101        m if m.contains("gemini") => 1_000_000,
102        m if m.starts_with("o1") || m.starts_with("o3") => 200_000,
103        m if m.contains("opus") => 200_000,
104        m if m.contains("sonnet") => 200_000,
105        m if m.contains("haiku") => 200_000,
106        m if m.contains("gpt-4o") => 128_000,
107        m if m.contains("gpt-4-turbo") => 128_000,
108        m if m.contains("gpt-4") => 8_192,
109        m if m.contains("gpt-3.5") => 16_385,
110        m if m.contains("llama") => 8_192,
111        _ => 200_000, // default to large
112    }
113}
114
115// ─── Warning state ───────────────────────────────────────────────────────────
116
117/// Calculate the token warning state given current usage.
118pub fn calculate_token_warning_state(tokens_used: u64, context_limit: u64) -> TokenWarningState {
119    if context_limit == 0 {
120        return TokenWarningState::Ok;
121    }
122    let pct = tokens_used as f64 / context_limit as f64;
123    if pct >= CRITICAL_PCT {
124        TokenWarningState::Critical
125    } else if pct >= WARNING_PCT {
126        TokenWarningState::Warning
127    } else {
128        TokenWarningState::Ok
129    }
130}
131
132// ─── Should compact ──────────────────────────────────────────────────────────
133
134/// Check if compaction should trigger.
135pub fn should_compact(tokens_used: u64, context_limit: u64) -> bool {
136    if context_limit == 0 {
137        return false;
138    }
139    (tokens_used as f64 / context_limit as f64) >= AUTOCOMPACT_TRIGGER_FRACTION
140}
141
142/// Check if auto-compact should run (considering state/circuit breaker).
143pub fn should_auto_compact(tokens_used: u64, context_limit: u64, state: &AutoCompactState) -> bool {
144    if state.disabled {
145        return false;
146    }
147    should_compact(tokens_used, context_limit)
148}
149
150/// Check if context collapse is needed (emergency, >98%).
151pub fn should_context_collapse(tokens_used: u64, context_limit: u64) -> bool {
152    if context_limit == 0 {
153        return false;
154    }
155    (tokens_used as f64 / context_limit as f64) >= 0.98
156}
157
158// ─── Message grouping ────────────────────────────────────────────────────────
159
160/// Extract a topic hint from messages (first file path or tool name).
161fn extract_topic_hint(messages: &[Message]) -> Option<String> {
162    for msg in messages {
163        for block in msg.content_blocks() {
164            match &block {
165                ContentBlock::ToolUse { name, input, .. } => {
166                    if let Some(path) = input.get("file_path").and_then(|v| v.as_str()) {
167                        return Some(path.to_string());
168                    }
169                    return Some(name.clone());
170                }
171                _ => {}
172            }
173        }
174    }
175    None
176}
177
178/// Group messages into semantically coherent chunks at API-round boundaries.
179/// Each group = one assistant response + its tool results.
180pub fn group_messages_for_compact(messages: &[Message]) -> Vec<MessageGroup> {
181    let mut groups: Vec<MessageGroup> = Vec::new();
182    let mut current: Vec<Message> = Vec::new();
183
184    for msg in messages {
185        current.push(msg.clone());
186        // End group at assistant messages that don't have tool use (end of a "round")
187        if msg.role == Role::Assistant && !msg.has_tool_use() {
188            let token_est = current.iter().map(|m| m.get_all_text().len() / 4).sum();
189            let hint = extract_topic_hint(&current);
190            groups.push(MessageGroup {
191                messages: std::mem::take(&mut current),
192                topic_hint: hint,
193                token_estimate: token_est,
194            });
195        }
196    }
197    // Leftover messages
198    if !current.is_empty() {
199        let token_est = current.iter().map(|m| m.get_all_text().len() / 4).sum();
200        let hint = extract_topic_hint(&current);
201        groups.push(MessageGroup {
202            messages: current,
203            topic_hint: hint,
204            token_estimate: token_est,
205        });
206    }
207    groups
208}
209
210// ─── Snip compact (simple truncation) ────────────────────────────────────────
211
212/// Remove oldest messages, keeping only the newest `keep_n`.
213/// Returns (remaining messages, estimated tokens freed).
214pub fn snip_compact(messages: Vec<Message>, keep_n: usize) -> (Vec<Message>, u64) {
215    if messages.len() <= keep_n {
216        return (messages, 0);
217    }
218    let removed = &messages[..messages.len() - keep_n];
219    let freed = estimate_messages_tokens(removed);
220    let kept = messages[messages.len() - keep_n..].to_vec();
221    (kept, freed)
222}
223
224/// Calculate how many messages to keep given a token budget.
225pub fn calculate_messages_to_keep_index(messages: &[Message], token_budget: u64) -> usize {
226    let mut total: u64 = 0;
227    for (i, msg) in messages.iter().rev().enumerate() {
228        total += estimate_tokens(&msg.get_all_text());
229        if total > token_budget {
230            return messages.len() - i;
231        }
232    }
233    0 // keep all
234}
235
236// ─── Collapse strategies ─────────────────────────────────────────────────────
237
238/// Collapse repeated file read results: if the same file is read multiple
239/// times, only keep the latest result.
240pub fn collapse_read_tool_results(messages: Vec<Message>) -> Vec<Message> {
241    let mut seen_files: std::collections::HashSet<String> = std::collections::HashSet::new();
242    let mut result: Vec<Message> = Vec::new();
243
244    // Process in reverse to keep latest reads
245    for msg in messages.into_iter().rev() {
246        let dominated = match &msg.content {
247            MessageContent::Blocks(blocks) => {
248                blocks.iter().all(|b| {
249                    if let ContentBlock::ToolResult {
250                        tool_use_id,
251                        content,
252                        ..
253                    } = b
254                    {
255                        // Check if this is a file read result we've already seen
256                        if let ToolResultContent::Text(text) = content {
257                            if text.contains('\t') {
258                                // Line-numbered output = file read
259                                let key = tool_use_id.clone();
260                                if seen_files.contains(&key) {
261                                    return true; // dominated, skip
262                                }
263                                seen_files.insert(key);
264                            }
265                        }
266                        false
267                    } else {
268                        false
269                    }
270                })
271            }
272            _ => false,
273        };
274
275        if !dominated {
276            result.push(msg);
277        }
278    }
279
280    result.reverse();
281    result
282}
283
284// ─── Compact prompt ──────────────────────────────────────────────────────────
285
286/// Build the compaction prompt for the LLM.
287pub fn get_compact_prompt(custom_instructions: Option<&str>) -> String {
288    let mut prompt = String::from(
289        "Summarize the conversation so far. Focus on:\n\
290        1. Key decisions made and their rationale\n\
291        2. Files that were read, created, or modified (with paths)\n\
292        3. Tool results that are still relevant\n\
293        4. Outstanding tasks or next steps\n\
294        5. Any errors encountered and how they were resolved\n\n\
295        Be concise but preserve all actionable information. \
296        Use bullet points. Include file paths verbatim.",
297    );
298    if let Some(instructions) = custom_instructions {
299        prompt.push_str("\n\nAdditional context: ");
300        prompt.push_str(instructions);
301    }
302    prompt
303}
304
305/// Format raw compact output into a summary message.
306pub fn format_compact_summary(raw: &str) -> String {
307    format!(
308        "<context_summary>\n\
309        The following is a summary of the conversation so far:\n\n\
310        {}\n\
311        </context_summary>",
312        raw.trim()
313    )
314}
315
316// ─── Full compaction (requires provider call) ────────────────────────────────
317
318/// Compact the conversation by summarizing older messages.
319///
320/// 1. Split messages into "old" (to compact) and "recent" (to keep)
321/// 2. Group old messages by topic
322/// 3. Send to provider for summarization
323/// 4. Replace old messages with summary
324pub async fn compact_conversation(
325    provider: &dyn Provider,
326    messages: &[Message],
327    model: &str,
328    keep_recent: usize,
329    custom_instructions: Option<&str>,
330) -> Result<CompactResult> {
331    let messages_before = messages.len();
332
333    if messages.len() <= keep_recent {
334        return Ok(CompactResult {
335            messages_before,
336            messages_after: messages_before,
337            tokens_freed_estimate: 0,
338            summary: String::new(),
339        });
340    }
341
342    let split_idx = messages.len() - keep_recent;
343    let old_messages = &messages[..split_idx];
344    let recent_messages = &messages[split_idx..];
345
346    // Build compaction request
347    let old_text: String = old_messages
348        .iter()
349        .map(|m| {
350            let role = match m.role {
351                Role::User => "User",
352                Role::Assistant => "Assistant",
353                Role::System => "System",
354            };
355            format!("{}: {}", role, m.get_all_text())
356        })
357        .collect::<Vec<_>>()
358        .join("\n\n");
359
360    let compact_prompt = get_compact_prompt(custom_instructions);
361    let request = cersei_provider::CompletionRequest {
362        model: model.to_string(),
363        messages: vec![
364            Message::user(format!(
365                "Here is the conversation history to summarize:\n\n{}\n\n{}",
366                old_text, compact_prompt
367            )),
368        ],
369        system: Some("You are a conversation summarizer. Be concise and preserve all actionable information.".into()),
370        tools: Vec::new(),
371        max_tokens: 4096,
372        temperature: Some(0.0),
373        stop_sequences: Vec::new(),
374        options: cersei_provider::ProviderOptions::default(),
375    };
376
377    // Collect streaming response into a complete message
378    let stream = provider.complete(request).await?;
379    let mut rx = stream.into_receiver();
380    let mut accumulator = cersei_provider::StreamAccumulator::new();
381    while let Some(event) = rx.recv().await {
382        accumulator.process_event(event);
383    }
384    let response = accumulator.into_response()?;
385    let summary_text = response.message.get_all_text();
386    let formatted_summary = format_compact_summary(&summary_text);
387
388    let tokens_freed = estimate_messages_tokens(old_messages);
389
390    // Build compacted messages: summary + recent
391    let messages_after = 1 + recent_messages.len(); // summary message + recent
392
393    Ok(CompactResult {
394        messages_before,
395        messages_after,
396        tokens_freed_estimate: tokens_freed,
397        summary: formatted_summary,
398    })
399}
400
401/// Check and run auto-compact if needed. Returns None if no compaction needed.
402pub async fn auto_compact_if_needed(
403    provider: &dyn Provider,
404    messages: &[Message],
405    model: &str,
406    tokens_used: u64,
407    state: &mut AutoCompactState,
408) -> Option<CompactResult> {
409    let context_limit = context_window_for_model(model);
410    if !should_auto_compact(tokens_used, context_limit, state) {
411        return None;
412    }
413
414    match compact_conversation(provider, messages, model, KEEP_RECENT_MESSAGES, None).await {
415        Ok(result) => {
416            state.on_success();
417            Some(result)
418        }
419        Err(_) => {
420            state.on_failure();
421            None
422        }
423    }
424}
425
426// ─── Tests ───────────────────────────────────────────────────────────────────
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    fn make_messages(n: usize) -> Vec<Message> {
433        (0..n)
434            .map(|i| {
435                if i % 2 == 0 {
436                    Message::user(format!("User message {}", i))
437                } else {
438                    Message::assistant(format!("Assistant response {} with some longer text to simulate real content that takes up tokens in the context window.", i))
439                }
440            })
441            .collect()
442    }
443
444    #[test]
445    fn test_token_warning_ok() {
446        assert_eq!(
447            calculate_token_warning_state(50_000, 200_000),
448            TokenWarningState::Ok
449        );
450    }
451
452    #[test]
453    fn test_token_warning_warning() {
454        assert_eq!(
455            calculate_token_warning_state(170_000, 200_000),
456            TokenWarningState::Warning
457        );
458    }
459
460    #[test]
461    fn test_token_warning_critical() {
462        assert_eq!(
463            calculate_token_warning_state(196_000, 200_000),
464            TokenWarningState::Critical
465        );
466    }
467
468    #[test]
469    fn test_should_compact() {
470        assert!(!should_compact(100_000, 200_000)); // 50%
471        assert!(!should_compact(170_000, 200_000)); // 85%
472        assert!(should_compact(185_000, 200_000)); // 92.5%
473        assert!(should_compact(195_000, 200_000)); // 97.5%
474    }
475
476    #[test]
477    fn test_should_auto_compact_disabled() {
478        let state = AutoCompactState {
479            disabled: true,
480            ..Default::default()
481        };
482        assert!(!should_auto_compact(195_000, 200_000, &state));
483    }
484
485    #[test]
486    fn test_circuit_breaker() {
487        let mut state = AutoCompactState::default();
488        state.on_failure();
489        state.on_failure();
490        assert!(!state.disabled);
491        state.on_failure(); // 3rd failure
492        assert!(state.disabled);
493    }
494
495    #[test]
496    fn test_snip_compact() {
497        let messages = make_messages(20);
498        let (kept, freed) = snip_compact(messages, 10);
499        assert_eq!(kept.len(), 10);
500        assert!(freed > 0);
501    }
502
503    #[test]
504    fn test_snip_compact_already_small() {
505        let messages = make_messages(5);
506        let (kept, freed) = snip_compact(messages, 10);
507        assert_eq!(kept.len(), 5);
508        assert_eq!(freed, 0);
509    }
510
511    #[test]
512    fn test_group_messages() {
513        let mut messages = Vec::new();
514        messages.push(Message::user("Read file A"));
515        messages.push(Message::assistant("Contents of A"));
516        messages.push(Message::user("Now edit B"));
517        messages.push(Message::assistant("Edited B"));
518
519        let groups = group_messages_for_compact(&messages);
520        assert_eq!(groups.len(), 2);
521    }
522
523    #[test]
524    fn test_estimate_tokens() {
525        assert_eq!(estimate_tokens("hello world"), 2); // 11 chars / 4
526        assert_eq!(estimate_tokens(""), 0);
527        assert!(estimate_tokens(&"x".repeat(1000)) > 200);
528    }
529
530    #[test]
531    fn test_context_window_for_model() {
532        assert_eq!(context_window_for_model("claude-sonnet-4-6"), 200_000);
533        assert_eq!(context_window_for_model("gpt-4o"), 128_000);
534        assert_eq!(context_window_for_model("gpt-4"), 8_192);
535    }
536
537    #[test]
538    fn test_compact_prompt_with_instructions() {
539        let prompt = get_compact_prompt(Some("Focus on API changes"));
540        assert!(prompt.contains("Focus on API changes"));
541        assert!(prompt.contains("Summarize"));
542    }
543
544    #[test]
545    fn test_format_compact_summary() {
546        let summary = format_compact_summary("- Did X\n- Did Y");
547        assert!(summary.contains("<context_summary>"));
548        assert!(summary.contains("- Did X"));
549    }
550
551    #[test]
552    fn test_calculate_messages_to_keep_index() {
553        let messages = make_messages(20);
554        let idx = calculate_messages_to_keep_index(&messages, 100);
555        assert!(idx > 0);
556        assert!(idx < 20);
557    }
558
559    #[test]
560    fn test_messages_to_keep_all_fit() {
561        let messages = make_messages(3);
562        let idx = calculate_messages_to_keep_index(&messages, 100_000);
563        assert_eq!(idx, 0); // keep all
564    }
565}