use anyhow::Result;
use rayon::prelude::*;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tiktoken_rs::{cl100k_base, CoreBPE};
pub struct TokenCounter {
encoder: Arc<CoreBPE>,
cache: Arc<Mutex<HashMap<u64, usize>>>,
}
impl TokenCounter {
pub fn new() -> Result<Self> {
let encoder = cl100k_base()?;
Ok(TokenCounter { encoder: Arc::new(encoder), cache: Arc::new(Mutex::new(HashMap::new())) })
}
pub fn count_tokens(&self, text: &str) -> Result<usize> {
let hash = calculate_hash(text);
if let Ok(cache) = self.cache.lock() {
if let Some(&count) = cache.get(&hash) {
return Ok(count);
}
}
let tokens = self.encoder.encode_with_special_tokens(text);
let count = tokens.len();
if let Ok(mut cache) = self.cache.lock() {
cache.insert(hash, count);
}
Ok(count)
}
pub fn count_tokens_parallel(&self, texts: &[String]) -> Result<Vec<usize>> {
texts.par_iter().map(|text| self.count_tokens(text)).collect()
}
pub fn count_file_tokens(&self, content: &str, path: &str) -> Result<FileTokenCount> {
let content_tokens = self.count_tokens(content)?;
let header = format!("## {path}\n\n```\n");
let footer = "\n```\n\n";
let header_tokens = self.count_tokens(&header)?;
let footer_tokens = self.count_tokens(footer)?;
Ok(FileTokenCount {
content_tokens,
overhead_tokens: header_tokens + footer_tokens,
total_tokens: content_tokens + header_tokens + footer_tokens,
})
}
pub fn estimate_total_tokens(&self, files: &[(String, String)]) -> Result<TotalTokenEstimate> {
let mut total_content = 0;
let mut total_overhead = 0;
let mut file_counts = Vec::new();
for (path, content) in files {
let count = self.count_file_tokens(content, path)?;
total_content += count.content_tokens;
total_overhead += count.overhead_tokens;
file_counts.push((path.clone(), count));
}
Ok(TotalTokenEstimate {
total_tokens: total_content + total_overhead,
content_tokens: total_content,
overhead_tokens: total_overhead,
file_counts,
})
}
}
impl Default for TokenCounter {
fn default() -> Self {
Self::new().expect("Failed to create token counter")
}
}
#[derive(Debug, Clone)]
pub struct FileTokenCount {
pub content_tokens: usize,
pub overhead_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug)]
pub struct TotalTokenEstimate {
pub total_tokens: usize,
pub content_tokens: usize,
pub overhead_tokens: usize,
pub file_counts: Vec<(String, FileTokenCount)>,
}
fn calculate_hash(text: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
pub fn would_exceed_limit(current_tokens: usize, file_tokens: usize, max_tokens: usize) -> bool {
current_tokens + file_tokens > max_tokens
}
pub fn remaining_tokens(current_tokens: usize, max_tokens: usize) -> usize {
max_tokens.saturating_sub(current_tokens)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_counting() {
let counter = TokenCounter::new().unwrap();
let count = counter.count_tokens("Hello, world!").unwrap();
assert!(count > 0);
let count = counter.count_tokens("").unwrap();
assert_eq!(count, 0);
let text = "This is a test text for caching";
let count1 = counter.count_tokens(text).unwrap();
let count2 = counter.count_tokens(text).unwrap();
assert_eq!(count1, count2);
}
#[test]
fn test_file_token_counting() {
let counter = TokenCounter::new().unwrap();
let content = "fn main() {\n println!(\"Hello, world!\");\n}";
let path = "src/main.rs";
let count = counter.count_file_tokens(content, path).unwrap();
assert!(count.content_tokens > 0);
assert!(count.overhead_tokens > 0);
assert_eq!(count.total_tokens, count.content_tokens + count.overhead_tokens);
}
#[test]
fn test_parallel_counting() {
let counter = TokenCounter::new().unwrap();
let texts =
vec!["First text".to_string(), "Second text".to_string(), "Third text".to_string()];
let counts = counter.count_tokens_parallel(&texts).unwrap();
assert_eq!(counts.len(), 3);
assert!(counts.iter().all(|&c| c > 0));
}
#[test]
fn test_token_limit_checks() {
assert!(would_exceed_limit(900, 200, 1000));
assert!(!would_exceed_limit(800, 200, 1000));
assert_eq!(remaining_tokens(300, 1000), 700);
assert_eq!(remaining_tokens(1100, 1000), 0);
}
#[test]
fn test_total_estimation() {
let counter = TokenCounter::new().unwrap();
let files = vec![
("file1.rs".to_string(), "content1".to_string()),
("file2.rs".to_string(), "content2".to_string()),
];
let estimate = counter.estimate_total_tokens(&files).unwrap();
assert!(estimate.total_tokens > 0);
assert_eq!(estimate.file_counts.len(), 2);
}
}