hermes_core/tokenizer/
hf_tokenizer.rs1use std::collections::HashMap;
11#[cfg(feature = "native")]
12use std::sync::Arc;
13
14#[cfg(feature = "native")]
15use parking_lot::RwLock;
16use tokenizers::Tokenizer;
17
18use log::debug;
19
20use crate::Result;
21use crate::error::Error;
22
23pub struct HfTokenizer {
25 tokenizer: Tokenizer,
26}
27
28#[derive(Debug, Clone)]
30pub enum TokenizerSource {
31 #[cfg(not(target_arch = "wasm32"))]
33 HuggingFace(String),
34 #[cfg(not(target_arch = "wasm32"))]
36 LocalFile(String),
37 IndexDirectory(String),
39}
40
41impl TokenizerSource {
42 #[cfg(not(target_arch = "wasm32"))]
48 pub fn parse(path: &str) -> Self {
49 if let Some(relative) = path.strip_prefix("index://") {
50 TokenizerSource::IndexDirectory(relative.to_string())
51 } else if path.starts_with('/') {
52 TokenizerSource::LocalFile(path.to_string())
53 } else {
54 TokenizerSource::HuggingFace(path.to_string())
55 }
56 }
57
58 #[cfg(target_arch = "wasm32")]
62 pub fn parse(path: &str) -> Self {
63 if let Some(relative) = path.strip_prefix("index://") {
64 TokenizerSource::IndexDirectory(relative.to_string())
65 } else {
66 TokenizerSource::IndexDirectory(path.to_string())
68 }
69 }
70}
71
72impl HfTokenizer {
73 #[cfg(not(target_arch = "wasm32"))]
79 pub fn load(name_or_path: &str) -> Result<Self> {
80 let tokenizer = if name_or_path.contains('/') && !name_or_path.starts_with('/') {
81 Tokenizer::from_pretrained(name_or_path, None).map_err(|e| {
83 Error::Tokenizer(format!(
84 "Failed to load tokenizer '{}': {}",
85 name_or_path, e
86 ))
87 })?
88 } else {
89 Tokenizer::from_file(name_or_path).map_err(|e| {
91 Error::Tokenizer(format!(
92 "Failed to load tokenizer from '{}': {}",
93 name_or_path, e
94 ))
95 })?
96 };
97
98 Ok(Self { tokenizer })
99 }
100
101 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
107 let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| {
108 Error::Tokenizer(format!("Failed to parse tokenizer from bytes: {}", e))
109 })?;
110 Ok(Self { tokenizer })
111 }
112
113 #[cfg(not(target_arch = "wasm32"))]
115 pub fn from_source(source: &TokenizerSource) -> Result<Self> {
116 match source {
117 TokenizerSource::HuggingFace(name) => {
118 let tokenizer = Tokenizer::from_pretrained(name, None).map_err(|e| {
119 Error::Tokenizer(format!("Failed to load tokenizer '{}': {}", name, e))
120 })?;
121 Ok(Self { tokenizer })
122 }
123 TokenizerSource::LocalFile(path) => {
124 let tokenizer = Tokenizer::from_file(path).map_err(|e| {
125 Error::Tokenizer(format!("Failed to load tokenizer from '{}': {}", path, e))
126 })?;
127 Ok(Self { tokenizer })
128 }
129 TokenizerSource::IndexDirectory(_) => {
130 Err(Error::Tokenizer(
132 "IndexDirectory source requires using from_bytes with Directory read"
133 .to_string(),
134 ))
135 }
136 }
137 }
138
139 pub fn tokenize(&self, text: &str) -> Result<Vec<(u32, u32)>> {
144 let encoding = self
145 .tokenizer
146 .encode(text, false)
147 .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
148
149 let mut counts: HashMap<u32, u32> = HashMap::new();
151 for &id in encoding.get_ids() {
152 *counts.entry(id).or_insert(0) += 1;
153 }
154
155 let result: Vec<(u32, u32)> = counts.into_iter().collect();
156 debug!(
157 "Tokenized query: text={:?} tokens={:?} token_ids={:?} unique_count={}",
158 text,
159 encoding.get_tokens(),
160 encoding.get_ids(),
161 result.len()
162 );
163
164 Ok(result)
165 }
166
167 pub fn tokenize_unique(&self, text: &str) -> Result<Vec<u32>> {
169 let encoding = self
170 .tokenizer
171 .encode(text, false)
172 .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
173
174 let mut ids: Vec<u32> = encoding.get_ids().to_vec();
176 ids.sort_unstable();
177 ids.dedup();
178
179 debug!(
180 "Tokenized query (unique): text={:?} tokens={:?} token_ids={:?} unique_count={}",
181 text,
182 encoding.get_tokens(),
183 encoding.get_ids(),
184 ids.len()
185 );
186
187 Ok(ids)
188 }
189}
190
191#[cfg(feature = "native")]
193pub struct TokenizerCache {
194 cache: RwLock<HashMap<String, Arc<HfTokenizer>>>,
195}
196
197#[cfg(feature = "native")]
198impl Default for TokenizerCache {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204#[cfg(feature = "native")]
205impl TokenizerCache {
206 pub fn new() -> Self {
208 Self {
209 cache: RwLock::new(HashMap::new()),
210 }
211 }
212
213 pub fn get_or_load(&self, name_or_path: &str) -> Result<Arc<HfTokenizer>> {
215 {
217 let cache = self.cache.read();
218 if let Some(tokenizer) = cache.get(name_or_path) {
219 return Ok(Arc::clone(tokenizer));
220 }
221 }
222
223 let tokenizer = Arc::new(HfTokenizer::load(name_or_path)?);
225 {
226 let mut cache = self.cache.write();
227 cache.insert(name_or_path.to_string(), Arc::clone(&tokenizer));
228 }
229
230 Ok(tokenizer)
231 }
232
233 pub fn clear(&self) {
235 let mut cache = self.cache.write();
236 cache.clear();
237 }
238}
239
240#[cfg(feature = "native")]
242static TOKENIZER_CACHE: std::sync::OnceLock<TokenizerCache> = std::sync::OnceLock::new();
243
244#[cfg(feature = "native")]
246pub fn tokenizer_cache() -> &'static TokenizerCache {
247 TOKENIZER_CACHE.get_or_init(TokenizerCache::new)
248}
249
250#[cfg(test)]
251#[cfg(feature = "native")]
252mod tests {
253 use super::*;
254
255 #[test]
259 #[ignore]
260 fn test_load_tokenizer_from_hub() {
261 let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
262 let tokens = tokenizer.tokenize("hello world").unwrap();
263 assert!(!tokens.is_empty());
264 }
265
266 #[test]
267 #[ignore]
268 fn test_tokenize_unique() {
269 let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
270 let ids = tokenizer.tokenize_unique("the quick brown fox").unwrap();
271 let mut sorted = ids.clone();
273 sorted.sort_unstable();
274 sorted.dedup();
275 assert_eq!(ids.len(), sorted.len());
276 }
277
278 #[test]
279 fn test_tokenizer_cache() {
280 let cache = TokenizerCache::new();
281 assert!(cache.cache.read().is_empty());
283 }
284}