use crate::client::{ContentPart, Message, MessageContent};
pub struct BudgetConfig {
pub total_context: u32,
pub response_headroom: u32,
}
impl BudgetConfig {
pub fn from_context_tokens(context_tokens: u32) -> Self {
let response_headroom = (context_tokens as f32 * 0.15) as u32;
Self {
total_context: context_tokens,
response_headroom,
}
}
pub fn usable(&self) -> u32 {
self.total_context.saturating_sub(self.response_headroom)
}
pub fn compression_threshold(&self) -> u32 {
(self.usable() as f32 * 0.80) as u32
}
}
pub fn estimate_tokens(s: &str) -> usize {
s.chars().count() / 4 + 10
}
pub fn estimate_messages(messages: &[Message]) -> usize {
messages.iter().map(|m| estimate_message(m)).sum()
}
fn estimate_message(m: &Message) -> usize {
match &m.content {
MessageContent::Text(t) => estimate_tokens(t),
MessageContent::Parts(parts) => parts
.iter()
.map(|p| match p {
ContentPart::Text { text } => estimate_tokens(text),
ContentPart::ToolResult { content, .. } => estimate_tokens(content),
})
.sum(),
}
}
pub struct Budget {
config: BudgetConfig,
}
impl Budget {
pub fn new(context_tokens: u32) -> Self {
Self {
config: BudgetConfig::from_context_tokens(context_tokens),
}
}
pub fn total_context(&self) -> u32 {
self.config.total_context
}
pub fn enforce(&self, messages: &mut Vec<Message>, system_tokens: usize) -> (usize, bool) {
let threshold = self.config.compression_threshold() as usize;
let current = estimate_messages(messages) + system_tokens;
if current <= threshold {
return (current, false);
}
self.compress_tool_results(messages);
let after_pass1 = estimate_messages(messages) + system_tokens;
if after_pass1 <= threshold {
return (after_pass1, true);
}
self.trim_oldest_turns(messages);
let after_pass2 = estimate_messages(messages) + system_tokens;
(after_pass2, true)
}
fn compress_tool_results(&self, messages: &mut Vec<Message>) {
let last_tool_idx = messages
.iter()
.rposition(|m| m.role == "tool")
.unwrap_or(0);
for (idx, msg) in messages.iter_mut().enumerate() {
if msg.role != "tool" || idx >= last_tool_idx {
continue;
}
if let MessageContent::Parts(parts) = &mut msg.content {
for part in parts.iter_mut() {
if let ContentPart::ToolResult { content, .. } = part {
if content.len() <= 200 {
continue;
}
*content = compress_tool_content(content);
}
}
}
}
}
fn trim_oldest_turns(&self, messages: &mut Vec<Message>) {
let protected_tail = 4usize;
if messages.len() <= protected_tail + 1 {
return;
}
let drop_before = messages.len() - protected_tail;
let mut drop_idx = None;
for i in 1..drop_before {
if messages[i].role == "assistant" {
drop_idx = Some(i);
break;
}
}
if let Some(idx) = drop_idx {
let end = if idx + 1 < messages.len() && messages[idx + 1].role == "tool" {
idx + 2
} else {
idx + 1
};
messages.drain(idx..end);
}
}
}
fn compress_tool_content(content: &str) -> String {
let first = content.lines().next().unwrap_or(content);
if first.starts_with('[') && first.contains(" — ") {
let inner = first.trim_start_matches('[');
let path_part = inner
.split(" —")
.next()
.unwrap_or(inner)
.trim_end_matches(']')
.trim();
let line_count = content.lines().filter(|l| l.contains(" | ")).count();
if line_count > 0 {
return format!("[content compressed — ✓ Read {path_part} ({line_count} lines). Ask to recall if needed.]");
}
return format!("[content compressed — ✓ Read {path_part}. Ask to recall if needed.]");
}
first.to_string()
}
#[derive(Default)]
pub struct LoopDetector {
recent: Vec<(String, String)>, }
impl LoopDetector {
pub fn record(&mut self, tool_name: &str, args: &str) -> bool {
let fp = format!("{tool_name}::{}", &args[..args.len().min(400)]);
self.recent.push((tool_name.to_string(), fp.clone()));
if self.recent.len() > 5 {
self.recent.remove(0);
}
let count = self.recent.iter().filter(|(_, f)| f == &fp).count();
count >= 2
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_budget_config_from_context_tokens() {
let config = BudgetConfig::from_context_tokens(1000);
assert_eq!(config.response_headroom, 150);
}
#[test]
fn test_budget_config_usable() {
let config = BudgetConfig {
total_context: 1000,
response_headroom: 150,
};
assert_eq!(config.usable(), 850);
}
#[test]
fn test_budget_config_compression_threshold() {
let config = BudgetConfig::from_context_tokens(1000);
assert_eq!(config.compression_threshold(), 680);
}
#[test]
fn test_estimate_tokens() {
let text = "Hello, world!"; assert_eq!(estimate_tokens(text), 13);
}
#[test]
fn test_estimate_messages() {
let messages = vec![Message {
role: "assistant".to_string(),
content: MessageContent::Text("Test message".to_string()),
}];
assert_eq!(estimate_messages(&messages), 13);
}
#[test]
fn test_loop_detector_record() {
let mut detector = LoopDetector::default();
assert!(!detector.record("read_file", "path=src/budget.rs"));
assert!(detector.record("read_file", "path=src/budget.rs"));
}
#[test]
fn test_compress_tool_content() {
let content = "[src/budget.rs — 228 lines total, showing ...]";
assert!(compress_tool_content(content).starts_with("[content compressed — ✓ Read "));
}
}