Skip to main content

cognee_chunking/
config.rs

1//! Chunking configuration — tokenizer selection via environment variables.
2//!
3//! [`TokenCounterKind`] selects which token counting implementation to use based on
4//! environment variables and the active embedding provider. Call [`TokenCounterKind::from_env`]
5//! at pipeline construction time to pick the best available counter automatically, then
6//! call [`TokenCounterKind::build`] to construct the counter.
7
8use std::path::PathBuf;
9
10use serde::{Deserialize, Serialize};
11
12use crate::error::ChunkingError;
13use crate::token_counter::{TokenCounter, WordCounter};
14
15/// Selects which token counting implementation to use.
16///
17/// `from_env()` picks the best available counter based on env vars and the current
18/// embedding provider setting. `WordCounter` is the last-resort fallback.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub enum TokenCounterKind {
21    /// Accurate BPE/WordPiece via a HuggingFace tokenizer model ID (requires network or cache).
22    HuggingFace { model_id: String },
23    /// Accurate BPE/WordPiece from a local tokenizer.json file.
24    HuggingFaceFile { path: PathBuf },
25    /// TikToken cl100k_base BPE (for OpenAI models).
26    TikToken,
27    /// Whitespace word count. Last-resort fallback.
28    Word,
29}
30
31impl TokenCounterKind {
32    /// Determine the best available token counter from the environment.
33    ///
34    /// Mirrors Python's `LiteLLMEmbeddingEngine.get_tokenizer()` logic, which selects a
35    /// tokenizer based on the provider and stores it on the engine instance. Python's
36    /// `chunk_by_sentence()` calls `embedding_engine.tokenizer.count_tokens()` directly —
37    /// the tokenizer is a property of the engine, not a separate config. The Rust design
38    /// decouples them (`TokenCounterKind` is independent of the engine), but the selection
39    /// logic below preserves the same provider → tokenizer mapping.
40    ///
41    /// Priority order (highest wins):
42    /// 1. `COGNEE_TOKEN_COUNTER=tiktoken` → TikToken
43    /// 2. `COGNEE_TOKEN_COUNTER=huggingface` or `COGNEE_TOKEN_COUNTER=hf` → check
44    ///    `HUGGINGFACE_TOKENIZER`
45    /// 3. `HUGGINGFACE_TOKENIZER` env var is set → HuggingFace { model_id }
46    /// 4. `EMBEDDING_PROVIDER=onnx` or `fastembed` and `EMBEDDING_TOKENIZER_PATH` is set
47    ///    and the file exists → HuggingFaceFile
48    /// 5. `EMBEDDING_PROVIDER=openai` or `openai_compatible` → TikToken
49    /// 6. `EMBEDDING_PROVIDER=ollama` and `HUGGINGFACE_TOKENIZER` set → HuggingFace
50    /// 7. Fallback → Word
51    pub fn from_env() -> Self {
52        // Priority 1 & 2: explicit COGNEE_TOKEN_COUNTER override
53        if let Ok(counter) = std::env::var("COGNEE_TOKEN_COUNTER") {
54            match counter.to_lowercase().as_str() {
55                "tiktoken" => return TokenCounterKind::TikToken,
56                "word" => return TokenCounterKind::Word,
57                "huggingface" | "hf" => {
58                    if let Ok(model_id) = std::env::var("HUGGINGFACE_TOKENIZER")
59                        && !model_id.trim().is_empty()
60                    {
61                        return TokenCounterKind::HuggingFace { model_id };
62                    }
63                    // explicit hf requested but no model id — fall through to other priorities
64                }
65                _ => {}
66            }
67        }
68
69        // Priority 3: HUGGINGFACE_TOKENIZER set (any provider)
70        if let Ok(model_id) = std::env::var("HUGGINGFACE_TOKENIZER")
71            && !model_id.trim().is_empty()
72        {
73            return TokenCounterKind::HuggingFace { model_id };
74        }
75
76        // Priority 4–6: based on EMBEDDING_PROVIDER
77        // Python's default embedding provider is `openai`, whose default tokenizer is
78        // tiktoken cl100k_base. Match that when EMBEDDING_PROVIDER is unset so an
79        // out-of-box OpenAI-family setup counts BPE tokens, not whitespace.
80        // Users who explicitly set EMBEDDING_PROVIDER=onnx (or point to a tokenizer
81        // file via EMBEDDING_TOKENIZER_PATH) get the HuggingFaceFile path as before.
82        let provider = std::env::var("EMBEDDING_PROVIDER")
83            .unwrap_or_else(|_| "openai".to_string())
84            .to_lowercase();
85
86        match provider.as_str() {
87            "onnx" | "fastembed" => {
88                // Try to reuse the ONNX engine's tokenizer file
89                if let Ok(path) = std::env::var("EMBEDDING_TOKENIZER_PATH") {
90                    let p = PathBuf::from(&path);
91                    if p.exists() {
92                        return TokenCounterKind::HuggingFaceFile { path: p };
93                    }
94                }
95                // No tokenizer file available — fall through to Word
96                TokenCounterKind::Word
97            }
98            "openai" | "openai_compatible" => TokenCounterKind::TikToken,
99            "ollama" => {
100                if let Ok(model_id) = std::env::var("HUGGINGFACE_TOKENIZER")
101                    && !model_id.trim().is_empty()
102                {
103                    return TokenCounterKind::HuggingFace { model_id };
104                }
105                TokenCounterKind::Word
106            }
107            _ => TokenCounterKind::Word,
108        }
109    }
110
111    /// Construct a boxed `TokenCounter` from this kind.
112    ///
113    /// Returns an error if the selected kind cannot be constructed (e.g. file not found,
114    /// model download failed). When the relevant Cargo feature is disabled, silently falls
115    /// back to `WordCounter` and logs a warning — so the crate compiles without optional
116    /// features but users get a visible signal that their configured tokenizer is inactive.
117    pub fn build(self) -> Result<Box<dyn TokenCounter + Send + Sync>, ChunkingError> {
118        match self {
119            TokenCounterKind::Word => Ok(Box::new(WordCounter)),
120
121            #[cfg(feature = "hf-tokenizer")]
122            TokenCounterKind::HuggingFace { model_id } => {
123                let counter =
124                    crate::token_counter::HuggingFaceTokenCounter::from_pretrained(&model_id)?;
125                Ok(Box::new(counter))
126            }
127
128            #[cfg(feature = "hf-tokenizer")]
129            TokenCounterKind::HuggingFaceFile { path } => {
130                let counter = crate::token_counter::HuggingFaceTokenCounter::from_file(path)?;
131                Ok(Box::new(counter))
132            }
133
134            #[cfg(feature = "tiktoken")]
135            TokenCounterKind::TikToken => {
136                let counter = crate::token_counter::TikTokenCounter::cl100k_base()?;
137                Ok(Box::new(counter))
138            }
139
140            // When the relevant feature is disabled, fall back to Word with a warning.
141            // This keeps the crate usable without optional features while signalling to
142            // the user that their configured tokenizer is not active.
143            #[cfg(not(feature = "hf-tokenizer"))]
144            TokenCounterKind::HuggingFace { model_id: _ } => {
145                eprintln!(
146                    "cognee-chunking: HuggingFace tokenizer requested but `hf-tokenizer` \
147                     feature is not enabled — falling back to WordCounter"
148                );
149                Ok(Box::new(WordCounter))
150            }
151
152            #[cfg(not(feature = "hf-tokenizer"))]
153            TokenCounterKind::HuggingFaceFile { path: _ } => {
154                eprintln!(
155                    "cognee-chunking: HuggingFaceFile tokenizer requested but `hf-tokenizer` \
156                     feature is not enabled — falling back to WordCounter"
157                );
158                Ok(Box::new(WordCounter))
159            }
160
161            #[cfg(not(feature = "tiktoken"))]
162            TokenCounterKind::TikToken => {
163                eprintln!(
164                    "cognee-chunking: TikToken tokenizer requested but `tiktoken` feature is \
165                     not enabled — falling back to WordCounter"
166                );
167                Ok(Box::new(WordCounter))
168            }
169        }
170    }
171}
172
173#[cfg(test)]
174#[allow(
175    clippy::unwrap_used,
176    clippy::expect_used,
177    reason = "test code — panics are acceptable failures"
178)]
179mod tests {
180    use super::*;
181
182    /// When no env vars are set the default provider is treated as `openai`, which maps
183    /// to `TikToken` — matching Python's out-of-box cl100k_base tokenizer.
184    ///
185    /// # Safety
186    /// `std::env::remove_var` is marked `unsafe` in edition 2024.  Tests run
187    /// single-threaded under the project harness (`--test-threads=1`), so there
188    /// are no concurrent readers of the modified env vars.
189    #[test]
190    fn from_env_defaults_to_tiktoken_for_openai_family() {
191        unsafe {
192            std::env::remove_var("EMBEDDING_PROVIDER");
193            std::env::remove_var("COGNEE_TOKEN_COUNTER");
194            std::env::remove_var("HUGGINGFACE_TOKENIZER");
195            std::env::remove_var("EMBEDDING_TOKENIZER_PATH");
196        }
197        assert!(matches!(
198            TokenCounterKind::from_env(),
199            TokenCounterKind::TikToken
200        ));
201    }
202
203    /// Explicitly setting EMBEDDING_PROVIDER=onnx still falls back to Word when
204    /// no tokenizer file is available (existing ONNX-user behaviour is unchanged).
205    #[test]
206    fn from_env_onnx_without_tokenizer_falls_back_to_word() {
207        unsafe {
208            std::env::set_var("EMBEDDING_PROVIDER", "onnx");
209            std::env::remove_var("COGNEE_TOKEN_COUNTER");
210            std::env::remove_var("HUGGINGFACE_TOKENIZER");
211            std::env::remove_var("EMBEDDING_TOKENIZER_PATH");
212        }
213        assert!(matches!(
214            TokenCounterKind::from_env(),
215            TokenCounterKind::Word
216        ));
217        // Restore
218        unsafe { std::env::remove_var("EMBEDDING_PROVIDER") };
219    }
220
221    #[test]
222    fn word_variant_builds() {
223        let counter = TokenCounterKind::Word.build();
224        assert!(counter.is_ok());
225        let counter = counter.unwrap();
226        assert_eq!(counter.count_tokens("hello world"), 2);
227    }
228
229    #[test]
230    fn word_variant_builds_empty() {
231        let counter = TokenCounterKind::Word.build().unwrap();
232        assert_eq!(counter.count_tokens(""), 0);
233    }
234
235    #[test]
236    #[cfg(feature = "tiktoken")]
237    fn tiktoken_variant_builds() {
238        let counter = TokenCounterKind::TikToken.build();
239        assert!(counter.is_ok());
240    }
241
242    #[test]
243    #[cfg(not(feature = "hf-tokenizer"))]
244    fn hf_falls_back_without_feature() {
245        let counter = TokenCounterKind::HuggingFace {
246            model_id: "bert-base-uncased".to_string(),
247        }
248        .build();
249        assert!(counter.is_ok(), "should fall back to WordCounter");
250        let counter = counter.unwrap();
251        assert_eq!(counter.count_tokens("hello world"), 2);
252    }
253
254    #[test]
255    #[cfg(not(feature = "tiktoken"))]
256    fn tiktoken_falls_back_without_feature() {
257        let counter = TokenCounterKind::TikToken.build();
258        assert!(counter.is_ok(), "should fall back to WordCounter");
259        let counter = counter.unwrap();
260        assert_eq!(counter.count_tokens("hello world"), 2);
261    }
262}