use serde::{Deserialize, Serialize};
use crate::types::{AgentMessage, ContentBlock, LlmMessage};
pub trait TokenCounter: Send + Sync {
fn count_tokens(&self, message: &AgentMessage) -> usize;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DefaultTokenCounter;
impl TokenCounter for DefaultTokenCounter {
fn count_tokens(&self, message: &AgentMessage) -> usize {
match message {
AgentMessage::Llm(llm) => {
let chars: usize = content_blocks(llm)
.iter()
.map(|b| match b {
ContentBlock::Text { text } => text.len(),
ContentBlock::Thinking { thinking, .. } => thinking.len(),
ContentBlock::ToolCall { arguments, .. } => arguments.to_string().len(),
ContentBlock::Image { .. } => 0,
ContentBlock::Extension { data, .. } => data.to_string().len(),
})
.sum();
chars / 4
}
AgentMessage::Custom(_) => 100,
}
}
}
pub fn estimate_tokens(msg: &AgentMessage) -> usize {
DefaultTokenCounter.count_tokens(msg)
}
fn content_blocks(msg: &LlmMessage) -> &[ContentBlock] {
match msg {
LlmMessage::User(m) => &m.content,
LlmMessage::Assistant(m) => &m.content,
LlmMessage::ToolResult(m) => &m.content,
}
}
fn is_tool_result(messages: &[AgentMessage], idx: usize) -> bool {
matches!(
messages.get(idx),
Some(AgentMessage::Llm(LlmMessage::ToolResult(_)))
)
}
fn tool_call_ids(message: &AgentMessage) -> Option<Vec<&str>> {
match message {
AgentMessage::Llm(LlmMessage::Assistant(assistant)) => {
let ids: Vec<&str> = assistant
.content
.iter()
.filter_map(|block| match block {
ContentBlock::ToolCall { id, .. } => Some(id.as_str()),
_ => None,
})
.collect();
(!ids.is_empty()).then_some(ids)
}
_ => None,
}
}
fn tool_result_id(message: &AgentMessage) -> Option<&str> {
match message {
AgentMessage::Llm(LlmMessage::ToolResult(result)) => Some(result.tool_call_id.as_str()),
_ => None,
}
}
fn extend_anchor_for_tool_results(messages: &[AgentMessage], anchor_end: usize) -> usize {
if anchor_end == 0 || anchor_end >= messages.len() {
return anchor_end;
}
let mut assistant_idx = anchor_end - 1;
while is_tool_result(messages, assistant_idx) {
if assistant_idx == 0 {
return anchor_end;
}
assistant_idx -= 1;
}
let Some(call_ids) = tool_call_ids(&messages[assistant_idx]) else {
return anchor_end;
};
let mut group_end = assistant_idx + 1;
while group_end < messages.len() {
let Some(result_id) = tool_result_id(&messages[group_end]) else {
break;
};
if !call_ids.contains(&result_id) {
break;
}
group_end += 1;
}
anchor_end.max(group_end)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactionReport {
pub dropped_count: usize,
pub tokens_before: usize,
pub tokens_after: usize,
pub overflow: bool,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub dropped_messages: Vec<LlmMessage>,
}
pub fn compact_sliding_window(
messages: &mut Vec<AgentMessage>,
budget: usize,
anchor: usize,
) -> Option<CompactionReport> {
compact_sliding_window_with(messages, budget, anchor, None)
}
pub fn compact_sliding_window_with(
messages: &mut Vec<AgentMessage>,
budget: usize,
anchor: usize,
counter: Option<&dyn TokenCounter>,
) -> Option<CompactionReport> {
let default = DefaultTokenCounter;
let counter: &dyn TokenCounter = counter.unwrap_or(&default);
let count = |m: &AgentMessage| counter.count_tokens(m);
let tokens_before: usize = messages.iter().map(count).sum();
if tokens_before <= budget {
return None;
}
let len = messages.len();
let effective_anchor = extend_anchor_for_tool_results(messages, anchor.min(len));
let anchor_tokens: usize = messages[..effective_anchor].iter().map(count).sum();
let remaining_budget = budget.saturating_sub(anchor_tokens);
let mut tail_tokens = 0;
let mut tail_start = len;
for i in (effective_anchor..len).rev() {
let msg_tokens = count(&messages[i]);
if tail_tokens + msg_tokens > remaining_budget {
break;
}
tail_tokens += msg_tokens;
tail_start = i;
}
while tail_start > effective_anchor && tail_start < len && is_tool_result(messages, tail_start)
{
tail_start -= 1;
}
if tail_start <= effective_anchor {
return None;
}
let dropped_count = tail_start - effective_anchor;
let dropped_messages: Vec<LlmMessage> = messages[effective_anchor..tail_start]
.iter()
.filter_map(|m| match m {
AgentMessage::Llm(llm) => Some(llm.clone()),
AgentMessage::Custom(_) => None,
})
.collect();
let tail: Vec<AgentMessage> = messages.drain(tail_start..).collect();
messages.truncate(effective_anchor);
messages.extend(tail);
let tokens_after: usize = messages.iter().map(count).sum();
Some(CompactionReport {
dropped_count,
tokens_before,
tokens_after,
overflow: false,
dropped_messages,
})
}
#[deprecated(since = "0.5.0", note = "Use SlidingWindowTransformer instead")]
pub fn sliding_window(
normal_budget: usize,
overflow_budget: usize,
anchor: usize,
) -> impl Fn(&mut Vec<AgentMessage>, bool) + Send + Sync {
move |messages: &mut Vec<AgentMessage>, overflow: bool| {
let budget = if overflow {
overflow_budget
} else {
normal_budget
};
compact_sliding_window(messages, budget, anchor);
}
}
pub fn is_context_overflow(
messages: &[AgentMessage],
model: &crate::types::ModelSpec,
counter: Option<&dyn TokenCounter>,
) -> bool {
let max_window = model
.capabilities
.as_ref()
.and_then(|c| c.max_context_window);
let Some(max_window) = max_window else {
return false;
};
let default = DefaultTokenCounter;
let counter: &dyn TokenCounter = counter.unwrap_or(&default);
let total_tokens: usize = messages.iter().map(|m| counter.count_tokens(m)).sum();
total_tokens as u64 > max_window
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{
AssistantMessage, ContentBlock, Cost, LlmMessage, StopReason, ToolResultMessage, Usage,
UserMessage,
};
fn text_message(text: &str) -> AgentMessage {
AgentMessage::Llm(LlmMessage::User(UserMessage {
content: vec![ContentBlock::Text {
text: text.to_owned(),
}],
timestamp: 0,
cache_hint: None,
}))
}
fn tool_call_message(id: &str) -> AgentMessage {
AgentMessage::Llm(LlmMessage::Assistant(AssistantMessage {
content: vec![ContentBlock::ToolCall {
id: id.into(),
name: "test".into(),
arguments: serde_json::json!({}),
partial_json: None,
}],
provider: String::new(),
model_id: String::new(),
usage: Usage::default(),
cost: Cost::default(),
stop_reason: StopReason::ToolUse,
error_message: None,
error_kind: None,
timestamp: 0,
cache_hint: None,
}))
}
fn multi_tool_call_message(ids: &[&str]) -> AgentMessage {
AgentMessage::Llm(LlmMessage::Assistant(AssistantMessage {
content: ids
.iter()
.map(|id| ContentBlock::ToolCall {
id: (*id).into(),
name: "test".into(),
arguments: serde_json::json!({}),
partial_json: None,
})
.collect(),
provider: String::new(),
model_id: String::new(),
usage: Usage::default(),
cost: Cost::default(),
stop_reason: StopReason::ToolUse,
error_message: None,
error_kind: None,
timestamp: 0,
cache_hint: None,
}))
}
fn tool_result_message(id: &str, text: &str) -> AgentMessage {
AgentMessage::Llm(LlmMessage::ToolResult(ToolResultMessage {
tool_call_id: id.into(),
content: vec![ContentBlock::Text { text: text.into() }],
is_error: false,
timestamp: 0,
details: serde_json::Value::Null,
cache_hint: None,
}))
}
#[test]
#[allow(deprecated)]
fn under_budget_no_change() {
let compact = sliding_window(10_000, 5_000, 1);
let mut messages = vec![text_message("hello"), text_message("world")];
compact(&mut messages, false);
assert_eq!(messages.len(), 2);
}
#[test]
#[allow(deprecated)]
fn over_budget_trims_middle() {
let body = "x".repeat(400);
let compact = sliding_window(250, 100, 1);
let mut messages = vec![
text_message(&body),
text_message(&body),
text_message(&body),
text_message(&body),
];
compact(&mut messages, false);
assert_eq!(messages.len(), 2);
}
#[test]
#[allow(deprecated)]
fn overflow_uses_smaller_budget() {
let body = "x".repeat(400);
let compact = sliding_window(1000, 150, 1);
let mut messages = vec![
text_message(&body),
text_message(&body),
text_message(&body),
text_message(&body),
];
compact(&mut messages, false);
assert_eq!(messages.len(), 4);
compact(&mut messages, true);
assert!(messages.len() < 4);
}
#[test]
#[allow(deprecated)]
fn preserves_tool_result_pair() {
let compact = sliding_window(300, 100, 1);
let body = "x".repeat(400);
let mut messages = vec![
text_message(&body), text_message(&body), tool_call_message("tc1"),
tool_result_message("tc1", "result"),
];
compact(&mut messages, false);
let has_result = messages
.iter()
.any(|m| matches!(m, AgentMessage::Llm(LlmMessage::ToolResult(_))));
let has_call = messages.iter().any(|m| {
matches!(m, AgentMessage::Llm(LlmMessage::Assistant(a))
if a.content.iter().any(|b| matches!(b, ContentBlock::ToolCall { .. })))
});
if has_result {
assert!(has_call);
}
}
#[test]
#[allow(deprecated)]
fn empty_messages_no_change() {
let compact = sliding_window(100, 50, 1);
let mut messages: Vec<AgentMessage> = vec![];
compact(&mut messages, false);
assert!(messages.is_empty());
}
#[test]
#[allow(deprecated)]
fn single_message_preserved() {
let body = "x".repeat(4000); let compact = sliding_window(10, 5, 1);
let mut messages = vec![text_message(&body)];
compact(&mut messages, false);
assert_eq!(messages.len(), 1);
}
#[test]
#[allow(deprecated)]
fn anchor_messages_always_kept() {
let body = "x".repeat(400); let compact = sliding_window(50, 25, 2);
let mut messages = vec![
text_message(&body), text_message(&body), text_message(&body), text_message(&body), ];
compact(&mut messages, false);
assert!(messages.len() >= 2);
for msg in &messages[..2] {
if let AgentMessage::Llm(LlmMessage::User(u)) = msg {
assert_eq!(u.content[0], ContentBlock::Text { text: body.clone() });
} else {
panic!("expected user message in anchor position");
}
}
}
#[test]
#[allow(deprecated)]
fn all_messages_under_budget_with_large_system_prompt() {
let compact = sliding_window(500, 250, 1);
let mut messages = vec![
text_message(&"a".repeat(400)), text_message(&"b".repeat(400)), ];
compact(&mut messages, false);
assert_eq!(messages.len(), 2);
}
#[test]
#[allow(deprecated)]
fn tool_result_at_boundary_preserved() {
let body = "x".repeat(400); let compact = sliding_window(250, 100, 1);
let mut messages = vec![
text_message(&body), text_message(&body), tool_call_message("tc1"), tool_result_message("tc1", &body), ];
compact(&mut messages, false);
let has_result = messages
.iter()
.any(|m| matches!(m, AgentMessage::Llm(LlmMessage::ToolResult(_))));
let has_call = messages.iter().any(|m| {
matches!(m, AgentMessage::Llm(LlmMessage::Assistant(a))
if a.content.iter().any(|b| matches!(b, ContentBlock::ToolCall { .. })))
});
if has_result {
assert!(has_call, "tool result kept without its preceding tool call");
}
}
#[test]
#[allow(deprecated)]
fn anchor_boundary_keeps_result_with_anchor_tool_call() {
let body = "x".repeat(400); let compact = sliding_window(250, 100, 2);
let mut messages = vec![
text_message(&body), tool_call_message("tc1"), tool_result_message("tc1", &body), text_message(&body), text_message(&body), ];
compact(&mut messages, false);
let has_call = messages.iter().any(|message| {
matches!(
message,
AgentMessage::Llm(LlmMessage::Assistant(assistant))
if assistant.content.iter().any(|block| matches!(
block,
ContentBlock::ToolCall { id, .. } if id == "tc1"
))
)
});
let has_result = messages.iter().any(|message| {
matches!(
message,
AgentMessage::Llm(LlmMessage::ToolResult(result))
if result.tool_call_id == "tc1"
)
});
assert!(has_call, "anchor tool call should still be present");
assert!(
has_result,
"anchor-side compaction must keep the matching tool result"
);
}
#[test]
#[allow(deprecated)]
fn anchor_boundary_keeps_all_results_for_multi_tool_call_message() {
let body = "x".repeat(400); let compact = sliding_window(250, 100, 2);
let mut messages = vec![
text_message(&body),
multi_tool_call_message(&["tc1", "tc2"]),
tool_result_message("tc1", &body),
tool_result_message("tc2", &body),
text_message(&body),
];
compact(&mut messages, false);
let kept_results: Vec<&str> = messages
.iter()
.filter_map(|message| match message {
AgentMessage::Llm(LlmMessage::ToolResult(result)) => {
Some(result.tool_call_id.as_str())
}
_ => None,
})
.collect();
assert_eq!(kept_results, vec!["tc1", "tc2"]);
}
#[test]
#[allow(deprecated)]
fn anchor_boundary_inside_multi_tool_results_keeps_whole_group() {
let body = "x".repeat(400); let compact = sliding_window(250, 100, 3);
let mut messages = vec![
text_message(&body),
multi_tool_call_message(&["tc1", "tc2"]),
tool_result_message("tc1", &body),
tool_result_message("tc2", &body),
text_message(&body),
];
compact(&mut messages, false);
let kept_results: Vec<&str> = messages
.iter()
.filter_map(|message| match message {
AgentMessage::Llm(LlmMessage::ToolResult(result)) => {
Some(result.tool_call_id.as_str())
}
_ => None,
})
.collect();
assert_eq!(kept_results, vec!["tc1", "tc2"]);
}
#[test]
#[allow(deprecated)]
fn consecutive_tool_pairs_preserved() {
let compact = sliding_window(500, 100, 1);
let body = "x".repeat(400);
let mut messages = vec![
text_message(&body), text_message(&body), tool_call_message("tc1"), tool_result_message("tc1", "r1"), tool_call_message("tc2"), tool_result_message("tc2", "r2"), ];
compact(&mut messages, false);
for msg in &messages {
if let AgentMessage::Llm(LlmMessage::ToolResult(tr)) = msg {
let call_present = messages.iter().any(|m| {
matches!(m, AgentMessage::Llm(LlmMessage::Assistant(a))
if a.content.iter().any(|b| matches!(b, ContentBlock::ToolCall { id, .. } if id == &tr.tool_call_id)))
});
assert!(
call_present,
"tool result {} kept without its call",
tr.tool_call_id
);
}
}
}
#[test]
#[allow(deprecated)]
fn custom_messages_token_estimation() {
#[derive(Debug)]
struct TestCustom;
impl crate::types::CustomMessage for TestCustom {
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
let compact = sliding_window(150, 50, 1);
let mut messages: Vec<AgentMessage> = vec![
AgentMessage::Custom(Box::new(TestCustom)), AgentMessage::Custom(Box::new(TestCustom)), ];
compact(&mut messages, false);
assert_eq!(messages.len(), 1);
}
#[test]
#[allow(deprecated)]
fn overflow_budget_smaller_than_normal() {
let body = "x".repeat(400); let compact = sliding_window(350, 150, 1);
let mut normal_msgs = vec![
text_message(&body),
text_message(&body),
text_message(&body),
text_message(&body),
];
compact(&mut normal_msgs, false);
let normal_count = normal_msgs.len();
let mut overflow_msgs = vec![
text_message(&body),
text_message(&body),
text_message(&body),
text_message(&body),
];
compact(&mut overflow_msgs, true);
let overflow_count = overflow_msgs.len();
assert!(
overflow_count < normal_count,
"overflow budget ({overflow_count} msgs) should be more aggressive than normal ({normal_count} msgs)"
);
}
#[test]
fn default_token_counter_matches_estimate_tokens() {
let msg = text_message(&"x".repeat(400));
assert_eq!(
DefaultTokenCounter.count_tokens(&msg),
estimate_tokens(&msg)
);
assert_eq!(DefaultTokenCounter.count_tokens(&msg), 100);
}
#[test]
fn default_token_counter_custom_message_flat_100() {
#[derive(Debug)]
struct TestCustom;
impl crate::types::CustomMessage for TestCustom {
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
let msg = AgentMessage::Custom(Box::new(TestCustom));
assert_eq!(DefaultTokenCounter.count_tokens(&msg), 100);
}
struct CharCounter;
impl TokenCounter for CharCounter {
fn count_tokens(&self, message: &AgentMessage) -> usize {
match message {
AgentMessage::Llm(llm) => content_blocks(llm)
.iter()
.map(|b| match b {
ContentBlock::Text { text } => text.len(),
_ => 0,
})
.sum(),
AgentMessage::Custom(_) => 50,
}
}
}
#[test]
fn custom_counter_used_by_compact_sliding_window_with() {
let body = "x".repeat(400);
let mut messages = vec![
text_message(&body),
text_message(&body),
text_message(&body),
];
let result = compact_sliding_window_with(&mut messages, 500, 1, Some(&CharCounter));
assert!(result.is_some());
assert_eq!(messages.len(), 1);
let r = result.unwrap();
assert_eq!(r.tokens_before, 1200);
assert_eq!(r.tokens_after, 400);
}
#[test]
fn custom_counter_no_compaction_when_under_budget() {
let body = "x".repeat(100);
let mut messages = vec![text_message(&body), text_message(&body)];
let result = compact_sliding_window_with(&mut messages, 500, 1, Some(&CharCounter));
assert!(result.is_none());
assert_eq!(messages.len(), 2);
}
#[test]
fn compact_sliding_window_backward_compat() {
let body = "x".repeat(400); let mut messages = vec![
text_message(&body),
text_message(&body),
text_message(&body),
];
let result = compact_sliding_window(&mut messages, 250, 1);
assert!(result.is_some());
assert_eq!(messages.len(), 2);
}
#[test]
fn compaction_report_includes_dropped_messages() {
let body = "x".repeat(400); let mut messages = vec![
text_message(&body),
text_message(&body),
text_message(&body),
text_message(&body),
];
let report = compact_sliding_window_with(&mut messages, 250, 1, None).unwrap();
assert_eq!(report.dropped_count, 2);
assert_eq!(report.dropped_messages.len(), 2);
assert_eq!(messages.len(), 2);
}
#[test]
fn compaction_report_dropped_messages_empty_when_no_compaction() {
let mut messages = vec![text_message("hello"), text_message("world")];
let result = compact_sliding_window_with(&mut messages, 10_000, 1, None);
assert!(result.is_none());
}
fn model_with_window(window: u64) -> crate::types::ModelSpec {
crate::types::ModelSpec {
provider: "test".into(),
model_id: "test-model".into(),
thinking_level: crate::types::ThinkingLevel::default(),
thinking_budgets: None,
provider_config: None,
capabilities: Some(
crate::types::ModelCapabilities::none().with_max_context_window(window),
),
}
}
fn model_no_window() -> crate::types::ModelSpec {
crate::types::ModelSpec {
provider: "test".into(),
model_id: "test-model".into(),
thinking_level: crate::types::ThinkingLevel::default(),
thinking_budgets: None,
provider_config: None,
capabilities: None,
}
}
#[test]
fn overflow_within_budget_returns_false() {
let messages = vec![text_message(&"x".repeat(400))]; assert!(!is_context_overflow(
&messages,
&model_with_window(1000),
None
));
}
#[test]
fn overflow_exceeding_budget_returns_true() {
let messages = vec![
text_message(&"x".repeat(400)), text_message(&"x".repeat(400)), ];
assert!(is_context_overflow(
&messages,
&model_with_window(150),
None
));
}
#[test]
fn overflow_no_window_returns_false() {
let messages = vec![text_message(&"x".repeat(40_000))]; assert!(!is_context_overflow(&messages, &model_no_window(), None));
}
#[test]
fn overflow_custom_counter() {
let messages = vec![text_message(&"x".repeat(400))]; assert!(is_context_overflow(
&messages,
&model_with_window(300),
Some(&CharCounter)
));
assert!(!is_context_overflow(
&messages,
&model_with_window(300),
None
));
}
#[test]
fn overflow_empty_messages_returns_false() {
let messages: Vec<AgentMessage> = vec![];
assert!(!is_context_overflow(
&messages,
&model_with_window(100),
None
));
}
}