Skip to main content

hermes_core/tokenizer/
hf_tokenizer.rs

1//! HuggingFace tokenizer support for sparse vector queries
2//!
3//! Provides query-time tokenization using HuggingFace tokenizers.
4//! Used when a sparse vector field has `query_tokenizer` configured.
5//!
6//! Supports both native and WASM targets:
7//! - Native: Full support with `onig` regex and HTTP hub downloads
8//! - WASM: Limited to `from_bytes()` loading (no HTTP, no onig regex)
9
10use std::collections::HashMap;
11#[cfg(feature = "native")]
12use std::sync::Arc;
13
14#[cfg(feature = "native")]
15use parking_lot::RwLock;
16use tokenizers::Tokenizer;
17
18use crate::Result;
19use crate::error::Error;
20
21/// Cached HuggingFace tokenizer
22pub struct HfTokenizer {
23    tokenizer: Tokenizer,
24}
25
26/// Tokenizer source - where to load the tokenizer from
27#[derive(Debug, Clone)]
28pub enum TokenizerSource {
29    /// Load from HuggingFace hub (e.g., "bert-base-uncased") - native only
30    #[cfg(not(target_arch = "wasm32"))]
31    HuggingFace(String),
32    /// Load from local file path - native only
33    #[cfg(not(target_arch = "wasm32"))]
34    LocalFile(String),
35    /// Load from index directory (relative path within index)
36    IndexDirectory(String),
37}
38
39impl TokenizerSource {
40    /// Parse a tokenizer path string into a TokenizerSource
41    ///
42    /// - Paths starting with `index://` are relative to index directory
43    /// - On native: Paths starting with `/` are absolute local paths
44    /// - On native: Other paths are treated as HuggingFace hub identifiers
45    #[cfg(not(target_arch = "wasm32"))]
46    pub fn parse(path: &str) -> Self {
47        if let Some(relative) = path.strip_prefix("index://") {
48            TokenizerSource::IndexDirectory(relative.to_string())
49        } else if path.starts_with('/') {
50            TokenizerSource::LocalFile(path.to_string())
51        } else {
52            TokenizerSource::HuggingFace(path.to_string())
53        }
54    }
55
56    /// Parse a tokenizer path string into a TokenizerSource (WASM version)
57    ///
58    /// On WASM, only index:// paths are supported
59    #[cfg(target_arch = "wasm32")]
60    pub fn parse(path: &str) -> Self {
61        if let Some(relative) = path.strip_prefix("index://") {
62            TokenizerSource::IndexDirectory(relative.to_string())
63        } else {
64            // On WASM, treat all paths as index-relative
65            TokenizerSource::IndexDirectory(path.to_string())
66        }
67    }
68}
69
70impl HfTokenizer {
71    /// Load a tokenizer from HuggingFace hub or local path (native only)
72    ///
73    /// Examples:
74    /// - `"Alibaba-NLP/gte-Qwen2-1.5B-instruct"` - from HuggingFace hub
75    /// - `"/path/to/tokenizer.json"` - from local file
76    #[cfg(not(target_arch = "wasm32"))]
77    pub fn load(name_or_path: &str) -> Result<Self> {
78        let tokenizer = if name_or_path.contains('/') && !name_or_path.starts_with('/') {
79            // Looks like a HuggingFace hub identifier
80            Tokenizer::from_pretrained(name_or_path, None).map_err(|e| {
81                Error::Tokenizer(format!(
82                    "Failed to load tokenizer '{}': {}",
83                    name_or_path, e
84                ))
85            })?
86        } else {
87            // Local file path
88            Tokenizer::from_file(name_or_path).map_err(|e| {
89                Error::Tokenizer(format!(
90                    "Failed to load tokenizer from '{}': {}",
91                    name_or_path, e
92                ))
93            })?
94        };
95
96        Ok(Self { tokenizer })
97    }
98
99    /// Load a tokenizer from bytes (e.g., read from Directory)
100    ///
101    /// This allows loading tokenizers from any Directory implementation,
102    /// including remote storage like S3 or HTTP.
103    /// This is the primary method for WASM targets.
104    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
105        let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| {
106            Error::Tokenizer(format!("Failed to parse tokenizer from bytes: {}", e))
107        })?;
108        Ok(Self { tokenizer })
109    }
110
111    /// Load from a TokenizerSource (native only)
112    #[cfg(not(target_arch = "wasm32"))]
113    pub fn from_source(source: &TokenizerSource) -> Result<Self> {
114        match source {
115            TokenizerSource::HuggingFace(name) => {
116                let tokenizer = Tokenizer::from_pretrained(name, None).map_err(|e| {
117                    Error::Tokenizer(format!("Failed to load tokenizer '{}': {}", name, e))
118                })?;
119                Ok(Self { tokenizer })
120            }
121            TokenizerSource::LocalFile(path) => {
122                let tokenizer = Tokenizer::from_file(path).map_err(|e| {
123                    Error::Tokenizer(format!("Failed to load tokenizer from '{}': {}", path, e))
124                })?;
125                Ok(Self { tokenizer })
126            }
127            TokenizerSource::IndexDirectory(_) => {
128                // For index directory, caller must use from_bytes with data read from Directory
129                Err(Error::Tokenizer(
130                    "IndexDirectory source requires using from_bytes with Directory read"
131                        .to_string(),
132                ))
133            }
134        }
135    }
136
137    /// Tokenize text and return token IDs
138    ///
139    /// Returns a vector of (token_id, count) pairs where count is the
140    /// number of times each token appears in the text.
141    pub fn tokenize(&self, text: &str) -> Result<Vec<(u32, u32)>> {
142        let encoding = self
143            .tokenizer
144            .encode(text, false)
145            .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
146
147        // Count token occurrences
148        let mut counts: HashMap<u32, u32> = HashMap::new();
149        for &id in encoding.get_ids() {
150            *counts.entry(id).or_insert(0) += 1;
151        }
152
153        Ok(counts.into_iter().collect())
154    }
155
156    /// Tokenize text and return unique token IDs (for weighting: one)
157    pub fn tokenize_unique(&self, text: &str) -> Result<Vec<u32>> {
158        let encoding = self
159            .tokenizer
160            .encode(text, false)
161            .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
162
163        // Get unique token IDs
164        let mut ids: Vec<u32> = encoding.get_ids().to_vec();
165        ids.sort_unstable();
166        ids.dedup();
167
168        Ok(ids)
169    }
170}
171
172/// Global tokenizer cache for reuse across queries
173#[cfg(feature = "native")]
174pub struct TokenizerCache {
175    cache: RwLock<HashMap<String, Arc<HfTokenizer>>>,
176}
177
178#[cfg(feature = "native")]
179impl Default for TokenizerCache {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185#[cfg(feature = "native")]
186impl TokenizerCache {
187    /// Create a new tokenizer cache
188    pub fn new() -> Self {
189        Self {
190            cache: RwLock::new(HashMap::new()),
191        }
192    }
193
194    /// Get or load a tokenizer
195    pub fn get_or_load(&self, name_or_path: &str) -> Result<Arc<HfTokenizer>> {
196        // Check cache first
197        {
198            let cache = self.cache.read();
199            if let Some(tokenizer) = cache.get(name_or_path) {
200                return Ok(Arc::clone(tokenizer));
201            }
202        }
203
204        // Load and cache
205        let tokenizer = Arc::new(HfTokenizer::load(name_or_path)?);
206        {
207            let mut cache = self.cache.write();
208            cache.insert(name_or_path.to_string(), Arc::clone(&tokenizer));
209        }
210
211        Ok(tokenizer)
212    }
213
214    /// Clear the cache
215    pub fn clear(&self) {
216        let mut cache = self.cache.write();
217        cache.clear();
218    }
219}
220
221/// Global tokenizer cache instance
222#[cfg(feature = "native")]
223static TOKENIZER_CACHE: std::sync::OnceLock<TokenizerCache> = std::sync::OnceLock::new();
224
225/// Get the global tokenizer cache
226#[cfg(feature = "native")]
227pub fn tokenizer_cache() -> &'static TokenizerCache {
228    TOKENIZER_CACHE.get_or_init(TokenizerCache::new)
229}
230
231#[cfg(test)]
232#[cfg(feature = "native")]
233mod tests {
234    use super::*;
235
236    // Note: These tests require network access to download tokenizers
237    // They are ignored by default to avoid CI issues
238
239    #[test]
240    #[ignore]
241    fn test_load_tokenizer_from_hub() {
242        let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
243        let tokens = tokenizer.tokenize("hello world").unwrap();
244        assert!(!tokens.is_empty());
245    }
246
247    #[test]
248    #[ignore]
249    fn test_tokenize_unique() {
250        let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
251        let ids = tokenizer.tokenize_unique("the quick brown fox").unwrap();
252        // Should have unique tokens
253        let mut sorted = ids.clone();
254        sorted.sort_unstable();
255        sorted.dedup();
256        assert_eq!(ids.len(), sorted.len());
257    }
258
259    #[test]
260    fn test_tokenizer_cache() {
261        let cache = TokenizerCache::new();
262        // Just test that the cache structure works
263        assert!(cache.cache.read().is_empty());
264    }
265}