cognee_chunking/config.rs
1//! Chunking configuration — tokenizer selection via environment variables.
2//!
3//! [`TokenCounterKind`] selects which token counting implementation to use based on
4//! environment variables and the active embedding provider. Call [`TokenCounterKind::from_env`]
5//! at pipeline construction time to pick the best available counter automatically, then
6//! call [`TokenCounterKind::build`] to construct the counter.
7
8use std::path::PathBuf;
9
10use serde::{Deserialize, Serialize};
11
12use crate::error::ChunkingError;
13use crate::token_counter::{TokenCounter, WordCounter};
14
15/// Selects which token counting implementation to use.
16///
17/// `from_env()` picks the best available counter based on env vars and the current
18/// embedding provider setting. `WordCounter` is the last-resort fallback.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub enum TokenCounterKind {
21 /// Accurate BPE/WordPiece via a HuggingFace tokenizer model ID (requires network or cache).
22 HuggingFace { model_id: String },
23 /// Accurate BPE/WordPiece from a local tokenizer.json file.
24 HuggingFaceFile { path: PathBuf },
25 /// TikToken cl100k_base BPE (for OpenAI models).
26 TikToken,
27 /// Whitespace word count. Last-resort fallback.
28 Word,
29}
30
31impl TokenCounterKind {
32 /// Determine the best available token counter from the environment.
33 ///
34 /// Mirrors Python's `LiteLLMEmbeddingEngine.get_tokenizer()` logic, which selects a
35 /// tokenizer based on the provider and stores it on the engine instance. Python's
36 /// `chunk_by_sentence()` calls `embedding_engine.tokenizer.count_tokens()` directly —
37 /// the tokenizer is a property of the engine, not a separate config. The Rust design
38 /// decouples them (`TokenCounterKind` is independent of the engine), but the selection
39 /// logic below preserves the same provider → tokenizer mapping.
40 ///
41 /// Priority order (highest wins):
42 /// 1. `COGNEE_TOKEN_COUNTER=tiktoken` → TikToken
43 /// 2. `COGNEE_TOKEN_COUNTER=huggingface` or `COGNEE_TOKEN_COUNTER=hf` → check
44 /// `HUGGINGFACE_TOKENIZER`
45 /// 3. `HUGGINGFACE_TOKENIZER` env var is set → HuggingFace { model_id }
46 /// 4. `EMBEDDING_PROVIDER=onnx` or `fastembed` and `EMBEDDING_TOKENIZER_PATH` is set
47 /// and the file exists → HuggingFaceFile
48 /// 5. `EMBEDDING_PROVIDER=openai` or `openai_compatible` → TikToken
49 /// 6. `EMBEDDING_PROVIDER=ollama` and `HUGGINGFACE_TOKENIZER` set → HuggingFace
50 /// 7. Fallback → Word
51 pub fn from_env() -> Self {
52 // Priority 1 & 2: explicit COGNEE_TOKEN_COUNTER override
53 if let Ok(counter) = std::env::var("COGNEE_TOKEN_COUNTER") {
54 match counter.to_lowercase().as_str() {
55 "tiktoken" => return TokenCounterKind::TikToken,
56 "word" => return TokenCounterKind::Word,
57 "huggingface" | "hf" => {
58 if let Ok(model_id) = std::env::var("HUGGINGFACE_TOKENIZER")
59 && !model_id.trim().is_empty()
60 {
61 return TokenCounterKind::HuggingFace { model_id };
62 }
63 // explicit hf requested but no model id — fall through to other priorities
64 }
65 _ => {}
66 }
67 }
68
69 // Priority 3: HUGGINGFACE_TOKENIZER set (any provider)
70 if let Ok(model_id) = std::env::var("HUGGINGFACE_TOKENIZER")
71 && !model_id.trim().is_empty()
72 {
73 return TokenCounterKind::HuggingFace { model_id };
74 }
75
76 // Priority 4–6: based on EMBEDDING_PROVIDER
77 // Python's default embedding provider is `openai`, whose default tokenizer is
78 // tiktoken cl100k_base. Match that when EMBEDDING_PROVIDER is unset so an
79 // out-of-box OpenAI-family setup counts BPE tokens, not whitespace.
80 // Users who explicitly set EMBEDDING_PROVIDER=onnx (or point to a tokenizer
81 // file via EMBEDDING_TOKENIZER_PATH) get the HuggingFaceFile path as before.
82 let provider = std::env::var("EMBEDDING_PROVIDER")
83 .unwrap_or_else(|_| "openai".to_string())
84 .to_lowercase();
85
86 match provider.as_str() {
87 "onnx" | "fastembed" => {
88 // Try to reuse the ONNX engine's tokenizer file
89 if let Ok(path) = std::env::var("EMBEDDING_TOKENIZER_PATH") {
90 let p = PathBuf::from(&path);
91 if p.exists() {
92 return TokenCounterKind::HuggingFaceFile { path: p };
93 }
94 }
95 // No tokenizer file available — fall through to Word
96 TokenCounterKind::Word
97 }
98 "openai" | "openai_compatible" => TokenCounterKind::TikToken,
99 "ollama" => {
100 if let Ok(model_id) = std::env::var("HUGGINGFACE_TOKENIZER")
101 && !model_id.trim().is_empty()
102 {
103 return TokenCounterKind::HuggingFace { model_id };
104 }
105 TokenCounterKind::Word
106 }
107 _ => TokenCounterKind::Word,
108 }
109 }
110
111 /// Construct a boxed `TokenCounter` from this kind.
112 ///
113 /// Returns an error if the selected kind cannot be constructed (e.g. file not found,
114 /// model download failed). When the relevant Cargo feature is disabled, silently falls
115 /// back to `WordCounter` and logs a warning — so the crate compiles without optional
116 /// features but users get a visible signal that their configured tokenizer is inactive.
117 pub fn build(self) -> Result<Box<dyn TokenCounter + Send + Sync>, ChunkingError> {
118 match self {
119 TokenCounterKind::Word => Ok(Box::new(WordCounter)),
120
121 #[cfg(feature = "hf-tokenizer")]
122 TokenCounterKind::HuggingFace { model_id } => {
123 let counter =
124 crate::token_counter::HuggingFaceTokenCounter::from_pretrained(&model_id)?;
125 Ok(Box::new(counter))
126 }
127
128 #[cfg(feature = "hf-tokenizer")]
129 TokenCounterKind::HuggingFaceFile { path } => {
130 let counter = crate::token_counter::HuggingFaceTokenCounter::from_file(path)?;
131 Ok(Box::new(counter))
132 }
133
134 #[cfg(feature = "tiktoken")]
135 TokenCounterKind::TikToken => {
136 let counter = crate::token_counter::TikTokenCounter::cl100k_base()?;
137 Ok(Box::new(counter))
138 }
139
140 // When the relevant feature is disabled, fall back to Word with a warning.
141 // This keeps the crate usable without optional features while signalling to
142 // the user that their configured tokenizer is not active.
143 #[cfg(not(feature = "hf-tokenizer"))]
144 TokenCounterKind::HuggingFace { model_id: _ } => {
145 eprintln!(
146 "cognee-chunking: HuggingFace tokenizer requested but `hf-tokenizer` \
147 feature is not enabled — falling back to WordCounter"
148 );
149 Ok(Box::new(WordCounter))
150 }
151
152 #[cfg(not(feature = "hf-tokenizer"))]
153 TokenCounterKind::HuggingFaceFile { path: _ } => {
154 eprintln!(
155 "cognee-chunking: HuggingFaceFile tokenizer requested but `hf-tokenizer` \
156 feature is not enabled — falling back to WordCounter"
157 );
158 Ok(Box::new(WordCounter))
159 }
160
161 #[cfg(not(feature = "tiktoken"))]
162 TokenCounterKind::TikToken => {
163 eprintln!(
164 "cognee-chunking: TikToken tokenizer requested but `tiktoken` feature is \
165 not enabled — falling back to WordCounter"
166 );
167 Ok(Box::new(WordCounter))
168 }
169 }
170 }
171}
172
173#[cfg(test)]
174#[allow(
175 clippy::unwrap_used,
176 clippy::expect_used,
177 reason = "test code — panics are acceptable failures"
178)]
179mod tests {
180 use super::*;
181
182 /// When no env vars are set the default provider is treated as `openai`, which maps
183 /// to `TikToken` — matching Python's out-of-box cl100k_base tokenizer.
184 ///
185 /// # Safety
186 /// `std::env::remove_var` is marked `unsafe` in edition 2024. Tests run
187 /// single-threaded under the project harness (`--test-threads=1`), so there
188 /// are no concurrent readers of the modified env vars.
189 #[test]
190 fn from_env_defaults_to_tiktoken_for_openai_family() {
191 unsafe {
192 std::env::remove_var("EMBEDDING_PROVIDER");
193 std::env::remove_var("COGNEE_TOKEN_COUNTER");
194 std::env::remove_var("HUGGINGFACE_TOKENIZER");
195 std::env::remove_var("EMBEDDING_TOKENIZER_PATH");
196 }
197 assert!(matches!(
198 TokenCounterKind::from_env(),
199 TokenCounterKind::TikToken
200 ));
201 }
202
203 /// Explicitly setting EMBEDDING_PROVIDER=onnx still falls back to Word when
204 /// no tokenizer file is available (existing ONNX-user behaviour is unchanged).
205 #[test]
206 fn from_env_onnx_without_tokenizer_falls_back_to_word() {
207 unsafe {
208 std::env::set_var("EMBEDDING_PROVIDER", "onnx");
209 std::env::remove_var("COGNEE_TOKEN_COUNTER");
210 std::env::remove_var("HUGGINGFACE_TOKENIZER");
211 std::env::remove_var("EMBEDDING_TOKENIZER_PATH");
212 }
213 assert!(matches!(
214 TokenCounterKind::from_env(),
215 TokenCounterKind::Word
216 ));
217 // Restore
218 unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
219 }
220
221 #[test]
222 fn word_variant_builds() {
223 let counter = TokenCounterKind::Word.build();
224 assert!(counter.is_ok());
225 let counter = counter.unwrap();
226 assert_eq!(counter.count_tokens("hello world"), 2);
227 }
228
229 #[test]
230 fn word_variant_builds_empty() {
231 let counter = TokenCounterKind::Word.build().unwrap();
232 assert_eq!(counter.count_tokens(""), 0);
233 }
234
235 #[test]
236 #[cfg(feature = "tiktoken")]
237 fn tiktoken_variant_builds() {
238 let counter = TokenCounterKind::TikToken.build();
239 assert!(counter.is_ok());
240 }
241
242 #[test]
243 #[cfg(not(feature = "hf-tokenizer"))]
244 fn hf_falls_back_without_feature() {
245 let counter = TokenCounterKind::HuggingFace {
246 model_id: "bert-base-uncased".to_string(),
247 }
248 .build();
249 assert!(counter.is_ok(), "should fall back to WordCounter");
250 let counter = counter.unwrap();
251 assert_eq!(counter.count_tokens("hello world"), 2);
252 }
253
254 #[test]
255 #[cfg(not(feature = "tiktoken"))]
256 fn tiktoken_falls_back_without_feature() {
257 let counter = TokenCounterKind::TikToken.build();
258 assert!(counter.is_ok(), "should fall back to WordCounter");
259 let counter = counter.unwrap();
260 assert_eq!(counter.count_tokens("hello world"), 2);
261 }
262}