use tracing::{debug, info, warn};
use crate::reasoning::conversation::{Conversation, MessageRole};
#[derive(Debug, Clone, Default)]
pub enum ContextStrategy {
#[default]
SlidingWindow,
ObservationMasking,
AnchoredSummary {
recent_count: usize,
},
}
pub trait ContextManager: Send + Sync {
fn manage_context(&self, conversation: &mut Conversation, max_tokens: usize);
fn strategy_name(&self) -> &str;
}
pub struct DefaultContextManager {
strategy: ContextStrategy,
}
impl DefaultContextManager {
pub fn new(strategy: ContextStrategy) -> Self {
Self { strategy }
}
fn apply_sliding_window(conversation: &mut Conversation, max_tokens: usize) {
conversation.truncate_to_budget(max_tokens);
}
fn apply_observation_masking(conversation: &mut Conversation, max_tokens: usize) {
let estimated = conversation.estimate_tokens();
if estimated <= max_tokens {
return;
}
info!(
estimated_tokens = estimated,
max_tokens,
over_by = estimated - max_tokens,
"ObservationMasking: context exceeds budget, masking old tool results"
);
let messages = conversation.messages().to_vec();
let total = messages.len();
if total <= 3 {
warn!("ObservationMasking: only {} messages, cannot mask", total);
return;
}
let keep_recent = 6.min(total);
let mut new_messages = Vec::new();
let mut masked_count = 0usize;
for (i, msg) in messages.iter().enumerate() {
if i >= total - keep_recent {
new_messages.push(msg.clone());
} else if msg.role == MessageRole::Tool {
let mut masked = msg.clone();
masked.content = format!(
"[Previous {} result omitted for context management]",
msg.tool_name.as_deref().unwrap_or("tool")
);
masked_count += 1;
new_messages.push(masked);
} else {
new_messages.push(msg.clone());
}
}
info!(
masked_tool_results = masked_count,
kept_recent = keep_recent,
total_messages = total,
"ObservationMasking: masked old tool results"
);
*conversation = Conversation::new();
for msg in new_messages {
conversation.push(msg);
}
if conversation.estimate_tokens() > max_tokens {
let still_estimated = conversation.estimate_tokens();
warn!(
still_estimated,
max_tokens, "ObservationMasking insufficient, falling back to SlidingWindow"
);
Self::apply_sliding_window(conversation, max_tokens);
}
}
fn apply_anchored_summary(
conversation: &mut Conversation,
max_tokens: usize,
recent_count: usize,
) {
if conversation.estimate_tokens() <= max_tokens {
return;
}
let messages = conversation.messages().to_vec();
let total = messages.len();
let mut anchor_end = 0;
for (i, msg) in messages.iter().enumerate() {
if msg.role == MessageRole::System || (msg.role == MessageRole::User && i <= 1) {
anchor_end = i + 1;
} else {
break;
}
}
let keep_recent = recent_count.min(total.saturating_sub(anchor_end));
let recent_start = total.saturating_sub(keep_recent);
let mut new_messages: Vec<_> = messages[..anchor_end].to_vec();
if anchor_end < recent_start {
let middle_count = recent_start - anchor_end;
let tool_calls_in_middle = messages[anchor_end..recent_start]
.iter()
.filter(|m| !m.tool_calls.is_empty())
.count();
let tool_results_in_middle = messages[anchor_end..recent_start]
.iter()
.filter(|m| m.role == MessageRole::Tool)
.count();
let summary = format!(
"[Context summary: {} messages omitted ({} tool calls, {} tool results). The conversation continued with the agent working on the task.]",
middle_count, tool_calls_in_middle, tool_results_in_middle
);
new_messages.push(crate::reasoning::conversation::ConversationMessage::user(
summary,
));
}
new_messages.extend(messages[recent_start..].to_vec());
*conversation = Conversation::new();
for msg in new_messages {
conversation.push(msg);
}
if conversation.estimate_tokens() > max_tokens {
Self::apply_sliding_window(conversation, max_tokens);
}
}
}
impl Default for DefaultContextManager {
fn default() -> Self {
Self::new(ContextStrategy::SlidingWindow)
}
}
impl ContextManager for DefaultContextManager {
fn manage_context(&self, conversation: &mut Conversation, max_tokens: usize) {
let before_tokens = conversation.estimate_tokens();
let before_len = conversation.len();
debug!(
strategy = self.strategy_name(),
estimated_tokens = before_tokens,
max_tokens,
message_count = before_len,
"Context management check"
);
match &self.strategy {
ContextStrategy::SlidingWindow => {
Self::apply_sliding_window(conversation, max_tokens);
}
ContextStrategy::ObservationMasking => {
Self::apply_observation_masking(conversation, max_tokens);
}
ContextStrategy::AnchoredSummary { recent_count } => {
Self::apply_anchored_summary(conversation, max_tokens, *recent_count);
}
}
let after_tokens = conversation.estimate_tokens();
let after_len = conversation.len();
if after_tokens < before_tokens {
info!(
strategy = self.strategy_name(),
before_tokens,
after_tokens,
tokens_saved = before_tokens - after_tokens,
messages_before = before_len,
messages_after = after_len,
"Context compaction triggered"
);
}
}
fn strategy_name(&self) -> &str {
match self.strategy {
ContextStrategy::SlidingWindow => "sliding_window",
ContextStrategy::ObservationMasking => "observation_masking",
ContextStrategy::AnchoredSummary { .. } => "anchored_summary",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reasoning::conversation::{ConversationMessage, ToolCall};
fn build_long_conversation() -> Conversation {
let mut conv = Conversation::with_system("You are a research agent.");
for i in 0..20 {
conv.push(ConversationMessage::user(format!(
"Research question {} about a topic that requires multiple paragraphs of text to describe properly",
i
)));
conv.push(ConversationMessage::assistant_tool_calls(vec![ToolCall {
id: format!("call_{}", i),
name: "web_search".into(),
arguments: format!(r#"{{"query": "topic {} detailed information"}}"#, i),
}]));
conv.push(ConversationMessage::tool_result(
format!("call_{}", i),
"web_search",
format!("Here are the detailed results for query {}. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", i),
));
conv.push(ConversationMessage::assistant(format!(
"Based on the search results for question {}, I found that the topic involves multiple interesting aspects that we should discuss in detail.",
i
)));
}
conv
}
#[test]
fn test_sliding_window_no_truncation_needed() {
let mgr = DefaultContextManager::new(ContextStrategy::SlidingWindow);
let mut conv = Conversation::with_system("sys");
conv.push(ConversationMessage::user("hi"));
conv.push(ConversationMessage::assistant("hello"));
let original_tokens = conv.estimate_tokens();
mgr.manage_context(&mut conv, 10000);
assert_eq!(conv.estimate_tokens(), original_tokens);
}
#[test]
fn test_sliding_window_truncation() {
let mgr = DefaultContextManager::new(ContextStrategy::SlidingWindow);
let mut conv = build_long_conversation();
let original_len = conv.len();
mgr.manage_context(&mut conv, 200);
assert!(conv.len() < original_len);
assert!(conv.estimate_tokens() <= 200);
assert_eq!(conv.messages()[0].role, MessageRole::System);
}
#[test]
fn test_observation_masking() {
let mgr = DefaultContextManager::new(ContextStrategy::ObservationMasking);
let mut conv = build_long_conversation();
mgr.manage_context(&mut conv, 500);
let mut found_masked = false;
for msg in conv.messages() {
if msg.role == MessageRole::Tool && msg.content.contains("omitted") {
found_masked = true;
break;
}
}
assert!(found_masked || conv.estimate_tokens() <= 500);
}
#[test]
fn test_anchored_summary() {
let mgr = DefaultContextManager::new(ContextStrategy::AnchoredSummary { recent_count: 6 });
let mut conv = build_long_conversation();
let original_len = conv.len();
mgr.manage_context(&mut conv, 500);
assert!(conv.len() < original_len);
assert_eq!(conv.messages()[0].role, MessageRole::System);
let has_summary = conv
.messages()
.iter()
.any(|m| m.content.contains("Context summary"));
assert!(has_summary || conv.estimate_tokens() <= 500);
}
#[test]
fn test_strategy_name() {
assert_eq!(
DefaultContextManager::new(ContextStrategy::SlidingWindow).strategy_name(),
"sliding_window"
);
assert_eq!(
DefaultContextManager::new(ContextStrategy::ObservationMasking).strategy_name(),
"observation_masking"
);
assert_eq!(
DefaultContextManager::new(ContextStrategy::AnchoredSummary { recent_count: 4 })
.strategy_name(),
"anchored_summary"
);
}
#[test]
fn test_context_within_budget_untouched() {
let mgr = DefaultContextManager::new(ContextStrategy::ObservationMasking);
let mut conv = Conversation::with_system("sys");
conv.push(ConversationMessage::user("short"));
let before = conv.len();
mgr.manage_context(&mut conv, 100_000);
assert_eq!(conv.len(), before);
}
}