hermes_core/tokenizer/
hf_tokenizer.rs1use std::collections::HashMap;
11#[cfg(feature = "native")]
12use std::sync::Arc;
13
14#[cfg(feature = "native")]
15use parking_lot::RwLock;
16use tokenizers::Tokenizer;
17
18use crate::Result;
19use crate::error::Error;
20
21pub struct HfTokenizer {
23 tokenizer: Tokenizer,
24}
25
26#[derive(Debug, Clone)]
28pub enum TokenizerSource {
29 #[cfg(not(target_arch = "wasm32"))]
31 HuggingFace(String),
32 #[cfg(not(target_arch = "wasm32"))]
34 LocalFile(String),
35 IndexDirectory(String),
37}
38
39impl TokenizerSource {
40 #[cfg(not(target_arch = "wasm32"))]
46 pub fn parse(path: &str) -> Self {
47 if let Some(relative) = path.strip_prefix("index://") {
48 TokenizerSource::IndexDirectory(relative.to_string())
49 } else if path.starts_with('/') {
50 TokenizerSource::LocalFile(path.to_string())
51 } else {
52 TokenizerSource::HuggingFace(path.to_string())
53 }
54 }
55
56 #[cfg(target_arch = "wasm32")]
60 pub fn parse(path: &str) -> Self {
61 if let Some(relative) = path.strip_prefix("index://") {
62 TokenizerSource::IndexDirectory(relative.to_string())
63 } else {
64 TokenizerSource::IndexDirectory(path.to_string())
66 }
67 }
68}
69
70impl HfTokenizer {
71 #[cfg(not(target_arch = "wasm32"))]
77 pub fn load(name_or_path: &str) -> Result<Self> {
78 let tokenizer = if name_or_path.contains('/') && !name_or_path.starts_with('/') {
79 Tokenizer::from_pretrained(name_or_path, None).map_err(|e| {
81 Error::Tokenizer(format!(
82 "Failed to load tokenizer '{}': {}",
83 name_or_path, e
84 ))
85 })?
86 } else {
87 Tokenizer::from_file(name_or_path).map_err(|e| {
89 Error::Tokenizer(format!(
90 "Failed to load tokenizer from '{}': {}",
91 name_or_path, e
92 ))
93 })?
94 };
95
96 Ok(Self { tokenizer })
97 }
98
99 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
105 let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| {
106 Error::Tokenizer(format!("Failed to parse tokenizer from bytes: {}", e))
107 })?;
108 Ok(Self { tokenizer })
109 }
110
111 #[cfg(not(target_arch = "wasm32"))]
113 pub fn from_source(source: &TokenizerSource) -> Result<Self> {
114 match source {
115 TokenizerSource::HuggingFace(name) => {
116 let tokenizer = Tokenizer::from_pretrained(name, None).map_err(|e| {
117 Error::Tokenizer(format!("Failed to load tokenizer '{}': {}", name, e))
118 })?;
119 Ok(Self { tokenizer })
120 }
121 TokenizerSource::LocalFile(path) => {
122 let tokenizer = Tokenizer::from_file(path).map_err(|e| {
123 Error::Tokenizer(format!("Failed to load tokenizer from '{}': {}", path, e))
124 })?;
125 Ok(Self { tokenizer })
126 }
127 TokenizerSource::IndexDirectory(_) => {
128 Err(Error::Tokenizer(
130 "IndexDirectory source requires using from_bytes with Directory read"
131 .to_string(),
132 ))
133 }
134 }
135 }
136
137 pub fn tokenize(&self, text: &str) -> Result<Vec<(u32, u32)>> {
142 let encoding = self
143 .tokenizer
144 .encode(text, false)
145 .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
146
147 let mut counts: HashMap<u32, u32> = HashMap::new();
149 for &id in encoding.get_ids() {
150 *counts.entry(id).or_insert(0) += 1;
151 }
152
153 Ok(counts.into_iter().collect())
154 }
155
156 pub fn tokenize_unique(&self, text: &str) -> Result<Vec<u32>> {
158 let encoding = self
159 .tokenizer
160 .encode(text, false)
161 .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
162
163 let mut ids: Vec<u32> = encoding.get_ids().to_vec();
165 ids.sort_unstable();
166 ids.dedup();
167
168 Ok(ids)
169 }
170}
171
172#[cfg(feature = "native")]
174pub struct TokenizerCache {
175 cache: RwLock<HashMap<String, Arc<HfTokenizer>>>,
176}
177
178#[cfg(feature = "native")]
179impl Default for TokenizerCache {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185#[cfg(feature = "native")]
186impl TokenizerCache {
187 pub fn new() -> Self {
189 Self {
190 cache: RwLock::new(HashMap::new()),
191 }
192 }
193
194 pub fn get_or_load(&self, name_or_path: &str) -> Result<Arc<HfTokenizer>> {
196 {
198 let cache = self.cache.read();
199 if let Some(tokenizer) = cache.get(name_or_path) {
200 return Ok(Arc::clone(tokenizer));
201 }
202 }
203
204 let tokenizer = Arc::new(HfTokenizer::load(name_or_path)?);
206 {
207 let mut cache = self.cache.write();
208 cache.insert(name_or_path.to_string(), Arc::clone(&tokenizer));
209 }
210
211 Ok(tokenizer)
212 }
213
214 pub fn clear(&self) {
216 let mut cache = self.cache.write();
217 cache.clear();
218 }
219}
220
221#[cfg(feature = "native")]
223static TOKENIZER_CACHE: std::sync::OnceLock<TokenizerCache> = std::sync::OnceLock::new();
224
225#[cfg(feature = "native")]
227pub fn tokenizer_cache() -> &'static TokenizerCache {
228 TOKENIZER_CACHE.get_or_init(TokenizerCache::new)
229}
230
231#[cfg(test)]
232#[cfg(feature = "native")]
233mod tests {
234 use super::*;
235
236 #[test]
240 #[ignore]
241 fn test_load_tokenizer_from_hub() {
242 let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
243 let tokens = tokenizer.tokenize("hello world").unwrap();
244 assert!(!tokens.is_empty());
245 }
246
247 #[test]
248 #[ignore]
249 fn test_tokenize_unique() {
250 let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
251 let ids = tokenizer.tokenize_unique("the quick brown fox").unwrap();
252 let mut sorted = ids.clone();
254 sorted.sort_unstable();
255 sorted.dedup();
256 assert_eq!(ids.len(), sorted.len());
257 }
258
259 #[test]
260 fn test_tokenizer_cache() {
261 let cache = TokenizerCache::new();
262 assert!(cache.cache.read().is_empty());
264 }
265}