1use crate::{IncrementalTokenizer, Tokenizer, TokenizerFactory, TokenizerInfo, TokenizerType};
4use async_trait::async_trait;
5use ferrum_types::{Result, SpecialTokens, TokenId};
6use parking_lot::RwLock;
7use std::sync::Arc;
8use tokenizers::Tokenizer as HfTokenizer;
9use tracing::debug;
10
11pub struct HuggingFaceTokenizer {
13 tokenizer: Arc<HfTokenizer>,
14 special_tokens: SpecialTokens,
15 info: TokenizerInfo,
16 decode_cache: RwLock<DecodeCache>,
18}
19
20#[derive(Debug, Clone, Default)]
22pub struct IncrementalState {
23 tokens: Vec<TokenId>,
25 text: String,
27}
28
29#[derive(Debug, Default)]
31struct DecodeCache {
32 cache: std::collections::HashMap<Vec<TokenId>, String>,
33 max_size: usize,
34}
35
36impl DecodeCache {
37 fn new(max_size: usize) -> Self {
38 Self {
39 cache: std::collections::HashMap::new(),
40 max_size,
41 }
42 }
43
44 fn get(&self, tokens: &[TokenId]) -> Option<&String> {
45 self.cache.get(tokens)
46 }
47
48 fn insert(&mut self, tokens: Vec<TokenId>, text: String) {
49 if self.cache.len() >= self.max_size {
50 let to_remove: Vec<_> = self
51 .cache
52 .keys()
53 .take(self.cache.len() / 2)
54 .cloned()
55 .collect();
56 for key in to_remove {
57 self.cache.remove(&key);
58 }
59 }
60 self.cache.insert(tokens, text);
61 }
62}
63
64impl HuggingFaceTokenizer {
65 pub async fn new(tokenizer: HfTokenizer) -> Result<Self> {
67 let vocab_size = tokenizer.get_vocab_size(false);
68
69 let special_tokens = extract_special_tokens(&tokenizer)?;
71
72 let info = TokenizerInfo {
73 tokenizer_type: TokenizerType::BPE, vocab_size,
75 special_tokens: special_tokens.clone(),
76 supports_incremental: true,
77 supports_chat_template: false, max_token_length: None, model_name: None, };
81
82 debug!(
83 "Created HuggingFace tokenizer with vocab size {}",
84 vocab_size
85 );
86
87 Ok(Self {
88 tokenizer: Arc::new(tokenizer),
89 special_tokens,
90 info,
91 decode_cache: RwLock::new(DecodeCache::new(1000)),
92 })
93 }
94
95 pub async fn from_file(path: &str) -> Result<Self> {
97 let tokenizer = HfTokenizer::from_file(path).map_err(|e| {
98 ferrum_types::FerrumError::tokenizer(format!("Failed to load tokenizer: {}", e))
99 })?;
100 Self::new(tokenizer).await
101 }
102
103 pub async fn from_pretrained(repo_id: &str, _revision: Option<&str>) -> Result<Self> {
105 let api = hf_hub::api::tokio::Api::new().map_err(|e| {
106 ferrum_types::FerrumError::tokenizer(format!("Failed to create HF API: {}", e))
107 })?;
108
109 let repo = api.repo(hf_hub::Repo::model(repo_id.to_string()));
110
111 let tokenizer_file = repo.get("tokenizer.json").await.map_err(|e| {
114 ferrum_types::FerrumError::tokenizer(format!("Failed to download tokenizer: {}", e))
115 })?;
116
117 let tokenizer = HfTokenizer::from_file(&tokenizer_file).map_err(|e| {
118 ferrum_types::FerrumError::tokenizer(format!("Failed to load tokenizer: {}", e))
119 })?;
120
121 Self::new(tokenizer).await
122 }
123}
124
125impl Tokenizer for HuggingFaceTokenizer {
126 fn encode(&self, text: &str, add_special: bool) -> Result<Vec<TokenId>> {
127 let encoding = self
128 .tokenizer
129 .encode(text, add_special)
130 .map_err(|e| ferrum_types::FerrumError::tokenizer(format!("Encoding failed: {}", e)))?;
131
132 Ok(encoding
133 .get_ids()
134 .iter()
135 .map(|&id| TokenId::new(id))
136 .collect())
137 }
138
139 fn decode(&self, tokens: &[TokenId], skip_special: bool) -> Result<String> {
140 let token_ids: Vec<u32> = tokens.iter().map(|t| t.get()).collect();
141
142 let text = self
143 .tokenizer
144 .decode(&token_ids, skip_special)
145 .map_err(|e| ferrum_types::FerrumError::tokenizer(format!("Decoding failed: {}", e)))?;
146
147 Ok(text)
148 }
149
150 fn decode_incremental(&self, prev: &[TokenId], next: TokenId) -> Result<String> {
151 if let Some(cached_prev) = self.decode_cache.read().get(prev) {
153 let mut all_tokens = prev.to_vec();
154 all_tokens.push(next);
155 let full_text = self.decode(&all_tokens, true)?;
156
157 {
159 let mut cache = self.decode_cache.write();
160 cache.insert(all_tokens, full_text.clone());
161 }
162
163 return Ok(full_text[cached_prev.len()..].to_string());
165 }
166
167 let prev_text = if prev.is_empty() {
169 String::new()
170 } else {
171 self.decode(prev, true)?
172 };
173
174 let mut all_tokens = prev.to_vec();
175 all_tokens.push(next);
176 let full_text = self.decode(&all_tokens, true)?;
177
178 {
180 let mut cache = self.decode_cache.write();
181 if !prev.is_empty() {
182 cache.insert(prev.to_vec(), prev_text.clone());
183 }
184 cache.insert(all_tokens, full_text.clone());
185 }
186
187 Ok(full_text[prev_text.len()..].to_string())
188 }
189
190 fn vocab_size(&self) -> usize {
191 self.info.vocab_size
192 }
193
194 fn special_tokens(&self) -> &SpecialTokens {
195 &self.special_tokens
196 }
197
198 fn token_id(&self, text: &str) -> Option<TokenId> {
199 self.tokenizer.token_to_id(text).map(TokenId::new)
200 }
201
202 fn token_text(&self, _token_id: TokenId) -> Option<&str> {
203 None
205 }
206
207 fn apply_chat_template(
208 &self,
209 messages: &[ferrum_interfaces::tokenizer::ChatMessage],
210 ) -> Result<String> {
211 let mut result = String::new();
213 for msg in messages {
214 result.push_str(&format!("{}: {}\n", msg.role, msg.content));
215 }
216 Ok(result.trim_end().to_string())
217 }
218
219 fn info(&self) -> TokenizerInfo {
220 self.info.clone()
221 }
222}
223
224impl IncrementalTokenizer for HuggingFaceTokenizer {
225 type State = IncrementalState;
226
227 fn create_state(&self) -> Self::State {
228 IncrementalState::default()
229 }
230
231 fn decode_incremental_with_state(
232 &self,
233 state: &mut Self::State,
234 token: TokenId,
235 ) -> Result<String> {
236 state.tokens.push(token);
237
238 let full_text = self.decode(&state.tokens, true)?;
240
241 let delta = full_text[state.text.len()..].to_string();
243
244 state.text = full_text;
246
247 Ok(delta)
248 }
249
250 fn reset_state(&self, state: &mut Self::State) {
251 state.tokens.clear();
252 state.text.clear();
253 }
254
255 fn get_decoded_text(&self, state: &Self::State) -> String {
256 state.text.clone()
257 }
258}
259
260#[derive(Debug, Clone, Default)]
262pub struct HuggingFaceTokenizerFactory;
263
264impl HuggingFaceTokenizerFactory {
265 pub fn new() -> Self {
266 Self
267 }
268}
269
270#[async_trait]
271impl TokenizerFactory for HuggingFaceTokenizerFactory {
272 async fn load_from_file(&self, path: &str) -> Result<Box<dyn Tokenizer>> {
273 let tokenizer = HuggingFaceTokenizer::from_file(path).await?;
274 Ok(Box::new(tokenizer))
275 }
276
277 async fn load_from_bytes(&self, data: &[u8]) -> Result<Box<dyn Tokenizer>> {
278 let tokenizer = HfTokenizer::from_bytes(data).map_err(|e| {
279 ferrum_types::FerrumError::tokenizer(format!(
280 "Failed to load tokenizer from bytes: {}",
281 e
282 ))
283 })?;
284 let tokenizer = HuggingFaceTokenizer::new(tokenizer).await?;
285 Ok(Box::new(tokenizer))
286 }
287
288 async fn load_from_hub(
289 &self,
290 repo_id: &str,
291 revision: Option<&str>,
292 ) -> Result<Box<dyn Tokenizer>> {
293 let tokenizer = HuggingFaceTokenizer::from_pretrained(repo_id, revision).await?;
294 Ok(Box::new(tokenizer))
295 }
296
297 async fn create_from_config(
298 &self,
299 config: &ferrum_interfaces::tokenizer::TokenizerConfig,
300 ) -> Result<Box<dyn Tokenizer>> {
301 self.load_from_file(&config.path).await
303 }
304
305 fn supported_types(&self) -> Vec<TokenizerType> {
306 vec![
307 TokenizerType::BPE,
308 TokenizerType::WordPiece,
309 TokenizerType::SentencePiece,
310 ]
311 }
312}
313
314fn extract_special_tokens(tokenizer: &HfTokenizer) -> Result<SpecialTokens> {
320 let _vocab = tokenizer.get_vocab(false);
321
322 let bos_token = tokenizer
323 .token_to_id("<s>")
324 .or_else(|| tokenizer.token_to_id("[BOS]"))
325 .or_else(|| tokenizer.token_to_id("<bos>"))
326 .map(TokenId::new);
327
328 let eos_token = tokenizer
329 .token_to_id("</s>")
330 .or_else(|| tokenizer.token_to_id("[EOS]"))
331 .or_else(|| tokenizer.token_to_id("<eos>"))
332 .map(TokenId::new);
333
334 let unk_token = tokenizer
335 .token_to_id("<unk>")
336 .or_else(|| tokenizer.token_to_id("[UNK]"))
337 .map(TokenId::new);
338
339 let pad_token = tokenizer
340 .token_to_id("<pad>")
341 .or_else(|| tokenizer.token_to_id("[PAD]"))
342 .map(TokenId::new);
343
344 let sep_token = tokenizer
345 .token_to_id("[SEP]")
346 .or_else(|| tokenizer.token_to_id("<sep>"))
347 .map(TokenId::new);
348
349 let cls_token = tokenizer
350 .token_to_id("[CLS]")
351 .or_else(|| tokenizer.token_to_id("<cls>"))
352 .map(TokenId::new);
353
354 let mask_token = tokenizer
355 .token_to_id("[MASK]")
356 .or_else(|| tokenizer.token_to_id("<mask>"))
357 .map(TokenId::new);
358
359 Ok(SpecialTokens {
360 bos_token,
361 eos_token,
362 unk_token,
363 pad_token,
364 sep_token,
365 cls_token,
366 mask_token,
367 })
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
375 fn test_decode_cache_creation() {
376 let cache = DecodeCache::new(100);
377 assert_eq!(cache.max_size, 100);
378 assert_eq!(cache.cache.len(), 0);
379 }
380
381 #[test]
382 fn test_decode_cache_insert_and_get() {
383 let mut cache = DecodeCache::new(10);
384 let tokens = vec![TokenId::new(1), TokenId::new(2)];
385 let text = "hello".to_string();
386
387 cache.insert(tokens.clone(), text.clone());
388
389 let result = cache.get(&tokens);
390 assert!(result.is_some());
391 assert_eq!(result.unwrap(), &text);
392 }
393
394 #[test]
395 fn test_decode_cache_eviction() {
396 let mut cache = DecodeCache::new(2);
397
398 cache.insert(vec![TokenId::new(1)], "a".to_string());
400 cache.insert(vec![TokenId::new(2)], "b".to_string());
401
402 assert_eq!(cache.cache.len(), 2);
403
404 cache.insert(vec![TokenId::new(3)], "c".to_string());
406
407 assert!(cache.cache.len() <= 2);
409 }
410
411 #[test]
412 fn test_incremental_state_default() {
413 let state = IncrementalState::default();
414 let debug_str = format!("{:?}", state);
415 assert!(debug_str.contains("IncrementalState"));
416 }
417
418 #[test]
419 fn test_incremental_state_clone() {
420 let state = IncrementalState::default();
421 let cloned = state.clone();
422
423 let state_str = format!("{:?}", state);
425 let cloned_str = format!("{:?}", cloned);
426 assert_eq!(state_str, cloned_str);
427 }
428
429 #[test]
430 fn test_huggingface_tokenizer_factory_creation() {
431 let factory = HuggingFaceTokenizerFactory::new();
432 let debug_str = format!("{:?}", factory);
433 assert!(debug_str.contains("HuggingFaceTokenizerFactory"));
434 }
435
436 #[test]
437 fn test_huggingface_tokenizer_factory_default() {
438 let factory = HuggingFaceTokenizerFactory::default();
439 let debug_str = format!("{:?}", factory);
440 assert!(debug_str.contains("HuggingFaceTokenizerFactory"));
441 }
442
443 #[test]
444 fn test_huggingface_tokenizer_factory_clone() {
445 let factory = HuggingFaceTokenizerFactory::new();
446 let cloned = factory.clone();
447
448 let factory_str = format!("{:?}", factory);
449 let cloned_str = format!("{:?}", cloned);
450 assert_eq!(factory_str, cloned_str);
451 }
452
453 #[test]
454 fn test_huggingface_tokenizer_factory_supported_types() {
455 let factory = HuggingFaceTokenizerFactory::new();
456 let types = factory.supported_types();
457
458 assert!(types.len() >= 1);
459 assert!(types.contains(&TokenizerType::BPE));
460 }
461
462 #[test]
463 fn test_extract_special_tokens_with_mock_tokenizer() {
464 use tokenizers::models::bpe::{Vocab, BPE};
465 use tokenizers::{AddedToken, Tokenizer as HfTokenizer};
466
467 let vocab: Vocab = [
469 ("hello".to_string(), 0),
470 ("<s>".to_string(), 1),
471 ("</s>".to_string(), 2),
472 ("<unk>".to_string(), 3),
473 ("<pad>".to_string(), 4),
474 ]
475 .into_iter()
476 .collect();
477
478 let merges = vec![];
479 let bpe = BPE::builder()
480 .vocab_and_merges(vocab, merges)
481 .unk_token("<unk>".to_string())
482 .build()
483 .unwrap();
484
485 let mut tokenizer = HfTokenizer::new(bpe);
486 tokenizer.add_special_tokens(&[
487 AddedToken::from("<s>", true),
488 AddedToken::from("</s>", true),
489 AddedToken::from("<unk>", true),
490 AddedToken::from("<pad>", true),
491 ]);
492
493 let result = extract_special_tokens(&tokenizer);
495 assert!(result.is_ok());
496
497 let special_tokens = result.unwrap();
498 assert!(special_tokens.bos_token.is_some());
499 assert!(special_tokens.eos_token.is_some());
500 assert!(special_tokens.unk_token.is_some());
501 assert!(special_tokens.pad_token.is_some());
502 }
503
504 #[tokio::test]
505 async fn test_huggingface_tokenizer_with_mock() {
506 use tokenizers::models::bpe::{Vocab, BPE};
507 use tokenizers::{AddedToken, Tokenizer as HfTokenizer};
508
509 let vocab: Vocab = [
510 ("hello".to_string(), 0),
511 ("world".to_string(), 1),
512 ("<s>".to_string(), 2),
513 ("</s>".to_string(), 3),
514 ("<unk>".to_string(), 4),
515 ]
516 .into_iter()
517 .collect();
518
519 let merges = vec![];
520 let bpe = BPE::builder()
521 .vocab_and_merges(vocab, merges)
522 .unk_token("<unk>".to_string())
523 .build()
524 .unwrap();
525
526 let mut hf_tokenizer = HfTokenizer::new(bpe);
527 hf_tokenizer.add_special_tokens(&[
528 AddedToken::from("<s>", true),
529 AddedToken::from("</s>", true),
530 AddedToken::from("<unk>", true),
531 ]);
532
533 let result = HuggingFaceTokenizer::new(hf_tokenizer).await;
535 assert!(result.is_ok());
536
537 let tokenizer = result.unwrap();
538 assert_eq!(tokenizer.vocab_size(), 5);
539 }
540
541 #[tokio::test]
542 async fn test_tokenizer_encode_decode() {
543 use tokenizers::models::bpe::{Vocab, BPE};
544 use tokenizers::{AddedToken, Tokenizer as HfTokenizer};
545
546 let vocab: Vocab = [
547 ("hello".to_string(), 0),
548 ("world".to_string(), 1),
549 ("<s>".to_string(), 2),
550 ("</s>".to_string(), 3),
551 ("<unk>".to_string(), 4),
552 ]
553 .into_iter()
554 .collect();
555
556 let merges = vec![];
557 let bpe = BPE::builder()
558 .vocab_and_merges(vocab, merges)
559 .unk_token("<unk>".to_string())
560 .build()
561 .unwrap();
562
563 let mut hf_tokenizer = HfTokenizer::new(bpe);
564 hf_tokenizer.add_special_tokens(&[
565 AddedToken::from("<s>", true),
566 AddedToken::from("</s>", true),
567 AddedToken::from("<unk>", true),
568 ]);
569
570 let tokenizer = HuggingFaceTokenizer::new(hf_tokenizer).await.unwrap();
571
572 let result = tokenizer.encode("hello", false);
574 assert!(result.is_ok());
575
576 let _tokens = result.unwrap();
577 let decoded = tokenizer.decode(&[], false);
582 assert!(decoded.is_ok());
583 }
584
585 #[tokio::test]
586 async fn test_tokenizer_special_tokens() {
587 use tokenizers::models::bpe::{Vocab, BPE};
588 use tokenizers::{AddedToken, Tokenizer as HfTokenizer};
589
590 let vocab: Vocab = [
591 ("hello".to_string(), 0),
592 ("<s>".to_string(), 1),
593 ("</s>".to_string(), 2),
594 ]
595 .into_iter()
596 .collect();
597
598 let merges = vec![];
599 let bpe = BPE::builder()
600 .vocab_and_merges(vocab, merges)
601 .build()
602 .unwrap();
603
604 let mut hf_tokenizer = HfTokenizer::new(bpe);
605 hf_tokenizer.add_special_tokens(&[
606 AddedToken::from("<s>", true),
607 AddedToken::from("</s>", true),
608 ]);
609
610 let tokenizer = HuggingFaceTokenizer::new(hf_tokenizer).await.unwrap();
611 let special_tokens = tokenizer.special_tokens();
612
613 assert!(special_tokens.bos_token.is_some() || special_tokens.eos_token.is_some());
615 }
616
617 #[tokio::test]
618 async fn test_tokenizer_token_id_lookup() {
619 use tokenizers::models::bpe::{Vocab, BPE};
620 use tokenizers::Tokenizer as HfTokenizer;
621
622 let vocab: Vocab = [("hello".to_string(), 0), ("world".to_string(), 1)]
623 .into_iter()
624 .collect();
625
626 let merges = vec![];
627 let bpe = BPE::builder()
628 .vocab_and_merges(vocab, merges)
629 .build()
630 .unwrap();
631
632 let hf_tokenizer = HfTokenizer::new(bpe);
633 let tokenizer = HuggingFaceTokenizer::new(hf_tokenizer).await.unwrap();
634
635 let token_id = tokenizer.token_id("hello");
637 assert!(token_id.is_some());
638 assert_eq!(token_id.unwrap().get(), 0);
639 }
640
641 #[tokio::test]
642 async fn test_tokenizer_info() {
643 use tokenizers::models::bpe::{Vocab, BPE};
644 use tokenizers::Tokenizer as HfTokenizer;
645
646 let vocab: Vocab = [("hello".to_string(), 0), ("world".to_string(), 1)]
647 .into_iter()
648 .collect();
649
650 let merges = vec![];
651 let bpe = BPE::builder()
652 .vocab_and_merges(vocab, merges)
653 .build()
654 .unwrap();
655
656 let hf_tokenizer = HfTokenizer::new(bpe);
657 let tokenizer = HuggingFaceTokenizer::new(hf_tokenizer).await.unwrap();
658
659 let info = tokenizer.info();
660 assert_eq!(info.vocab_size, 2);
661 assert!(info.supports_incremental);
662 assert_eq!(info.tokenizer_type, TokenizerType::BPE);
663 }
664
665 #[tokio::test]
666 async fn test_incremental_tokenizer_interface() {
667 use tokenizers::models::bpe::{Vocab, BPE};
668 use tokenizers::Tokenizer as HfTokenizer;
669
670 let vocab: Vocab = [("hello".to_string(), 0), ("world".to_string(), 1)]
671 .into_iter()
672 .collect();
673
674 let merges = vec![];
675 let bpe = BPE::builder()
676 .vocab_and_merges(vocab, merges)
677 .build()
678 .unwrap();
679
680 let hf_tokenizer = HfTokenizer::new(bpe);
681 let tokenizer = HuggingFaceTokenizer::new(hf_tokenizer).await.unwrap();
682
683 let mut state = tokenizer.create_state();
685
686 let result = tokenizer.decode_incremental_with_state(&mut state, TokenId::new(0));
688 assert!(result.is_ok());
689
690 tokenizer.reset_state(&mut state);
692 let text = tokenizer.get_decoded_text(&state);
693 assert!(text.is_empty());
694 }
695}