use serde::{Deserialize, Serialize};
pub struct TokenEstimator {
model_rules: TokenCalculationRules,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenCalculationRules {
pub english_char_per_token: f64,
pub chinese_char_per_token: f64,
pub base_token_overhead: usize,
}
impl Default for TokenCalculationRules {
fn default() -> Self {
Self {
english_char_per_token: 4.0,
chinese_char_per_token: 1.5,
base_token_overhead: 50,
}
}
}
#[derive(Debug, Clone)]
pub struct TokenEstimation {
pub estimated_tokens: usize,
#[allow(dead_code)]
pub character_count: usize,
#[allow(dead_code)]
pub chinese_char_count: usize,
#[allow(dead_code)]
pub english_char_count: usize,
}
impl TokenEstimator {
pub fn new() -> Self {
Self {
model_rules: TokenCalculationRules::default(),
}
}
pub fn estimate_tokens(&self, text: &str) -> TokenEstimation {
let character_count = text.chars().count();
let chinese_char_count = self.count_chinese_chars(text);
let english_char_count = self.count_english_chars(text);
let other_char_count = character_count - chinese_char_count - english_char_count;
let chinese_tokens =
(chinese_char_count as f64 / self.model_rules.chinese_char_per_token).ceil() as usize;
let english_tokens =
(english_char_count as f64 / self.model_rules.english_char_per_token).ceil() as usize;
let other_tokens = if other_char_count > 0 {
(other_char_count as f64 / self.model_rules.english_char_per_token).ceil() as usize
} else {
0
};
let estimated_tokens =
chinese_tokens + english_tokens + other_tokens + self.model_rules.base_token_overhead;
TokenEstimation {
estimated_tokens,
character_count,
chinese_char_count,
english_char_count,
}
}
fn count_chinese_chars(&self, text: &str) -> usize {
text.chars().filter(|c| self.is_chinese_char(*c)).count()
}
fn count_english_chars(&self, text: &str) -> usize {
text.chars()
.filter(|c| {
c.is_ascii_alphabetic()
|| c.is_ascii_whitespace()
|| c.is_ascii_digit()
|| c.is_ascii_punctuation()
})
.count()
}
fn is_chinese_char(&self, c: char) -> bool {
matches!(c as u32,
0x4E00..=0x9FFF | 0x3400..=0x4DBF | 0x20000..=0x2A6DF | 0x2A700..=0x2B73F | 0x2B740..=0x2B81F | 0x2B820..=0x2CEAF | 0x2CEB0..=0x2EBEF | 0x30000..=0x3134F )
}
}