1use crate::conversation::message::{ActionRequiredData, MessageMetadata};
2use crate::conversation::message::{Message, MessageContent};
3use crate::conversation::{merge_consecutive_messages, Conversation};
4use crate::prompt_template::render_global_file;
5use crate::providers::base::{Provider, ProviderUsage};
6use crate::providers::errors::ProviderError;
7use crate::{config::Config, token_counter::create_token_counter};
8use anyhow::Result;
9use rmcp::model::Role;
10use serde::Serialize;
11use tracing::{debug, info};
12
13pub const DEFAULT_COMPACTION_THRESHOLD: f64 = 0.8;
14
15const CONVERSATION_CONTINUATION_TEXT: &str =
16 "The previous message contains a summary that was prepared because a context limit was reached.
17Do not mention that you read a summary or that conversation summarization occurred.
18Just continue the conversation naturally based on the summarized context";
19
20const TOOL_LOOP_CONTINUATION_TEXT: &str =
21 "The previous message contains a summary that was prepared because a context limit was reached.
22Do not mention that you read a summary or that conversation summarization occurred.
23Continue calling tools as necessary to complete the task.";
24
25const MANUAL_COMPACT_CONTINUATION_TEXT: &str =
26 "The previous message contains a summary that was prepared at the user's request.
27Do not mention that you read a summary or that conversation summarization occurred.
28Just continue the conversation naturally based on the summarized context";
29
30#[derive(Serialize)]
31struct SummarizeContext {
32 messages: String,
33}
34
35pub async fn compact_messages(
51 provider: &dyn Provider,
52 conversation: &Conversation,
53 manual_compact: bool,
54) -> Result<(Conversation, ProviderUsage)> {
55 info!("Performing message compaction");
56
57 let messages = conversation.messages();
58
59 let has_text_only = |msg: &Message| {
60 let has_text = msg
61 .content
62 .iter()
63 .any(|c| matches!(c, MessageContent::Text(_)));
64 let has_tool_content = msg.content.iter().any(|c| {
65 matches!(
66 c,
67 MessageContent::ToolRequest(_) | MessageContent::ToolResponse(_)
68 )
69 });
70 has_text && !has_tool_content
71 };
72
73 let extract_text = |msg: &Message| -> Option<String> {
74 let text_parts: Vec<String> = msg
75 .content
76 .iter()
77 .filter_map(|c| {
78 if let MessageContent::Text(text) = c {
79 Some(text.text.clone())
80 } else {
81 None
82 }
83 })
84 .collect();
85
86 if text_parts.is_empty() {
87 None
88 } else {
89 Some(text_parts.join("\n"))
90 }
91 };
92
93 let (preserved_user_message, is_most_recent) = if !manual_compact {
95 let found_msg = messages.iter().enumerate().rev().find(|(_, msg)| {
96 msg.is_agent_visible()
97 && matches!(msg.role, rmcp::model::Role::User)
98 && has_text_only(msg)
99 });
100
101 if let Some((idx, msg)) = found_msg {
102 let is_last = idx == messages.len() - 1;
103 (Some(msg.clone()), is_last)
104 } else {
105 (None, false)
106 }
107 } else {
108 (None, false)
109 };
110
111 let messages_to_compact = messages.as_slice();
112
113 let (summary_message, summarization_usage) = do_compact(provider, messages_to_compact).await?;
114
115 let mut final_messages = Vec::new();
120
121 for (idx, msg) in messages_to_compact.iter().enumerate() {
122 let updated_metadata = if is_most_recent
123 && idx == messages_to_compact.len() - 1
124 && preserved_user_message.is_some()
125 {
126 MessageMetadata::invisible()
128 } else {
129 msg.metadata.with_agent_invisible()
130 };
131 let updated_msg = msg.clone().with_metadata(updated_metadata);
132 final_messages.push(updated_msg);
133 }
134
135 let summary_msg = summary_message.with_metadata(MessageMetadata::agent_only());
136
137 let mut continuation_messages = vec![summary_msg];
138
139 let continuation_text = if manual_compact {
140 MANUAL_COMPACT_CONTINUATION_TEXT
141 } else if is_most_recent {
142 CONVERSATION_CONTINUATION_TEXT
143 } else {
144 TOOL_LOOP_CONTINUATION_TEXT
145 };
146
147 let continuation_msg = Message::assistant()
148 .with_text(continuation_text)
149 .with_metadata(MessageMetadata::agent_only());
150 continuation_messages.push(continuation_msg);
151
152 let (merged_continuation, _issues) = merge_consecutive_messages(continuation_messages);
153 final_messages.extend(merged_continuation);
154
155 if let Some(user_msg) = preserved_user_message {
156 if let Some(text) = extract_text(&user_msg) {
157 final_messages.push(Message::user().with_text(&text));
158 }
159 }
160
161 Ok((
162 Conversation::new_unvalidated(final_messages),
163 summarization_usage,
164 ))
165}
166
167pub async fn check_if_compaction_needed(
169 provider: &dyn Provider,
170 conversation: &Conversation,
171 threshold_override: Option<f64>,
172 session: &crate::session::Session,
173) -> Result<bool> {
174 let messages = conversation.messages();
175 let config = Config::global();
176 let threshold = threshold_override.unwrap_or_else(|| {
177 config
178 .get_param::<f64>("ASTER_AUTO_COMPACT_THRESHOLD")
179 .unwrap_or(DEFAULT_COMPACTION_THRESHOLD)
180 });
181
182 let context_limit = provider.get_model_config().context_limit();
183
184 let (current_tokens, token_source) = match session.total_tokens {
185 Some(tokens) => (tokens as usize, "session metadata"),
186 None => {
187 let token_counter = create_token_counter()
188 .await
189 .map_err(|e| anyhow::anyhow!("Failed to create token counter: {}", e))?;
190
191 let token_counts: Vec<_> = messages
192 .iter()
193 .filter(|m| m.is_agent_visible())
194 .map(|msg| token_counter.count_chat_tokens("", std::slice::from_ref(msg), &[]))
195 .collect();
196
197 (token_counts.iter().sum(), "estimated")
198 }
199 };
200
201 let usage_ratio = current_tokens as f64 / context_limit as f64;
202
203 let needs_compaction = if threshold <= 0.0 || threshold >= 1.0 {
204 false } else {
206 usage_ratio > threshold
207 };
208
209 debug!(
210 "Compaction check: {} / {} tokens ({:.1}%), threshold: {:.1}%, needs compaction: {}, source: {}",
211 current_tokens,
212 context_limit,
213 usage_ratio * 100.0,
214 threshold * 100.0,
215 needs_compaction,
216 token_source
217 );
218
219 Ok(needs_compaction)
220}
221
222fn filter_tool_responses<'a>(messages: &[&'a Message], remove_percent: u32) -> Vec<&'a Message> {
223 fn has_tool_response(msg: &Message) -> bool {
224 msg.content
225 .iter()
226 .any(|c| matches!(c, MessageContent::ToolResponse(_)))
227 }
228
229 if remove_percent == 0 {
230 return messages.to_vec();
231 }
232
233 let tool_indices: Vec<usize> = messages
234 .iter()
235 .enumerate()
236 .filter(|(_, msg)| has_tool_response(msg))
237 .map(|(i, _)| i)
238 .collect();
239
240 if tool_indices.is_empty() {
241 return messages.to_vec();
242 }
243
244 let num_to_remove = ((tool_indices.len() * remove_percent as usize) / 100).max(1);
245
246 let middle = tool_indices.len() / 2;
247 let mut indices_to_remove = Vec::new();
248
249 for i in 0..num_to_remove {
251 if i % 2 == 0 {
252 let offset = i / 2;
253 if middle > offset {
254 indices_to_remove.push(tool_indices[middle - offset - 1]);
255 }
256 } else {
257 let offset = i / 2;
258 if middle + offset < tool_indices.len() {
259 indices_to_remove.push(tool_indices[middle + offset]);
260 }
261 }
262 }
263
264 messages
265 .iter()
266 .enumerate()
267 .filter(|(i, _)| !indices_to_remove.contains(i))
268 .map(|(_, msg)| *msg)
269 .collect()
270}
271
272async fn do_compact(
273 provider: &dyn Provider,
274 messages: &[Message],
275) -> Result<(Message, ProviderUsage), anyhow::Error> {
276 let agent_visible_messages: Vec<&Message> = messages
277 .iter()
278 .filter(|msg| msg.is_agent_visible())
279 .collect();
280
281 let removal_percentages = [0, 10, 20, 50, 100];
283
284 for (attempt, &remove_percent) in removal_percentages.iter().enumerate() {
285 let filtered_messages = filter_tool_responses(&agent_visible_messages, remove_percent);
286
287 let messages_text = filtered_messages
288 .iter()
289 .map(|&msg| format_message_for_compacting(msg))
290 .collect::<Vec<_>>()
291 .join("\n");
292
293 let context = SummarizeContext {
294 messages: messages_text,
295 };
296
297 let system_prompt = render_global_file("summarize_oneshot.md", &context)?;
298
299 let user_message = Message::user()
300 .with_text("Please summarize the conversation history provided in the system prompt.");
301 let summarization_request = vec![user_message];
302
303 match provider
304 .complete_fast(&system_prompt, &summarization_request, &[])
305 .await
306 {
307 Ok((mut response, mut provider_usage)) => {
308 response.role = Role::User;
309
310 provider_usage
311 .ensure_tokens(&system_prompt, &summarization_request, &response, &[])
312 .await
313 .map_err(|e| anyhow::anyhow!("Failed to ensure usage tokens: {}", e))?;
314
315 return Ok((response, provider_usage));
316 }
317 Err(e) => {
318 if matches!(e, ProviderError::ContextLengthExceeded(_)) {
319 if attempt < removal_percentages.len() - 1 {
320 continue;
321 } else {
322 return Err(anyhow::anyhow!(
323 "Failed to compact: context limit exceeded even after removing all tool responses"
324 ));
325 }
326 }
327 return Err(e.into());
328 }
329 }
330 }
331
332 Err(anyhow::anyhow!(
333 "Unexpected: exhausted all attempts without returning"
334 ))
335}
336
337fn format_message_for_compacting(msg: &Message) -> String {
338 let content_parts: Vec<String> = msg
339 .content
340 .iter()
341 .map(|content| match content {
342 MessageContent::Text(text) => text.text.clone(),
343 MessageContent::Image(img) => format!("[image: {}]", img.mime_type),
344 MessageContent::ToolRequest(req) => {
345 if let Ok(call) = &req.tool_call {
346 format!(
347 "tool_request({}): {}",
348 call.name,
349 serde_json::to_string_pretty(&call.arguments)
350 .unwrap_or_else(|_| "<<invalid json>>".to_string())
351 )
352 } else {
353 "tool_request: [error]".to_string()
354 }
355 }
356 MessageContent::ToolResponse(res) => {
357 if let Ok(result) = &res.tool_result {
358 let text_items: Vec<String> = result
359 .content
360 .iter()
361 .filter_map(|content| {
362 content.as_text().map(|text_str| text_str.text.clone())
363 })
364 .collect();
365
366 if !text_items.is_empty() {
367 format!("tool_response: {}", text_items.join("\n"))
368 } else {
369 "tool_response: [non-text content]".to_string()
370 }
371 } else {
372 "tool_response: [error]".to_string()
373 }
374 }
375 MessageContent::ToolConfirmationRequest(req) => {
376 format!("tool_confirmation_request: {}", req.tool_name)
377 }
378 MessageContent::ActionRequired(action) => match &action.data {
379 ActionRequiredData::ToolConfirmation { tool_name, .. } => {
380 format!("action_required(tool_confirmation): {}", tool_name)
381 }
382 ActionRequiredData::Elicitation { message, .. } => {
383 format!("action_required(elicitation): {}", message)
384 }
385 ActionRequiredData::ElicitationResponse { id, .. } => {
386 format!("action_required(elicitation_response): {}", id)
387 }
388 },
389 MessageContent::FrontendToolRequest(req) => {
390 if let Ok(call) = &req.tool_call {
391 format!("frontend_tool_request: {}", call.name)
392 } else {
393 "frontend_tool_request: [error]".to_string()
394 }
395 }
396 MessageContent::Thinking(thinking) => format!("thinking: {}", thinking.thinking),
397 MessageContent::RedactedThinking(_) => "redacted_thinking".to_string(),
398 MessageContent::SystemNotification(notification) => {
399 format!("system_notification: {}", notification.msg)
400 }
401 })
402 .collect();
403
404 let role_str = match msg.role {
405 Role::User => "user",
406 Role::Assistant => "assistant",
407 };
408
409 if content_parts.is_empty() {
410 format!("[{}]: <empty message>", role_str)
411 } else {
412 format!("[{}]: {}", role_str, content_parts.join("\n"))
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use crate::{
420 model::ModelConfig,
421 providers::{
422 base::{ProviderMetadata, Usage},
423 errors::ProviderError,
424 },
425 };
426 use async_trait::async_trait;
427 use rmcp::model::{AnnotateAble, CallToolRequestParam, RawContent, Tool};
428
429 struct MockProvider {
430 message: Message,
431 config: ModelConfig,
432 max_tool_responses: Option<usize>,
433 }
434
435 impl MockProvider {
436 fn new(message: Message, context_limit: usize) -> Self {
437 Self {
438 message,
439 config: ModelConfig {
440 model_name: "test".to_string(),
441 context_limit: Some(context_limit),
442 temperature: None,
443 max_tokens: None,
444 toolshim: false,
445 toolshim_model: None,
446 fast_model: None,
447 },
448 max_tool_responses: None,
449 }
450 }
451
452 fn with_max_tool_responses(mut self, max: usize) -> Self {
453 self.max_tool_responses = Some(max);
454 self
455 }
456 }
457
458 #[async_trait]
459 impl Provider for MockProvider {
460 fn metadata() -> ProviderMetadata {
461 ProviderMetadata::new("mock", "", "", "", vec![""], "", vec![])
462 }
463
464 fn get_name(&self) -> &str {
465 "mock"
466 }
467
468 async fn complete_with_model(
469 &self,
470 _model_config: &ModelConfig,
471 _system: &str,
472 messages: &[Message],
473 _tools: &[Tool],
474 ) -> Result<(Message, ProviderUsage), ProviderError> {
475 if let Some(max) = self.max_tool_responses {
477 let tool_response_count = messages
478 .iter()
479 .filter(|m| {
480 m.content
481 .iter()
482 .any(|c| matches!(c, MessageContent::ToolResponse(_)))
483 })
484 .count();
485
486 if tool_response_count > max {
487 return Err(ProviderError::ContextLengthExceeded(format!(
488 "Too many tool responses: {} > {}",
489 tool_response_count, max
490 )));
491 }
492 }
493
494 Ok((
495 self.message.clone(),
496 ProviderUsage::new("mock-model".to_string(), Usage::default()),
497 ))
498 }
499
500 fn get_model_config(&self) -> ModelConfig {
501 self.config.clone()
502 }
503 }
504
505 #[tokio::test]
506 async fn test_keeps_tool_request() {
507 let response_message = Message::assistant().with_text("<mock summary>");
508 let provider = MockProvider::new(response_message, 1);
509 let basic_conversation = vec![
510 Message::user().with_text("read hello.txt"),
511 Message::assistant().with_tool_request(
512 "tool_0",
513 Ok(CallToolRequestParam {
514 name: "read_file".into(),
515 arguments: None,
516 }),
517 ),
518 Message::user().with_tool_response(
519 "tool_0",
520 Ok(rmcp::model::CallToolResult {
521 content: vec![RawContent::text("hello, world").no_annotation()],
522 structured_content: None,
523 is_error: Some(false),
524 meta: None,
525 }),
526 ),
527 ];
528
529 let conversation = Conversation::new_unvalidated(basic_conversation);
530 let (compacted_conversation, _usage) = compact_messages(&provider, &conversation, false)
531 .await
532 .unwrap();
533
534 let agent_conversation = compacted_conversation.agent_visible_messages();
535
536 let _ = Conversation::new(agent_conversation)
537 .expect("compaction should produce a valid conversation");
538 }
539
540 #[tokio::test]
541 async fn test_progressive_removal_on_context_exceeded() {
542 let response_message = Message::assistant().with_text("<mock summary>");
543 let provider = MockProvider::new(response_message, 1000).with_max_tool_responses(2);
545
546 let mut messages = vec![Message::user().with_text("start")];
548 for i in 0..10 {
549 messages.push(Message::assistant().with_tool_request(
550 format!("tool_{}", i),
551 Ok(CallToolRequestParam {
552 name: "read_file".into(),
553 arguments: None,
554 }),
555 ));
556 messages.push(Message::user().with_tool_response(
557 format!("tool_{}", i),
558 Ok(rmcp::model::CallToolResult {
559 content: vec![RawContent::text(format!("response{}", i)).no_annotation()],
560 structured_content: None,
561 is_error: Some(false),
562 meta: None,
563 }),
564 ));
565 }
566
567 let conversation = Conversation::new_unvalidated(messages);
568 let result = compact_messages(&provider, &conversation, false).await;
569
570 assert!(
572 result.is_ok(),
573 "Should succeed with progressive removal: {:?}",
574 result.err()
575 );
576 }
577}