Skip to main content

cognee_chunking/
token_counter.rs

1/// Trait for counting tokens in text. Allows swapping word count for a real
2/// tokenizer (e.g. HuggingFace tokenizers) later.
3pub trait TokenCounter {
4    fn count_tokens(&self, text: &str) -> usize;
5}
6
7/// Blanket implementation so `Box<dyn TokenCounter + Send + Sync>` can be passed
8/// to functions that accept `impl TokenCounter` (like `chunk_text`).
9impl<T: TokenCounter + ?Sized> TokenCounter for Box<T> {
10    fn count_tokens(&self, text: &str) -> usize {
11        (**self).count_tokens(text)
12    }
13}
14
15/// Blanket implementation so `&dyn TokenCounter` can be used anywhere `TokenCounter` is required.
16impl<T: TokenCounter + ?Sized> TokenCounter for &T {
17    fn count_tokens(&self, text: &str) -> usize {
18        (*self).count_tokens(text)
19    }
20}
21
22/// Simple token counter that splits on whitespace and counts words.
23#[derive(Debug, Clone, Default)]
24pub struct WordCounter;
25
26impl TokenCounter for WordCounter {
27    fn count_tokens(&self, text: &str) -> usize {
28        text.split_whitespace().count()
29    }
30}
31
32#[cfg(any(feature = "hf-tokenizer", feature = "tiktoken"))]
33use crate::error::ChunkingError;
34#[cfg(feature = "hf-tokenizer")]
35use std::{path::Path, sync::Arc};
36
37/// Token counter backed by a HuggingFace `tokenizers` tokenizer.
38///
39/// Drop-in replacement for `WordCounter` when accurate BPE/WordPiece token counts are needed.
40/// Use when chunking for models that use HuggingFace tokenizers (BGE, MiniLM, etc.).
41#[cfg(feature = "hf-tokenizer")]
42pub struct HuggingFaceTokenCounter {
43    tokenizer: Arc<tokenizers::Tokenizer>,
44}
45
46#[cfg(feature = "hf-tokenizer")]
47impl HuggingFaceTokenCounter {
48    /// Load from a local `tokenizer.json` file.
49    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ChunkingError> {
50        let tokenizer = tokenizers::Tokenizer::from_file(path)
51            .map_err(|e| ChunkingError::TokenizerError(e.to_string()))?;
52        Ok(Self {
53            tokenizer: Arc::new(tokenizer),
54        })
55    }
56
57    /// Load from a HuggingFace model ID (requires network access).
58    /// Caches locally in the HuggingFace cache directory.
59    pub fn from_pretrained(model_id: &str) -> Result<Self, ChunkingError> {
60        let tokenizer = tokenizers::Tokenizer::from_pretrained(model_id, None)
61            .map_err(|e: tokenizers::Error| ChunkingError::TokenizerError(e.to_string()))?;
62        Ok(Self {
63            tokenizer: Arc::new(tokenizer),
64        })
65    }
66}
67
68#[cfg(feature = "hf-tokenizer")]
69impl TokenCounter for HuggingFaceTokenCounter {
70    fn count_tokens(&self, text: &str) -> usize {
71        self.tokenizer
72            .encode(text, false)
73            .map(|enc| enc.len())
74            .unwrap_or_else(|_| text.split_whitespace().count()) // fallback on encode error
75    }
76}
77
78/// Token counter using TikToken BPE encoding (cl100k_base).
79///
80/// Use when chunking for OpenAI models (text-embedding-3-large, GPT-4, etc.).
81/// Matches Python's TikTokenTokenizer with cl100k_base encoding.
82#[cfg(feature = "tiktoken")]
83pub struct TikTokenCounter {
84    bpe: tiktoken_rs::CoreBPE,
85}
86
87#[cfg(feature = "tiktoken")]
88impl TikTokenCounter {
89    /// Create with cl100k_base encoding (matches GPT-4, text-embedding-3-large).
90    pub fn cl100k_base() -> Result<Self, ChunkingError> {
91        let bpe =
92            tiktoken_rs::cl100k_base().map_err(|e| ChunkingError::TokenizerError(e.to_string()))?;
93        Ok(Self { bpe })
94    }
95}
96
97#[cfg(feature = "tiktoken")]
98impl TokenCounter for TikTokenCounter {
99    fn count_tokens(&self, text: &str) -> usize {
100        self.bpe.encode_with_special_tokens(text).len()
101    }
102}
103
104#[cfg(test)]
105#[allow(
106    clippy::unwrap_used,
107    clippy::expect_used,
108    reason = "test code — panics are acceptable failures"
109)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn word_counter_empty() {
115        assert_eq!(WordCounter.count_tokens(""), 0);
116    }
117
118    #[test]
119    fn word_counter_whitespace_only() {
120        assert_eq!(WordCounter.count_tokens("   \n\t  "), 0);
121    }
122
123    #[test]
124    fn word_counter_simple() {
125        assert_eq!(WordCounter.count_tokens("hello world"), 2);
126    }
127
128    #[test]
129    fn word_counter_punctuation() {
130        assert_eq!(WordCounter.count_tokens("Hello, world! How are you?"), 5);
131    }
132}
133
134#[cfg(all(test, feature = "hf-tokenizer"))]
135#[allow(
136    clippy::unwrap_used,
137    clippy::expect_used,
138    reason = "test code — panics are acceptable failures"
139)]
140mod hf_tests {
141    use super::*;
142
143    #[test]
144    fn test_from_file_nonexistent() {
145        let result = HuggingFaceTokenCounter::from_file("/nonexistent/tokenizer.json");
146        assert!(result.is_err());
147    }
148}
149
150#[cfg(all(test, feature = "tiktoken"))]
151#[allow(
152    clippy::unwrap_used,
153    clippy::expect_used,
154    reason = "test code — panics are acceptable failures"
155)]
156mod tiktoken_tests {
157    use super::*;
158
159    #[test]
160    fn cl100k_base_constructs() {
161        let counter = TikTokenCounter::cl100k_base();
162        assert!(counter.is_ok());
163    }
164
165    #[test]
166    fn counts_known_text() {
167        let counter = TikTokenCounter::cl100k_base().expect("cl100k_base should load");
168        // "Hello, world!" is 4 tokens in cl100k_base
169        let count = counter.count_tokens("Hello, world!");
170        assert!(count > 0);
171        // verify it's in reasonable range (3-6 tokens for this string)
172        assert!((3..=6).contains(&count), "Expected 3-6 tokens, got {count}");
173    }
174
175    #[test]
176    fn empty_string_is_zero_tokens() {
177        let counter = TikTokenCounter::cl100k_base().expect("cl100k_base should load");
178        assert_eq!(counter.count_tokens(""), 0);
179    }
180}