use crate::config::CompactionConfig;
use crate::providers::ToolDefinition;
use crate::session::{ContentPart, Message, Role};
const CHARS_PER_TOKEN: f64 = 4.0;
const TOOL_RESULT_CHARS_PER_TOKEN: f64 = 2.0;
const IMAGE_TOKEN_ESTIMATE: usize = 2000;
const MESSAGE_FRAMING_TOKENS: usize = 4;
#[derive(Debug, Clone, PartialEq)]
pub enum CompactionStrategy {
None,
Summarize { keep_recent: usize },
Truncate { keep_recent: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompactionUrgency {
Normal,
Emergency,
Critical,
}
#[derive(Debug, Clone, PartialEq)]
pub enum PreflightAction {
Ok,
Trimmed,
NeedsCompaction,
}
pub struct ContextMonitor {
context_limit: usize,
threshold: f64,
emergency_threshold: f64,
critical_threshold: f64,
input_headroom_ratio: f64,
single_tool_result_share: f64,
safety_margin: f64,
}
impl ContextMonitor {
pub fn new(context_limit: usize, threshold: f64) -> Self {
Self::new_with_thresholds(context_limit, threshold, 0.90, 0.95)
}
pub fn new_with_thresholds(
context_limit: usize,
threshold: f64,
emergency_threshold: f64,
critical_threshold: f64,
) -> Self {
Self {
context_limit,
threshold,
emergency_threshold,
critical_threshold,
input_headroom_ratio: 0.75,
single_tool_result_share: 0.50,
safety_margin: 1.2,
}
}
pub fn from_config(config: &CompactionConfig) -> Self {
Self {
context_limit: config.context_limit,
threshold: config.threshold,
emergency_threshold: config.emergency_threshold,
critical_threshold: config.critical_threshold,
input_headroom_ratio: config.input_headroom_ratio,
single_tool_result_share: config.single_tool_result_share,
safety_margin: config.safety_margin,
}
}
pub fn estimate_tokens(messages: &[Message]) -> usize {
Self::estimate_tokens_with_margin(messages, 1.2)
}
pub fn estimate_tokens_with_margin(messages: &[Message], safety_margin: f64) -> usize {
let raw: f64 = messages.iter().map(Self::estimate_message_tokens).sum();
(raw * safety_margin) as usize
}
fn estimate_message_tokens(msg: &Message) -> f64 {
let is_tool_result = msg.role == Role::Tool;
let cpt = if is_tool_result {
TOOL_RESULT_CHARS_PER_TOKEN
} else {
CHARS_PER_TOKEN
};
let content_tokens = msg.content.len() as f64 / cpt;
let parts_tokens: f64 = if msg.content_parts.is_empty() {
0.0
} else {
msg.content_parts
.iter()
.map(|part| match part {
ContentPart::Text { text } => text.len() as f64 / cpt,
ContentPart::Image { .. } => IMAGE_TOKEN_ESTIMATE as f64,
})
.sum()
};
let body_tokens = content_tokens.max(parts_tokens);
let tool_call_tokens: f64 = msg
.tool_calls
.as_ref()
.map(|calls| {
calls
.iter()
.map(|tc| {
tc.name.len() as f64 / CHARS_PER_TOKEN
+ tc.arguments.len() as f64 / TOOL_RESULT_CHARS_PER_TOKEN
})
.sum()
})
.unwrap_or(0.0);
body_tokens + tool_call_tokens + MESSAGE_FRAMING_TOKENS as f64
}
pub fn estimate_tokens_full(
messages: &[Message],
tool_definitions: &[ToolDefinition],
safety_margin: f64,
) -> usize {
let msg_tokens: f64 = messages.iter().map(Self::estimate_message_tokens).sum();
let tool_def_tokens: f64 = tool_definitions
.iter()
.map(|td| {
let schema = td.parameters.to_string();
(td.name.len() + td.description.len()) as f64 / CHARS_PER_TOKEN
+ schema.len() as f64 / TOOL_RESULT_CHARS_PER_TOKEN
})
.sum();
((msg_tokens + tool_def_tokens) * safety_margin) as usize
}
pub fn context_budget(&self) -> usize {
(self.context_limit as f64 * self.input_headroom_ratio) as usize
}
pub fn needs_compaction(&self, messages: &[Message]) -> bool {
let estimated = Self::estimate_tokens_with_margin(messages, self.safety_margin);
estimated as f64 > self.threshold * self.context_limit as f64
}
pub fn urgency(&self, messages: &[Message]) -> Option<CompactionUrgency> {
let estimated = Self::estimate_tokens_with_margin(messages, self.safety_margin);
let ratio = estimated as f64 / self.context_limit as f64;
if ratio <= self.threshold {
None
} else if ratio >= self.critical_threshold {
Some(CompactionUrgency::Critical)
} else if ratio >= self.emergency_threshold {
Some(CompactionUrgency::Emergency)
} else {
Some(CompactionUrgency::Normal)
}
}
pub fn suggest_strategy(&self, messages: &[Message]) -> CompactionStrategy {
let estimated = Self::estimate_tokens_with_margin(messages, self.safety_margin);
let ratio = estimated as f64 / self.context_limit as f64;
match self.urgency(messages) {
None => CompactionStrategy::None,
Some(CompactionUrgency::Critical) => CompactionStrategy::Truncate { keep_recent: 3 },
Some(CompactionUrgency::Emergency) => CompactionStrategy::Truncate { keep_recent: 5 },
Some(CompactionUrgency::Normal) => {
if ratio > 0.85 {
CompactionStrategy::Summarize { keep_recent: 5 }
} else {
CompactionStrategy::Summarize { keep_recent: 8 }
}
}
}
}
pub fn preflight_check(
&self,
messages: &mut [Message],
tool_definitions: &[ToolDefinition],
) -> PreflightAction {
let budget = self.context_budget();
let single_result_cap_tokens =
(self.context_limit as f64 * self.single_tool_result_share) as usize;
let single_result_cap_chars =
(single_result_cap_tokens as f64 * TOOL_RESULT_CHARS_PER_TOKEN) as usize;
let mut trimmed = false;
for msg in messages.iter_mut() {
if msg.role == Role::Tool && msg.content.len() > single_result_cap_chars {
let original_len = msg.content.len();
let head_budget = single_result_cap_chars * 7 / 10; let tail_budget = single_result_cap_chars * 3 / 10;
let head = safe_truncate(&msg.content, head_budget);
let tail = safe_truncate_tail(&msg.content, tail_budget);
let truncated = original_len - head.len() - tail.len();
msg.content = format!("{head}\n...[truncated {truncated} bytes]...\n{tail}");
if !msg.content_parts.is_empty() {
msg.content_parts = vec![ContentPart::Text {
text: msg.content.clone(),
}];
}
trimmed = true;
}
}
let total = Self::estimate_tokens_with_margin(messages, self.safety_margin);
if total > budget {
let tool_indices: Vec<usize> = messages
.iter()
.enumerate()
.filter(|(_, m)| m.role == Role::Tool && m.content != "[compacted]")
.map(|(i, _)| i)
.collect();
for idx in tool_indices {
messages[idx].content = "[compacted]".to_string();
if !messages[idx].content_parts.is_empty() {
messages[idx].content_parts = vec![ContentPart::Text {
text: "[compacted]".to_string(),
}];
}
trimmed = true;
let new_total = Self::estimate_tokens_with_margin(messages, self.safety_margin);
if new_total <= budget {
break;
}
}
}
let final_total =
Self::estimate_tokens_full(messages, tool_definitions, self.safety_margin);
let hard_limit = (self.context_limit as f64 * self.emergency_threshold) as usize;
let compaction_limit = budget.min(hard_limit);
if final_total > compaction_limit {
return PreflightAction::NeedsCompaction;
}
if trimmed {
PreflightAction::Trimmed
} else {
PreflightAction::Ok
}
}
}
fn safe_truncate(s: &str, max_bytes: usize) -> &str {
if max_bytes >= s.len() {
return s;
}
let mut end = max_bytes;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}
fn safe_truncate_tail(s: &str, max_bytes: usize) -> &str {
if max_bytes >= s.len() {
return s;
}
let mut start = s.len() - max_bytes;
while start < s.len() && !s.is_char_boundary(start) {
start += 1;
}
&s[start..]
}
impl Default for ContextMonitor {
fn default() -> Self {
Self {
context_limit: 180_000,
threshold: 0.70,
emergency_threshold: 0.90,
critical_threshold: 0.95,
input_headroom_ratio: 0.75,
single_tool_result_share: 0.50,
safety_margin: 1.2,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::ToolCall;
fn make_message(content: &str) -> Message {
Message::user(content)
}
fn raw_estimate(msg: &Message) -> usize {
ContextMonitor::estimate_tokens_with_margin(&[msg.clone()], 1.0)
}
#[test]
fn test_estimate_tokens_empty_messages() {
let messages: Vec<Message> = vec![];
assert_eq!(
ContextMonitor::estimate_tokens_with_margin(&messages, 1.0),
0
);
}
#[test]
fn test_estimate_tokens_single_message() {
let messages = vec![make_message("Hello world")];
assert_eq!(
ContextMonitor::estimate_tokens_with_margin(&messages, 1.0),
6
);
}
#[test]
fn test_estimate_tokens_empty_content() {
let messages = vec![make_message("")];
assert_eq!(
ContextMonitor::estimate_tokens_with_margin(&messages, 1.0),
4
);
}
#[test]
fn test_estimate_tokens_with_safety_margin() {
let messages = vec![make_message("Hello world")];
assert_eq!(ContextMonitor::estimate_tokens(&messages), 8);
}
#[test]
fn test_tool_result_weighted_heavier() {
let text = "a]".repeat(100); let user_msg = Message::user(&text);
let tool_msg = Message::tool_result("call_1", &text);
let user_tokens = raw_estimate(&user_msg);
let tool_tokens = raw_estimate(&tool_msg);
assert!(
tool_tokens > user_tokens,
"Tool results should be weighted heavier: tool={tool_tokens} user={user_tokens}"
);
}
#[test]
fn test_tool_call_arguments_counted() {
let mut msg = Message::assistant("ok");
msg.tool_calls = Some(vec![ToolCall::new(
"call_1",
"shell",
r#"{"command": "ls -la /very/long/path/to/something"}"#,
)]);
let with_calls = raw_estimate(&msg);
let plain = raw_estimate(&Message::assistant("ok"));
assert!(
with_calls > plain,
"Tool call arguments should add tokens: with={with_calls} plain={plain}"
);
}
#[test]
fn test_estimate_tokens_full_includes_tool_defs() {
let messages = vec![make_message("Hello")];
let tool_defs = vec![ToolDefinition::new(
"shell",
"Execute a shell command on the system",
serde_json::json!({
"type": "object",
"properties": {
"command": {"type": "string", "description": "The command to run"}
},
"required": ["command"]
}),
)];
let without = ContextMonitor::estimate_tokens_full(&messages, &[], 1.0);
let with = ContextMonitor::estimate_tokens_full(&messages, &tool_defs, 1.0);
assert!(
with > without,
"Tool definitions should add tokens: with={with} without={without}"
);
}
#[test]
fn test_urgency_tiers() {
let monitor = ContextMonitor::new_with_thresholds(1000, 0.70, 0.90, 0.95);
let small: Vec<Message> = (0..5).map(|_| make_message("hello")).collect();
assert_eq!(monitor.urgency(&small), None);
let text = "x".repeat(200);
let critical: Vec<Message> = (0..20).map(|_| make_message(&text)).collect();
let est = ContextMonitor::estimate_tokens(&critical);
assert!(est > 950, "Expected >950, got {est}");
assert_eq!(
monitor.urgency(&critical),
Some(CompactionUrgency::Critical)
);
}
#[test]
fn test_needs_compaction_below_threshold() {
let monitor = ContextMonitor::new(10_000, 0.80);
let messages = vec![make_message("Hello")];
assert!(!monitor.needs_compaction(&messages));
}
#[test]
fn test_needs_compaction_above_threshold() {
let monitor = ContextMonitor::new(100, 0.80);
let messages = vec![make_message(&"x".repeat(300))];
assert!(monitor.needs_compaction(&messages));
}
#[test]
fn test_strategy_below_threshold() {
let monitor = ContextMonitor::new(100_000, 0.80);
let messages = vec![make_message("Hello world")];
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::None
);
}
#[test]
fn test_strategy_above_threshold_below_85() {
let monitor = ContextMonitor::new(1000, 0.80);
let messages = vec![make_message(&"x".repeat(2700))];
let est = ContextMonitor::estimate_tokens(&messages);
assert!(est > 800 && est < 850, "Expected 800-850, got {est}");
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::Summarize { keep_recent: 8 }
);
}
#[test]
fn test_strategy_above_85() {
let monitor = ContextMonitor::new(1000, 0.80);
let messages = vec![make_message(&"x".repeat(2900))];
let est = ContextMonitor::estimate_tokens(&messages);
assert!(est > 850 && est < 900, "Expected 850-900, got {est}");
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::Summarize { keep_recent: 5 }
);
}
#[test]
fn test_strategy_above_95() {
let monitor = ContextMonitor::new(1000, 0.80);
let messages = vec![make_message(&"x".repeat(3200))];
let est = ContextMonitor::estimate_tokens(&messages);
assert!(est > 950, "Expected >950, got {est}");
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::Truncate { keep_recent: 3 }
);
}
#[test]
fn test_empty_message_list_strategy() {
let monitor = ContextMonitor::new(100_000, 0.80);
assert_eq!(monitor.suggest_strategy(&[]), CompactionStrategy::None);
assert!(!monitor.needs_compaction(&[]));
}
#[test]
fn test_single_message_no_compaction() {
let monitor = ContextMonitor::new(100_000, 0.80);
let messages = vec![make_message("Just one message here")];
assert!(!monitor.needs_compaction(&messages));
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::None
);
}
#[test]
fn test_custom_threshold() {
let monitor = ContextMonitor::new(100, 0.10);
let messages = vec![make_message("Hello world")];
assert!(!monitor.needs_compaction(&messages));
let messages = vec![make_message("Hello world"), make_message("Hello world")];
assert!(monitor.needs_compaction(&messages));
}
#[test]
fn test_default_values() {
let monitor = ContextMonitor::default();
let messages = vec![make_message("Hello"), make_message("World")];
assert!(!monitor.needs_compaction(&messages));
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::None
);
}
#[test]
fn test_preflight_ok_when_small() {
let monitor = ContextMonitor::default();
let mut messages = vec![make_message("Hello")];
assert_eq!(
monitor.preflight_check(&mut messages, &[]),
PreflightAction::Ok
);
}
#[test]
fn test_preflight_trims_oversized_tool_result() {
let monitor = ContextMonitor::new(1000, 0.70);
let big_result = "x".repeat(5000); let mut messages = vec![
Message::user("hi"),
Message::tool_result("call_1", &big_result),
];
let action = monitor.preflight_check(&mut messages, &[]);
assert!(
matches!(action, PreflightAction::Trimmed | PreflightAction::Ok),
"Expected trimmed or ok, got {action:?}"
);
assert!(
messages[1].content.len() < 5000,
"Tool result should have been trimmed"
);
}
#[test]
fn test_from_config() {
let config = CompactionConfig {
enabled: true,
context_limit: 50_000,
threshold: 0.60,
emergency_threshold: 0.85,
critical_threshold: 0.90,
input_headroom_ratio: 0.70,
single_tool_result_share: 0.40,
safety_margin: 1.3,
overflow_retries: 5,
};
let monitor = ContextMonitor::from_config(&config);
assert_eq!(monitor.context_budget(), 35_000); }
#[test]
fn test_preflight_pass2_compacts_oldest_tool_results() {
let config = CompactionConfig {
enabled: true,
context_limit: 200,
threshold: 0.70,
emergency_threshold: 0.90,
critical_threshold: 0.95,
input_headroom_ratio: 0.75,
single_tool_result_share: 0.90,
safety_margin: 1.0, overflow_retries: 3,
};
let monitor = ContextMonitor::from_config(&config);
assert_eq!(monitor.context_budget(), 150);
let mut messages = vec![
Message::tool_result("c1", &"a".repeat(400)),
Message::tool_result("c2", &"b".repeat(400)),
Message::tool_result("c3", &"c".repeat(400)),
];
let action = monitor.preflight_check(&mut messages, &[]);
let compacted_count = messages
.iter()
.filter(|m| m.content == "[compacted]")
.count();
assert!(
compacted_count > 0,
"Pass 2 should have compacted some tool results, got 0"
);
assert!(
matches!(
action,
PreflightAction::Trimmed | PreflightAction::NeedsCompaction
),
"Expected Trimmed or NeedsCompaction, got {action:?}"
);
}
#[test]
fn test_preflight_pass2_syncs_content_parts() {
let config = CompactionConfig {
enabled: true,
context_limit: 100,
input_headroom_ratio: 0.50,
single_tool_result_share: 0.90,
safety_margin: 1.0,
..Default::default()
};
let monitor = ContextMonitor::from_config(&config);
let mut messages = vec![Message::tool_result("c1", &"x".repeat(400))];
assert!(!messages[0].content_parts.is_empty());
monitor.preflight_check(&mut messages, &[]);
if messages[0].content == "[compacted]" {
assert_eq!(messages[0].content_parts.len(), 1);
if let ContentPart::Text { text } = &messages[0].content_parts[0] {
assert_eq!(text, "[compacted]");
}
}
}
#[test]
fn test_preflight_pass3_needs_compaction() {
let config = CompactionConfig {
enabled: true,
context_limit: 50,
input_headroom_ratio: 0.50,
single_tool_result_share: 0.90,
safety_margin: 1.0,
..Default::default()
};
let monitor = ContextMonitor::from_config(&config);
let mut messages = vec![Message::user(&"x".repeat(200))];
let action = monitor.preflight_check(&mut messages, &[]);
assert_eq!(action, PreflightAction::NeedsCompaction);
}
#[test]
fn test_estimate_tokens_with_image() {
use crate::session::ImageSource;
let mut msg = Message::user("caption");
msg.content_parts = vec![
ContentPart::Text {
text: "caption".to_string(),
},
ContentPart::Image {
source: ImageSource::Base64 {
data: "abc".to_string(),
},
media_type: "image/png".to_string(),
},
];
let tokens = raw_estimate(&msg);
assert!(
tokens > 2000,
"Image should contribute ~2000 tokens, got {tokens}"
);
}
#[test]
fn test_estimate_uses_max_of_content_and_parts() {
let mut msg = Message::user("short");
msg.content_parts = vec![ContentPart::Text {
text: "x".repeat(1000),
}];
let tokens = raw_estimate(&msg);
assert!(
tokens > 200,
"Should use content_parts estimate when larger, got {tokens}"
);
}
#[test]
fn test_urgency_normal_tier() {
let monitor = ContextMonitor::new_with_thresholds(1000, 0.70, 0.90, 0.95);
let messages = vec![make_message(&"x".repeat(2400))];
assert_eq!(monitor.urgency(&messages), Some(CompactionUrgency::Normal));
}
#[test]
fn test_urgency_emergency_tier() {
let monitor = ContextMonitor::new_with_thresholds(1000, 0.70, 0.90, 0.95);
let messages = vec![make_message(&"x".repeat(3050))];
let est = ContextMonitor::estimate_tokens(&messages);
assert!(est >= 900 && est < 950, "Expected 900-950, got {est}");
assert_eq!(
monitor.urgency(&messages),
Some(CompactionUrgency::Emergency)
);
}
}