Skip to main content

aster/context_mgmt/
mod.rs

1use crate::conversation::message::{ActionRequiredData, MessageMetadata};
2use crate::conversation::message::{Message, MessageContent};
3use crate::conversation::{merge_consecutive_messages, Conversation};
4use crate::prompt_template::render_global_file;
5use crate::providers::base::{Provider, ProviderUsage};
6use crate::providers::errors::ProviderError;
7use crate::{config::Config, token_counter::create_token_counter};
8use anyhow::Result;
9use rmcp::model::Role;
10use serde::Serialize;
11use tracing::{debug, info};
12
13pub const DEFAULT_COMPACTION_THRESHOLD: f64 = 0.8;
14
15const CONVERSATION_CONTINUATION_TEXT: &str =
16    "The previous message contains a summary that was prepared because a context limit was reached.
17Do not mention that you read a summary or that conversation summarization occurred.
18Just continue the conversation naturally based on the summarized context";
19
20const TOOL_LOOP_CONTINUATION_TEXT: &str =
21    "The previous message contains a summary that was prepared because a context limit was reached.
22Do not mention that you read a summary or that conversation summarization occurred.
23Continue calling tools as necessary to complete the task.";
24
25const MANUAL_COMPACT_CONTINUATION_TEXT: &str =
26    "The previous message contains a summary that was prepared at the user's request.
27Do not mention that you read a summary or that conversation summarization occurred.
28Just continue the conversation naturally based on the summarized context";
29
30#[derive(Serialize)]
31struct SummarizeContext {
32    messages: String,
33}
34
35/// Compact messages by summarizing them
36///
37/// This function performs the actual compaction by summarizing messages and updating
38/// their visibility metadata. It does not check thresholds - use `check_if_compaction_needed`
39/// first to determine if compaction is necessary.
40///
41/// # Arguments
42/// * `provider` - The provider to use for summarization
43/// * `conversation` - The current conversation history
44/// * `manual_compact` - If true, this is a manual compaction (don't preserve user message)
45///
46/// # Returns
47/// * A tuple containing:
48///   - `Conversation`: The compacted messages
49///   - `ProviderUsage`: Provider usage from summarization
50pub async fn compact_messages(
51    provider: &dyn Provider,
52    conversation: &Conversation,
53    manual_compact: bool,
54) -> Result<(Conversation, ProviderUsage)> {
55    info!("Performing message compaction");
56
57    let messages = conversation.messages();
58
59    let has_text_only = |msg: &Message| {
60        let has_text = msg
61            .content
62            .iter()
63            .any(|c| matches!(c, MessageContent::Text(_)));
64        let has_tool_content = msg.content.iter().any(|c| {
65            matches!(
66                c,
67                MessageContent::ToolRequest(_) | MessageContent::ToolResponse(_)
68            )
69        });
70        has_text && !has_tool_content
71    };
72
73    let extract_text = |msg: &Message| -> Option<String> {
74        let text_parts: Vec<String> = msg
75            .content
76            .iter()
77            .filter_map(|c| {
78                if let MessageContent::Text(text) = c {
79                    Some(text.text.clone())
80                } else {
81                    None
82                }
83            })
84            .collect();
85
86        if text_parts.is_empty() {
87            None
88        } else {
89            Some(text_parts.join("\n"))
90        }
91    };
92
93    // Find and preserve the most recent user message for non-manual compacts
94    let (preserved_user_message, is_most_recent) = if !manual_compact {
95        let found_msg = messages.iter().enumerate().rev().find(|(_, msg)| {
96            msg.is_agent_visible()
97                && matches!(msg.role, rmcp::model::Role::User)
98                && has_text_only(msg)
99        });
100
101        if let Some((idx, msg)) = found_msg {
102            let is_last = idx == messages.len() - 1;
103            (Some(msg.clone()), is_last)
104        } else {
105            (None, false)
106        }
107    } else {
108        (None, false)
109    };
110
111    let messages_to_compact = messages.as_slice();
112
113    let (summary_message, summarization_usage) = do_compact(provider, messages_to_compact).await?;
114
115    // Create the final message list with updated visibility metadata:
116    // 1. Original messages become user_visible but not agent_visible
117    // 2. Summary message becomes agent_visible but not user_visible
118    // 3. Assistant messages to continue the conversation are also agent_visible but not user_visible
119    let mut final_messages = Vec::new();
120
121    for (idx, msg) in messages_to_compact.iter().enumerate() {
122        let updated_metadata = if is_most_recent
123            && idx == messages_to_compact.len() - 1
124            && preserved_user_message.is_some()
125        {
126            // This is the most recent message and we're preserving it by adding a fresh copy
127            MessageMetadata::invisible()
128        } else {
129            msg.metadata.with_agent_invisible()
130        };
131        let updated_msg = msg.clone().with_metadata(updated_metadata);
132        final_messages.push(updated_msg);
133    }
134
135    let summary_msg = summary_message.with_metadata(MessageMetadata::agent_only());
136
137    let mut continuation_messages = vec![summary_msg];
138
139    let continuation_text = if manual_compact {
140        MANUAL_COMPACT_CONTINUATION_TEXT
141    } else if is_most_recent {
142        CONVERSATION_CONTINUATION_TEXT
143    } else {
144        TOOL_LOOP_CONTINUATION_TEXT
145    };
146
147    let continuation_msg = Message::assistant()
148        .with_text(continuation_text)
149        .with_metadata(MessageMetadata::agent_only());
150    continuation_messages.push(continuation_msg);
151
152    let (merged_continuation, _issues) = merge_consecutive_messages(continuation_messages);
153    final_messages.extend(merged_continuation);
154
155    if let Some(user_msg) = preserved_user_message {
156        if let Some(text) = extract_text(&user_msg) {
157            final_messages.push(Message::user().with_text(&text));
158        }
159    }
160
161    Ok((
162        Conversation::new_unvalidated(final_messages),
163        summarization_usage,
164    ))
165}
166
167/// Check if messages exceed the auto-compaction threshold
168pub async fn check_if_compaction_needed(
169    provider: &dyn Provider,
170    conversation: &Conversation,
171    threshold_override: Option<f64>,
172    session: &crate::session::Session,
173) -> Result<bool> {
174    let messages = conversation.messages();
175    let config = Config::global();
176    let threshold = threshold_override.unwrap_or_else(|| {
177        config
178            .get_param::<f64>("ASTER_AUTO_COMPACT_THRESHOLD")
179            .unwrap_or(DEFAULT_COMPACTION_THRESHOLD)
180    });
181
182    let context_limit = provider.get_model_config().context_limit();
183
184    let (current_tokens, token_source) = match session.total_tokens {
185        Some(tokens) => (tokens as usize, "session metadata"),
186        None => {
187            let token_counter = create_token_counter()
188                .await
189                .map_err(|e| anyhow::anyhow!("Failed to create token counter: {}", e))?;
190
191            let token_counts: Vec<_> = messages
192                .iter()
193                .filter(|m| m.is_agent_visible())
194                .map(|msg| token_counter.count_chat_tokens("", std::slice::from_ref(msg), &[]))
195                .collect();
196
197            (token_counts.iter().sum(), "estimated")
198        }
199    };
200
201    let usage_ratio = current_tokens as f64 / context_limit as f64;
202
203    let needs_compaction = if threshold <= 0.0 || threshold >= 1.0 {
204        false // Auto-compact is disabled.
205    } else {
206        usage_ratio > threshold
207    };
208
209    debug!(
210        "Compaction check: {} / {} tokens ({:.1}%), threshold: {:.1}%, needs compaction: {}, source: {}",
211        current_tokens,
212        context_limit,
213        usage_ratio * 100.0,
214        threshold * 100.0,
215        needs_compaction,
216        token_source
217    );
218
219    Ok(needs_compaction)
220}
221
222fn filter_tool_responses<'a>(messages: &[&'a Message], remove_percent: u32) -> Vec<&'a Message> {
223    fn has_tool_response(msg: &Message) -> bool {
224        msg.content
225            .iter()
226            .any(|c| matches!(c, MessageContent::ToolResponse(_)))
227    }
228
229    if remove_percent == 0 {
230        return messages.to_vec();
231    }
232
233    let tool_indices: Vec<usize> = messages
234        .iter()
235        .enumerate()
236        .filter(|(_, msg)| has_tool_response(msg))
237        .map(|(i, _)| i)
238        .collect();
239
240    if tool_indices.is_empty() {
241        return messages.to_vec();
242    }
243
244    let num_to_remove = ((tool_indices.len() * remove_percent as usize) / 100).max(1);
245
246    let middle = tool_indices.len() / 2;
247    let mut indices_to_remove = Vec::new();
248
249    // Middle out
250    for i in 0..num_to_remove {
251        if i % 2 == 0 {
252            let offset = i / 2;
253            if middle > offset {
254                indices_to_remove.push(tool_indices[middle - offset - 1]);
255            }
256        } else {
257            let offset = i / 2;
258            if middle + offset < tool_indices.len() {
259                indices_to_remove.push(tool_indices[middle + offset]);
260            }
261        }
262    }
263
264    messages
265        .iter()
266        .enumerate()
267        .filter(|(i, _)| !indices_to_remove.contains(i))
268        .map(|(_, msg)| *msg)
269        .collect()
270}
271
272async fn do_compact(
273    provider: &dyn Provider,
274    messages: &[Message],
275) -> Result<(Message, ProviderUsage), anyhow::Error> {
276    let agent_visible_messages: Vec<&Message> = messages
277        .iter()
278        .filter(|msg| msg.is_agent_visible())
279        .collect();
280
281    // Try progressively removing more tool response messages from the middle to reduce context length
282    let removal_percentages = [0, 10, 20, 50, 100];
283
284    for (attempt, &remove_percent) in removal_percentages.iter().enumerate() {
285        let filtered_messages = filter_tool_responses(&agent_visible_messages, remove_percent);
286
287        let messages_text = filtered_messages
288            .iter()
289            .map(|&msg| format_message_for_compacting(msg))
290            .collect::<Vec<_>>()
291            .join("\n");
292
293        let context = SummarizeContext {
294            messages: messages_text,
295        };
296
297        let system_prompt = render_global_file("summarize_oneshot.md", &context)?;
298
299        let user_message = Message::user()
300            .with_text("Please summarize the conversation history provided in the system prompt.");
301        let summarization_request = vec![user_message];
302
303        match provider
304            .complete_fast(&system_prompt, &summarization_request, &[])
305            .await
306        {
307            Ok((mut response, mut provider_usage)) => {
308                response.role = Role::User;
309
310                provider_usage
311                    .ensure_tokens(&system_prompt, &summarization_request, &response, &[])
312                    .await
313                    .map_err(|e| anyhow::anyhow!("Failed to ensure usage tokens: {}", e))?;
314
315                return Ok((response, provider_usage));
316            }
317            Err(e) => {
318                if matches!(e, ProviderError::ContextLengthExceeded(_)) {
319                    if attempt < removal_percentages.len() - 1 {
320                        continue;
321                    } else {
322                        return Err(anyhow::anyhow!(
323                            "Failed to compact: context limit exceeded even after removing all tool responses"
324                        ));
325                    }
326                }
327                return Err(e.into());
328            }
329        }
330    }
331
332    Err(anyhow::anyhow!(
333        "Unexpected: exhausted all attempts without returning"
334    ))
335}
336
337fn format_message_for_compacting(msg: &Message) -> String {
338    let content_parts: Vec<String> = msg
339        .content
340        .iter()
341        .map(|content| match content {
342            MessageContent::Text(text) => text.text.clone(),
343            MessageContent::Image(img) => format!("[image: {}]", img.mime_type),
344            MessageContent::ToolRequest(req) => {
345                if let Ok(call) = &req.tool_call {
346                    format!(
347                        "tool_request({}): {}",
348                        call.name,
349                        serde_json::to_string_pretty(&call.arguments)
350                            .unwrap_or_else(|_| "<<invalid json>>".to_string())
351                    )
352                } else {
353                    "tool_request: [error]".to_string()
354                }
355            }
356            MessageContent::ToolResponse(res) => {
357                if let Ok(result) = &res.tool_result {
358                    let text_items: Vec<String> = result
359                        .content
360                        .iter()
361                        .filter_map(|content| {
362                            content.as_text().map(|text_str| text_str.text.clone())
363                        })
364                        .collect();
365
366                    if !text_items.is_empty() {
367                        format!("tool_response: {}", text_items.join("\n"))
368                    } else {
369                        "tool_response: [non-text content]".to_string()
370                    }
371                } else {
372                    "tool_response: [error]".to_string()
373                }
374            }
375            MessageContent::ToolConfirmationRequest(req) => {
376                format!("tool_confirmation_request: {}", req.tool_name)
377            }
378            MessageContent::ActionRequired(action) => match &action.data {
379                ActionRequiredData::ToolConfirmation { tool_name, .. } => {
380                    format!("action_required(tool_confirmation): {}", tool_name)
381                }
382                ActionRequiredData::Elicitation { message, .. } => {
383                    format!("action_required(elicitation): {}", message)
384                }
385                ActionRequiredData::ElicitationResponse { id, .. } => {
386                    format!("action_required(elicitation_response): {}", id)
387                }
388            },
389            MessageContent::FrontendToolRequest(req) => {
390                if let Ok(call) = &req.tool_call {
391                    format!("frontend_tool_request: {}", call.name)
392                } else {
393                    "frontend_tool_request: [error]".to_string()
394                }
395            }
396            MessageContent::Thinking(thinking) => format!("thinking: {}", thinking.thinking),
397            MessageContent::RedactedThinking(_) => "redacted_thinking".to_string(),
398            MessageContent::SystemNotification(notification) => {
399                format!("system_notification: {}", notification.msg)
400            }
401        })
402        .collect();
403
404    let role_str = match msg.role {
405        Role::User => "user",
406        Role::Assistant => "assistant",
407    };
408
409    if content_parts.is_empty() {
410        format!("[{}]: <empty message>", role_str)
411    } else {
412        format!("[{}]: {}", role_str, content_parts.join("\n"))
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use crate::{
420        model::ModelConfig,
421        providers::{
422            base::{ProviderMetadata, Usage},
423            errors::ProviderError,
424        },
425    };
426    use async_trait::async_trait;
427    use rmcp::model::{AnnotateAble, CallToolRequestParam, RawContent, Tool};
428
429    struct MockProvider {
430        message: Message,
431        config: ModelConfig,
432        max_tool_responses: Option<usize>,
433    }
434
435    impl MockProvider {
436        fn new(message: Message, context_limit: usize) -> Self {
437            Self {
438                message,
439                config: ModelConfig {
440                    model_name: "test".to_string(),
441                    context_limit: Some(context_limit),
442                    temperature: None,
443                    max_tokens: None,
444                    toolshim: false,
445                    toolshim_model: None,
446                    fast_model: None,
447                },
448                max_tool_responses: None,
449            }
450        }
451
452        fn with_max_tool_responses(mut self, max: usize) -> Self {
453            self.max_tool_responses = Some(max);
454            self
455        }
456    }
457
458    #[async_trait]
459    impl Provider for MockProvider {
460        fn metadata() -> ProviderMetadata {
461            ProviderMetadata::new("mock", "", "", "", vec![""], "", vec![])
462        }
463
464        fn get_name(&self) -> &str {
465            "mock"
466        }
467
468        async fn complete_with_model(
469            &self,
470            _model_config: &ModelConfig,
471            _system: &str,
472            messages: &[Message],
473            _tools: &[Tool],
474        ) -> Result<(Message, ProviderUsage), ProviderError> {
475            // If max_tool_responses is set, fail if we have too many
476            if let Some(max) = self.max_tool_responses {
477                let tool_response_count = messages
478                    .iter()
479                    .filter(|m| {
480                        m.content
481                            .iter()
482                            .any(|c| matches!(c, MessageContent::ToolResponse(_)))
483                    })
484                    .count();
485
486                if tool_response_count > max {
487                    return Err(ProviderError::ContextLengthExceeded(format!(
488                        "Too many tool responses: {} > {}",
489                        tool_response_count, max
490                    )));
491                }
492            }
493
494            Ok((
495                self.message.clone(),
496                ProviderUsage::new("mock-model".to_string(), Usage::default()),
497            ))
498        }
499
500        fn get_model_config(&self) -> ModelConfig {
501            self.config.clone()
502        }
503    }
504
505    #[tokio::test]
506    async fn test_keeps_tool_request() {
507        let response_message = Message::assistant().with_text("<mock summary>");
508        let provider = MockProvider::new(response_message, 1);
509        let basic_conversation = vec![
510            Message::user().with_text("read hello.txt"),
511            Message::assistant().with_tool_request(
512                "tool_0",
513                Ok(CallToolRequestParam {
514                    name: "read_file".into(),
515                    arguments: None,
516                }),
517            ),
518            Message::user().with_tool_response(
519                "tool_0",
520                Ok(rmcp::model::CallToolResult {
521                    content: vec![RawContent::text("hello, world").no_annotation()],
522                    structured_content: None,
523                    is_error: Some(false),
524                    meta: None,
525                }),
526            ),
527        ];
528
529        let conversation = Conversation::new_unvalidated(basic_conversation);
530        let (compacted_conversation, _usage) = compact_messages(&provider, &conversation, false)
531            .await
532            .unwrap();
533
534        let agent_conversation = compacted_conversation.agent_visible_messages();
535
536        let _ = Conversation::new(agent_conversation)
537            .expect("compaction should produce a valid conversation");
538    }
539
540    #[tokio::test]
541    async fn test_progressive_removal_on_context_exceeded() {
542        let response_message = Message::assistant().with_text("<mock summary>");
543        // Set max to 2 tool responses - will trigger progressive removal
544        let provider = MockProvider::new(response_message, 1000).with_max_tool_responses(2);
545
546        // Create a conversation with many tool responses
547        let mut messages = vec![Message::user().with_text("start")];
548        for i in 0..10 {
549            messages.push(Message::assistant().with_tool_request(
550                format!("tool_{}", i),
551                Ok(CallToolRequestParam {
552                    name: "read_file".into(),
553                    arguments: None,
554                }),
555            ));
556            messages.push(Message::user().with_tool_response(
557                format!("tool_{}", i),
558                Ok(rmcp::model::CallToolResult {
559                    content: vec![RawContent::text(format!("response{}", i)).no_annotation()],
560                    structured_content: None,
561                    is_error: Some(false),
562                    meta: None,
563                }),
564            ));
565        }
566
567        let conversation = Conversation::new_unvalidated(messages);
568        let result = compact_messages(&provider, &conversation, false).await;
569
570        // Should succeed after progressive removal
571        assert!(
572            result.is_ok(),
573            "Should succeed with progressive removal: {:?}",
574            result.err()
575        );
576    }
577}