use crate::api::types::{ContentBlock, Message, MessageContent};
const TOOL_OUTPUT_MAX_CHARS: usize = 30_000;
use std::collections::HashSet;
use std::sync::LazyLock;
use tiktoken_rs::{cl100k_base, CoreBPE};
static TOKENIZER: LazyLock<CoreBPE> =
LazyLock::new(|| cl100k_base().expect("failed to initialize cl100k tokenizer"));
static NO_SPECIAL: LazyLock<HashSet<&'static str>> = LazyLock::new(HashSet::new);
fn count_tokens(text: &str) -> usize {
TOKENIZER.encode(text, &NO_SPECIAL).0.len()
}
pub fn estimate_tokens(messages: &[Message]) -> usize {
let mut total = 0;
for msg in messages {
match &msg.content {
MessageContent::Text(text) => {
total += count_tokens(text);
}
MessageContent::Blocks(blocks) => {
for block in blocks {
match block {
ContentBlock::Text { text } => {
total += count_tokens(text);
}
ContentBlock::ToolUse { input, name, .. } => {
total += count_tokens(name);
total += count_tokens(&input.to_string());
}
ContentBlock::ToolResult { content, .. } => {
total += count_tokens(content);
}
}
}
}
}
}
total
}
pub fn truncate_tool_output(output: &str) -> (String, bool) {
if output.len() <= TOOL_OUTPUT_MAX_CHARS {
return (output.to_string(), false);
}
let keep_start = TOOL_OUTPUT_MAX_CHARS * 2 / 3;
let keep_end = TOOL_OUTPUT_MAX_CHARS / 6;
let start = crate::utils::truncate_str(output, keep_start);
let end = crate::utils::tail_str(output, keep_end);
let truncated_chars = output.len() - start.len() - end.len();
let result = format!("{start}\n\n... ({truncated_chars} characters truncated) ...\n\n{end}");
(result, true)
}
fn contains_tool_result(msg: &Message) -> bool {
match &msg.content {
MessageContent::Text(_) => false,
MessageContent::Blocks(blocks) => blocks
.iter()
.any(|b| matches!(b, ContentBlock::ToolResult { .. })),
}
}
pub fn snip_old_messages(messages: &[Message], keep_recent: usize) -> Option<Vec<Message>> {
if messages.len() <= keep_recent + 2 {
return None; }
let mut snip_count = messages.len() - keep_recent;
while snip_count > 0 && contains_tool_result(&messages[snip_count]) {
snip_count -= 1;
}
if snip_count == 0 {
return None; }
let snipped = &messages[..snip_count];
let kept = &messages[snip_count..];
let snip_tokens = estimate_tokens(snipped);
let marker = Message::user(&format!(
"[{snip_count} earlier messages snipped (~{snip_tokens} tokens). The conversation continues below.]"
));
let mut result = vec![marker];
result.extend_from_slice(kept);
Some(result)
}
pub enum CompactStrategy {
None,
Snip,
Summarize,
}
pub fn should_compact(messages: &[Message], context_window: usize) -> CompactStrategy {
let tokens = estimate_tokens(messages);
let threshold_snip = context_window * 60 / 100; let threshold_summarize = context_window * 80 / 100;
if tokens > threshold_summarize {
CompactStrategy::Summarize
} else if tokens > threshold_snip {
CompactStrategy::Snip
} else {
CompactStrategy::None
}
}
pub fn context_window_for_model(model: &str) -> usize {
if model.contains("opus") {
200_000
} else if model.contains("sonnet") {
200_000
} else if model.contains("haiku") {
200_000
} else if model.contains("gpt-4o") {
128_000
} else if model.contains("gpt-4") {
128_000
} else if model.contains("gpt-3.5") {
16_000
} else {
128_000
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn estimate_tokens_empty() {
assert_eq!(estimate_tokens(&[]), 0);
}
#[test]
fn estimate_tokens_text() {
let msgs = vec![Message::user("hello world")]; let tokens = estimate_tokens(&msgs);
assert!(tokens > 0);
assert!(tokens < 10);
}
#[test]
fn truncate_short_output_unchanged() {
let (result, truncated) = truncate_tool_output("short");
assert_eq!(result, "short");
assert!(!truncated);
}
#[test]
fn truncate_long_output() {
let long = "x".repeat(50_000);
let (result, truncated) = truncate_tool_output(&long);
assert!(truncated);
assert!(result.len() < long.len());
assert!(result.contains("truncated"));
}
#[test]
fn truncate_long_multibyte_output_no_panic() {
let long = "🦀".repeat(15_000); let (result, truncated) = truncate_tool_output(&long);
assert!(truncated);
assert!(result.contains("truncated"));
assert!(result.starts_with('🦀'));
assert!(result.ends_with('🦀'));
}
#[test]
fn snip_not_enough_messages() {
let msgs = vec![Message::user("hi"), Message::assistant_text("hello")];
assert!(snip_old_messages(&msgs, 5).is_none());
}
#[test]
fn snip_keeps_recent() {
let msgs: Vec<Message> = (0..20)
.map(|i| Message::user(&format!("message {i}")))
.collect();
let result = snip_old_messages(&msgs, 5).unwrap();
assert_eq!(result.len(), 6);
if let MessageContent::Text(text) = &result.last().unwrap().content {
assert_eq!(text, "message 19");
}
}
fn assert_no_orphaned_tool_results(messages: &[Message]) {
let mut seen_tool_use_ids = std::collections::HashSet::new();
for msg in messages {
if let MessageContent::Blocks(blocks) = &msg.content {
for block in blocks {
match block {
ContentBlock::ToolUse { id, .. } => {
seen_tool_use_ids.insert(id.clone());
}
ContentBlock::ToolResult { tool_use_id, .. } => {
assert!(
seen_tool_use_ids.contains(tool_use_id),
"orphaned tool_result: {tool_use_id}"
);
}
ContentBlock::Text { .. } => {}
}
}
}
}
}
fn tool_round(n: usize) -> Vec<Message> {
vec![
Message::user(&format!("request {n}")),
Message::assistant_blocks(vec![ContentBlock::ToolUse {
id: format!("tu_{n}"),
name: "Read".to_string(),
input: serde_json::json!({"file_path": "/tmp/x"}),
}]),
Message::tool_results(vec![ContentBlock::ToolResult {
tool_use_id: format!("tu_{n}"),
content: "contents".to_string(),
is_error: None,
}]),
]
}
#[test]
fn snip_never_orphans_tool_results() {
let mut msgs: Vec<Message> = Vec::new();
for n in 0..6 {
msgs.extend(tool_round(n));
}
for keep_recent in 1..msgs.len() {
if let Some(snipped) = snip_old_messages(&msgs, keep_recent) {
assert_no_orphaned_tool_results(&snipped);
}
}
}
#[test]
fn snip_backs_up_to_include_tool_use() {
let mut msgs: Vec<Message> = Vec::new();
for n in 0..6 {
msgs.extend(tool_round(n));
}
let snipped = snip_old_messages(&msgs, 4).unwrap();
assert_no_orphaned_tool_results(&snipped);
assert!(snipped.len() > 4);
if let MessageContent::Blocks(blocks) = &snipped[1].content {
assert!(matches!(blocks[0], ContentBlock::ToolUse { .. }));
} else {
panic!("expected the kept window to open with the tool_use message");
}
}
#[test]
fn snip_returns_none_when_no_safe_cut() {
let mut msgs = vec![Message::user("start")];
msgs.push(Message::assistant_blocks(vec![ContentBlock::ToolUse {
id: "tu_0".to_string(),
name: "Read".to_string(),
input: serde_json::json!({}),
}]));
for _ in 0..8 {
msgs.push(Message::tool_results(vec![ContentBlock::ToolResult {
tool_use_id: "tu_0".to_string(),
content: "x".to_string(),
is_error: None,
}]));
}
let all_results: Vec<Message> = msgs[2..].to_vec();
assert!(snip_old_messages(&all_results, 2).is_none());
}
#[test]
fn should_compact_small_conversation() {
let msgs = vec![Message::user("hi")];
assert!(matches!(
should_compact(&msgs, 200_000),
CompactStrategy::None
));
}
#[test]
fn context_window_known_models() {
assert_eq!(
context_window_for_model("claude-sonnet-4-20250514"),
200_000
);
assert_eq!(context_window_for_model("gpt-4o"), 128_000);
}
}