use std::sync::OnceLock;
use tiktoken_rs::CoreBPE;
use crate::intelligence::compression::{detect_encoding, TokenEncoding};
use crate::intelligence::context_builder::TokenCounter;
fn cl100k() -> &'static CoreBPE {
static LOCK: OnceLock<CoreBPE> = OnceLock::new();
LOCK.get_or_init(|| tiktoken_rs::cl100k_base().expect("cl100k_base init"))
}
fn o200k() -> &'static CoreBPE {
static LOCK: OnceLock<CoreBPE> = OnceLock::new();
LOCK.get_or_init(|| tiktoken_rs::o200k_base().expect("o200k_base init"))
}
pub struct TiktokenCounter {
encoding: TokenEncoding,
}
impl TiktokenCounter {
pub fn new(encoding: TokenEncoding) -> Self {
Self { encoding }
}
pub fn for_model(model: &str) -> Option<Self> {
detect_encoding(model).map(Self::new)
}
pub fn with_fallback(model: &str) -> Self {
let encoding = detect_encoding(model).unwrap_or(TokenEncoding::Cl100kBase);
Self::new(encoding)
}
pub fn encoding_name(&self) -> &'static str {
self.encoding.as_str()
}
pub(crate) fn encode(&self, text: &str) -> Vec<usize> {
let bpe = match self.encoding {
TokenEncoding::Cl100kBase => cl100k(),
TokenEncoding::O200kBase => o200k(),
};
bpe.encode_with_special_tokens(text)
}
fn decode(&self, ids: &[usize]) -> String {
let bpe = match self.encoding {
TokenEncoding::Cl100kBase => cl100k(),
TokenEncoding::O200kBase => o200k(),
};
bpe.decode(ids.to_vec()).unwrap_or_default()
}
}
impl TokenCounter for TiktokenCounter {
fn count_tokens(&self, text: &str) -> usize {
self.encode(text).len()
}
}
#[derive(Debug, Clone)]
pub struct TextChunk {
pub text: String,
pub start_char: usize,
pub end_char: usize,
pub token_count: usize,
}
pub struct TokenChunker {
counter: TiktokenCounter,
chunk_size: usize,
chunk_overlap: usize,
}
impl TokenChunker {
pub fn new(counter: TiktokenCounter, chunk_size: usize, chunk_overlap: usize) -> Self {
assert!(
chunk_overlap < chunk_size,
"chunk_overlap must be < chunk_size"
);
Self {
counter,
chunk_size,
chunk_overlap,
}
}
pub fn chunk(&self, text: &str) -> Vec<TextChunk> {
let ids = self.counter.encode(text);
if ids.is_empty() {
return Vec::new();
}
let step = self.chunk_size - self.chunk_overlap;
let mut chunks = Vec::new();
let mut start = 0usize;
while start < ids.len() {
let end = (start + self.chunk_size).min(ids.len());
let chunk_ids = &ids[start..end];
let chunk_text = self.counter.decode(chunk_ids);
let prefix_text = self.counter.decode(&ids[..start]);
let start_char = prefix_text.len();
let end_char = start_char + chunk_text.len();
chunks.push(TextChunk {
text: chunk_text,
start_char,
end_char,
token_count: chunk_ids.len(),
});
if end == ids.len() {
break;
}
start += step;
}
chunks
}
}
#[derive(Debug, Clone)]
pub struct TokenBudgetResult {
pub original_token_count: usize,
pub prepared_token_count: usize,
pub compression_ratio: f64,
pub compression_level: Option<String>,
pub tokenizer_id: String,
}
impl TokenBudgetResult {
pub fn from_heuristic(original_len: usize, prepared_len: usize, level: Option<&str>) -> Self {
let orig = original_len.div_ceil(4);
let prep = prepared_len.div_ceil(4);
Self {
original_token_count: orig,
prepared_token_count: prep,
compression_ratio: if orig == 0 {
1.0
} else {
prep as f64 / orig as f64
},
compression_level: level.map(str::to_string),
tokenizer_id: "chars/4".to_string(),
}
}
pub fn from_tiktoken(
original: &str,
prepared: &str,
level: Option<&str>,
counter: &TiktokenCounter,
) -> Self {
let orig = counter.count_tokens(original);
let prep = counter.count_tokens(prepared);
Self {
original_token_count: orig,
prepared_token_count: prep,
compression_ratio: if orig == 0 {
1.0
} else {
prep as f64 / orig as f64
},
compression_level: level.map(str::to_string),
tokenizer_id: counter.encoding_name().to_string(),
}
}
}
pub struct TiktokenTokenCounter(pub TiktokenCounter);
impl TokenCounter for TiktokenTokenCounter {
fn count_tokens(&self, text: &str) -> usize {
self.0.count_tokens(text)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::intelligence::context_builder::ContextBuilder;
use crate::intelligence::context_compression::{ContextCompressor, MemoryInput};
#[test]
fn test_for_model_known_counts_tokens() {
let counter = TiktokenCounter::for_model("gpt-4").expect("gpt-4 should be known");
let count = counter.count_tokens("hello world");
assert!(count > 0, "should count at least 1 token");
assert!(count <= 5, "hello world should be a small number of tokens");
}
#[test]
fn test_for_model_unknown_returns_none() {
let result = TiktokenCounter::for_model("totally-unknown-model-xyz");
assert!(result.is_none());
}
#[test]
fn test_with_fallback_unknown_still_counts() {
let counter = TiktokenCounter::with_fallback("totally-unknown-model-xyz");
assert_eq!(counter.encoding_name(), "cl100k_base");
let count = counter.count_tokens("hello world");
assert!(count > 0);
}
#[test]
fn test_chunker_respects_chunk_size() {
let counter = TiktokenCounter::with_fallback("claude");
let chunk_size = 10;
let chunker = TokenChunker::new(counter, chunk_size, 2);
let word = "word ";
let text = word.repeat(100);
let chunks = chunker.chunk(&text);
assert!(!chunks.is_empty());
for chunk in &chunks {
assert!(
chunk.token_count <= chunk_size,
"chunk has {} tokens, expected <= {}",
chunk.token_count,
chunk_size
);
}
}
#[test]
fn test_chunker_overlap() {
let counter = TiktokenCounter::with_fallback("claude");
let full_ids =
counter.encode("the quick brown fox jumps over the lazy dog and then runs away");
let chunk_size = 5;
let overlap = 2;
let chunker = TokenChunker::new(
TiktokenCounter::with_fallback("claude"),
chunk_size,
overlap,
);
let text = "the quick brown fox jumps over the lazy dog and then runs away";
let chunks = chunker.chunk(text);
if chunks.len() >= 2 {
let c0_ids = TiktokenCounter::with_fallback("claude").encode(&chunks[0].text);
let c1_ids = TiktokenCounter::with_fallback("claude").encode(&chunks[1].text);
let tail: Vec<_> = c0_ids.iter().rev().take(overlap).rev().collect();
let head: Vec<_> = c1_ids.iter().take(overlap).collect();
assert_eq!(
tail, head,
"overlap tokens should match between consecutive chunks"
);
}
let _ = full_ids; }
#[test]
fn test_chunk_byte_offsets() {
let counter = TiktokenCounter::with_fallback("gpt-4");
let chunk_size = 3;
let chunker = TokenChunker::new(counter, chunk_size, 0);
let text = "hello world foo bar";
let chunks = chunker.chunk(text);
for chunk in &chunks {
let slice = &text[chunk.start_char..chunk.end_char];
assert_eq!(
slice, chunk.text,
"byte range should match decoded chunk text"
);
}
}
#[test]
fn test_compressor_with_token_counter() {
let counter = TiktokenCounter::with_fallback("claude");
let mut compressor = ContextCompressor::with_token_counter(1000, counter);
let memories = vec![MemoryInput {
id: 1,
content: "This is a test memory with several words.".to_string(),
importance: 1.0,
}];
let plan = compressor.compress_for_context_with_diagnostics(&memories);
assert_eq!(plan.entries.len(), 1);
assert_eq!(plan.entries[0].tokenizer_id, "cl100k_base");
}
#[test]
fn test_token_budget_result_tiktoken() {
let counter = TiktokenCounter::with_fallback("gpt-4");
let original = "This is a fairly long sentence with many words in it.";
let prepared = "long sentence many words";
let result = TokenBudgetResult::from_tiktoken(original, prepared, Some("Medium"), &counter);
assert!(result.original_token_count > result.prepared_token_count);
assert!(result.compression_ratio < 1.0);
assert_eq!(result.compression_level.as_deref(), Some("Medium"));
assert_eq!(result.tokenizer_id, "cl100k_base");
}
#[test]
fn test_context_builder_with_tiktoken() {
let builder = ContextBuilder::with_tiktoken("gpt-4");
let count = builder.estimate_tokens("hello world foo bar baz qux quux corge grault");
assert!(count > 0);
assert!(count < 30, "count should be reasonable, got {}", count);
}
}