Skip to main content

agent_sdk/context/
compactor.rs

1//! Context compaction implementation.
2
3use crate::llm::{ChatOutcome, ChatRequest, Content, ContentBlock, LlmProvider, Message, Role};
4use anyhow::{Context, Result, bail};
5use async_trait::async_trait;
6use std::fmt::Write;
7use std::sync::Arc;
8
9use super::config::CompactionConfig;
10use super::estimator::TokenEstimator;
11
12/// Trait for context compaction strategies.
13///
14/// Implement this trait to provide custom compaction logic.
15#[async_trait]
16pub trait ContextCompactor: Send + Sync {
17    /// Compact a list of messages into a summary.
18    ///
19    /// # Errors
20    /// Returns an error if summarization fails.
21    async fn compact(&self, messages: &[Message]) -> Result<String>;
22
23    /// Estimate tokens for a message list.
24    fn estimate_tokens(&self, messages: &[Message]) -> usize;
25
26    /// Check if compaction is needed.
27    fn needs_compaction(&self, messages: &[Message]) -> bool;
28
29    /// Perform full compaction, returning new message history.
30    ///
31    /// # Errors
32    /// Returns an error if compaction fails.
33    async fn compact_history(&self, messages: Vec<Message>) -> Result<CompactionResult>;
34}
35
36/// Result of a compaction operation.
37#[derive(Debug, Clone)]
38pub struct CompactionResult {
39    /// The new compacted message history.
40    pub messages: Vec<Message>,
41    /// Number of messages before compaction.
42    pub original_count: usize,
43    /// Number of messages after compaction.
44    pub new_count: usize,
45    /// Estimated tokens before compaction.
46    pub original_tokens: usize,
47    /// Estimated tokens after compaction.
48    pub new_tokens: usize,
49}
50
51/// LLM-based context compactor.
52///
53/// Uses the LLM itself to summarize older messages into a compact form.
54pub struct LlmContextCompactor<P: LlmProvider> {
55    provider: Arc<P>,
56    config: CompactionConfig,
57}
58
59impl<P: LlmProvider> LlmContextCompactor<P> {
60    /// Create a new LLM context compactor.
61    #[must_use]
62    pub const fn new(provider: Arc<P>, config: CompactionConfig) -> Self {
63        Self { provider, config }
64    }
65
66    /// Create with default configuration.
67    #[must_use]
68    pub fn with_defaults(provider: Arc<P>) -> Self {
69        Self::new(provider, CompactionConfig::default())
70    }
71
72    /// Get the configuration.
73    #[must_use]
74    pub const fn config(&self) -> &CompactionConfig {
75        &self.config
76    }
77
78    /// Format messages for summarization.
79    fn format_messages_for_summary(messages: &[Message]) -> String {
80        let mut output = String::new();
81
82        for message in messages {
83            let role = match message.role {
84                Role::User => "User",
85                Role::Assistant => "Assistant",
86            };
87
88            let _ = write!(output, "{role}: ");
89
90            match &message.content {
91                Content::Text(text) => {
92                    let _ = writeln!(output, "{text}");
93                }
94                Content::Blocks(blocks) => {
95                    for block in blocks {
96                        match block {
97                            ContentBlock::Text { text } => {
98                                let _ = writeln!(output, "{text}");
99                            }
100                            ContentBlock::Thinking { thinking } => {
101                                // Include thinking in summaries for context
102                                let _ = writeln!(output, "[Thinking: {thinking}]");
103                            }
104                            ContentBlock::ToolUse { name, input, .. } => {
105                                let _ = writeln!(
106                                    output,
107                                    "[Called tool: {name} with input: {}]",
108                                    serde_json::to_string(input).unwrap_or_default()
109                                );
110                            }
111                            ContentBlock::ToolResult {
112                                content, is_error, ..
113                            } => {
114                                let status = if is_error.unwrap_or(false) {
115                                    "error"
116                                } else {
117                                    "success"
118                                };
119                                // Truncate long tool results (Unicode-safe; avoid slicing mid-codepoint)
120                                let truncated = if content.chars().count() > 500 {
121                                    let prefix: String = content.chars().take(500).collect();
122                                    format!("{prefix}... (truncated)")
123                                } else {
124                                    content.clone()
125                                };
126                                let _ = writeln!(output, "[Tool result ({status}): {truncated}]");
127                            }
128                        }
129                    }
130                }
131            }
132            output.push('\n');
133        }
134
135        output
136    }
137
138    /// Build the summarization prompt.
139    fn build_summary_prompt(messages_text: &str) -> String {
140        format!(
141            r"Summarize this conversation concisely, preserving:
142- Key decisions and conclusions reached
143- Important file paths, code changes, and technical details
144- Current task context and what has been accomplished
145- Any pending items, errors encountered, or next steps
146
147Be specific about technical details (file names, function names, error messages) as these are critical for continuing the work.
148
149Conversation:
150{messages_text}
151
152Provide a concise summary (aim for 500-1000 words):"
153        )
154    }
155}
156
157#[async_trait]
158impl<P: LlmProvider> ContextCompactor for LlmContextCompactor<P> {
159    async fn compact(&self, messages: &[Message]) -> Result<String> {
160        let messages_text = Self::format_messages_for_summary(messages);
161        let prompt = Self::build_summary_prompt(&messages_text);
162
163        let request = ChatRequest {
164            system: "You are a precise summarizer. Your task is to create concise but complete summaries of conversations, preserving all technical details that would be needed to continue the work.".to_string(),
165            messages: vec![Message::user(prompt)],
166            tools: None,
167            max_tokens: 2000,
168            thinking: None,
169        };
170
171        let outcome = self
172            .provider
173            .chat(request)
174            .await
175            .context("Failed to call LLM for summarization")?;
176
177        match outcome {
178            ChatOutcome::Success(response) => response
179                .first_text()
180                .map(String::from)
181                .context("No text in summarization response"),
182            ChatOutcome::RateLimited => {
183                bail!("Rate limited during summarization")
184            }
185            ChatOutcome::InvalidRequest(msg) => {
186                bail!("Invalid request during summarization: {msg}")
187            }
188            ChatOutcome::ServerError(msg) => {
189                bail!("Server error during summarization: {msg}")
190            }
191        }
192    }
193
194    fn estimate_tokens(&self, messages: &[Message]) -> usize {
195        TokenEstimator::estimate_history(messages)
196    }
197
198    fn needs_compaction(&self, messages: &[Message]) -> bool {
199        if !self.config.auto_compact {
200            return false;
201        }
202
203        if messages.len() < self.config.min_messages_for_compaction {
204            return false;
205        }
206
207        let estimated_tokens = self.estimate_tokens(messages);
208        estimated_tokens > self.config.threshold_tokens
209    }
210
211    async fn compact_history(&self, messages: Vec<Message>) -> Result<CompactionResult> {
212        let original_count = messages.len();
213        let original_tokens = self.estimate_tokens(&messages);
214
215        // Ensure we have enough messages to compact
216        if messages.len() <= self.config.retain_recent {
217            return Ok(CompactionResult {
218                messages,
219                original_count,
220                new_count: original_count,
221                original_tokens,
222                new_tokens: original_tokens,
223            });
224        }
225
226        // Split messages: old messages to summarize, recent messages to keep
227        let split_point = messages.len().saturating_sub(self.config.retain_recent);
228        let (to_summarize, to_keep) = messages.split_at(split_point);
229
230        // Summarize old messages
231        let summary = self.compact(to_summarize).await?;
232
233        // Build new message history
234        let mut new_messages = Vec::with_capacity(2 + to_keep.len());
235
236        // Add summary as a user message
237        new_messages.push(Message::user(format!(
238            "[Previous conversation summary]\n\n{summary}"
239        )));
240
241        // Add acknowledgment from assistant
242        new_messages.push(Message::assistant(
243            "I understand the context from the summary. Let me continue from where we left off.",
244        ));
245
246        // Add recent messages
247        new_messages.extend(to_keep.iter().cloned());
248
249        let new_count = new_messages.len();
250        let new_tokens = self.estimate_tokens(&new_messages);
251
252        Ok(CompactionResult {
253            messages: new_messages,
254            original_count,
255            new_count,
256            original_tokens,
257            new_tokens,
258        })
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::llm::{ChatResponse, StopReason, Usage};
266
267    struct MockProvider {
268        summary_response: String,
269    }
270
271    impl MockProvider {
272        fn new(summary: &str) -> Self {
273            Self {
274                summary_response: summary.to_string(),
275            }
276        }
277    }
278
279    #[async_trait]
280    impl LlmProvider for MockProvider {
281        async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
282            Ok(ChatOutcome::Success(ChatResponse {
283                id: "test".to_string(),
284                content: vec![ContentBlock::Text {
285                    text: self.summary_response.clone(),
286                }],
287                model: "mock".to_string(),
288                stop_reason: Some(StopReason::EndTurn),
289                usage: Usage {
290                    input_tokens: 100,
291                    output_tokens: 50,
292                },
293            }))
294        }
295
296        fn model(&self) -> &'static str {
297            "mock-model"
298        }
299
300        fn provider(&self) -> &'static str {
301            "mock"
302        }
303    }
304
305    #[test]
306    fn test_needs_compaction_below_threshold() {
307        let provider = Arc::new(MockProvider::new("summary"));
308        let config = CompactionConfig::default()
309            .with_threshold_tokens(10_000)
310            .with_min_messages(5);
311        let compactor = LlmContextCompactor::new(provider, config);
312
313        // Only 3 messages, below min_messages
314        let messages = vec![
315            Message::user("Hello"),
316            Message::assistant("Hi"),
317            Message::user("How are you?"),
318        ];
319
320        assert!(!compactor.needs_compaction(&messages));
321    }
322
323    #[test]
324    fn test_needs_compaction_above_threshold() {
325        let provider = Arc::new(MockProvider::new("summary"));
326        let config = CompactionConfig::default()
327            .with_threshold_tokens(50) // Very low threshold
328            .with_min_messages(3);
329        let compactor = LlmContextCompactor::new(provider, config);
330
331        // Messages that exceed threshold
332        let messages = vec![
333            Message::user("Hello, this is a longer message to test compaction"),
334            Message::assistant(
335                "Hi there! This is also a longer response to help trigger compaction",
336            ),
337            Message::user("Great, let's continue with even more text here"),
338            Message::assistant("Absolutely, adding more content to ensure we exceed the threshold"),
339        ];
340
341        assert!(compactor.needs_compaction(&messages));
342    }
343
344    #[test]
345    fn test_needs_compaction_auto_disabled() {
346        let provider = Arc::new(MockProvider::new("summary"));
347        let config = CompactionConfig::default()
348            .with_threshold_tokens(10) // Very low
349            .with_min_messages(1)
350            .with_auto_compact(false);
351        let compactor = LlmContextCompactor::new(provider, config);
352
353        let messages = vec![
354            Message::user("Hello, this is a longer message"),
355            Message::assistant("Response here"),
356        ];
357
358        assert!(!compactor.needs_compaction(&messages));
359    }
360
361    #[tokio::test]
362    async fn test_compact_history() -> Result<()> {
363        let provider = Arc::new(MockProvider::new(
364            "User asked about Rust programming. Assistant explained ownership, borrowing, and lifetimes.",
365        ));
366        let config = CompactionConfig::default()
367            .with_retain_recent(2)
368            .with_min_messages(3);
369        let compactor = LlmContextCompactor::new(provider, config);
370
371        // Use longer messages to ensure compaction actually reduces tokens
372        let messages = vec![
373            Message::user(
374                "What is Rust? I've heard it's a systems programming language but I don't know much about it. Can you explain the key features and why people are excited about it?",
375            ),
376            Message::assistant(
377                "Rust is a systems programming language focused on safety, speed, and concurrency. It achieves memory safety without garbage collection through its ownership system. The key features include zero-cost abstractions, guaranteed memory safety, threads without data races, and minimal runtime.",
378            ),
379            Message::user(
380                "Tell me about ownership in detail. How does it work and what are the rules? I want to understand this core concept thoroughly.",
381            ),
382            Message::assistant(
383                "Ownership is Rust's central feature with three rules: each value has one owner, only one owner at a time, and the value is dropped when owner goes out of scope. This system prevents memory leaks, double frees, and dangling pointers at compile time.",
384            ),
385            Message::user("What about borrowing?"), // Keep
386            Message::assistant("Borrowing allows references to data without taking ownership."), // Keep
387        ];
388
389        let result = compactor.compact_history(messages).await?;
390
391        // Should have: summary message + ack + 2 recent messages = 4
392        assert_eq!(result.new_count, 4);
393        assert_eq!(result.original_count, 6);
394
395        // With longer original messages, compaction should reduce tokens
396        assert!(
397            result.new_tokens < result.original_tokens,
398            "Expected fewer tokens after compaction: new={} < original={}",
399            result.new_tokens,
400            result.original_tokens
401        );
402
403        // First message should be the summary
404        if let Content::Text(text) = &result.messages[0].content {
405            assert!(text.contains("Previous conversation summary"));
406        }
407
408        Ok(())
409    }
410
411    #[tokio::test]
412    async fn test_compact_history_too_few_messages() -> Result<()> {
413        let provider = Arc::new(MockProvider::new("summary"));
414        let config = CompactionConfig::default().with_retain_recent(5);
415        let compactor = LlmContextCompactor::new(provider, config);
416
417        // Only 3 messages, less than retain_recent
418        let messages = vec![
419            Message::user("Hello"),
420            Message::assistant("Hi"),
421            Message::user("Bye"),
422        ];
423
424        let result = compactor.compact_history(messages.clone()).await?;
425
426        // Should return original messages unchanged
427        assert_eq!(result.new_count, 3);
428        assert_eq!(result.messages.len(), 3);
429
430        Ok(())
431    }
432
433    #[test]
434    fn test_format_messages_for_summary() {
435        let messages = vec![Message::user("Hello"), Message::assistant("Hi there!")];
436
437        let formatted = LlmContextCompactor::<MockProvider>::format_messages_for_summary(&messages);
438
439        assert!(formatted.contains("User: Hello"));
440        assert!(formatted.contains("Assistant: Hi there!"));
441    }
442
443    #[test]
444    fn test_format_messages_for_summary_truncates_tool_results_unicode_safely() {
445        let long_unicode = "é".repeat(600);
446
447        let messages = vec![Message {
448            role: Role::Assistant,
449            content: Content::Blocks(vec![ContentBlock::ToolResult {
450                tool_use_id: "tool-1".to_string(),
451                content: long_unicode,
452                is_error: Some(false),
453            }]),
454        }];
455
456        let formatted = LlmContextCompactor::<MockProvider>::format_messages_for_summary(&messages);
457
458        assert!(formatted.contains("... (truncated)"));
459    }
460}