use infernum_core::Message;
use super::types::*;
#[derive(Debug)]
pub struct ContextWindowManager {
messages: Vec<Message>,
state_snapshot: LoopStateSnapshot,
estimated_tokens: u32,
max_context_tokens: u32,
compressions: Vec<CompressionEvent>,
original_tokens: u32,
}
impl ContextWindowManager {
pub fn new(max_context_tokens: u32) -> Self {
Self {
messages: Vec::new(),
state_snapshot: LoopStateSnapshot {
iteration: 0,
max_iterations: 0,
token_budget_remaining: max_context_tokens,
tools_available: Vec::new(),
context_pressure: 0.0,
},
estimated_tokens: 0,
max_context_tokens,
compressions: Vec::new(),
original_tokens: 0,
}
}
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn estimated_tokens(&self) -> u32 {
self.estimated_tokens
}
pub fn pressure(&self) -> f32 {
if self.max_context_tokens == 0 {
return 1.0;
}
(self.estimated_tokens as f32 / self.max_context_tokens as f32).clamp(0.0, 1.0)
}
pub fn is_under_pressure(&self) -> bool {
self.pressure() > 0.8
}
pub fn compressions(&self) -> &[CompressionEvent] {
&self.compressions
}
pub fn update_state(&mut self, snapshot: LoopStateSnapshot) {
self.state_snapshot = snapshot;
}
pub fn set_initial_messages(&mut self, messages: Vec<Message>) {
self.estimated_tokens = estimate_tokens_for_messages(&messages);
self.original_tokens = self.estimated_tokens;
self.messages = messages;
}
pub fn push_message(&mut self, message: Message) {
let tokens = estimate_tokens(&message.content);
self.estimated_tokens = self.estimated_tokens.saturating_add(tokens);
self.original_tokens = self.original_tokens.saturating_add(tokens);
self.messages.push(message);
}
pub fn compress(&mut self, strategy: &CompressionStrategy, iteration: u32) -> u32 {
let before = self.estimated_tokens;
match strategy {
CompressionStrategy::SummarizeOldResults { keep_recent } => {
self.summarize_old_results(*keep_recent);
},
CompressionStrategy::PruneDeadEnds => {
self.prune_dead_ends();
},
CompressionStrategy::CollapseExploration { summary_tokens } => {
self.collapse_exploration(*summary_tokens);
},
CompressionStrategy::AgentDirected => {
},
}
let after = self.estimated_tokens;
let saved = before.saturating_sub(after);
if saved > 0 {
self.compressions.push(CompressionEvent {
strategy: strategy.clone(),
tokens_saved: saved,
at_iteration: iteration,
});
}
saved
}
pub fn snapshot(&self) -> ContextWindow {
let context_messages: Vec<ContextMessage> = self
.messages
.iter()
.map(|m| ContextMessage {
role: format!("{:?}", m.role).to_lowercase(),
content: m.content.clone(),
tool_call_id: m.tool_call_id.clone(),
})
.collect();
ContextWindow {
messages: context_messages,
system_state: self.state_snapshot.clone(),
original_token_count: self.original_tokens,
current_token_count: self.estimated_tokens,
compressions_applied: self.compressions.clone(),
}
}
fn summarize_old_results(&mut self, keep_recent: u32) {
let tool_indices: Vec<usize> = self
.messages
.iter()
.enumerate()
.filter(|(_, m)| m.tool_call_id.is_some())
.map(|(i, _)| i)
.collect();
if tool_indices.len() <= keep_recent as usize {
return; }
let to_summarize = tool_indices.len() - keep_recent as usize;
let mut summarized_count = 0;
for &idx in tool_indices.iter().take(to_summarize) {
let msg = &self.messages[idx];
let original_len = msg.content.len();
let summary = if original_len > 200 {
format!(
"[Summarized] {}... ({} chars truncated)",
&msg.content[..200.min(msg.content.len())],
original_len - 200
)
} else {
continue; };
let old_tokens = estimate_tokens(&self.messages[idx].content);
self.messages[idx].content = summary;
let new_tokens = estimate_tokens(&self.messages[idx].content);
self.estimated_tokens = self
.estimated_tokens
.saturating_sub(old_tokens)
.saturating_add(new_tokens);
summarized_count += 1;
}
if summarized_count > 0 {
tracing::debug!(summarized_count, "Summarized old tool results");
}
}
fn prune_dead_ends(&mut self) {
let original_len = self.messages.len();
let last_assistant_idx = self
.messages
.iter()
.rposition(|m| matches!(m.role, infernum_core::Role::Assistant));
let cutoff = last_assistant_idx.unwrap_or(0);
let mut indices_to_remove = Vec::new();
for (i, msg) in self.messages.iter().enumerate() {
if i >= cutoff {
break;
}
if msg.tool_call_id.is_some() && is_dead_end_result(&msg.content) {
indices_to_remove.push(i);
}
}
let mut tokens_freed = 0u32;
for &idx in indices_to_remove.iter().rev() {
tokens_freed =
tokens_freed.saturating_add(estimate_tokens(&self.messages[idx].content));
self.messages.remove(idx);
}
self.estimated_tokens = self.estimated_tokens.saturating_sub(tokens_freed);
if self.messages.len() < original_len {
tracing::debug!(
pruned = original_len - self.messages.len(),
tokens_freed,
"Pruned dead-end tool results"
);
}
}
fn collapse_exploration(&mut self, _summary_tokens: u32) {
let tool_result_count = self
.messages
.iter()
.filter(|m| m.tool_call_id.is_some())
.count();
if tool_result_count <= 6 {
return;
}
let tool_indices: Vec<usize> = self
.messages
.iter()
.enumerate()
.filter(|(_, m)| m.tool_call_id.is_some())
.map(|(i, _)| i)
.collect();
let keep_count = 4.min(tool_indices.len());
let to_collapse = &tool_indices[..tool_indices.len() - keep_count];
if to_collapse.is_empty() {
return;
}
let mut summary_parts = Vec::new();
let mut tokens_freed = 0u32;
for &idx in to_collapse {
let msg = &self.messages[idx];
let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown");
let snippet = if msg.content.len() > 80 {
format!("{}...", &msg.content[..80])
} else {
msg.content.clone()
};
summary_parts.push(format!("- {tool_id}: {snippet}"));
tokens_freed = tokens_freed.saturating_add(estimate_tokens(&msg.content));
}
let summary = format!(
"[Exploration summary — {} earlier tool results collapsed]\n{}",
to_collapse.len(),
summary_parts.join("\n")
);
for &idx in to_collapse.iter().rev() {
self.messages.remove(idx);
}
let summary_tokens = estimate_tokens(&summary);
tokens_freed = tokens_freed.saturating_sub(summary_tokens);
self.estimated_tokens = self.estimated_tokens.saturating_sub(tokens_freed);
let insert_pos = 1.min(self.messages.len());
self.messages.insert(insert_pos, Message::system(summary));
tracing::debug!(
collapsed = to_collapse.len(),
tokens_freed,
"Collapsed exploration branches"
);
}
}
fn estimate_tokens(text: &str) -> u32 {
(text.len() as u32 / 4).max(1)
}
fn estimate_tokens_for_messages(messages: &[Message]) -> u32 {
messages
.iter()
.map(|m| estimate_tokens(&m.content) + 4) .sum()
}
fn is_dead_end_result(content: &str) -> bool {
let lower = content.to_lowercase();
lower.starts_with("error:")
|| lower.starts_with("no results")
|| lower.starts_with("not found")
|| lower.is_empty()
|| lower == "null"
}
#[cfg(test)]
mod tests {
use super::*;
use infernum_core::Message;
#[test]
fn test_new_context_manager() {
let mgr = ContextWindowManager::new(4096);
assert_eq!(mgr.estimated_tokens(), 0);
assert_eq!(mgr.pressure(), 0.0);
assert!(!mgr.is_under_pressure());
}
#[test]
fn test_push_message_updates_tokens() {
let mut mgr = ContextWindowManager::new(4096);
mgr.push_message(Message::user("Hello, world!"));
assert!(mgr.estimated_tokens() > 0);
}
#[test]
fn test_pressure_calculation() {
let mut mgr = ContextWindowManager::new(100);
mgr.push_message(Message::user(&"x".repeat(360)));
assert!(mgr.pressure() > 0.8);
assert!(mgr.is_under_pressure());
}
#[test]
fn test_set_initial_messages() {
let mut mgr = ContextWindowManager::new(4096);
mgr.set_initial_messages(vec![
Message::system("You are a helper."),
Message::user("Do something."),
]);
assert_eq!(mgr.messages().len(), 2);
assert!(mgr.estimated_tokens() > 0);
}
#[test]
fn test_summarize_old_results_keeps_recent() {
let mut mgr = ContextWindowManager::new(10000);
mgr.push_message(Message::system("sys"));
for i in 0..5 {
let mut msg = Message::tool_result(format!("call_{i}"), &"x".repeat(500));
msg.tool_call_id = Some(format!("call_{i}"));
mgr.push_message(msg);
}
let saved = mgr.compress(
&CompressionStrategy::SummarizeOldResults { keep_recent: 2 },
1,
);
assert!(saved > 0);
let tool_msgs: Vec<_> = mgr
.messages()
.iter()
.filter(|m| m.tool_call_id.is_some())
.collect();
assert_eq!(tool_msgs.len(), 5); assert!(tool_msgs[3].content.len() >= 500);
assert!(tool_msgs[4].content.len() >= 500);
}
#[test]
fn test_prune_dead_ends() {
let mut mgr = ContextWindowManager::new(10000);
mgr.push_message(Message::system("sys"));
let mut good = Message::tool_result("call_good", "Found 42 results");
good.tool_call_id = Some("call_good".to_string());
mgr.push_message(good);
let mut bad = Message::tool_result("call_bad", "Error: file not found");
bad.tool_call_id = Some("call_bad".to_string());
mgr.push_message(bad);
mgr.push_message(Message::assistant("Let me try something else."));
let saved = mgr.compress(&CompressionStrategy::PruneDeadEnds, 1);
assert!(saved > 0);
assert_eq!(
mgr.messages()
.iter()
.filter(|m| m.tool_call_id.is_some())
.count(),
1 );
}
#[test]
fn test_snapshot() {
let mut mgr = ContextWindowManager::new(4096);
mgr.set_initial_messages(vec![Message::system("sys"), Message::user("do")]);
let snap = mgr.snapshot();
assert_eq!(snap.messages.len(), 2);
assert_eq!(snap.original_token_count, snap.current_token_count);
}
#[test]
fn test_estimate_tokens() {
assert_eq!(estimate_tokens(""), 1); assert_eq!(estimate_tokens("abcd"), 1); assert_eq!(estimate_tokens(&"a".repeat(100)), 25); }
#[test]
fn test_is_dead_end_result() {
assert!(is_dead_end_result("Error: something went wrong"));
assert!(is_dead_end_result("No results found"));
assert!(is_dead_end_result("Not found"));
assert!(is_dead_end_result(""));
assert!(!is_dead_end_result("Found 42 items"));
}
}