use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use crate::compaction::estimate_input_tokens_conservative;
use crate::models::{Message, SystemPrompt};
const AUDIT_RING_CAPACITY: usize = 64;
#[derive(Debug, Default, Clone)]
pub struct TokenEstimateCache {
messages_revision: u64,
system_fingerprint: u64,
cached_tokens: Option<usize>,
audit_ring: Vec<(u64, usize)>,
hits: u64,
misses: u64,
}
impl TokenEstimateCache {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn lookup_or_compute(
&mut self,
messages_revision: u64,
system_prompt: Option<&SystemPrompt>,
messages: &[Message],
) -> usize {
let system_fingerprint = fingerprint_system_prompt(system_prompt);
if self.messages_revision == messages_revision
&& self.system_fingerprint == system_fingerprint
&& let Some(tokens) = self.cached_tokens
{
self.hits = self.hits.saturating_add(1);
return tokens;
}
let tokens = estimate_input_tokens_conservative(messages, system_prompt);
self.messages_revision = messages_revision;
self.system_fingerprint = system_fingerprint;
self.cached_tokens = Some(tokens);
self.misses = self.misses.saturating_add(1);
self.push_audit(messages_revision, tokens);
tokens
}
#[allow(dead_code)] pub fn bump_messages_revision(&mut self, revision: u64) {
if revision > self.messages_revision {
self.messages_revision = revision;
self.cached_tokens = None;
}
}
#[allow(dead_code)] pub fn invalidate(&mut self) {
self.cached_tokens = None;
self.system_fingerprint = 0;
self.audit_ring.clear();
self.hits = 0;
self.misses = 0;
}
#[allow(dead_code)] #[must_use]
pub fn stats(&self) -> (u64, u64) {
(self.hits, self.misses)
}
#[allow(dead_code)] #[must_use]
pub fn recent_audit(&self) -> &[(u64, usize)] {
&self.audit_ring
}
fn push_audit(&mut self, revision: u64, tokens: usize) {
if self.audit_ring.len() >= AUDIT_RING_CAPACITY {
self.audit_ring.remove(0);
}
self.audit_ring.push((revision, tokens));
}
}
fn fingerprint_system_prompt(system: Option<&SystemPrompt>) -> u64 {
let Some(system) = system else {
return 0;
};
let mut hasher = DefaultHasher::new();
match system {
SystemPrompt::Text(text) => {
"text".hash(&mut hasher);
text.hash(&mut hasher);
}
SystemPrompt::Blocks(blocks) => {
"blocks".hash(&mut hasher);
blocks.len().hash(&mut hasher);
for block in blocks {
block.block_type.hash(&mut hasher);
block.text.hash(&mut hasher);
}
}
}
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::{ContentBlock, SystemBlock};
fn user_text(s: &str) -> Message {
Message {
role: "user".to_string(),
content: vec![ContentBlock::Text {
text: s.to_string(),
cache_control: None,
}],
}
}
fn sys_text(s: &str) -> SystemPrompt {
SystemPrompt::Text(s.to_string())
}
#[test]
fn first_call_is_a_miss() {
let mut cache = TokenEstimateCache::new();
let messages = vec![user_text("hello world")];
let tokens = cache.lookup_or_compute(1, None, &messages);
let (hits, misses) = cache.stats();
assert!(tokens > 0);
assert_eq!(hits, 0);
assert_eq!(misses, 1);
}
#[test]
fn repeated_call_with_same_revision_is_a_hit() {
let mut cache = TokenEstimateCache::new();
let messages = vec![user_text("hello world")];
let _ = cache.lookup_or_compute(1, None, &messages);
let _ = cache.lookup_or_compute(1, None, &messages);
let (hits, misses) = cache.stats();
assert_eq!(hits, 1);
assert_eq!(misses, 1);
}
#[test]
fn revision_bump_invalidates() {
let mut cache = TokenEstimateCache::new();
let messages = vec![user_text("hi")];
let a = cache.lookup_or_compute(1, None, &messages);
let b = cache.lookup_or_compute(2, None, &messages);
let (hits, misses) = cache.stats();
assert_eq!(a, b);
assert_eq!(hits, 0);
assert_eq!(misses, 2);
}
#[test]
fn system_prompt_change_invalidates() {
let mut cache = TokenEstimateCache::new();
let messages = vec![user_text("hi")];
let _ = cache.lookup_or_compute(1, Some(&sys_text("alpha")), &messages);
let _ = cache.lookup_or_compute(1, Some(&sys_text("beta")), &messages);
let (hits, misses) = cache.stats();
assert_eq!(hits, 0);
assert_eq!(misses, 2);
}
#[test]
fn bump_messages_revision_clears_cache() {
let mut cache = TokenEstimateCache::new();
let messages = vec![user_text("x")];
let _ = cache.lookup_or_compute(1, None, &messages);
cache.bump_messages_revision(2);
let _ = cache.lookup_or_compute(2, None, &messages);
let (hits, misses) = cache.stats();
assert_eq!(hits, 0);
assert_eq!(misses, 2);
}
#[test]
fn bump_to_smaller_revision_is_noop() {
let mut cache = TokenEstimateCache::new();
let messages = vec![user_text("x")];
let _ = cache.lookup_or_compute(5, None, &messages);
cache.bump_messages_revision(2);
let _ = cache.lookup_or_compute(5, None, &messages);
let (hits, _) = cache.stats();
assert_eq!(hits, 1, "downward revision bumps must not invalidate");
}
#[test]
fn invalidate_resets_state() {
let mut cache = TokenEstimateCache::new();
let messages = vec![user_text("x")];
let _ = cache.lookup_or_compute(1, None, &messages);
let _ = cache.lookup_or_compute(1, None, &messages);
cache.invalidate();
let (hits, misses) = cache.stats();
assert_eq!(hits, 0);
assert_eq!(misses, 0);
}
#[test]
fn blocks_system_prompt_yields_distinct_fingerprint() {
let blocks_a = SystemPrompt::Blocks(vec![SystemBlock {
block_type: "text".to_string(),
text: "alpha".to_string(),
cache_control: None,
}]);
let blocks_b = SystemPrompt::Blocks(vec![SystemBlock {
block_type: "text".to_string(),
text: "beta".to_string(),
cache_control: None,
}]);
let mut cache = TokenEstimateCache::new();
let messages = vec![user_text("hi")];
let _ = cache.lookup_or_compute(1, Some(&blocks_a), &messages);
let _ = cache.lookup_or_compute(1, Some(&blocks_b), &messages);
let (hits, misses) = cache.stats();
assert_eq!(hits, 0);
assert_eq!(misses, 2);
}
#[test]
fn audit_ring_records_recent_pairs() {
let mut cache = TokenEstimateCache::new();
let messages = vec![user_text("hi")];
for rev in 1..=5 {
let _ = cache.lookup_or_compute(rev, None, &messages);
}
let ring = cache.recent_audit();
assert_eq!(ring.len(), 5);
assert_eq!(ring.last().copied(), Some((5, ring.last().unwrap().1)));
}
#[test]
fn audit_ring_bounded_by_capacity() {
let mut cache = TokenEstimateCache::new();
let messages = vec![user_text("hi")];
for rev in 1..=(AUDIT_RING_CAPACITY + 10) as u64 {
let _ = cache.lookup_or_compute(rev, None, &messages);
}
let ring = cache.recent_audit();
assert_eq!(ring.len(), AUDIT_RING_CAPACITY);
assert_eq!(ring.last().unwrap().0, (AUDIT_RING_CAPACITY + 10) as u64);
}
}