use crate::context::hardware::estimate_tokens_heuristic;
use crate::context::strategies::CompactStrategy;
use crate::core::message::{Message, ToolCallInfo};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, PartialEq)]
pub struct CompactEvent {
pub step_index: i64,
pub tokens_before: i64,
pub tokens_after: i64,
pub budget_tokens: i64,
pub messages_before: usize,
pub messages_after: usize,
pub phase_reached: i64,
}
pub type OnCompactFn = Box<dyn Fn(&CompactEvent) + Send + Sync>;
pub type OnThresholdFn = Box<dyn Fn(i64, i64, f64) -> Option<String> + Send + Sync>;
#[derive(Debug, Clone, Copy)]
struct StoredTokenCount {
count: i64,
messages_fingerprint: Option<u64>,
}
impl StoredTokenCount {
fn matches(self, messages_fingerprint: u64) -> bool {
self.messages_fingerprint
.map(|fingerprint| fingerprint == messages_fingerprint)
.unwrap_or(true)
}
}
pub struct ContextManager {
strategy: Box<dyn CompactStrategy>,
budget_tokens: i64,
on_compact: Option<OnCompactFn>,
context_thresholds: Option<Vec<f64>>,
on_context_threshold: Option<OnThresholdFn>,
stored_token_count: Option<StoredTokenCount>,
last_observed_messages_fingerprint: Option<u64>,
fired_thresholds: Vec<bool>,
}
impl ContextManager {
pub fn new(
strategy: Box<dyn CompactStrategy>,
budget_tokens: i64,
on_compact: Option<OnCompactFn>,
context_thresholds: Option<Vec<f64>>,
on_context_threshold: Option<OnThresholdFn>,
) -> Self {
let fired = context_thresholds
.as_ref()
.map(|t| vec![false; t.len()])
.unwrap_or_default();
let sorted_thresholds = context_thresholds.map(|mut t| {
t.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
t
});
Self {
strategy,
budget_tokens,
on_compact,
context_thresholds: sorted_thresholds,
on_context_threshold,
stored_token_count: None,
last_observed_messages_fingerprint: None,
fired_thresholds: fired,
}
}
pub fn budget(&self) -> i64 {
self.budget_tokens
}
pub fn estimate_tokens(&self, messages: &[Message]) -> i64 {
let fingerprint = message_fingerprint(messages);
match self.stored_token_count {
Some(stored) if stored.matches(fingerprint) => stored.count,
_ => estimate_tokens_heuristic(messages),
}
}
pub fn update_token_count(&mut self, count: i64) {
self.stored_token_count = Some(StoredTokenCount {
count,
messages_fingerprint: self.last_observed_messages_fingerprint,
});
}
fn estimate_current_tokens(&mut self, messages: &[Message]) -> i64 {
let fingerprint = self.observe_messages(messages);
match self.stored_token_count {
Some(stored) if stored.matches(fingerprint) => stored.count,
Some(_) => {
self.stored_token_count = None;
estimate_tokens_heuristic(messages)
}
None => estimate_tokens_heuristic(messages),
}
}
fn observe_messages(&mut self, messages: &[Message]) -> u64 {
let fingerprint = message_fingerprint(messages);
self.last_observed_messages_fingerprint = Some(fingerprint);
fingerprint
}
pub fn maybe_compact<'a>(
&mut self,
messages: &'a [Message],
step_index: i64,
step_hint: Option<&str>,
) -> std::borrow::Cow<'a, [Message]> {
let tokens_before = self.estimate_current_tokens(messages);
let (compacted, phase) = self
.strategy
.compact(messages, self.budget_tokens, step_hint);
if phase == 0 {
return std::borrow::Cow::Borrowed(messages);
}
let tokens_after = estimate_tokens_heuristic(&compacted);
let event = CompactEvent {
step_index,
tokens_before,
tokens_after,
budget_tokens: self.budget_tokens,
messages_before: messages.len(),
messages_after: compacted.len(),
phase_reached: phase,
};
if let Some(ref callback) = self.on_compact {
callback(&event);
}
self.stored_token_count = None;
std::borrow::Cow::Owned(compacted)
}
pub fn check_thresholds(&mut self, messages: &[Message]) -> Option<String> {
if self.context_thresholds.is_none() || self.on_context_threshold.is_none() {
return None;
}
if self.budget_tokens <= 0 {
return None;
}
let tokens = self.estimate_current_tokens(messages);
let pct = tokens as f64 / self.budget_tokens as f64;
let thresholds = self.context_thresholds.as_ref()?;
for (i, &threshold) in thresholds.iter().enumerate() {
if pct < threshold && self.fired_thresholds[i] {
self.fired_thresholds[i] = false;
}
}
let mut fired_idx: Option<usize> = None;
for (i, &threshold) in thresholds.iter().enumerate().rev() {
if pct >= threshold && !self.fired_thresholds[i] {
fired_idx = Some(i);
break;
}
}
let idx = fired_idx?;
self.fired_thresholds[idx] = true;
let callback = self.on_context_threshold.as_ref()?;
callback(tokens, self.budget_tokens, pct)
}
}
fn message_fingerprint(messages: &[Message]) -> u64 {
let mut hasher = DefaultHasher::new();
messages.len().hash(&mut hasher);
for message in messages {
message.role.hash(&mut hasher);
message.content.hash(&mut hasher);
message.metadata.msg_type.hash(&mut hasher);
message.metadata.step_index.hash(&mut hasher);
message.metadata.original_type.hash(&mut hasher);
message.metadata.token_estimate.hash(&mut hasher);
message.tool_name.hash(&mut hasher);
message.tool_call_id.hash(&mut hasher);
hash_tool_calls(&message.tool_calls, &mut hasher);
}
hasher.finish()
}
fn hash_tool_calls(tool_calls: &Option<Vec<ToolCallInfo>>, hasher: &mut DefaultHasher) {
match tool_calls {
Some(calls) => {
true.hash(hasher);
calls.len().hash(hasher);
for call in calls {
call.name.hash(hasher);
call.call_id.hash(hasher);
match &call.args {
Some(args) => {
true.hash(hasher);
args.len().hash(hasher);
for (key, value) in args {
key.hash(hasher);
value.to_string().hash(hasher);
}
}
None => false.hash(hasher),
}
}
}
None => false.hash(hasher),
}
}
pub fn default_context_warning(tokens: i64, budget: i64, pct: f64) -> Option<String> {
let pct_display = (pct * 100.0) as i64;
let message = if pct >= 0.80 {
format!(
"Context window nearly full: {}% ({} / {} tokens)",
pct_display, tokens, budget
)
} else if pct >= 0.65 {
format!(
"Context window filling up: {}% ({} / {} tokens)",
pct_display, tokens, budget
)
} else {
format!(
"Context usage at {}% ({} / {} tokens)",
pct_display, tokens, budget
)
};
Some(message)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::strategies::NoCompact;
use crate::core::message::{Message, MessageMeta, MessageRole, MessageType};
#[test]
fn compact_event_fields() {
let event = CompactEvent {
step_index: 5,
tokens_before: 1000,
tokens_after: 500,
budget_tokens: 800,
messages_before: 10,
messages_after: 6,
phase_reached: 2,
};
assert_eq!(event.step_index, 5);
assert_eq!(event.tokens_after, 500);
assert_eq!(event.phase_reached, 2);
}
#[test]
fn estimate_tokens_heuristic_fallback() {
let msgs = vec![Message::new(
MessageRole::User,
"a".repeat(100),
MessageMeta::new(MessageType::UserInput),
)];
let mgr = ContextManager::new(Box::new(NoCompact), 1000, None, None, None);
assert_eq!(mgr.estimate_tokens(&msgs), 25);
}
#[test]
fn update_token_count_overrides_heuristic() {
let msgs = vec![Message::new(
MessageRole::User,
"a".repeat(100),
MessageMeta::new(MessageType::UserInput),
)];
let mut mgr = ContextManager::new(Box::new(NoCompact), 1000, None, None, None);
mgr.update_token_count(500);
assert_eq!(mgr.estimate_tokens(&msgs), 500);
}
#[test]
fn default_warning_escalates() {
let w50 = default_context_warning(400, 800, 0.50).unwrap();
assert!(w50.contains("50%"));
assert!(!w50.contains("nearly full"));
assert!(!w50.contains("filling up"));
let w65 = default_context_warning(520, 800, 0.65).unwrap();
assert!(w65.contains("filling up"));
let w80 = default_context_warning(640, 800, 0.80).unwrap();
assert!(w80.contains("nearly full"));
}
}