use thiserror::Error;
use tiktoken_rs::{cl100k_base, CoreBPE};
#[derive(Debug, Error)]
pub enum TokenCounterError {
#[error("Failed to initialize tokenizer: {0}")]
InitError(String),
}
pub struct TokenCounter {
bpe: CoreBPE,
}
impl TokenCounter {
pub fn new() -> Result<Self, TokenCounterError> {
let bpe = cl100k_base().map_err(|e| TokenCounterError::InitError(e.to_string()))?;
Ok(Self { bpe })
}
pub fn count(&self, text: &str) -> usize {
self.bpe.encode_with_special_tokens(text).len()
}
pub fn count_message(&self, role: &str, content: &str) -> usize {
let content_tokens = self.count(content);
let role_overhead = match role {
"system" => 4,
"user" => 4,
"assistant" => 4,
"function" | "tool" => 5,
_ => 4,
};
content_tokens + role_overhead
}
pub fn count_messages(&self, messages: &[(String, String)]) -> usize {
messages
.iter()
.map(|(role, content)| self.count_message(role, content))
.sum::<usize>()
+ 3 }
pub fn truncate(&self, text: &str, max_tokens: usize) -> (String, usize) {
let tokens = self.bpe.encode_with_special_tokens(text);
if tokens.len() <= max_tokens {
return (text.to_string(), tokens.len());
}
let truncated_tokens = &tokens[..max_tokens];
let truncated_text = self
.bpe
.decode(truncated_tokens.to_vec())
.unwrap_or_else(|_| text[..text.len() / 2].to_string());
(truncated_text, max_tokens)
}
pub fn chunk(&self, text: &str, chunk_size: usize) -> Vec<String> {
let tokens = self.bpe.encode_with_special_tokens(text);
let mut chunks = Vec::new();
for chunk_tokens in tokens.chunks(chunk_size) {
if let Ok(chunk_text) = self.bpe.decode(chunk_tokens.to_vec()) {
chunks.push(chunk_text);
}
}
chunks
}
pub fn will_fit(&self, text: &str, budget: usize) -> bool {
self.count(text) <= budget
}
}
impl Default for TokenCounter {
fn default() -> Self {
Self::new().expect("Failed to initialize default token counter")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_tokens() {
let counter = TokenCounter::new().unwrap();
assert_eq!(counter.count(""), 0);
let count = counter.count("Hello, world!");
assert!(count > 0);
assert!(count < 10);
}
#[test]
fn test_count_message() {
let counter = TokenCounter::new().unwrap();
let content_tokens = counter.count("Hello");
let message_tokens = counter.count_message("user", "Hello");
assert!(message_tokens > content_tokens);
}
#[test]
fn test_truncate() {
let counter = TokenCounter::new().unwrap();
let long_text = "This is a long text that we want to truncate to a smaller size.";
let (truncated, count) = counter.truncate(long_text, 5);
assert!(count <= 5);
assert!(truncated.len() < long_text.len());
}
#[test]
fn test_will_fit() {
let counter = TokenCounter::new().unwrap();
assert!(counter.will_fit("Hello", 100));
assert!(!counter.will_fit("Hello ".repeat(1000).as_str(), 10));
}
}