use crate::Result;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tiktoken_rs::{get_bpe_from_model, CoreBPE};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenizerConfig {
pub encoding_model: String,
pub enable_caching: bool,
pub token_budget: Option<usize>,
}
impl Default for TokenizerConfig {
fn default() -> Self {
Self {
encoding_model: "gpt-4".to_string(),
enable_caching: true,
token_budget: Some(128000), }
}
}
static GLOBAL_TOKEN_COUNTER: Lazy<TokenCounter> = Lazy::new(|| {
TokenCounter::new(TokenizerConfig::default())
.expect("Failed to initialize global token counter")
});
pub struct TokenCounter {
config: TokenizerConfig,
bpe: Arc<CoreBPE>,
}
impl std::fmt::Debug for TokenCounter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenCounter")
.field("config", &self.config)
.field("bpe", &"<CoreBPE>")
.finish()
}
}
impl TokenCounter {
pub fn new(config: TokenizerConfig) -> Result<Self> {
let bpe = get_bpe_from_model(&config.encoding_model).map_err(|e| {
crate::ScribeError::tokenization(format!(
"Failed to load tokenizer for model '{}': {}",
config.encoding_model, e
))
})?;
Ok(Self {
config,
bpe: Arc::new(bpe),
})
}
pub fn default() -> Result<Self> {
Self::new(TokenizerConfig::default())
}
pub fn global() -> &'static TokenCounter {
&GLOBAL_TOKEN_COUNTER
}
pub fn count_tokens(&self, content: &str) -> Result<usize> {
let tokens = self.bpe.encode_with_special_tokens(content);
Ok(tokens.len())
}
pub fn count_tokens_batch(&self, contents: &[&str]) -> Result<usize> {
let mut total = 0;
for content in contents {
total += self.count_tokens(content)?;
}
Ok(total)
}
pub fn estimate_file_tokens(
&self,
content: &str,
file_path: &std::path::Path,
) -> Result<usize> {
let base_tokens = self.count_tokens(content)?;
let multiplier = self.get_language_multiplier(file_path);
Ok((base_tokens as f64 * multiplier).ceil() as usize)
}
fn get_language_multiplier(&self, file_path: &std::path::Path) -> f64 {
let extension = file_path
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("");
match extension {
"java" | "csharp" | "cs" => 1.2,
"py" | "python" => 0.9,
"js" | "javascript" | "ts" | "typescript" => 0.95,
"rs" | "rust" => 1.0,
"go" => 0.95,
"json" | "yaml" | "yml" | "toml" => 0.8,
"xml" | "html" | "htm" => 1.1,
"md" | "markdown" | "txt" => 0.7,
_ => 1.0,
}
}
pub fn fits_budget(&self, content: &str) -> Result<bool> {
if let Some(budget) = self.config.token_budget {
let token_count = self.count_tokens(content)?;
Ok(token_count <= budget)
} else {
Ok(true) }
}
pub fn remaining_budget(&self, used_tokens: usize) -> Option<usize> {
self.config
.token_budget
.map(|budget| budget.saturating_sub(used_tokens))
}
pub fn chunk_content(&self, content: &str, chunk_size: usize) -> Result<Vec<String>> {
let tokens = self.bpe.encode_with_special_tokens(content);
let mut chunks = Vec::new();
for chunk_tokens in tokens.chunks(chunk_size) {
let chunk_text = self.bpe.decode(chunk_tokens.to_vec()).map_err(|e| {
crate::ScribeError::tokenization(format!("Failed to decode token chunk: {}", e))
})?;
chunks.push(chunk_text);
}
Ok(chunks)
}
pub fn config(&self) -> &TokenizerConfig {
&self.config
}
pub fn set_token_budget(&mut self, budget: Option<usize>) {
self.config.token_budget = budget;
}
}
#[derive(Debug, Clone)]
pub struct TokenBudget {
total_budget: usize,
used_tokens: usize,
reserved_tokens: usize,
}
impl TokenBudget {
pub fn new(total_budget: usize) -> Self {
Self {
total_budget,
used_tokens: 0,
reserved_tokens: 0,
}
}
pub fn total(&self) -> usize {
self.total_budget
}
pub fn used(&self) -> usize {
self.used_tokens
}
pub fn reserved(&self) -> usize {
self.reserved_tokens
}
pub fn available(&self) -> usize {
self.total_budget
.saturating_sub(self.used_tokens + self.reserved_tokens)
}
pub fn can_allocate(&self, tokens: usize) -> bool {
self.available() >= tokens
}
pub fn allocate(&mut self, tokens: usize) -> bool {
if self.can_allocate(tokens) {
self.used_tokens += tokens;
true
} else {
false
}
}
pub fn reserve(&mut self, tokens: usize) -> bool {
if self.available() >= tokens {
self.reserved_tokens += tokens;
true
} else {
false
}
}
pub fn confirm_reservation(&mut self, tokens: usize) {
let to_confirm = tokens.min(self.reserved_tokens);
self.reserved_tokens -= to_confirm;
self.used_tokens += to_confirm;
}
pub fn release_reservation(&mut self, tokens: usize) {
self.reserved_tokens = self.reserved_tokens.saturating_sub(tokens);
}
pub fn utilization(&self) -> f64 {
(self.used_tokens as f64 / self.total_budget as f64) * 100.0
}
pub fn reset(&mut self) {
self.used_tokens = 0;
self.reserved_tokens = 0;
}
}
pub mod utils {
use super::*;
pub fn estimate_tokens_legacy(content: &str) -> usize {
(content.chars().count() as f64 / 4.0).ceil() as usize
}
pub fn compare_tokenization_accuracy(
content: &str,
counter: &TokenCounter,
) -> Result<TokenizationComparison> {
let tiktoken_count = counter.count_tokens(content)?;
let legacy_count = estimate_tokens_legacy(content);
let accuracy_ratio = if legacy_count > 0 {
tiktoken_count as f64 / legacy_count as f64
} else {
1.0
};
Ok(TokenizationComparison {
tiktoken_count,
legacy_count,
accuracy_ratio,
improvement: if accuracy_ratio < 1.0 {
Some((1.0 - accuracy_ratio) * 100.0)
} else {
None
},
})
}
pub fn recommend_token_budget(model: &str, content_type: ContentType) -> usize {
let base_budget = match model {
"gpt-4" | "gpt-4-turbo" => 128000,
"gpt-4-32k" => 32000,
"gpt-3.5-turbo" => 16000,
"gpt-3.5-turbo-16k" => 16000,
_ => 8000, };
match content_type {
ContentType::Code => (base_budget as f64 * 0.8) as usize, ContentType::Documentation => base_budget,
ContentType::Mixed => (base_budget as f64 * 0.9) as usize,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum ContentType {
Code,
Documentation,
Mixed,
}
#[derive(Debug, Clone)]
pub struct TokenizationComparison {
pub tiktoken_count: usize,
pub legacy_count: usize,
pub accuracy_ratio: f64,
pub improvement: Option<f64>, }
impl TokenizationComparison {
pub fn format(&self) -> String {
match self.improvement {
Some(improvement) => format!(
"Tiktoken: {} tokens, Legacy: {} tokens, {:.1}% more accurate",
self.tiktoken_count, self.legacy_count, improvement
),
None => format!(
"Tiktoken: {} tokens, Legacy: {} tokens, {:.2}x ratio",
self.tiktoken_count, self.legacy_count, self.accuracy_ratio
),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
#[test]
fn test_token_counter_creation() {
let config = TokenizerConfig::default();
let counter = TokenCounter::new(config);
assert!(counter.is_ok());
}
#[test]
fn test_basic_token_counting() {
let counter = TokenCounter::default().unwrap();
let simple_text = "Hello, world!";
let count = counter.count_tokens(simple_text).unwrap();
assert!(count > 0);
assert!(count < 10); }
#[test]
fn test_code_token_counting() {
let counter = TokenCounter::default().unwrap();
let rust_code = r#"
fn main() {
println!("Hello, world!");
let x = 42;
if x > 0 {
println!("Positive number: {}", x);
}
}
"#;
let count = counter.count_tokens(rust_code).unwrap();
assert!(count > 20); assert!(count < 100); }
#[test]
fn test_language_multipliers() {
let counter = TokenCounter::default().unwrap();
let content = "function test() { return 42; }";
let js_tokens = counter
.estimate_file_tokens(content, Path::new("test.js"))
.unwrap();
let java_tokens = counter
.estimate_file_tokens(content, Path::new("test.java"))
.unwrap();
let py_tokens = counter
.estimate_file_tokens(content, Path::new("test.py"))
.unwrap();
assert!(java_tokens >= js_tokens);
assert!(py_tokens <= js_tokens);
}
#[test]
fn test_token_budget() {
let mut budget = TokenBudget::new(1000);
assert_eq!(budget.total(), 1000);
assert_eq!(budget.used(), 0);
assert_eq!(budget.available(), 1000);
assert!(budget.allocate(300));
assert_eq!(budget.used(), 300);
assert_eq!(budget.available(), 700);
assert!(budget.reserve(200));
assert_eq!(budget.reserved(), 200);
assert_eq!(budget.available(), 500);
budget.confirm_reservation(150);
assert_eq!(budget.used(), 450);
assert_eq!(budget.reserved(), 50);
assert_eq!(budget.available(), 500);
}
#[test]
fn test_content_chunking() {
let counter = TokenCounter::default().unwrap();
let long_content = "word ".repeat(1000); let chunks = counter.chunk_content(&long_content, 100).unwrap();
assert!(chunks.len() > 1);
for chunk in &chunks {
let chunk_tokens = counter.count_tokens(chunk).unwrap();
assert!(chunk_tokens <= 120); }
}
#[test]
fn test_tokenization_comparison() {
let counter = TokenCounter::default().unwrap();
let code_content = r#"
use std::collections::HashMap;
fn process_data(input: &str) -> Result<HashMap<String, i32>, Box<dyn std::error::Error>> {
let mut result = HashMap::new();
for line in input.lines() {
let parts: Vec<&str> = line.split(':').collect();
if parts.len() == 2 {
result.insert(parts[0].to_string(), parts[1].parse()?);
}
}
Ok(result)
}
"#;
let comparison = utils::compare_tokenization_accuracy(code_content, &counter).unwrap();
assert!(comparison.tiktoken_count > 0);
assert!(comparison.legacy_count > 0);
assert!(comparison.accuracy_ratio > 0.0);
let formatted = comparison.format();
assert!(formatted.contains("Tiktoken"));
assert!(formatted.contains("Legacy"));
}
#[test]
fn test_budget_recommendations() {
let code_budget = utils::recommend_token_budget("gpt-4", ContentType::Code);
let doc_budget = utils::recommend_token_budget("gpt-4", ContentType::Documentation);
let mixed_budget = utils::recommend_token_budget("gpt-4", ContentType::Mixed);
assert!(code_budget < doc_budget); assert!(mixed_budget > code_budget);
assert!(mixed_budget < doc_budget);
}
}