cognee_chunking/
token_counter.rs1pub trait TokenCounter {
4 fn count_tokens(&self, text: &str) -> usize;
5}
6
7impl<T: TokenCounter + ?Sized> TokenCounter for Box<T> {
10 fn count_tokens(&self, text: &str) -> usize {
11 (**self).count_tokens(text)
12 }
13}
14
15impl<T: TokenCounter + ?Sized> TokenCounter for &T {
17 fn count_tokens(&self, text: &str) -> usize {
18 (*self).count_tokens(text)
19 }
20}
21
22#[derive(Debug, Clone, Default)]
24pub struct WordCounter;
25
26impl TokenCounter for WordCounter {
27 fn count_tokens(&self, text: &str) -> usize {
28 text.split_whitespace().count()
29 }
30}
31
32#[cfg(any(feature = "hf-tokenizer", feature = "tiktoken"))]
33use crate::error::ChunkingError;
34#[cfg(feature = "hf-tokenizer")]
35use std::{path::Path, sync::Arc};
36
37#[cfg(feature = "hf-tokenizer")]
42pub struct HuggingFaceTokenCounter {
43 tokenizer: Arc<tokenizers::Tokenizer>,
44}
45
46#[cfg(feature = "hf-tokenizer")]
47impl HuggingFaceTokenCounter {
48 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ChunkingError> {
50 let tokenizer = tokenizers::Tokenizer::from_file(path)
51 .map_err(|e| ChunkingError::TokenizerError(e.to_string()))?;
52 Ok(Self {
53 tokenizer: Arc::new(tokenizer),
54 })
55 }
56
57 pub fn from_pretrained(model_id: &str) -> Result<Self, ChunkingError> {
60 let tokenizer = tokenizers::Tokenizer::from_pretrained(model_id, None)
61 .map_err(|e: tokenizers::Error| ChunkingError::TokenizerError(e.to_string()))?;
62 Ok(Self {
63 tokenizer: Arc::new(tokenizer),
64 })
65 }
66}
67
68#[cfg(feature = "hf-tokenizer")]
69impl TokenCounter for HuggingFaceTokenCounter {
70 fn count_tokens(&self, text: &str) -> usize {
71 self.tokenizer
72 .encode(text, false)
73 .map(|enc| enc.len())
74 .unwrap_or_else(|_| text.split_whitespace().count()) }
76}
77
78#[cfg(feature = "tiktoken")]
83pub struct TikTokenCounter {
84 bpe: tiktoken_rs::CoreBPE,
85}
86
87#[cfg(feature = "tiktoken")]
88impl TikTokenCounter {
89 pub fn cl100k_base() -> Result<Self, ChunkingError> {
91 let bpe =
92 tiktoken_rs::cl100k_base().map_err(|e| ChunkingError::TokenizerError(e.to_string()))?;
93 Ok(Self { bpe })
94 }
95}
96
97#[cfg(feature = "tiktoken")]
98impl TokenCounter for TikTokenCounter {
99 fn count_tokens(&self, text: &str) -> usize {
100 self.bpe.encode_with_special_tokens(text).len()
101 }
102}
103
104#[cfg(test)]
105#[allow(
106 clippy::unwrap_used,
107 clippy::expect_used,
108 reason = "test code — panics are acceptable failures"
109)]
110mod tests {
111 use super::*;
112
113 #[test]
114 fn word_counter_empty() {
115 assert_eq!(WordCounter.count_tokens(""), 0);
116 }
117
118 #[test]
119 fn word_counter_whitespace_only() {
120 assert_eq!(WordCounter.count_tokens(" \n\t "), 0);
121 }
122
123 #[test]
124 fn word_counter_simple() {
125 assert_eq!(WordCounter.count_tokens("hello world"), 2);
126 }
127
128 #[test]
129 fn word_counter_punctuation() {
130 assert_eq!(WordCounter.count_tokens("Hello, world! How are you?"), 5);
131 }
132}
133
134#[cfg(all(test, feature = "hf-tokenizer"))]
135#[allow(
136 clippy::unwrap_used,
137 clippy::expect_used,
138 reason = "test code — panics are acceptable failures"
139)]
140mod hf_tests {
141 use super::*;
142
143 #[test]
144 fn test_from_file_nonexistent() {
145 let result = HuggingFaceTokenCounter::from_file("/nonexistent/tokenizer.json");
146 assert!(result.is_err());
147 }
148}
149
150#[cfg(all(test, feature = "tiktoken"))]
151#[allow(
152 clippy::unwrap_used,
153 clippy::expect_used,
154 reason = "test code — panics are acceptable failures"
155)]
156mod tiktoken_tests {
157 use super::*;
158
159 #[test]
160 fn cl100k_base_constructs() {
161 let counter = TikTokenCounter::cl100k_base();
162 assert!(counter.is_ok());
163 }
164
165 #[test]
166 fn counts_known_text() {
167 let counter = TikTokenCounter::cl100k_base().expect("cl100k_base should load");
168 let count = counter.count_tokens("Hello, world!");
170 assert!(count > 0);
171 assert!((3..=6).contains(&count), "Expected 3-6 tokens, got {count}");
173 }
174
175 #[test]
176 fn empty_string_is_zero_tokens() {
177 let counter = TikTokenCounter::cl100k_base().expect("cl100k_base should load");
178 assert_eq!(counter.count_tokens(""), 0);
179 }
180}