hermes_core/tokenizer/
hf_tokenizer.rs1use std::collections::HashMap;
11use std::sync::Arc;
12
13use parking_lot::RwLock;
14use tokenizers::Tokenizer;
15
16use crate::Result;
17use crate::error::Error;
18
19pub struct HfTokenizer {
21 tokenizer: Tokenizer,
22}
23
24#[derive(Debug, Clone)]
26pub enum TokenizerSource {
27 #[cfg(not(target_arch = "wasm32"))]
29 HuggingFace(String),
30 #[cfg(not(target_arch = "wasm32"))]
32 LocalFile(String),
33 IndexDirectory(String),
35}
36
37impl TokenizerSource {
38 #[cfg(not(target_arch = "wasm32"))]
44 pub fn parse(path: &str) -> Self {
45 if let Some(relative) = path.strip_prefix("index://") {
46 TokenizerSource::IndexDirectory(relative.to_string())
47 } else if path.starts_with('/') {
48 TokenizerSource::LocalFile(path.to_string())
49 } else {
50 TokenizerSource::HuggingFace(path.to_string())
51 }
52 }
53
54 #[cfg(target_arch = "wasm32")]
58 pub fn parse(path: &str) -> Self {
59 if let Some(relative) = path.strip_prefix("index://") {
60 TokenizerSource::IndexDirectory(relative.to_string())
61 } else {
62 TokenizerSource::IndexDirectory(path.to_string())
64 }
65 }
66}
67
68impl HfTokenizer {
69 #[cfg(not(target_arch = "wasm32"))]
75 pub fn load(name_or_path: &str) -> Result<Self> {
76 let tokenizer = if name_or_path.contains('/') && !name_or_path.starts_with('/') {
77 Tokenizer::from_pretrained(name_or_path, None).map_err(|e| {
79 Error::Tokenizer(format!(
80 "Failed to load tokenizer '{}': {}",
81 name_or_path, e
82 ))
83 })?
84 } else {
85 Tokenizer::from_file(name_or_path).map_err(|e| {
87 Error::Tokenizer(format!(
88 "Failed to load tokenizer from '{}': {}",
89 name_or_path, e
90 ))
91 })?
92 };
93
94 Ok(Self { tokenizer })
95 }
96
97 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
103 let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| {
104 Error::Tokenizer(format!("Failed to parse tokenizer from bytes: {}", e))
105 })?;
106 Ok(Self { tokenizer })
107 }
108
109 #[cfg(not(target_arch = "wasm32"))]
111 pub fn from_source(source: &TokenizerSource) -> Result<Self> {
112 match source {
113 TokenizerSource::HuggingFace(name) => {
114 let tokenizer = Tokenizer::from_pretrained(name, None).map_err(|e| {
115 Error::Tokenizer(format!("Failed to load tokenizer '{}': {}", name, e))
116 })?;
117 Ok(Self { tokenizer })
118 }
119 TokenizerSource::LocalFile(path) => {
120 let tokenizer = Tokenizer::from_file(path).map_err(|e| {
121 Error::Tokenizer(format!("Failed to load tokenizer from '{}': {}", path, e))
122 })?;
123 Ok(Self { tokenizer })
124 }
125 TokenizerSource::IndexDirectory(_) => {
126 Err(Error::Tokenizer(
128 "IndexDirectory source requires using from_bytes with Directory read"
129 .to_string(),
130 ))
131 }
132 }
133 }
134
135 pub fn tokenize(&self, text: &str) -> Result<Vec<(u32, u32)>> {
140 let encoding = self
141 .tokenizer
142 .encode(text, false)
143 .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
144
145 let mut counts: HashMap<u32, u32> = HashMap::new();
147 for &id in encoding.get_ids() {
148 *counts.entry(id).or_insert(0) += 1;
149 }
150
151 Ok(counts.into_iter().collect())
152 }
153
154 pub fn tokenize_unique(&self, text: &str) -> Result<Vec<u32>> {
156 let encoding = self
157 .tokenizer
158 .encode(text, false)
159 .map_err(|e| Error::Tokenizer(format!("Tokenization failed: {}", e)))?;
160
161 let mut ids: Vec<u32> = encoding.get_ids().to_vec();
163 ids.sort_unstable();
164 ids.dedup();
165
166 Ok(ids)
167 }
168}
169
170#[cfg(feature = "native")]
172pub struct TokenizerCache {
173 cache: RwLock<HashMap<String, Arc<HfTokenizer>>>,
174}
175
176#[cfg(feature = "native")]
177impl Default for TokenizerCache {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183#[cfg(feature = "native")]
184impl TokenizerCache {
185 pub fn new() -> Self {
187 Self {
188 cache: RwLock::new(HashMap::new()),
189 }
190 }
191
192 pub fn get_or_load(&self, name_or_path: &str) -> Result<Arc<HfTokenizer>> {
194 {
196 let cache = self.cache.read();
197 if let Some(tokenizer) = cache.get(name_or_path) {
198 return Ok(Arc::clone(tokenizer));
199 }
200 }
201
202 let tokenizer = Arc::new(HfTokenizer::load(name_or_path)?);
204 {
205 let mut cache = self.cache.write();
206 cache.insert(name_or_path.to_string(), Arc::clone(&tokenizer));
207 }
208
209 Ok(tokenizer)
210 }
211
212 pub fn clear(&self) {
214 let mut cache = self.cache.write();
215 cache.clear();
216 }
217}
218
219#[cfg(feature = "native")]
221static TOKENIZER_CACHE: std::sync::OnceLock<TokenizerCache> = std::sync::OnceLock::new();
222
223#[cfg(feature = "native")]
225pub fn tokenizer_cache() -> &'static TokenizerCache {
226 TOKENIZER_CACHE.get_or_init(TokenizerCache::new)
227}
228
229#[cfg(test)]
230#[cfg(feature = "native")]
231mod tests {
232 use super::*;
233
234 #[test]
238 #[ignore]
239 fn test_load_tokenizer_from_hub() {
240 let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
241 let tokens = tokenizer.tokenize("hello world").unwrap();
242 assert!(!tokens.is_empty());
243 }
244
245 #[test]
246 #[ignore]
247 fn test_tokenize_unique() {
248 let tokenizer = HfTokenizer::load("bert-base-uncased").unwrap();
249 let ids = tokenizer.tokenize_unique("the quick brown fox").unwrap();
250 let mut sorted = ids.clone();
252 sorted.sort_unstable();
253 sorted.dedup();
254 assert_eq!(ids.len(), sorted.len());
255 }
256
257 #[test]
258 fn test_tokenizer_cache() {
259 let cache = TokenizerCache::new();
260 assert!(cache.cache.read().is_empty());
262 }
263}