1use crate::llm::{ChatOutcome, ChatRequest, Content, ContentBlock, LlmProvider, Message, Role};
4use anyhow::{Context, Result, bail};
5use async_trait::async_trait;
6use std::fmt::Write;
7use std::sync::Arc;
8
9use super::config::CompactionConfig;
10use super::estimator::TokenEstimator;
11
12#[async_trait]
16pub trait ContextCompactor: Send + Sync {
17 async fn compact(&self, messages: &[Message]) -> Result<String>;
22
23 fn estimate_tokens(&self, messages: &[Message]) -> usize;
25
26 fn needs_compaction(&self, messages: &[Message]) -> bool;
28
29 async fn compact_history(&self, messages: Vec<Message>) -> Result<CompactionResult>;
34}
35
36#[derive(Debug, Clone)]
38pub struct CompactionResult {
39 pub messages: Vec<Message>,
41 pub original_count: usize,
43 pub new_count: usize,
45 pub original_tokens: usize,
47 pub new_tokens: usize,
49}
50
51pub struct LlmContextCompactor<P: LlmProvider> {
55 provider: Arc<P>,
56 config: CompactionConfig,
57}
58
59impl<P: LlmProvider> LlmContextCompactor<P> {
60 #[must_use]
62 pub const fn new(provider: Arc<P>, config: CompactionConfig) -> Self {
63 Self { provider, config }
64 }
65
66 #[must_use]
68 pub fn with_defaults(provider: Arc<P>) -> Self {
69 Self::new(provider, CompactionConfig::default())
70 }
71
72 #[must_use]
74 pub const fn config(&self) -> &CompactionConfig {
75 &self.config
76 }
77
78 fn format_messages_for_summary(messages: &[Message]) -> String {
80 let mut output = String::new();
81
82 for message in messages {
83 let role = match message.role {
84 Role::User => "User",
85 Role::Assistant => "Assistant",
86 };
87
88 let _ = write!(output, "{role}: ");
89
90 match &message.content {
91 Content::Text(text) => {
92 let _ = writeln!(output, "{text}");
93 }
94 Content::Blocks(blocks) => {
95 for block in blocks {
96 match block {
97 ContentBlock::Text { text } => {
98 let _ = writeln!(output, "{text}");
99 }
100 ContentBlock::Thinking { thinking } => {
101 let _ = writeln!(output, "[Thinking: {thinking}]");
103 }
104 ContentBlock::ToolUse { name, input, .. } => {
105 let _ = writeln!(
106 output,
107 "[Called tool: {name} with input: {}]",
108 serde_json::to_string(input).unwrap_or_default()
109 );
110 }
111 ContentBlock::ToolResult {
112 content, is_error, ..
113 } => {
114 let status = if is_error.unwrap_or(false) {
115 "error"
116 } else {
117 "success"
118 };
119 let truncated = if content.chars().count() > 500 {
121 let prefix: String = content.chars().take(500).collect();
122 format!("{prefix}... (truncated)")
123 } else {
124 content.clone()
125 };
126 let _ = writeln!(output, "[Tool result ({status}): {truncated}]");
127 }
128 }
129 }
130 }
131 }
132 output.push('\n');
133 }
134
135 output
136 }
137
138 fn build_summary_prompt(messages_text: &str) -> String {
140 format!(
141 r"Summarize this conversation concisely, preserving:
142- Key decisions and conclusions reached
143- Important file paths, code changes, and technical details
144- Current task context and what has been accomplished
145- Any pending items, errors encountered, or next steps
146
147Be specific about technical details (file names, function names, error messages) as these are critical for continuing the work.
148
149Conversation:
150{messages_text}
151
152Provide a concise summary (aim for 500-1000 words):"
153 )
154 }
155}
156
157#[async_trait]
158impl<P: LlmProvider> ContextCompactor for LlmContextCompactor<P> {
159 async fn compact(&self, messages: &[Message]) -> Result<String> {
160 let messages_text = Self::format_messages_for_summary(messages);
161 let prompt = Self::build_summary_prompt(&messages_text);
162
163 let request = ChatRequest {
164 system: "You are a precise summarizer. Your task is to create concise but complete summaries of conversations, preserving all technical details that would be needed to continue the work.".to_string(),
165 messages: vec![Message::user(prompt)],
166 tools: None,
167 max_tokens: 2000,
168 thinking: None,
169 };
170
171 let outcome = self
172 .provider
173 .chat(request)
174 .await
175 .context("Failed to call LLM for summarization")?;
176
177 match outcome {
178 ChatOutcome::Success(response) => response
179 .first_text()
180 .map(String::from)
181 .context("No text in summarization response"),
182 ChatOutcome::RateLimited => {
183 bail!("Rate limited during summarization")
184 }
185 ChatOutcome::InvalidRequest(msg) => {
186 bail!("Invalid request during summarization: {msg}")
187 }
188 ChatOutcome::ServerError(msg) => {
189 bail!("Server error during summarization: {msg}")
190 }
191 }
192 }
193
194 fn estimate_tokens(&self, messages: &[Message]) -> usize {
195 TokenEstimator::estimate_history(messages)
196 }
197
198 fn needs_compaction(&self, messages: &[Message]) -> bool {
199 if !self.config.auto_compact {
200 return false;
201 }
202
203 if messages.len() < self.config.min_messages_for_compaction {
204 return false;
205 }
206
207 let estimated_tokens = self.estimate_tokens(messages);
208 estimated_tokens > self.config.threshold_tokens
209 }
210
211 async fn compact_history(&self, messages: Vec<Message>) -> Result<CompactionResult> {
212 let original_count = messages.len();
213 let original_tokens = self.estimate_tokens(&messages);
214
215 if messages.len() <= self.config.retain_recent {
217 return Ok(CompactionResult {
218 messages,
219 original_count,
220 new_count: original_count,
221 original_tokens,
222 new_tokens: original_tokens,
223 });
224 }
225
226 let split_point = messages.len().saturating_sub(self.config.retain_recent);
228 let (to_summarize, to_keep) = messages.split_at(split_point);
229
230 let summary = self.compact(to_summarize).await?;
232
233 let mut new_messages = Vec::with_capacity(2 + to_keep.len());
235
236 new_messages.push(Message::user(format!(
238 "[Previous conversation summary]\n\n{summary}"
239 )));
240
241 new_messages.push(Message::assistant(
243 "I understand the context from the summary. Let me continue from where we left off.",
244 ));
245
246 new_messages.extend(to_keep.iter().cloned());
248
249 let new_count = new_messages.len();
250 let new_tokens = self.estimate_tokens(&new_messages);
251
252 Ok(CompactionResult {
253 messages: new_messages,
254 original_count,
255 new_count,
256 original_tokens,
257 new_tokens,
258 })
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::llm::{ChatResponse, StopReason, Usage};
266
267 struct MockProvider {
268 summary_response: String,
269 }
270
271 impl MockProvider {
272 fn new(summary: &str) -> Self {
273 Self {
274 summary_response: summary.to_string(),
275 }
276 }
277 }
278
279 #[async_trait]
280 impl LlmProvider for MockProvider {
281 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
282 Ok(ChatOutcome::Success(ChatResponse {
283 id: "test".to_string(),
284 content: vec![ContentBlock::Text {
285 text: self.summary_response.clone(),
286 }],
287 model: "mock".to_string(),
288 stop_reason: Some(StopReason::EndTurn),
289 usage: Usage {
290 input_tokens: 100,
291 output_tokens: 50,
292 },
293 }))
294 }
295
296 fn model(&self) -> &'static str {
297 "mock-model"
298 }
299
300 fn provider(&self) -> &'static str {
301 "mock"
302 }
303 }
304
305 #[test]
306 fn test_needs_compaction_below_threshold() {
307 let provider = Arc::new(MockProvider::new("summary"));
308 let config = CompactionConfig::default()
309 .with_threshold_tokens(10_000)
310 .with_min_messages(5);
311 let compactor = LlmContextCompactor::new(provider, config);
312
313 let messages = vec![
315 Message::user("Hello"),
316 Message::assistant("Hi"),
317 Message::user("How are you?"),
318 ];
319
320 assert!(!compactor.needs_compaction(&messages));
321 }
322
323 #[test]
324 fn test_needs_compaction_above_threshold() {
325 let provider = Arc::new(MockProvider::new("summary"));
326 let config = CompactionConfig::default()
327 .with_threshold_tokens(50) .with_min_messages(3);
329 let compactor = LlmContextCompactor::new(provider, config);
330
331 let messages = vec![
333 Message::user("Hello, this is a longer message to test compaction"),
334 Message::assistant(
335 "Hi there! This is also a longer response to help trigger compaction",
336 ),
337 Message::user("Great, let's continue with even more text here"),
338 Message::assistant("Absolutely, adding more content to ensure we exceed the threshold"),
339 ];
340
341 assert!(compactor.needs_compaction(&messages));
342 }
343
344 #[test]
345 fn test_needs_compaction_auto_disabled() {
346 let provider = Arc::new(MockProvider::new("summary"));
347 let config = CompactionConfig::default()
348 .with_threshold_tokens(10) .with_min_messages(1)
350 .with_auto_compact(false);
351 let compactor = LlmContextCompactor::new(provider, config);
352
353 let messages = vec![
354 Message::user("Hello, this is a longer message"),
355 Message::assistant("Response here"),
356 ];
357
358 assert!(!compactor.needs_compaction(&messages));
359 }
360
361 #[tokio::test]
362 async fn test_compact_history() -> Result<()> {
363 let provider = Arc::new(MockProvider::new(
364 "User asked about Rust programming. Assistant explained ownership, borrowing, and lifetimes.",
365 ));
366 let config = CompactionConfig::default()
367 .with_retain_recent(2)
368 .with_min_messages(3);
369 let compactor = LlmContextCompactor::new(provider, config);
370
371 let messages = vec![
373 Message::user(
374 "What is Rust? I've heard it's a systems programming language but I don't know much about it. Can you explain the key features and why people are excited about it?",
375 ),
376 Message::assistant(
377 "Rust is a systems programming language focused on safety, speed, and concurrency. It achieves memory safety without garbage collection through its ownership system. The key features include zero-cost abstractions, guaranteed memory safety, threads without data races, and minimal runtime.",
378 ),
379 Message::user(
380 "Tell me about ownership in detail. How does it work and what are the rules? I want to understand this core concept thoroughly.",
381 ),
382 Message::assistant(
383 "Ownership is Rust's central feature with three rules: each value has one owner, only one owner at a time, and the value is dropped when owner goes out of scope. This system prevents memory leaks, double frees, and dangling pointers at compile time.",
384 ),
385 Message::user("What about borrowing?"), Message::assistant("Borrowing allows references to data without taking ownership."), ];
388
389 let result = compactor.compact_history(messages).await?;
390
391 assert_eq!(result.new_count, 4);
393 assert_eq!(result.original_count, 6);
394
395 assert!(
397 result.new_tokens < result.original_tokens,
398 "Expected fewer tokens after compaction: new={} < original={}",
399 result.new_tokens,
400 result.original_tokens
401 );
402
403 if let Content::Text(text) = &result.messages[0].content {
405 assert!(text.contains("Previous conversation summary"));
406 }
407
408 Ok(())
409 }
410
411 #[tokio::test]
412 async fn test_compact_history_too_few_messages() -> Result<()> {
413 let provider = Arc::new(MockProvider::new("summary"));
414 let config = CompactionConfig::default().with_retain_recent(5);
415 let compactor = LlmContextCompactor::new(provider, config);
416
417 let messages = vec![
419 Message::user("Hello"),
420 Message::assistant("Hi"),
421 Message::user("Bye"),
422 ];
423
424 let result = compactor.compact_history(messages.clone()).await?;
425
426 assert_eq!(result.new_count, 3);
428 assert_eq!(result.messages.len(), 3);
429
430 Ok(())
431 }
432
433 #[test]
434 fn test_format_messages_for_summary() {
435 let messages = vec![Message::user("Hello"), Message::assistant("Hi there!")];
436
437 let formatted = LlmContextCompactor::<MockProvider>::format_messages_for_summary(&messages);
438
439 assert!(formatted.contains("User: Hello"));
440 assert!(formatted.contains("Assistant: Hi there!"));
441 }
442
443 #[test]
444 fn test_format_messages_for_summary_truncates_tool_results_unicode_safely() {
445 let long_unicode = "é".repeat(600);
446
447 let messages = vec![Message {
448 role: Role::Assistant,
449 content: Content::Blocks(vec![ContentBlock::ToolResult {
450 tool_use_id: "tool-1".to_string(),
451 content: long_unicode,
452 is_error: Some(false),
453 }]),
454 }];
455
456 let formatted = LlmContextCompactor::<MockProvider>::format_messages_for_summary(&messages);
457
458 assert!(formatted.contains("... (truncated)"));
459 }
460}