hermes_core/tokenizer/
hf_tokenizer.rs1use std::collections::HashMap;
11#[cfg(feature = "native")]
12use std::sync::Arc;
13
14#[cfg(feature = "native")]
15use parking_lot::RwLock;
16use tokenizers::Tokenizer;
17
18use log::debug;
19
20use crate::Result;
21use crate::error::Error;
22
23pub struct HfTokenizer {
25 pub(crate) tokenizer: Tokenizer,
26}
27
28#[derive(Debug, Clone)]
30pub enum TokenizerSource {
31 #[cfg(not(target_arch = "wasm32"))]
33 HuggingFace(String),
34 #[cfg(not(target_arch = "wasm32"))]
36 LocalFile(String),
37 IndexDirectory(String),
39}
40
41impl TokenizerSource {
42 #[cfg(not(target_arch = "wasm32"))]
48 pub fn parse(path: &str) -> Self {
49 if let Some(relative) = path.strip_prefix("index://") {
50 TokenizerSource::IndexDirectory(relative.to_string())
51 } else if path.starts_with('/') {
52 TokenizerSource::LocalFile(path.to_string())
53 } else {
54 TokenizerSource::HuggingFace(path.to_string())
55 }
56 }
57
58 #[cfg(target_arch = "wasm32")]
62 pub fn parse(path: &str) -> Self {
63 if let Some(relative) = path.strip_prefix("index://") {
64 TokenizerSource::IndexDirectory(relative.to_string())
65 } else {
66 TokenizerSource::IndexDirectory(path.to_string())
68 }
69 }
70}
71
72impl HfTokenizer {
73 #[cfg(not(target_arch = "wasm32"))]
79 pub fn load(name_or_path: &str) -> Result<Self> {
80 let tokenizer = if name_or_path.contains('/') && !name_or_path.starts_with('/') {
81 Tokenizer::from_pretrained(name_or_path, None).map_err(|e| {
83 Error::Tokenizer(format!(
84 "Failed to load tokenizer '{}': {}",
85 name_or_path, e
86 ))
87 })?
88 } else {
89 Tokenizer::from_file(name_or_path).map_err(|e| {
91 Error::Tokenizer(format!(
92 "Failed to load tokenizer from '{}': {}",
93 name_or_path, e
94 ))
95 })?
96 };
97
98 Ok(Self { tokenizer })
99 }
100
101 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
107 let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| {
108 Error::Tokenizer(format!("Failed to parse tokenizer from bytes: {}", e))
109 })?;
110 Ok(Self { tokenizer })
111 }
112
113 #[cfg(not(target_arch = "wasm32"))]
115 pub fn from_source(source: &TokenizerSource) -> Result<Self> {
116 match source {
117 TokenizerSource::HuggingFace(name) => {
118 let tokenizer = Tokenizer::from_pretrained(name, None).map_err(|e| {
119 Error::Tokenizer(format!("Failed to load tokenizer '{}': {}", name, e))
120 })?;
121 Ok(Self { tokenizer })
122 }
123 TokenizerSource::LocalFile(path) => {
124 let tokenizer = Tokenizer::from_file(path).map_err(|e| {
125 Error::Tokenizer(format!("Failed to load tokenizer from '{}': {}", path, e))
126 })?;
127 Ok(Self { tokenizer })
128 }
129 TokenizerSource::IndexDirectory(_) => {
130 Err(Error::Tokenizer(
132 "IndexDirectory source requires using from_bytes with Directory read"
133 .to_string(),
134 ))
135 }
136 }
137 }
138
139 pub fn id_to_token(&self, id: u32) -> Option<String> {
141 self.tokenizer.id_to_token(id)
142 }
143
144 pub fn tokenize(&self, text: &str) -> Result<Vec<(u32, u32)>> {
149 let encoding = self
150 .tokenizer
151 .encode(text, false)
152 .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
153
154 let mut counts: HashMap<u32, u32> = HashMap::new();
156 for &id in encoding.get_ids() {
157 *counts.entry(id).or_insert(0) += 1;
158 }
159
160 let result: Vec<(u32, u32)> = counts.into_iter().collect();
161 let paired: Vec<_> = encoding
162 .get_tokens()
163 .iter()
164 .zip(encoding.get_ids())
165 .map(|(tok, id)| format!("({:?},{})", tok, id))
166 .collect();
167 debug!(
168 "Tokenized query: text={:?} tokens=[{}] unique_count={}",
169 text,
170 paired.join(", "),
171 result.len()
172 );
173
174 Ok(result)
175 }
176
177 pub fn tokenize_unique(&self, text: &str) -> Result<Vec<u32>> {
179 let encoding = self
180 .tokenizer
181 .encode(text, false)
182 .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
183
184 let mut ids: Vec<u32> = encoding.get_ids().to_vec();
186 ids.sort_unstable();
187 ids.dedup();
188
189 let paired: Vec<_> = encoding
190 .get_tokens()
191 .iter()
192 .zip(encoding.get_ids())
193 .map(|(tok, id)| format!("({:?},{})", tok, id))
194 .collect();
195 debug!(
196 "Tokenized query (unique): text={:?} tokens=[{}] unique_count={}",
197 text,
198 paired.join(", "),
199 ids.len()
200 );
201
202 Ok(ids)
203 }
204}
205
206#[cfg(feature = "native")]
208pub struct TokenizerCache {
209 cache: RwLock<HashMap<String, Arc<HfTokenizer>>>,
210}
211
212#[cfg(feature = "native")]
213impl Default for TokenizerCache {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219#[cfg(feature = "native")]
220impl TokenizerCache {
221 pub fn new() -> Self {
223 Self {
224 cache: RwLock::new(HashMap::new()),
225 }
226 }
227
228 pub fn get_or_load(&self, name_or_path: &str) -> Result<Arc<HfTokenizer>> {
230 {
232 let cache = self.cache.read();
233 if let Some(tokenizer) = cache.get(name_or_path) {
234 return Ok(Arc::clone(tokenizer));
235 }
236 }
237
238 let tokenizer = Arc::new(HfTokenizer::load(name_or_path)?);
240 {
241 let mut cache = self.cache.write();
242 cache.insert(name_or_path.to_string(), Arc::clone(&tokenizer));
243 }
244
245 Ok(tokenizer)
246 }
247
248 pub fn clear(&self) {
250 let mut cache = self.cache.write();
251 cache.clear();
252 }
253}
254
255#[cfg(feature = "native")]
257static TOKENIZER_CACHE: std::sync::OnceLock<TokenizerCache> = std::sync::OnceLock::new();
258
259#[cfg(feature = "native")]
261pub fn tokenizer_cache() -> &'static TokenizerCache {
262 TOKENIZER_CACHE.get_or_init(TokenizerCache::new)
263}
264
265#[cfg(test)]
266#[cfg(feature = "native")]
267mod tests {
268 use super::*;
269
270 #[test]
274 #[ignore]
275 fn test_load_tokenizer_from_hub() {
276 let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
277 let tokens = tokenizer.tokenize("hello world").unwrap();
278 assert!(!tokens.is_empty());
279 }
280
281 #[test]
282 #[ignore]
283 fn test_tokenize_unique() {
284 let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
285 let ids = tokenizer.tokenize_unique("the quick brown fox").unwrap();
286 let mut sorted = ids.clone();
288 sorted.sort_unstable();
289 sorted.dedup();
290 assert_eq!(ids.len(), sorted.len());
291 }
292
293 #[test]
294 fn test_tokenizer_cache() {
295 let cache = TokenizerCache::new();
296 assert!(cache.cache.read().is_empty());
298 }
299}