hermes_core/tokenizer/
hf_tokenizer.rs

1//! HuggingFace tokenizer support for sparse vector queries
2//!
3//! Provides query-time tokenization using HuggingFace tokenizers.
4//! Used when a sparse vector field has `query_tokenizer` configured.
5//!
6//! Supports both native and WASM targets:
7//! - Native: Full support with `onig` regex and HTTP hub downloads
8//! - WASM: Limited to `from_bytes()` loading (no HTTP, no onig regex)
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use parking_lot::RwLock;
14use tokenizers::Tokenizer;
15
16use crate::Result;
17use crate::error::Error;
18
19/// Cached HuggingFace tokenizer
20pub struct HfTokenizer {
21    tokenizer: Tokenizer,
22}
23
24/// Tokenizer source - where to load the tokenizer from
25#[derive(Debug, Clone)]
26pub enum TokenizerSource {
27    /// Load from HuggingFace hub (e.g., "bert-base-uncased") - native only
28    #[cfg(not(target_arch = "wasm32"))]
29    HuggingFace(String),
30    /// Load from local file path - native only
31    #[cfg(not(target_arch = "wasm32"))]
32    LocalFile(String),
33    /// Load from index directory (relative path within index)
34    IndexDirectory(String),
35}
36
37impl TokenizerSource {
38    /// Parse a tokenizer path string into a TokenizerSource
39    ///
40    /// - Paths starting with `index://` are relative to index directory
41    /// - On native: Paths starting with `/` are absolute local paths
42    /// - On native: Other paths are treated as HuggingFace hub identifiers
43    #[cfg(not(target_arch = "wasm32"))]
44    pub fn parse(path: &str) -> Self {
45        if let Some(relative) = path.strip_prefix("index://") {
46            TokenizerSource::IndexDirectory(relative.to_string())
47        } else if path.starts_with('/') {
48            TokenizerSource::LocalFile(path.to_string())
49        } else {
50            TokenizerSource::HuggingFace(path.to_string())
51        }
52    }
53
54    /// Parse a tokenizer path string into a TokenizerSource (WASM version)
55    ///
56    /// On WASM, only index:// paths are supported
57    #[cfg(target_arch = "wasm32")]
58    pub fn parse(path: &str) -> Self {
59        if let Some(relative) = path.strip_prefix("index://") {
60            TokenizerSource::IndexDirectory(relative.to_string())
61        } else {
62            // On WASM, treat all paths as index-relative
63            TokenizerSource::IndexDirectory(path.to_string())
64        }
65    }
66}
67
68impl HfTokenizer {
69    /// Load a tokenizer from HuggingFace hub or local path (native only)
70    ///
71    /// Examples:
72    /// - `"Alibaba-NLP/gte-Qwen2-1.5B-instruct"` - from HuggingFace hub
73    /// - `"/path/to/tokenizer.json"` - from local file
74    #[cfg(not(target_arch = "wasm32"))]
75    pub fn load(name_or_path: &str) -> Result<Self> {
76        let tokenizer = if name_or_path.contains('/') && !name_or_path.starts_with('/') {
77            // Looks like a HuggingFace hub identifier
78            Tokenizer::from_pretrained(name_or_path, None).map_err(|e| {
79                Error::Tokenizer(format!(
80                    "Failed to load tokenizer '{}': {}",
81                    name_or_path, e
82                ))
83            })?
84        } else {
85            // Local file path
86            Tokenizer::from_file(name_or_path).map_err(|e| {
87                Error::Tokenizer(format!(
88                    "Failed to load tokenizer from '{}': {}",
89                    name_or_path, e
90                ))
91            })?
92        };
93
94        Ok(Self { tokenizer })
95    }
96
97    /// Load a tokenizer from bytes (e.g., read from Directory)
98    ///
99    /// This allows loading tokenizers from any Directory implementation,
100    /// including remote storage like S3 or HTTP.
101    /// This is the primary method for WASM targets.
102    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
103        let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| {
104            Error::Tokenizer(format!("Failed to parse tokenizer from bytes: {}", e))
105        })?;
106        Ok(Self { tokenizer })
107    }
108
109    /// Load from a TokenizerSource (native only)
110    #[cfg(not(target_arch = "wasm32"))]
111    pub fn from_source(source: &TokenizerSource) -> Result<Self> {
112        match source {
113            TokenizerSource::HuggingFace(name) => {
114                let tokenizer = Tokenizer::from_pretrained(name, None).map_err(|e| {
115                    Error::Tokenizer(format!("Failed to load tokenizer '{}': {}", name, e))
116                })?;
117                Ok(Self { tokenizer })
118            }
119            TokenizerSource::LocalFile(path) => {
120                let tokenizer = Tokenizer::from_file(path).map_err(|e| {
121                    Error::Tokenizer(format!("Failed to load tokenizer from '{}': {}", path, e))
122                })?;
123                Ok(Self { tokenizer })
124            }
125            TokenizerSource::IndexDirectory(_) => {
126                // For index directory, caller must use from_bytes with data read from Directory
127                Err(Error::Tokenizer(
128                    "IndexDirectory source requires using from_bytes with Directory read"
129                        .to_string(),
130                ))
131            }
132        }
133    }
134
135    /// Tokenize text and return token IDs
136    ///
137    /// Returns a vector of (token_id, count) pairs where count is the
138    /// number of times each token appears in the text.
139    pub fn tokenize(&self, text: &str) -> Result<Vec<(u32, u32)>> {
140        let encoding = self
141            .tokenizer
142            .encode(text, false)
143            .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
144
145        // Count token occurrences
146        let mut counts: HashMap<u32, u32> = HashMap::new();
147        for &id in encoding.get_ids() {
148            *counts.entry(id).or_insert(0) += 1;
149        }
150
151        Ok(counts.into_iter().collect())
152    }
153
154    /// Tokenize text and return unique token IDs (for weighting: one)
155    pub fn tokenize_unique(&self, text: &str) -> Result<Vec<u32>> {
156        let encoding = self
157            .tokenizer
158            .encode(text, false)
159            .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
160
161        // Get unique token IDs
162        let mut ids: Vec<u32> = encoding.get_ids().to_vec();
163        ids.sort_unstable();
164        ids.dedup();
165
166        Ok(ids)
167    }
168}
169
170/// Global tokenizer cache for reuse across queries
171#[cfg(feature = "native")]
172pub struct TokenizerCache {
173    cache: RwLock<HashMap<String, Arc<HfTokenizer>>>,
174}
175
176#[cfg(feature = "native")]
177impl Default for TokenizerCache {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183#[cfg(feature = "native")]
184impl TokenizerCache {
185    /// Create a new tokenizer cache
186    pub fn new() -> Self {
187        Self {
188            cache: RwLock::new(HashMap::new()),
189        }
190    }
191
192    /// Get or load a tokenizer
193    pub fn get_or_load(&self, name_or_path: &str) -> Result<Arc<HfTokenizer>> {
194        // Check cache first
195        {
196            let cache = self.cache.read();
197            if let Some(tokenizer) = cache.get(name_or_path) {
198                return Ok(Arc::clone(tokenizer));
199            }
200        }
201
202        // Load and cache
203        let tokenizer = Arc::new(HfTokenizer::load(name_or_path)?);
204        {
205            let mut cache = self.cache.write();
206            cache.insert(name_or_path.to_string(), Arc::clone(&tokenizer));
207        }
208
209        Ok(tokenizer)
210    }
211
212    /// Clear the cache
213    pub fn clear(&self) {
214        let mut cache = self.cache.write();
215        cache.clear();
216    }
217}
218
219/// Global tokenizer cache instance
220#[cfg(feature = "native")]
221static TOKENIZER_CACHE: std::sync::OnceLock<TokenizerCache> = std::sync::OnceLock::new();
222
223/// Get the global tokenizer cache
224#[cfg(feature = "native")]
225pub fn tokenizer_cache() -> &'static TokenizerCache {
226    TOKENIZER_CACHE.get_or_init(TokenizerCache::new)
227}
228
229#[cfg(test)]
230#[cfg(feature = "native")]
231mod tests {
232    use super::*;
233
234    // Note: These tests require network access to download tokenizers
235    // They are ignored by default to avoid CI issues
236
237    #[test]
238    #[ignore]
239    fn test_load_tokenizer_from_hub() {
240        let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
241        let tokens = tokenizer.tokenize("hello world").unwrap();
242        assert!(!tokens.is_empty());
243    }
244
245    #[test]
246    #[ignore]
247    fn test_tokenize_unique() {
248        let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
249        let ids = tokenizer.tokenize_unique("the quick brown fox").unwrap();
250        // Should have unique tokens
251        let mut sorted = ids.clone();
252        sorted.sort_unstable();
253        sorted.dedup();
254        assert_eq!(ids.len(), sorted.len());
255    }
256
257    #[test]
258    fn test_tokenizer_cache() {
259        let cache = TokenizerCache::new();
260        // Just test that the cache structure works
261        assert!(cache.cache.read().is_empty());
262    }
263}