agent_sdk/context/
compactor.rs

1//! Context compaction implementation.
2
3use crate::llm::{ChatOutcome, ChatRequest, Content, ContentBlock, LlmProvider, Message, Role};
4use anyhow::{Context, Result, bail};
5use async_trait::async_trait;
6use std::fmt::Write;
7use std::sync::Arc;
8
9use super::config::CompactionConfig;
10use super::estimator::TokenEstimator;
11
12/// Trait for context compaction strategies.
13///
14/// Implement this trait to provide custom compaction logic.
15#[async_trait]
16pub trait ContextCompactor: Send + Sync {
17    /// Compact a list of messages into a summary.
18    ///
19    /// # Errors
20    /// Returns an error if summarization fails.
21    async fn compact(&self, messages: &[Message]) -> Result<String>;
22
23    /// Estimate tokens for a message list.
24    fn estimate_tokens(&self, messages: &[Message]) -> usize;
25
26    /// Check if compaction is needed.
27    fn needs_compaction(&self, messages: &[Message]) -> bool;
28
29    /// Perform full compaction, returning new message history.
30    ///
31    /// # Errors
32    /// Returns an error if compaction fails.
33    async fn compact_history(&self, messages: Vec<Message>) -> Result<CompactionResult>;
34}
35
36/// Result of a compaction operation.
37#[derive(Debug, Clone)]
38pub struct CompactionResult {
39    /// The new compacted message history.
40    pub messages: Vec<Message>,
41    /// Number of messages before compaction.
42    pub original_count: usize,
43    /// Number of messages after compaction.
44    pub new_count: usize,
45    /// Estimated tokens before compaction.
46    pub original_tokens: usize,
47    /// Estimated tokens after compaction.
48    pub new_tokens: usize,
49}
50
51/// LLM-based context compactor.
52///
53/// Uses the LLM itself to summarize older messages into a compact form.
54pub struct LlmContextCompactor<P: LlmProvider> {
55    provider: Arc<P>,
56    config: CompactionConfig,
57}
58
59impl<P: LlmProvider> LlmContextCompactor<P> {
60    /// Create a new LLM context compactor.
61    #[must_use]
62    pub const fn new(provider: Arc<P>, config: CompactionConfig) -> Self {
63        Self { provider, config }
64    }
65
66    /// Create with default configuration.
67    #[must_use]
68    pub fn with_defaults(provider: Arc<P>) -> Self {
69        Self::new(provider, CompactionConfig::default())
70    }
71
72    /// Get the configuration.
73    #[must_use]
74    pub const fn config(&self) -> &CompactionConfig {
75        &self.config
76    }
77
78    /// Format messages for summarization.
79    fn format_messages_for_summary(messages: &[Message]) -> String {
80        let mut output = String::new();
81
82        for message in messages {
83            let role = match message.role {
84                Role::User => "User",
85                Role::Assistant => "Assistant",
86            };
87
88            let _ = write!(output, "{role}: ");
89
90            match &message.content {
91                Content::Text(text) => {
92                    let _ = writeln!(output, "{text}");
93                }
94                Content::Blocks(blocks) => {
95                    for block in blocks {
96                        match block {
97                            ContentBlock::Text { text } => {
98                                let _ = writeln!(output, "{text}");
99                            }
100                            ContentBlock::ToolUse { name, input, .. } => {
101                                let _ = writeln!(
102                                    output,
103                                    "[Called tool: {name} with input: {}]",
104                                    serde_json::to_string(input).unwrap_or_default()
105                                );
106                            }
107                            ContentBlock::ToolResult {
108                                content, is_error, ..
109                            } => {
110                                let status = if is_error.unwrap_or(false) {
111                                    "error"
112                                } else {
113                                    "success"
114                                };
115                                // Truncate long tool results (Unicode-safe; avoid slicing mid-codepoint)
116                                let truncated = if content.chars().count() > 500 {
117                                    let prefix: String = content.chars().take(500).collect();
118                                    format!("{prefix}... (truncated)")
119                                } else {
120                                    content.clone()
121                                };
122                                let _ = writeln!(output, "[Tool result ({status}): {truncated}]");
123                            }
124                        }
125                    }
126                }
127            }
128            output.push('\n');
129        }
130
131        output
132    }
133
134    /// Build the summarization prompt.
135    fn build_summary_prompt(messages_text: &str) -> String {
136        format!(
137            r"Summarize this conversation concisely, preserving:
138- Key decisions and conclusions reached
139- Important file paths, code changes, and technical details
140- Current task context and what has been accomplished
141- Any pending items, errors encountered, or next steps
142
143Be specific about technical details (file names, function names, error messages) as these are critical for continuing the work.
144
145Conversation:
146{messages_text}
147
148Provide a concise summary (aim for 500-1000 words):"
149        )
150    }
151}
152
153#[async_trait]
154impl<P: LlmProvider> ContextCompactor for LlmContextCompactor<P> {
155    async fn compact(&self, messages: &[Message]) -> Result<String> {
156        let messages_text = Self::format_messages_for_summary(messages);
157        let prompt = Self::build_summary_prompt(&messages_text);
158
159        let request = ChatRequest {
160            system: "You are a precise summarizer. Your task is to create concise but complete summaries of conversations, preserving all technical details that would be needed to continue the work.".to_string(),
161            messages: vec![Message::user(prompt)],
162            tools: None,
163            max_tokens: 2000,
164        };
165
166        let outcome = self
167            .provider
168            .chat(request)
169            .await
170            .context("Failed to call LLM for summarization")?;
171
172        match outcome {
173            ChatOutcome::Success(response) => response
174                .first_text()
175                .map(String::from)
176                .context("No text in summarization response"),
177            ChatOutcome::RateLimited => {
178                bail!("Rate limited during summarization")
179            }
180            ChatOutcome::InvalidRequest(msg) => {
181                bail!("Invalid request during summarization: {msg}")
182            }
183            ChatOutcome::ServerError(msg) => {
184                bail!("Server error during summarization: {msg}")
185            }
186        }
187    }
188
189    fn estimate_tokens(&self, messages: &[Message]) -> usize {
190        TokenEstimator::estimate_history(messages)
191    }
192
193    fn needs_compaction(&self, messages: &[Message]) -> bool {
194        if !self.config.auto_compact {
195            return false;
196        }
197
198        if messages.len() < self.config.min_messages_for_compaction {
199            return false;
200        }
201
202        let estimated_tokens = self.estimate_tokens(messages);
203        estimated_tokens > self.config.threshold_tokens
204    }
205
206    async fn compact_history(&self, messages: Vec<Message>) -> Result<CompactionResult> {
207        let original_count = messages.len();
208        let original_tokens = self.estimate_tokens(&messages);
209
210        // Ensure we have enough messages to compact
211        if messages.len() <= self.config.retain_recent {
212            return Ok(CompactionResult {
213                messages,
214                original_count,
215                new_count: original_count,
216                original_tokens,
217                new_tokens: original_tokens,
218            });
219        }
220
221        // Split messages: old messages to summarize, recent messages to keep
222        let split_point = messages.len().saturating_sub(self.config.retain_recent);
223        let (to_summarize, to_keep) = messages.split_at(split_point);
224
225        // Summarize old messages
226        let summary = self.compact(to_summarize).await?;
227
228        // Build new message history
229        let mut new_messages = Vec::with_capacity(2 + to_keep.len());
230
231        // Add summary as a user message
232        new_messages.push(Message::user(format!(
233            "[Previous conversation summary]\n\n{summary}"
234        )));
235
236        // Add acknowledgment from assistant
237        new_messages.push(Message::assistant(
238            "I understand the context from the summary. Let me continue from where we left off.",
239        ));
240
241        // Add recent messages
242        new_messages.extend(to_keep.iter().cloned());
243
244        let new_count = new_messages.len();
245        let new_tokens = self.estimate_tokens(&new_messages);
246
247        Ok(CompactionResult {
248            messages: new_messages,
249            original_count,
250            new_count,
251            original_tokens,
252            new_tokens,
253        })
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::llm::{ChatResponse, StopReason, Usage};
261
262    struct MockProvider {
263        summary_response: String,
264    }
265
266    impl MockProvider {
267        fn new(summary: &str) -> Self {
268            Self {
269                summary_response: summary.to_string(),
270            }
271        }
272    }
273
274    #[async_trait]
275    impl LlmProvider for MockProvider {
276        async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
277            Ok(ChatOutcome::Success(ChatResponse {
278                id: "test".to_string(),
279                content: vec![ContentBlock::Text {
280                    text: self.summary_response.clone(),
281                }],
282                model: "mock".to_string(),
283                stop_reason: Some(StopReason::EndTurn),
284                usage: Usage {
285                    input_tokens: 100,
286                    output_tokens: 50,
287                },
288            }))
289        }
290
291        fn model(&self) -> &'static str {
292            "mock-model"
293        }
294
295        fn provider(&self) -> &'static str {
296            "mock"
297        }
298    }
299
300    #[test]
301    fn test_needs_compaction_below_threshold() {
302        let provider = Arc::new(MockProvider::new("summary"));
303        let config = CompactionConfig::default()
304            .with_threshold_tokens(10_000)
305            .with_min_messages(5);
306        let compactor = LlmContextCompactor::new(provider, config);
307
308        // Only 3 messages, below min_messages
309        let messages = vec![
310            Message::user("Hello"),
311            Message::assistant("Hi"),
312            Message::user("How are you?"),
313        ];
314
315        assert!(!compactor.needs_compaction(&messages));
316    }
317
318    #[test]
319    fn test_needs_compaction_above_threshold() {
320        let provider = Arc::new(MockProvider::new("summary"));
321        let config = CompactionConfig::default()
322            .with_threshold_tokens(50) // Very low threshold
323            .with_min_messages(3);
324        let compactor = LlmContextCompactor::new(provider, config);
325
326        // Messages that exceed threshold
327        let messages = vec![
328            Message::user("Hello, this is a longer message to test compaction"),
329            Message::assistant(
330                "Hi there! This is also a longer response to help trigger compaction",
331            ),
332            Message::user("Great, let's continue with even more text here"),
333            Message::assistant("Absolutely, adding more content to ensure we exceed the threshold"),
334        ];
335
336        assert!(compactor.needs_compaction(&messages));
337    }
338
339    #[test]
340    fn test_needs_compaction_auto_disabled() {
341        let provider = Arc::new(MockProvider::new("summary"));
342        let config = CompactionConfig::default()
343            .with_threshold_tokens(10) // Very low
344            .with_min_messages(1)
345            .with_auto_compact(false);
346        let compactor = LlmContextCompactor::new(provider, config);
347
348        let messages = vec![
349            Message::user("Hello, this is a longer message"),
350            Message::assistant("Response here"),
351        ];
352
353        assert!(!compactor.needs_compaction(&messages));
354    }
355
356    #[tokio::test]
357    async fn test_compact_history() -> Result<()> {
358        let provider = Arc::new(MockProvider::new(
359            "User asked about Rust programming. Assistant explained ownership, borrowing, and lifetimes.",
360        ));
361        let config = CompactionConfig::default()
362            .with_retain_recent(2)
363            .with_min_messages(3);
364        let compactor = LlmContextCompactor::new(provider, config);
365
366        // Use longer messages to ensure compaction actually reduces tokens
367        let messages = vec![
368            Message::user(
369                "What is Rust? I've heard it's a systems programming language but I don't know much about it. Can you explain the key features and why people are excited about it?",
370            ),
371            Message::assistant(
372                "Rust is a systems programming language focused on safety, speed, and concurrency. It achieves memory safety without garbage collection through its ownership system. The key features include zero-cost abstractions, guaranteed memory safety, threads without data races, and minimal runtime.",
373            ),
374            Message::user(
375                "Tell me about ownership in detail. How does it work and what are the rules? I want to understand this core concept thoroughly.",
376            ),
377            Message::assistant(
378                "Ownership is Rust's central feature with three rules: each value has one owner, only one owner at a time, and the value is dropped when owner goes out of scope. This system prevents memory leaks, double frees, and dangling pointers at compile time.",
379            ),
380            Message::user("What about borrowing?"), // Keep
381            Message::assistant("Borrowing allows references to data without taking ownership."), // Keep
382        ];
383
384        let result = compactor.compact_history(messages).await?;
385
386        // Should have: summary message + ack + 2 recent messages = 4
387        assert_eq!(result.new_count, 4);
388        assert_eq!(result.original_count, 6);
389
390        // With longer original messages, compaction should reduce tokens
391        assert!(
392            result.new_tokens < result.original_tokens,
393            "Expected fewer tokens after compaction: new={} < original={}",
394            result.new_tokens,
395            result.original_tokens
396        );
397
398        // First message should be the summary
399        if let Content::Text(text) = &result.messages[0].content {
400            assert!(text.contains("Previous conversation summary"));
401        }
402
403        Ok(())
404    }
405
406    #[tokio::test]
407    async fn test_compact_history_too_few_messages() -> Result<()> {
408        let provider = Arc::new(MockProvider::new("summary"));
409        let config = CompactionConfig::default().with_retain_recent(5);
410        let compactor = LlmContextCompactor::new(provider, config);
411
412        // Only 3 messages, less than retain_recent
413        let messages = vec![
414            Message::user("Hello"),
415            Message::assistant("Hi"),
416            Message::user("Bye"),
417        ];
418
419        let result = compactor.compact_history(messages.clone()).await?;
420
421        // Should return original messages unchanged
422        assert_eq!(result.new_count, 3);
423        assert_eq!(result.messages.len(), 3);
424
425        Ok(())
426    }
427
428    #[test]
429    fn test_format_messages_for_summary() {
430        let messages = vec![Message::user("Hello"), Message::assistant("Hi there!")];
431
432        let formatted = LlmContextCompactor::<MockProvider>::format_messages_for_summary(&messages);
433
434        assert!(formatted.contains("User: Hello"));
435        assert!(formatted.contains("Assistant: Hi there!"));
436    }
437
438    #[test]
439    fn test_format_messages_for_summary_truncates_tool_results_unicode_safely() {
440        let long_unicode = "é".repeat(600);
441
442        let messages = vec![Message {
443            role: Role::Assistant,
444            content: Content::Blocks(vec![ContentBlock::ToolResult {
445                tool_use_id: "tool-1".to_string(),
446                content: long_unicode,
447                is_error: Some(false),
448            }]),
449        }];
450
451        let formatted = LlmContextCompactor::<MockProvider>::format_messages_for_summary(&messages);
452
453        assert!(formatted.contains("... (truncated)"));
454    }
455}