use crate::bpe::BpeTokenizer;
use crate::vocab::{Vocabulary, VocabSource, BuiltInVocab};
pub struct TokenCounter {
tokenizer: BpeTokenizer,
model_name: String,
}
impl TokenCounter {
pub fn gpt4() -> Self {
let vocab = Vocabulary::load(VocabSource::BuiltIn(BuiltInVocab::Cl100kBase)).unwrap();
Self { tokenizer: BpeTokenizer::new(vocab), model_name: "gpt-4".into() }
}
pub fn gpt35() -> Self {
let vocab = Vocabulary::load(VocabSource::BuiltIn(BuiltInVocab::P50kBase)).unwrap();
Self { tokenizer: BpeTokenizer::new(vocab), model_name: "gpt-3.5".into() }
}
pub fn llama() -> Self {
let vocab = Vocabulary::load(VocabSource::BuiltIn(BuiltInVocab::Llama)).unwrap();
Self { tokenizer: BpeTokenizer::new(vocab), model_name: "llama".into() }
}
pub fn from_vocab_file(path: &str, model_name: &str) -> Result<Self, crate::vocab::VocabError> {
let vocab = Vocabulary::load(VocabSource::JsonFile(path.into()))?;
Ok(Self { tokenizer: BpeTokenizer::new(vocab), model_name: model_name.into() })
}
pub fn count(&self, text: &str) -> usize {
self.tokenizer.count_tokens(text)
}
pub fn count_message(&self, role: &str, content: &str) -> usize {
let overhead = 4; self.tokenizer.count_tokens(role) + self.tokenizer.count_tokens(content) + overhead
}
pub fn count_messages(&self, messages: &[(String, String)]) -> usize {
let base_overhead = 3; messages.iter().map(|(role, content)| self.count_message(role, content)).sum::<usize>() + base_overhead
}
pub fn fits(&self, text: &str, max_tokens: usize) -> bool {
self.count(text) <= max_tokens
}
pub fn model(&self) -> &str { &self.model_name }
pub fn tokenizer(&self) -> &BpeTokenizer { &self.tokenizer }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpt4_counter() {
let counter = TokenCounter::gpt4();
assert_eq!(counter.model(), "gpt-4");
let count = counter.count("Hello, world!");
assert!(count > 0);
}
#[test]
fn test_message_overhead() {
let counter = TokenCounter::gpt4();
let msg_count = counter.count_message("user", "Hello");
let text_count = counter.count("Hello");
assert!(msg_count > text_count);
}
#[test]
fn test_fits() {
let counter = TokenCounter::gpt4();
assert!(counter.fits("short", 100));
let long_text = "a".repeat(1000);
assert!(!counter.fits(&long_text, 1));
}
#[test]
fn test_multiple_messages() {
let counter = TokenCounter::gpt4();
let messages = vec![
("system".into(), "You are helpful.".into()),
("user".into(), "Hello".into()),
];
let count = counter.count_messages(&messages);
assert!(count > 0);
}
}