Skip to main content

hermes_core/tokenizer/
hf_tokenizer.rs

1//! HuggingFace tokenizer support for sparse vector queries
2//!
3//! Provides query-time tokenization using HuggingFace tokenizers.
4//! Used when a sparse vector field has `query_tokenizer` configured.
5//!
6//! Supports both native and WASM targets:
7//! - Native: Full support with `onig` regex and HTTP hub downloads
8//! - WASM: Limited to `from_bytes()` loading (no HTTP, no onig regex)
9
10use std::collections::HashMap;
11#[cfg(feature = "native")]
12use std::sync::Arc;
13
14#[cfg(feature = "native")]
15use parking_lot::RwLock;
16use tokenizers::Tokenizer;
17
18use log::debug;
19
20use crate::Result;
21use crate::error::Error;
22
23/// Cached HuggingFace tokenizer
24pub struct HfTokenizer {
25    tokenizer: Tokenizer,
26}
27
28/// Tokenizer source - where to load the tokenizer from
29#[derive(Debug, Clone)]
30pub enum TokenizerSource {
31    /// Load from HuggingFace hub (e.g., "bert-base-uncased") - native only
32    #[cfg(not(target_arch = "wasm32"))]
33    HuggingFace(String),
34    /// Load from local file path - native only
35    #[cfg(not(target_arch = "wasm32"))]
36    LocalFile(String),
37    /// Load from index directory (relative path within index)
38    IndexDirectory(String),
39}
40
41impl TokenizerSource {
42    /// Parse a tokenizer path string into a TokenizerSource
43    ///
44    /// - Paths starting with `index://` are relative to index directory
45    /// - On native: Paths starting with `/` are absolute local paths
46    /// - On native: Other paths are treated as HuggingFace hub identifiers
47    #[cfg(not(target_arch = "wasm32"))]
48    pub fn parse(path: &str) -> Self {
49        if let Some(relative) = path.strip_prefix("index://") {
50            TokenizerSource::IndexDirectory(relative.to_string())
51        } else if path.starts_with('/') {
52            TokenizerSource::LocalFile(path.to_string())
53        } else {
54            TokenizerSource::HuggingFace(path.to_string())
55        }
56    }
57
58    /// Parse a tokenizer path string into a TokenizerSource (WASM version)
59    ///
60    /// On WASM, only index:// paths are supported
61    #[cfg(target_arch = "wasm32")]
62    pub fn parse(path: &str) -> Self {
63        if let Some(relative) = path.strip_prefix("index://") {
64            TokenizerSource::IndexDirectory(relative.to_string())
65        } else {
66            // On WASM, treat all paths as index-relative
67            TokenizerSource::IndexDirectory(path.to_string())
68        }
69    }
70}
71
72impl HfTokenizer {
73    /// Load a tokenizer from HuggingFace hub or local path (native only)
74    ///
75    /// Examples:
76    /// - `"Alibaba-NLP/gte-Qwen2-1.5B-instruct"` - from HuggingFace hub
77    /// - `"/path/to/tokenizer.json"` - from local file
78    #[cfg(not(target_arch = "wasm32"))]
79    pub fn load(name_or_path: &str) -> Result<Self> {
80        let tokenizer = if name_or_path.contains('/') && !name_or_path.starts_with('/') {
81            // Looks like a HuggingFace hub identifier
82            Tokenizer::from_pretrained(name_or_path, None).map_err(|e| {
83                Error::Tokenizer(format!(
84                    "Failed to load tokenizer '{}': {}",
85                    name_or_path, e
86                ))
87            })?
88        } else {
89            // Local file path
90            Tokenizer::from_file(name_or_path).map_err(|e| {
91                Error::Tokenizer(format!(
92                    "Failed to load tokenizer from '{}': {}",
93                    name_or_path, e
94                ))
95            })?
96        };
97
98        Ok(Self { tokenizer })
99    }
100
101    /// Load a tokenizer from bytes (e.g., read from Directory)
102    ///
103    /// This allows loading tokenizers from any Directory implementation,
104    /// including remote storage like S3 or HTTP.
105    /// This is the primary method for WASM targets.
106    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
107        let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| {
108            Error::Tokenizer(format!("Failed to parse tokenizer from bytes: {}", e))
109        })?;
110        Ok(Self { tokenizer })
111    }
112
113    /// Load from a TokenizerSource (native only)
114    #[cfg(not(target_arch = "wasm32"))]
115    pub fn from_source(source: &TokenizerSource) -> Result<Self> {
116        match source {
117            TokenizerSource::HuggingFace(name) => {
118                let tokenizer = Tokenizer::from_pretrained(name, None).map_err(|e| {
119                    Error::Tokenizer(format!("Failed to load tokenizer '{}': {}", name, e))
120                })?;
121                Ok(Self { tokenizer })
122            }
123            TokenizerSource::LocalFile(path) => {
124                let tokenizer = Tokenizer::from_file(path).map_err(|e| {
125                    Error::Tokenizer(format!("Failed to load tokenizer from '{}': {}", path, e))
126                })?;
127                Ok(Self { tokenizer })
128            }
129            TokenizerSource::IndexDirectory(_) => {
130                // For index directory, caller must use from_bytes with data read from Directory
131                Err(Error::Tokenizer(
132                    "IndexDirectory source requires using from_bytes with Directory read"
133                        .to_string(),
134                ))
135            }
136        }
137    }
138
139    /// Tokenize text and return token IDs
140    ///
141    /// Returns a vector of (token_id, count) pairs where count is the
142    /// number of times each token appears in the text.
143    pub fn tokenize(&self, text: &str) -> Result<Vec<(u32, u32)>> {
144        let encoding = self
145            .tokenizer
146            .encode(text, false)
147            .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
148
149        // Count token occurrences
150        let mut counts: HashMap<u32, u32> = HashMap::new();
151        for &id in encoding.get_ids() {
152            *counts.entry(id).or_insert(0) += 1;
153        }
154
155        let result: Vec<(u32, u32)> = counts.into_iter().collect();
156        debug!(
157            "Tokenized query: text={:?} tokens={:?} token_ids={:?} unique_count={}",
158            text,
159            encoding.get_tokens(),
160            encoding.get_ids(),
161            result.len()
162        );
163
164        Ok(result)
165    }
166
167    /// Tokenize text and return unique token IDs (for weighting: one)
168    pub fn tokenize_unique(&self, text: &str) -> Result<Vec<u32>> {
169        let encoding = self
170            .tokenizer
171            .encode(text, false)
172            .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
173
174        // Get unique token IDs
175        let mut ids: Vec<u32> = encoding.get_ids().to_vec();
176        ids.sort_unstable();
177        ids.dedup();
178
179        debug!(
180            "Tokenized query (unique): text={:?} tokens={:?} token_ids={:?} unique_count={}",
181            text,
182            encoding.get_tokens(),
183            encoding.get_ids(),
184            ids.len()
185        );
186
187        Ok(ids)
188    }
189}
190
191/// Global tokenizer cache for reuse across queries
192#[cfg(feature = "native")]
193pub struct TokenizerCache {
194    cache: RwLock<HashMap<String, Arc<HfTokenizer>>>,
195}
196
197#[cfg(feature = "native")]
198impl Default for TokenizerCache {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204#[cfg(feature = "native")]
205impl TokenizerCache {
206    /// Create a new tokenizer cache
207    pub fn new() -> Self {
208        Self {
209            cache: RwLock::new(HashMap::new()),
210        }
211    }
212
213    /// Get or load a tokenizer
214    pub fn get_or_load(&self, name_or_path: &str) -> Result<Arc<HfTokenizer>> {
215        // Check cache first
216        {
217            let cache = self.cache.read();
218            if let Some(tokenizer) = cache.get(name_or_path) {
219                return Ok(Arc::clone(tokenizer));
220            }
221        }
222
223        // Load and cache
224        let tokenizer = Arc::new(HfTokenizer::load(name_or_path)?);
225        {
226            let mut cache = self.cache.write();
227            cache.insert(name_or_path.to_string(), Arc::clone(&tokenizer));
228        }
229
230        Ok(tokenizer)
231    }
232
233    /// Clear the cache
234    pub fn clear(&self) {
235        let mut cache = self.cache.write();
236        cache.clear();
237    }
238}
239
240/// Global tokenizer cache instance
241#[cfg(feature = "native")]
242static TOKENIZER_CACHE: std::sync::OnceLock<TokenizerCache> = std::sync::OnceLock::new();
243
244/// Get the global tokenizer cache
245#[cfg(feature = "native")]
246pub fn tokenizer_cache() -> &'static TokenizerCache {
247    TOKENIZER_CACHE.get_or_init(TokenizerCache::new)
248}
249
250#[cfg(test)]
251#[cfg(feature = "native")]
252mod tests {
253    use super::*;
254
255    // Note: These tests require network access to download tokenizers
256    // They are ignored by default to avoid CI issues
257
258    #[test]
259    #[ignore]
260    fn test_load_tokenizer_from_hub() {
261        let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
262        let tokens = tokenizer.tokenize("hello world").unwrap();
263        assert!(!tokens.is_empty());
264    }
265
266    #[test]
267    #[ignore]
268    fn test_tokenize_unique() {
269        let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
270        let ids = tokenizer.tokenize_unique("the quick brown fox").unwrap();
271        // Should have unique tokens
272        let mut sorted = ids.clone();
273        sorted.sort_unstable();
274        sorted.dedup();
275        assert_eq!(ids.len(), sorted.len());
276    }
277
278    #[test]
279    fn test_tokenizer_cache() {
280        let cache = TokenizerCache::new();
281        // Just test that the cache structure works
282        assert!(cache.cache.read().is_empty());
283    }
284}