use crate::api::types::{ContentBlock, Message, MessageContent};
const TOOL_OUTPUT_MAX_CHARS: usize = 30_000;
use std::collections::HashSet;
use std::sync::LazyLock;
use tiktoken_rs::{cl100k_base, CoreBPE};
static TOKENIZER: LazyLock<CoreBPE> =
LazyLock::new(|| cl100k_base().expect("failed to initialize cl100k tokenizer"));
static NO_SPECIAL: LazyLock<HashSet<&'static str>> = LazyLock::new(HashSet::new);
fn count_tokens(text: &str) -> usize {
TOKENIZER.encode(text, &NO_SPECIAL).0.len()
}
pub fn estimate_tokens(messages: &[Message]) -> usize {
let mut total = 0;
for msg in messages {
match &msg.content {
MessageContent::Text(text) => {
total += count_tokens(text);
}
MessageContent::Blocks(blocks) => {
for block in blocks {
match block {
ContentBlock::Text { text } => {
total += count_tokens(text);
}
ContentBlock::ToolUse { input, name, .. } => {
total += count_tokens(name);
total += count_tokens(&input.to_string());
}
ContentBlock::ToolResult { content, .. } => {
total += count_tokens(content);
}
}
}
}
}
}
total
}
pub fn truncate_tool_output(output: &str) -> (String, bool) {
if output.len() <= TOOL_OUTPUT_MAX_CHARS {
return (output.to_string(), false);
}
let keep_start = TOOL_OUTPUT_MAX_CHARS * 2 / 3;
let keep_end = TOOL_OUTPUT_MAX_CHARS / 6;
let start = &output[..keep_start];
let end = &output[output.len() - keep_end..];
let truncated_chars = output.len() - keep_start - keep_end;
let result = format!("{start}\n\n... ({truncated_chars} characters truncated) ...\n\n{end}");
(result, true)
}
pub fn snip_old_messages(messages: &[Message], keep_recent: usize) -> Option<Vec<Message>> {
if messages.len() <= keep_recent + 2 {
return None; }
let snip_count = messages.len() - keep_recent;
let snipped = &messages[..snip_count];
let kept = &messages[snip_count..];
let snip_tokens = estimate_tokens(snipped);
let marker = Message::user(&format!(
"[{snip_count} earlier messages snipped (~{snip_tokens} tokens). The conversation continues below.]"
));
let mut result = vec![marker];
result.extend_from_slice(kept);
Some(result)
}
pub enum CompactStrategy {
None,
Snip,
Summarize,
}
pub fn should_compact(messages: &[Message], context_window: usize) -> CompactStrategy {
let tokens = estimate_tokens(messages);
let threshold_snip = context_window * 60 / 100; let threshold_summarize = context_window * 80 / 100;
if tokens > threshold_summarize {
CompactStrategy::Summarize
} else if tokens > threshold_snip {
CompactStrategy::Snip
} else {
CompactStrategy::None
}
}
pub fn context_window_for_model(model: &str) -> usize {
if model.contains("opus") {
200_000
} else if model.contains("sonnet") {
200_000
} else if model.contains("haiku") {
200_000
} else if model.contains("gpt-4o") {
128_000
} else if model.contains("gpt-4") {
128_000
} else if model.contains("gpt-3.5") {
16_000
} else {
128_000
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn estimate_tokens_empty() {
assert_eq!(estimate_tokens(&[]), 0);
}
#[test]
fn estimate_tokens_text() {
let msgs = vec![Message::user("hello world")]; let tokens = estimate_tokens(&msgs);
assert!(tokens > 0);
assert!(tokens < 10);
}
#[test]
fn truncate_short_output_unchanged() {
let (result, truncated) = truncate_tool_output("short");
assert_eq!(result, "short");
assert!(!truncated);
}
#[test]
fn truncate_long_output() {
let long = "x".repeat(50_000);
let (result, truncated) = truncate_tool_output(&long);
assert!(truncated);
assert!(result.len() < long.len());
assert!(result.contains("truncated"));
}
#[test]
fn snip_not_enough_messages() {
let msgs = vec![Message::user("hi"), Message::assistant_text("hello")];
assert!(snip_old_messages(&msgs, 5).is_none());
}
#[test]
fn snip_keeps_recent() {
let msgs: Vec<Message> = (0..20)
.map(|i| Message::user(&format!("message {}", i)))
.collect();
let result = snip_old_messages(&msgs, 5).unwrap();
assert_eq!(result.len(), 6);
if let MessageContent::Text(text) = &result.last().unwrap().content {
assert_eq!(text, "message 19");
}
}
#[test]
fn should_compact_small_conversation() {
let msgs = vec![Message::user("hi")];
assert!(matches!(
should_compact(&msgs, 200_000),
CompactStrategy::None
));
}
#[test]
fn context_window_known_models() {
assert_eq!(
context_window_for_model("claude-sonnet-4-20250514"),
200_000
);
assert_eq!(context_window_for_model("gpt-4o"), 128_000);
}
}