pub trait TokenCounter {
fn count_tokens(&self, text: &str) -> usize;
}
impl<T: TokenCounter + ?Sized> TokenCounter for Box<T> {
fn count_tokens(&self, text: &str) -> usize {
(**self).count_tokens(text)
}
}
impl<T: TokenCounter + ?Sized> TokenCounter for &T {
fn count_tokens(&self, text: &str) -> usize {
(*self).count_tokens(text)
}
}
#[derive(Debug, Clone, Default)]
pub struct WordCounter;
impl TokenCounter for WordCounter {
fn count_tokens(&self, text: &str) -> usize {
text.split_whitespace().count()
}
}
#[cfg(any(feature = "hf-tokenizer", feature = "tiktoken"))]
use crate::error::ChunkingError;
#[cfg(feature = "hf-tokenizer")]
use std::{path::Path, sync::Arc};
#[cfg(feature = "hf-tokenizer")]
pub struct HuggingFaceTokenCounter {
tokenizer: Arc<tokenizers::Tokenizer>,
}
#[cfg(feature = "hf-tokenizer")]
impl HuggingFaceTokenCounter {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ChunkingError> {
let tokenizer = tokenizers::Tokenizer::from_file(path)
.map_err(|e| ChunkingError::TokenizerError(e.to_string()))?;
Ok(Self {
tokenizer: Arc::new(tokenizer),
})
}
pub fn from_pretrained(model_id: &str) -> Result<Self, ChunkingError> {
let tokenizer = tokenizers::Tokenizer::from_pretrained(model_id, None)
.map_err(|e: tokenizers::Error| ChunkingError::TokenizerError(e.to_string()))?;
Ok(Self {
tokenizer: Arc::new(tokenizer),
})
}
}
#[cfg(feature = "hf-tokenizer")]
impl TokenCounter for HuggingFaceTokenCounter {
fn count_tokens(&self, text: &str) -> usize {
self.tokenizer
.encode(text, false)
.map(|enc| enc.len())
.unwrap_or_else(|_| text.split_whitespace().count()) }
}
#[cfg(feature = "tiktoken")]
pub struct TikTokenCounter {
bpe: tiktoken_rs::CoreBPE,
}
#[cfg(feature = "tiktoken")]
impl TikTokenCounter {
pub fn cl100k_base() -> Result<Self, ChunkingError> {
let bpe =
tiktoken_rs::cl100k_base().map_err(|e| ChunkingError::TokenizerError(e.to_string()))?;
Ok(Self { bpe })
}
}
#[cfg(feature = "tiktoken")]
impl TokenCounter for TikTokenCounter {
fn count_tokens(&self, text: &str) -> usize {
self.bpe.encode_with_special_tokens(text).len()
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
mod tests {
use super::*;
#[test]
fn word_counter_empty() {
assert_eq!(WordCounter.count_tokens(""), 0);
}
#[test]
fn word_counter_whitespace_only() {
assert_eq!(WordCounter.count_tokens(" \n\t "), 0);
}
#[test]
fn word_counter_simple() {
assert_eq!(WordCounter.count_tokens("hello world"), 2);
}
#[test]
fn word_counter_punctuation() {
assert_eq!(WordCounter.count_tokens("Hello, world! How are you?"), 5);
}
}
#[cfg(all(test, feature = "hf-tokenizer"))]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
mod hf_tests {
use super::*;
#[test]
fn test_from_file_nonexistent() {
let result = HuggingFaceTokenCounter::from_file("/nonexistent/tokenizer.json");
assert!(result.is_err());
}
}
#[cfg(all(test, feature = "tiktoken"))]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
mod tiktoken_tests {
use super::*;
#[test]
fn cl100k_base_constructs() {
let counter = TikTokenCounter::cl100k_base();
assert!(counter.is_ok());
}
#[test]
fn counts_known_text() {
let counter = TikTokenCounter::cl100k_base().expect("cl100k_base should load");
let count = counter.count_tokens("Hello, world!");
assert!(count > 0);
assert!((3..=6).contains(&count), "Expected 3-6 tokens, got {count}");
}
#[test]
fn empty_string_is_zero_tokens() {
let counter = TikTokenCounter::cl100k_base().expect("cl100k_base should load");
assert_eq!(counter.count_tokens(""), 0);
}
}