use crate::types::*;
use serde::{Deserialize, Serialize};
pub fn estimate_tokens(text: &str) -> usize {
text.len().div_ceil(4)
}
pub fn message_tokens(msg: &AgentMessage) -> usize {
match msg {
AgentMessage::Llm(m) => match m {
Message::User { content, .. } => content_tokens(content) + 4,
Message::Assistant { content, .. } => content_tokens(content) + 4,
Message::ToolResult {
content, tool_name, ..
} => content_tokens(content) + estimate_tokens(tool_name) + 8,
},
AgentMessage::Extension(ext) => estimate_tokens(&ext.data.to_string()) + 4,
}
}
fn content_tokens(content: &[Content]) -> usize {
content
.iter()
.map(|c| match c {
Content::Text { text } => estimate_tokens(text),
Content::Image { data, .. } => {
let raw_bytes = data.len() * 3 / 4;
(raw_bytes / 750).clamp(85, 16_000)
}
Content::Thinking { thinking, .. } => estimate_tokens(thinking),
Content::ToolCall {
name, arguments, ..
} => estimate_tokens(name) + estimate_tokens(&arguments.to_string()) + 8,
})
.sum()
}
pub fn total_tokens(messages: &[AgentMessage]) -> usize {
messages.iter().map(message_tokens).sum()
}
pub struct ContextTracker {
last_usage_tokens: Option<usize>,
last_usage_index: Option<usize>,
}
impl ContextTracker {
pub fn new() -> Self {
Self {
last_usage_tokens: None,
last_usage_index: None,
}
}
pub fn record_usage(&mut self, usage: &Usage, message_index: usize) {
let total = usage.input + usage.output + usage.cache_read + usage.cache_write;
if total > 0 {
self.last_usage_tokens = Some(total as usize);
self.last_usage_index = Some(message_index);
}
}
pub fn estimate_context_tokens(&self, messages: &[AgentMessage]) -> usize {
match (self.last_usage_tokens, self.last_usage_index) {
(Some(usage_tokens), Some(idx)) if idx < messages.len() => {
let trailing: usize = messages[idx + 1..].iter().map(message_tokens).sum();
usage_tokens + trailing
}
_ => total_tokens(messages),
}
}
pub fn reset(&mut self) {
self.last_usage_tokens = None;
self.last_usage_index = None;
}
}
impl Default for ContextTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextConfig {
pub max_context_tokens: usize,
pub system_prompt_tokens: usize,
pub keep_recent: usize,
pub keep_first: usize,
pub tool_output_max_lines: usize,
}
impl Default for ContextConfig {
fn default() -> Self {
Self {
max_context_tokens: 100_000,
system_prompt_tokens: 4_000,
keep_recent: 10,
keep_first: 2,
tool_output_max_lines: 50,
}
}
}
impl ContextConfig {
pub fn from_context_window(context_window: u32) -> Self {
let max_context_tokens = (context_window as usize) * 80 / 100;
Self {
max_context_tokens,
..Default::default()
}
}
}
pub trait CompactionStrategy: Send + Sync {
fn compact(&self, messages: Vec<AgentMessage>, config: &ContextConfig) -> Vec<AgentMessage>;
}
pub struct DefaultCompaction;
impl CompactionStrategy for DefaultCompaction {
fn compact(&self, messages: Vec<AgentMessage>, config: &ContextConfig) -> Vec<AgentMessage> {
compact_messages(messages, config)
}
}
pub fn compact_messages(messages: Vec<AgentMessage>, config: &ContextConfig) -> Vec<AgentMessage> {
let budget = config
.max_context_tokens
.saturating_sub(config.system_prompt_tokens);
if total_tokens(&messages) <= budget {
return messages;
}
let compacted = level1_truncate_tool_outputs(&messages, config.tool_output_max_lines);
if total_tokens(&compacted) <= budget {
return compacted;
}
let compacted = level2_summarize_old_turns(&compacted, config.keep_recent);
if total_tokens(&compacted) <= budget {
return compacted;
}
level3_drop_middle(&compacted, config, budget)
}
fn level1_truncate_tool_outputs(messages: &[AgentMessage], max_lines: usize) -> Vec<AgentMessage> {
messages
.iter()
.map(|msg| match msg {
AgentMessage::Llm(Message::ToolResult {
tool_call_id,
tool_name,
content,
is_error,
timestamp,
}) => {
let truncated_content: Vec<Content> = content
.iter()
.map(|c| match c {
Content::Text { text } => Content::Text {
text: truncate_text_head_tail(text, max_lines),
},
other => other.clone(),
})
.collect();
AgentMessage::Llm(Message::ToolResult {
tool_call_id: tool_call_id.clone(),
tool_name: tool_name.clone(),
content: truncated_content,
is_error: *is_error,
timestamp: *timestamp,
})
}
other => other.clone(),
})
.collect()
}
fn truncate_text_head_tail(text: &str, max_lines: usize) -> String {
let lines: Vec<&str> = text.lines().collect();
if lines.len() <= max_lines {
return text.to_string();
}
let head = max_lines / 2;
let tail = max_lines - head;
let omitted = lines.len() - head - tail;
let mut result = lines[..head].join("\n");
result.push_str(&format!("\n\n[... {} lines truncated ...]\n\n", omitted));
result.push_str(&lines[lines.len() - tail..].join("\n"));
result
}
fn level2_summarize_old_turns(messages: &[AgentMessage], keep_recent: usize) -> Vec<AgentMessage> {
let len = messages.len();
if len <= keep_recent {
return messages.to_vec();
}
let boundary = len - keep_recent;
let mut result = Vec::new();
let mut i = 0;
while i < boundary {
let msg = &messages[i];
match msg {
AgentMessage::Llm(Message::Assistant { content, .. }) => {
let text_parts: Vec<&str> = content
.iter()
.filter_map(|c| match c {
Content::Text { text } => {
if text.len() > 200 {
None } else {
Some(text.as_str())
}
}
_ => None,
})
.collect();
let tool_count = content
.iter()
.filter(|c| matches!(c, Content::ToolCall { .. }))
.count();
let summary = if !text_parts.is_empty() {
text_parts.join(" ")
} else if tool_count > 0 {
format!("[Assistant used {} tool(s)]", tool_count)
} else {
"[Assistant response]".into()
};
result.push(AgentMessage::Llm(Message::User {
content: vec![Content::Text {
text: format!("[Summary] {}", summary),
}],
timestamp: now_ms(),
}));
i += 1;
while i < boundary {
if let AgentMessage::Llm(Message::ToolResult { .. }) = &messages[i] {
i += 1;
} else {
break;
}
}
continue;
}
AgentMessage::Llm(Message::ToolResult { .. }) => {
i += 1;
continue;
}
other => {
result.push(other.clone());
}
}
i += 1;
}
result.extend_from_slice(&messages[boundary..]);
result
}
fn level3_drop_middle(
messages: &[AgentMessage],
config: &ContextConfig,
budget: usize,
) -> Vec<AgentMessage> {
let len = messages.len();
let first_end = config.keep_first.min(len);
let recent_start = len.saturating_sub(config.keep_recent);
if first_end >= recent_start {
return keep_within_budget(messages, budget);
}
let first_msgs = &messages[..first_end];
let recent_msgs = &messages[recent_start..];
let removed = recent_start - first_end;
let marker = AgentMessage::Llm(Message::User {
content: vec![Content::Text {
text: format!(
"[Context compacted: {} messages removed to fit context window]",
removed
),
}],
timestamp: now_ms(),
});
let mut result = first_msgs.to_vec();
result.push(marker);
result.extend_from_slice(recent_msgs);
if total_tokens(&result) > budget {
return keep_within_budget(&result, budget);
}
result
}
fn keep_within_budget(messages: &[AgentMessage], budget: usize) -> Vec<AgentMessage> {
let mut result = Vec::new();
let mut remaining = budget;
for msg in messages.iter().rev() {
let tokens = message_tokens(msg);
if tokens > remaining {
break;
}
remaining -= tokens;
result.push(msg.clone());
}
result.reverse();
if result.len() < messages.len() {
let removed = messages.len() - result.len();
result.insert(
0,
AgentMessage::Llm(Message::User {
content: vec![Content::Text {
text: format!("[Context compacted: {} messages removed]", removed),
}],
timestamp: now_ms(),
}),
);
}
result
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionLimits {
pub max_turns: usize,
pub max_total_tokens: usize,
pub max_duration: std::time::Duration,
}
impl Default for ExecutionLimits {
fn default() -> Self {
Self {
max_turns: 50,
max_total_tokens: 1_000_000,
max_duration: std::time::Duration::from_secs(600),
}
}
}
pub struct ExecutionTracker {
pub limits: ExecutionLimits,
pub turns: usize,
pub tokens_used: usize,
pub started_at: std::time::Instant,
}
impl ExecutionTracker {
pub fn new(limits: ExecutionLimits) -> Self {
Self {
limits,
turns: 0,
tokens_used: 0,
started_at: std::time::Instant::now(),
}
}
pub fn record_turn(&mut self, tokens: usize) {
self.turns += 1;
self.tokens_used += tokens;
}
pub fn check_limits(&self) -> Option<String> {
if self.turns >= self.limits.max_turns {
return Some(format!(
"Max turns reached ({}/{})",
self.turns, self.limits.max_turns
));
}
if self.tokens_used >= self.limits.max_total_tokens {
return Some(format!(
"Max tokens reached ({}/{})",
self.tokens_used, self.limits.max_total_tokens
));
}
let elapsed = self.started_at.elapsed();
if elapsed >= self.limits.max_duration {
return Some(format!(
"Max duration reached ({:.0}s/{:.0}s)",
elapsed.as_secs_f64(),
self.limits.max_duration.as_secs_f64()
));
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_tokens() {
assert!(estimate_tokens("hello world") > 0);
assert!(estimate_tokens("hello world") < 10);
assert_eq!(estimate_tokens(""), 0);
}
#[test]
fn test_context_config_from_context_window() {
let config = ContextConfig::from_context_window(200_000);
assert_eq!(config.max_context_tokens, 160_000); assert_eq!(config.system_prompt_tokens, 4_000); assert_eq!(config.keep_recent, 10);
let config = ContextConfig::from_context_window(1_000_000);
assert_eq!(config.max_context_tokens, 800_000);
let config = ContextConfig::from_context_window(128_000);
assert_eq!(config.max_context_tokens, 102_400); }
#[test]
fn test_truncate_head_tail() {
let text = (1..=100)
.map(|i| format!("line {}", i))
.collect::<Vec<_>>()
.join("\n");
let result = truncate_text_head_tail(&text, 10);
assert!(result.contains("line 1"));
assert!(result.contains("line 5")); assert!(result.contains("line 100")); assert!(result.contains("truncated"));
assert!(!result.contains("line 50")); }
#[test]
fn test_level1_truncation() {
let big_output = (1..=200)
.map(|i| format!("output line {}", i))
.collect::<Vec<_>>()
.join("\n");
let messages = vec![
AgentMessage::Llm(Message::user("do something")),
AgentMessage::Llm(Message::ToolResult {
tool_call_id: "tc-1".into(),
tool_name: "bash".into(),
content: vec![Content::Text { text: big_output }],
is_error: false,
timestamp: 0,
}),
];
let compacted = level1_truncate_tool_outputs(&messages, 20);
let tool_msg = &compacted[1];
if let AgentMessage::Llm(Message::ToolResult { content, .. }) = tool_msg {
if let Content::Text { text } = &content[0] {
assert!(text.contains("truncated"));
assert!(text.contains("output line 1")); assert!(text.contains("output line 200")); assert!(text.lines().count() < 50);
} else {
panic!("expected text content");
}
} else {
panic!("expected tool result");
}
}
#[test]
fn test_compact_within_budget() {
let messages = vec![
AgentMessage::Llm(Message::user("Hello")),
AgentMessage::Llm(Message::user("World")),
];
let config = ContextConfig::default();
let result = compact_messages(messages.clone(), &config);
assert_eq!(result.len(), 2);
}
#[test]
fn test_compact_drops_middle_when_needed() {
let mut messages = Vec::new();
for i in 0..100 {
messages.push(AgentMessage::Llm(Message::user(format!(
"Message {} {}",
i,
"x".repeat(200)
))));
}
let config = ContextConfig {
max_context_tokens: 500,
system_prompt_tokens: 100,
keep_recent: 5,
keep_first: 2,
tool_output_max_lines: 20,
};
let result = compact_messages(messages, &config);
assert!(result.len() < 100);
assert!(result.len() >= 2);
}
#[test]
fn test_context_tracker_no_usage() {
let tracker = ContextTracker::new();
let messages = vec![
AgentMessage::Llm(Message::user("Hello")),
AgentMessage::Llm(Message::user("World")),
];
let tokens = tracker.estimate_context_tokens(&messages);
assert!(tokens > 0);
assert_eq!(tokens, total_tokens(&messages));
}
#[test]
fn test_context_tracker_with_usage() {
let mut tracker = ContextTracker::new();
let messages = vec![
AgentMessage::Llm(Message::user("Hello")),
AgentMessage::Llm(Message::Assistant {
content: vec![Content::Text {
text: "Hi there!".into(),
}],
stop_reason: StopReason::Stop,
model: "test".into(),
provider: "test".into(),
usage: Usage {
input: 100,
output: 50,
..Default::default()
},
timestamp: 0,
error_message: None,
}),
AgentMessage::Llm(Message::user("Follow up question here")),
];
tracker.record_usage(
&Usage {
input: 100,
output: 50,
..Default::default()
},
1,
);
let tokens = tracker.estimate_context_tokens(&messages);
let trailing_estimate = message_tokens(&messages[2]);
assert_eq!(tokens, 150 + trailing_estimate);
}
#[test]
fn test_context_tracker_reset() {
let mut tracker = ContextTracker::new();
tracker.record_usage(
&Usage {
input: 1000,
output: 500,
..Default::default()
},
5,
);
tracker.reset();
let messages = vec![AgentMessage::Llm(Message::user("test"))];
assert_eq!(
tracker.estimate_context_tokens(&messages),
total_tokens(&messages)
);
}
#[test]
fn test_execution_limits() {
let limits = ExecutionLimits {
max_turns: 3,
max_total_tokens: 1000,
max_duration: std::time::Duration::from_secs(60),
};
let mut tracker = ExecutionTracker::new(limits);
assert!(tracker.check_limits().is_none());
tracker.record_turn(100);
tracker.record_turn(100);
assert!(tracker.check_limits().is_none());
tracker.record_turn(100);
assert!(tracker.check_limits().is_some());
}
}