Skip to main content

hermes_core/tokenizer/
hf_tokenizer.rs

1//! HuggingFace tokenizer support for sparse vector queries
2//!
3//! Provides query-time tokenization using HuggingFace tokenizers.
4//! Used when a sparse vector field has `query_tokenizer` configured.
5//!
6//! Supports both native and WASM targets:
7//! - Native: Full support with `onig` regex and HTTP hub downloads
8//! - WASM: Limited to `from_bytes()` loading (no HTTP, no onig regex)
9
10use std::collections::HashMap;
11#[cfg(feature = "native")]
12use std::sync::Arc;
13
14#[cfg(feature = "native")]
15use parking_lot::RwLock;
16use tokenizers::Tokenizer;
17
18use log::debug;
19
20use crate::Result;
21use crate::error::Error;
22
23/// Cached HuggingFace tokenizer
24pub struct HfTokenizer {
25    pub(crate) tokenizer: Tokenizer,
26}
27
28/// Tokenizer source - where to load the tokenizer from
29#[derive(Debug, Clone)]
30pub enum TokenizerSource {
31    /// Load from HuggingFace hub (e.g., "bert-base-uncased") - native only
32    #[cfg(not(target_arch = "wasm32"))]
33    HuggingFace(String),
34    /// Load from local file path - native only
35    #[cfg(not(target_arch = "wasm32"))]
36    LocalFile(String),
37    /// Load from index directory (relative path within index)
38    IndexDirectory(String),
39}
40
41impl TokenizerSource {
42    /// Parse a tokenizer path string into a TokenizerSource
43    ///
44    /// - Paths starting with `index://` are relative to index directory
45    /// - On native: Paths starting with `/` are absolute local paths
46    /// - On native: Other paths are treated as HuggingFace hub identifiers
47    #[cfg(not(target_arch = "wasm32"))]
48    pub fn parse(path: &str) -> Self {
49        if let Some(relative) = path.strip_prefix("index://") {
50            TokenizerSource::IndexDirectory(relative.to_string())
51        } else if path.starts_with('/') {
52            TokenizerSource::LocalFile(path.to_string())
53        } else {
54            TokenizerSource::HuggingFace(path.to_string())
55        }
56    }
57
58    /// Parse a tokenizer path string into a TokenizerSource (WASM version)
59    ///
60    /// On WASM, only index:// paths are supported
61    #[cfg(target_arch = "wasm32")]
62    pub fn parse(path: &str) -> Self {
63        if let Some(relative) = path.strip_prefix("index://") {
64            TokenizerSource::IndexDirectory(relative.to_string())
65        } else {
66            // On WASM, treat all paths as index-relative
67            TokenizerSource::IndexDirectory(path.to_string())
68        }
69    }
70}
71
72impl HfTokenizer {
73    /// Load a tokenizer from HuggingFace hub or local path (native only)
74    ///
75    /// Examples:
76    /// - `"Alibaba-NLP/gte-Qwen2-1.5B-instruct"` - from HuggingFace hub
77    /// - `"/path/to/tokenizer.json"` - from local file
78    #[cfg(not(target_arch = "wasm32"))]
79    pub fn load(name_or_path: &str) -> Result<Self> {
80        let tokenizer = if name_or_path.contains('/') && !name_or_path.starts_with('/') {
81            // Looks like a HuggingFace hub identifier
82            Tokenizer::from_pretrained(name_or_path, None).map_err(|e| {
83                Error::Tokenizer(format!(
84                    "Failed to load tokenizer '{}': {}",
85                    name_or_path, e
86                ))
87            })?
88        } else {
89            // Local file path
90            Tokenizer::from_file(name_or_path).map_err(|e| {
91                Error::Tokenizer(format!(
92                    "Failed to load tokenizer from '{}': {}",
93                    name_or_path, e
94                ))
95            })?
96        };
97
98        Ok(Self { tokenizer })
99    }
100
101    /// Load a tokenizer from bytes (e.g., read from Directory)
102    ///
103    /// This allows loading tokenizers from any Directory implementation,
104    /// including remote storage like S3 or HTTP.
105    /// This is the primary method for WASM targets.
106    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
107        let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| {
108            Error::Tokenizer(format!("Failed to parse tokenizer from bytes: {}", e))
109        })?;
110        Ok(Self { tokenizer })
111    }
112
113    /// Load from a TokenizerSource (native only)
114    #[cfg(not(target_arch = "wasm32"))]
115    pub fn from_source(source: &TokenizerSource) -> Result<Self> {
116        match source {
117            TokenizerSource::HuggingFace(name) => {
118                let tokenizer = Tokenizer::from_pretrained(name, None).map_err(|e| {
119                    Error::Tokenizer(format!("Failed to load tokenizer '{}': {}", name, e))
120                })?;
121                Ok(Self { tokenizer })
122            }
123            TokenizerSource::LocalFile(path) => {
124                let tokenizer = Tokenizer::from_file(path).map_err(|e| {
125                    Error::Tokenizer(format!("Failed to load tokenizer from '{}': {}", path, e))
126                })?;
127                Ok(Self { tokenizer })
128            }
129            TokenizerSource::IndexDirectory(_) => {
130                // For index directory, caller must use from_bytes with data read from Directory
131                Err(Error::Tokenizer(
132                    "IndexDirectory source requires using from_bytes with Directory read"
133                        .to_string(),
134                ))
135            }
136        }
137    }
138
139    /// Resolve a token ID back to its text representation
140    pub fn id_to_token(&self, id: u32) -> Option<String> {
141        self.tokenizer.id_to_token(id)
142    }
143
144    /// Tokenize text and return token IDs
145    ///
146    /// Returns a vector of (token_id, count) pairs where count is the
147    /// number of times each token appears in the text.
148    pub fn tokenize(&self, text: &str) -> Result<Vec<(u32, u32)>> {
149        let encoding = self
150            .tokenizer
151            .encode(text, false)
152            .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
153
154        // Count token occurrences
155        let mut counts: HashMap<u32, u32> = HashMap::new();
156        for &id in encoding.get_ids() {
157            *counts.entry(id).or_insert(0) += 1;
158        }
159
160        let result: Vec<(u32, u32)> = counts.into_iter().collect();
161        let paired: Vec<_> = encoding
162            .get_tokens()
163            .iter()
164            .zip(encoding.get_ids())
165            .map(|(tok, id)| format!("({:?},{})", tok, id))
166            .collect();
167        debug!(
168            "Tokenized query: text={:?} tokens=[{}] unique_count={}",
169            text,
170            paired.join(", "),
171            result.len()
172        );
173
174        Ok(result)
175    }
176
177    /// Tokenize text and return unique token IDs (for weighting: one)
178    pub fn tokenize_unique(&self, text: &str) -> Result<Vec<u32>> {
179        let encoding = self
180            .tokenizer
181            .encode(text, false)
182            .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
183
184        // Get unique token IDs
185        let mut ids: Vec<u32> = encoding.get_ids().to_vec();
186        ids.sort_unstable();
187        ids.dedup();
188
189        let paired: Vec<_> = encoding
190            .get_tokens()
191            .iter()
192            .zip(encoding.get_ids())
193            .map(|(tok, id)| format!("({:?},{})", tok, id))
194            .collect();
195        debug!(
196            "Tokenized query (unique): text={:?} tokens=[{}] unique_count={}",
197            text,
198            paired.join(", "),
199            ids.len()
200        );
201
202        Ok(ids)
203    }
204}
205
206/// Global tokenizer cache for reuse across queries
207#[cfg(feature = "native")]
208pub struct TokenizerCache {
209    cache: RwLock<HashMap<String, Arc<HfTokenizer>>>,
210}
211
212#[cfg(feature = "native")]
213impl Default for TokenizerCache {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219#[cfg(feature = "native")]
220impl TokenizerCache {
221    /// Create a new tokenizer cache
222    pub fn new() -> Self {
223        Self {
224            cache: RwLock::new(HashMap::new()),
225        }
226    }
227
228    /// Get or load a tokenizer
229    pub fn get_or_load(&self, name_or_path: &str) -> Result<Arc<HfTokenizer>> {
230        // Check cache first
231        {
232            let cache = self.cache.read();
233            if let Some(tokenizer) = cache.get(name_or_path) {
234                return Ok(Arc::clone(tokenizer));
235            }
236        }
237
238        // Load and cache
239        let tokenizer = Arc::new(HfTokenizer::load(name_or_path)?);
240        {
241            let mut cache = self.cache.write();
242            cache.insert(name_or_path.to_string(), Arc::clone(&tokenizer));
243        }
244
245        Ok(tokenizer)
246    }
247
248    /// Clear the cache
249    pub fn clear(&self) {
250        let mut cache = self.cache.write();
251        cache.clear();
252    }
253}
254
255/// Global tokenizer cache instance
256#[cfg(feature = "native")]
257static TOKENIZER_CACHE: std::sync::OnceLock<TokenizerCache> = std::sync::OnceLock::new();
258
259/// Get the global tokenizer cache
260#[cfg(feature = "native")]
261pub fn tokenizer_cache() -> &'static TokenizerCache {
262    TOKENIZER_CACHE.get_or_init(TokenizerCache::new)
263}
264
265#[cfg(test)]
266#[cfg(feature = "native")]
267mod tests {
268    use super::*;
269
270    // Note: These tests require network access to download tokenizers
271    // They are ignored by default to avoid CI issues
272
273    #[test]
274    #[ignore]
275    fn test_load_tokenizer_from_hub() {
276        let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
277        let tokens = tokenizer.tokenize("hello world").unwrap();
278        assert!(!tokens.is_empty());
279    }
280
281    #[test]
282    #[ignore]
283    fn test_tokenize_unique() {
284        let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
285        let ids = tokenizer.tokenize_unique("the quick brown fox").unwrap();
286        // Should have unique tokens
287        let mut sorted = ids.clone();
288        sorted.sort_unstable();
289        sorted.dedup();
290        assert_eq!(ids.len(), sorted.len());
291    }
292
293    #[test]
294    fn test_tokenizer_cache() {
295        let cache = TokenizerCache::new();
296        // Just test that the cache structure works
297        assert!(cache.cache.read().is_empty());
298    }
299}