Skip to main content

ferrum_tokenizer/implementations/
huggingface.rs

1//! HuggingFace tokenizer implementation
2
3use crate::{IncrementalTokenizer, Tokenizer, TokenizerFactory, TokenizerInfo, TokenizerType};
4use async_trait::async_trait;
5use ferrum_types::{Result, SpecialTokens, TokenId};
6use parking_lot::RwLock;
7use std::sync::Arc;
8use tokenizers::Tokenizer as HfTokenizer;
9use tracing::debug;
10
11/// HuggingFace tokenizer wrapper
12pub struct HuggingFaceTokenizer {
13    tokenizer: Arc<HfTokenizer>,
14    special_tokens: SpecialTokens,
15    info: TokenizerInfo,
16    /// Incremental decode cache for efficiency
17    decode_cache: RwLock<DecodeCache>,
18}
19
20/// Incremental decoding state
21#[derive(Debug, Clone, Default)]
22pub struct IncrementalState {
23    /// Accumulated tokens
24    tokens: Vec<TokenId>,
25    /// Decoded text so far
26    text: String,
27}
28
29/// Cache for decoded token sequences
30#[derive(Debug, Default)]
31struct DecodeCache {
32    cache: std::collections::HashMap<Vec<TokenId>, String>,
33    max_size: usize,
34}
35
36impl DecodeCache {
37    fn new(max_size: usize) -> Self {
38        Self {
39            cache: std::collections::HashMap::new(),
40            max_size,
41        }
42    }
43
44    fn get(&self, tokens: &[TokenId]) -> Option<&String> {
45        self.cache.get(tokens)
46    }
47
48    fn insert(&mut self, tokens: Vec<TokenId>, text: String) {
49        if self.cache.len() >= self.max_size {
50            let to_remove: Vec<_> = self
51                .cache
52                .keys()
53                .take(self.cache.len() / 2)
54                .cloned()
55                .collect();
56            for key in to_remove {
57                self.cache.remove(&key);
58            }
59        }
60        self.cache.insert(tokens, text);
61    }
62}
63
64impl HuggingFaceTokenizer {
65    /// Create new HuggingFace tokenizer
66    pub async fn new(tokenizer: HfTokenizer) -> Result<Self> {
67        let vocab_size = tokenizer.get_vocab_size(false);
68
69        // Extract special tokens
70        let special_tokens = extract_special_tokens(&tokenizer)?;
71
72        let info = TokenizerInfo {
73            tokenizer_type: TokenizerType::BPE, // Most HF tokenizers use BPE
74            vocab_size,
75            special_tokens: special_tokens.clone(),
76            supports_incremental: true,
77            supports_chat_template: false, // MVP: chat template support disabled
78            max_token_length: None,        // HF tokenizers don't expose this directly
79            model_name: None,              // Can be set externally
80        };
81
82        debug!(
83            "Created HuggingFace tokenizer with vocab size {}",
84            vocab_size
85        );
86
87        Ok(Self {
88            tokenizer: Arc::new(tokenizer),
89            special_tokens,
90            info,
91            decode_cache: RwLock::new(DecodeCache::new(1000)),
92        })
93    }
94
95    /// Create from file path
96    pub async fn from_file(path: &str) -> Result<Self> {
97        let tokenizer = HfTokenizer::from_file(path).map_err(|e| {
98            ferrum_types::FerrumError::tokenizer(format!("Failed to load tokenizer: {}", e))
99        })?;
100        Self::new(tokenizer).await
101    }
102
103    /// Create from HuggingFace Hub
104    pub async fn from_pretrained(repo_id: &str, _revision: Option<&str>) -> Result<Self> {
105        let api = hf_hub::api::tokio::Api::new().map_err(|e| {
106            ferrum_types::FerrumError::tokenizer(format!("Failed to create HF API: {}", e))
107        })?;
108
109        let repo = api.repo(hf_hub::Repo::model(repo_id.to_string()));
110
111        // Note: hf_hub::api::tokio::ApiRepo doesn't have set_revision in newer versions
112        // Revision is handled via the Repo struct or api.model_with_revision
113        let tokenizer_file = repo.get("tokenizer.json").await.map_err(|e| {
114            ferrum_types::FerrumError::tokenizer(format!("Failed to download tokenizer: {}", e))
115        })?;
116
117        let tokenizer = HfTokenizer::from_file(&tokenizer_file).map_err(|e| {
118            ferrum_types::FerrumError::tokenizer(format!("Failed to load tokenizer: {}", e))
119        })?;
120
121        Self::new(tokenizer).await
122    }
123}
124
125impl Tokenizer for HuggingFaceTokenizer {
126    fn encode(&self, text: &str, add_special: bool) -> Result<Vec<TokenId>> {
127        let encoding = self
128            .tokenizer
129            .encode(text, add_special)
130            .map_err(|e| ferrum_types::FerrumError::tokenizer(format!("Encoding failed: {}", e)))?;
131
132        Ok(encoding
133            .get_ids()
134            .iter()
135            .map(|&id| TokenId::new(id))
136            .collect())
137    }
138
139    fn decode(&self, tokens: &[TokenId], skip_special: bool) -> Result<String> {
140        let token_ids: Vec<u32> = tokens.iter().map(|t| t.get()).collect();
141
142        let text = self
143            .tokenizer
144            .decode(&token_ids, skip_special)
145            .map_err(|e| ferrum_types::FerrumError::tokenizer(format!("Decoding failed: {}", e)))?;
146
147        Ok(text)
148    }
149
150    fn decode_incremental(&self, prev: &[TokenId], next: TokenId) -> Result<String> {
151        // Check cache first
152        if let Some(cached_prev) = self.decode_cache.read().get(prev) {
153            let mut all_tokens = prev.to_vec();
154            all_tokens.push(next);
155            let full_text = self.decode(&all_tokens, true)?;
156
157            // Cache the new sequence
158            {
159                let mut cache = self.decode_cache.write();
160                cache.insert(all_tokens, full_text.clone());
161            }
162
163            // Return only the delta
164            return Ok(full_text[cached_prev.len()..].to_string());
165        }
166
167        // No cache hit, decode both
168        let prev_text = if prev.is_empty() {
169            String::new()
170        } else {
171            self.decode(prev, true)?
172        };
173
174        let mut all_tokens = prev.to_vec();
175        all_tokens.push(next);
176        let full_text = self.decode(&all_tokens, true)?;
177
178        // Update cache
179        {
180            let mut cache = self.decode_cache.write();
181            if !prev.is_empty() {
182                cache.insert(prev.to_vec(), prev_text.clone());
183            }
184            cache.insert(all_tokens, full_text.clone());
185        }
186
187        Ok(full_text[prev_text.len()..].to_string())
188    }
189
190    fn vocab_size(&self) -> usize {
191        self.info.vocab_size
192    }
193
194    fn special_tokens(&self) -> &SpecialTokens {
195        &self.special_tokens
196    }
197
198    fn token_id(&self, text: &str) -> Option<TokenId> {
199        self.tokenizer.token_to_id(text).map(TokenId::new)
200    }
201
202    fn token_text(&self, _token_id: TokenId) -> Option<&str> {
203        // HF tokenizer doesn't support this efficiently, return None
204        None
205    }
206
207    fn apply_chat_template(
208        &self,
209        messages: &[ferrum_interfaces::tokenizer::ChatMessage],
210    ) -> Result<String> {
211        // MVP: simple concatenation
212        let mut result = String::new();
213        for msg in messages {
214            result.push_str(&format!("{}: {}\n", msg.role, msg.content));
215        }
216        Ok(result.trim_end().to_string())
217    }
218
219    fn info(&self) -> TokenizerInfo {
220        self.info.clone()
221    }
222}
223
224impl IncrementalTokenizer for HuggingFaceTokenizer {
225    type State = IncrementalState;
226
227    fn create_state(&self) -> Self::State {
228        IncrementalState::default()
229    }
230
231    fn decode_incremental_with_state(
232        &self,
233        state: &mut Self::State,
234        token: TokenId,
235    ) -> Result<String> {
236        state.tokens.push(token);
237
238        // Decode all tokens
239        let full_text = self.decode(&state.tokens, true)?;
240
241        // Calculate delta
242        let delta = full_text[state.text.len()..].to_string();
243
244        // Update state
245        state.text = full_text;
246
247        Ok(delta)
248    }
249
250    fn reset_state(&self, state: &mut Self::State) {
251        state.tokens.clear();
252        state.text.clear();
253    }
254
255    fn get_decoded_text(&self, state: &Self::State) -> String {
256        state.text.clone()
257    }
258}
259
260/// HuggingFace tokenizer factory
261#[derive(Debug, Clone, Default)]
262pub struct HuggingFaceTokenizerFactory;
263
264impl HuggingFaceTokenizerFactory {
265    pub fn new() -> Self {
266        Self
267    }
268}
269
270#[async_trait]
271impl TokenizerFactory for HuggingFaceTokenizerFactory {
272    async fn load_from_file(&self, path: &str) -> Result<Box<dyn Tokenizer>> {
273        let tokenizer = HuggingFaceTokenizer::from_file(path).await?;
274        Ok(Box::new(tokenizer))
275    }
276
277    async fn load_from_bytes(&self, data: &[u8]) -> Result<Box<dyn Tokenizer>> {
278        let tokenizer = HfTokenizer::from_bytes(data).map_err(|e| {
279            ferrum_types::FerrumError::tokenizer(format!(
280                "Failed to load tokenizer from bytes: {}",
281                e
282            ))
283        })?;
284        let tokenizer = HuggingFaceTokenizer::new(tokenizer).await?;
285        Ok(Box::new(tokenizer))
286    }
287
288    async fn load_from_hub(
289        &self,
290        repo_id: &str,
291        revision: Option<&str>,
292    ) -> Result<Box<dyn Tokenizer>> {
293        let tokenizer = HuggingFaceTokenizer::from_pretrained(repo_id, revision).await?;
294        Ok(Box::new(tokenizer))
295    }
296
297    async fn create_from_config(
298        &self,
299        config: &ferrum_interfaces::tokenizer::TokenizerConfig,
300    ) -> Result<Box<dyn Tokenizer>> {
301        // Load from path specified in config
302        self.load_from_file(&config.path).await
303    }
304
305    fn supported_types(&self) -> Vec<TokenizerType> {
306        vec![
307            TokenizerType::BPE,
308            TokenizerType::WordPiece,
309            TokenizerType::SentencePiece,
310        ]
311    }
312}
313
314// ============================================================================
315// Helper Functions
316// ============================================================================
317
318/// Extract special tokens from HF tokenizer
319fn extract_special_tokens(tokenizer: &HfTokenizer) -> Result<SpecialTokens> {
320    let _vocab = tokenizer.get_vocab(false);
321
322    let bos_token = tokenizer
323        .token_to_id("<s>")
324        .or_else(|| tokenizer.token_to_id("[BOS]"))
325        .or_else(|| tokenizer.token_to_id("<bos>"))
326        .map(TokenId::new);
327
328    let eos_token = tokenizer
329        .token_to_id("</s>")
330        .or_else(|| tokenizer.token_to_id("[EOS]"))
331        .or_else(|| tokenizer.token_to_id("<eos>"))
332        .map(TokenId::new);
333
334    let unk_token = tokenizer
335        .token_to_id("<unk>")
336        .or_else(|| tokenizer.token_to_id("[UNK]"))
337        .map(TokenId::new);
338
339    let pad_token = tokenizer
340        .token_to_id("<pad>")
341        .or_else(|| tokenizer.token_to_id("[PAD]"))
342        .map(TokenId::new);
343
344    let sep_token = tokenizer
345        .token_to_id("[SEP]")
346        .or_else(|| tokenizer.token_to_id("<sep>"))
347        .map(TokenId::new);
348
349    let cls_token = tokenizer
350        .token_to_id("[CLS]")
351        .or_else(|| tokenizer.token_to_id("<cls>"))
352        .map(TokenId::new);
353
354    let mask_token = tokenizer
355        .token_to_id("[MASK]")
356        .or_else(|| tokenizer.token_to_id("<mask>"))
357        .map(TokenId::new);
358
359    Ok(SpecialTokens {
360        bos_token,
361        eos_token,
362        unk_token,
363        pad_token,
364        sep_token,
365        cls_token,
366        mask_token,
367    })
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[test]
375    fn test_decode_cache_creation() {
376        let cache = DecodeCache::new(100);
377        assert_eq!(cache.max_size, 100);
378        assert_eq!(cache.cache.len(), 0);
379    }
380
381    #[test]
382    fn test_decode_cache_insert_and_get() {
383        let mut cache = DecodeCache::new(10);
384        let tokens = vec![TokenId::new(1), TokenId::new(2)];
385        let text = "hello".to_string();
386
387        cache.insert(tokens.clone(), text.clone());
388
389        let result = cache.get(&tokens);
390        assert!(result.is_some());
391        assert_eq!(result.unwrap(), &text);
392    }
393
394    #[test]
395    fn test_decode_cache_eviction() {
396        let mut cache = DecodeCache::new(2);
397
398        // 填满缓存
399        cache.insert(vec![TokenId::new(1)], "a".to_string());
400        cache.insert(vec![TokenId::new(2)], "b".to_string());
401
402        assert_eq!(cache.cache.len(), 2);
403
404        // 触发驱逐
405        cache.insert(vec![TokenId::new(3)], "c".to_string());
406
407        // 应该已经清理了一些旧条目
408        assert!(cache.cache.len() <= 2);
409    }
410
411    #[test]
412    fn test_incremental_state_default() {
413        let state = IncrementalState::default();
414        let debug_str = format!("{:?}", state);
415        assert!(debug_str.contains("IncrementalState"));
416    }
417
418    #[test]
419    fn test_incremental_state_clone() {
420        let state = IncrementalState::default();
421        let cloned = state.clone();
422
423        // 验证克隆成功
424        let state_str = format!("{:?}", state);
425        let cloned_str = format!("{:?}", cloned);
426        assert_eq!(state_str, cloned_str);
427    }
428
429    #[test]
430    fn test_huggingface_tokenizer_factory_creation() {
431        let factory = HuggingFaceTokenizerFactory::new();
432        let debug_str = format!("{:?}", factory);
433        assert!(debug_str.contains("HuggingFaceTokenizerFactory"));
434    }
435
436    #[test]
437    fn test_huggingface_tokenizer_factory_default() {
438        let factory = HuggingFaceTokenizerFactory::default();
439        let debug_str = format!("{:?}", factory);
440        assert!(debug_str.contains("HuggingFaceTokenizerFactory"));
441    }
442
443    #[test]
444    fn test_huggingface_tokenizer_factory_clone() {
445        let factory = HuggingFaceTokenizerFactory::new();
446        let cloned = factory.clone();
447
448        let factory_str = format!("{:?}", factory);
449        let cloned_str = format!("{:?}", cloned);
450        assert_eq!(factory_str, cloned_str);
451    }
452
453    #[test]
454    fn test_huggingface_tokenizer_factory_supported_types() {
455        let factory = HuggingFaceTokenizerFactory::new();
456        let types = factory.supported_types();
457
458        assert!(types.len() >= 1);
459        assert!(types.contains(&TokenizerType::BPE));
460    }
461
462    #[test]
463    fn test_extract_special_tokens_with_mock_tokenizer() {
464        use tokenizers::models::bpe::{Vocab, BPE};
465        use tokenizers::{AddedToken, Tokenizer as HfTokenizer};
466
467        // 创建一个简单的 mock tokenizer
468        let vocab: Vocab = [
469            ("hello".to_string(), 0),
470            ("<s>".to_string(), 1),
471            ("</s>".to_string(), 2),
472            ("<unk>".to_string(), 3),
473            ("<pad>".to_string(), 4),
474        ]
475        .into_iter()
476        .collect();
477
478        let merges = vec![];
479        let bpe = BPE::builder()
480            .vocab_and_merges(vocab, merges)
481            .unk_token("<unk>".to_string())
482            .build()
483            .unwrap();
484
485        let mut tokenizer = HfTokenizer::new(bpe);
486        tokenizer.add_special_tokens(&[
487            AddedToken::from("<s>", true),
488            AddedToken::from("</s>", true),
489            AddedToken::from("<unk>", true),
490            AddedToken::from("<pad>", true),
491        ]);
492
493        // 测试提取特殊 tokens
494        let result = extract_special_tokens(&tokenizer);
495        assert!(result.is_ok());
496
497        let special_tokens = result.unwrap();
498        assert!(special_tokens.bos_token.is_some());
499        assert!(special_tokens.eos_token.is_some());
500        assert!(special_tokens.unk_token.is_some());
501        assert!(special_tokens.pad_token.is_some());
502    }
503
504    #[tokio::test]
505    async fn test_huggingface_tokenizer_with_mock() {
506        use tokenizers::models::bpe::{Vocab, BPE};
507        use tokenizers::{AddedToken, Tokenizer as HfTokenizer};
508
509        let vocab: Vocab = [
510            ("hello".to_string(), 0),
511            ("world".to_string(), 1),
512            ("<s>".to_string(), 2),
513            ("</s>".to_string(), 3),
514            ("<unk>".to_string(), 4),
515        ]
516        .into_iter()
517        .collect();
518
519        let merges = vec![];
520        let bpe = BPE::builder()
521            .vocab_and_merges(vocab, merges)
522            .unk_token("<unk>".to_string())
523            .build()
524            .unwrap();
525
526        let mut hf_tokenizer = HfTokenizer::new(bpe);
527        hf_tokenizer.add_special_tokens(&[
528            AddedToken::from("<s>", true),
529            AddedToken::from("</s>", true),
530            AddedToken::from("<unk>", true),
531        ]);
532
533        // 测试创建 HuggingFaceTokenizer
534        let result = HuggingFaceTokenizer::new(hf_tokenizer).await;
535        assert!(result.is_ok());
536
537        let tokenizer = result.unwrap();
538        assert_eq!(tokenizer.vocab_size(), 5);
539    }
540
541    #[tokio::test]
542    async fn test_tokenizer_encode_decode() {
543        use tokenizers::models::bpe::{Vocab, BPE};
544        use tokenizers::{AddedToken, Tokenizer as HfTokenizer};
545
546        let vocab: Vocab = [
547            ("hello".to_string(), 0),
548            ("world".to_string(), 1),
549            ("<s>".to_string(), 2),
550            ("</s>".to_string(), 3),
551            ("<unk>".to_string(), 4),
552        ]
553        .into_iter()
554        .collect();
555
556        let merges = vec![];
557        let bpe = BPE::builder()
558            .vocab_and_merges(vocab, merges)
559            .unk_token("<unk>".to_string())
560            .build()
561            .unwrap();
562
563        let mut hf_tokenizer = HfTokenizer::new(bpe);
564        hf_tokenizer.add_special_tokens(&[
565            AddedToken::from("<s>", true),
566            AddedToken::from("</s>", true),
567            AddedToken::from("<unk>", true),
568        ]);
569
570        let tokenizer = HuggingFaceTokenizer::new(hf_tokenizer).await.unwrap();
571
572        // 测试 encode - 即使无法编码,也会返回 UNK token
573        let result = tokenizer.encode("hello", false);
574        assert!(result.is_ok());
575
576        let _tokens = result.unwrap();
577        // Tokenizer 可能返回空数组或 UNK tokens
578        // 我们只验证结果是 Ok
579
580        // 测试 decode with empty tokens
581        let decoded = tokenizer.decode(&[], false);
582        assert!(decoded.is_ok());
583    }
584
585    #[tokio::test]
586    async fn test_tokenizer_special_tokens() {
587        use tokenizers::models::bpe::{Vocab, BPE};
588        use tokenizers::{AddedToken, Tokenizer as HfTokenizer};
589
590        let vocab: Vocab = [
591            ("hello".to_string(), 0),
592            ("<s>".to_string(), 1),
593            ("</s>".to_string(), 2),
594        ]
595        .into_iter()
596        .collect();
597
598        let merges = vec![];
599        let bpe = BPE::builder()
600            .vocab_and_merges(vocab, merges)
601            .build()
602            .unwrap();
603
604        let mut hf_tokenizer = HfTokenizer::new(bpe);
605        hf_tokenizer.add_special_tokens(&[
606            AddedToken::from("<s>", true),
607            AddedToken::from("</s>", true),
608        ]);
609
610        let tokenizer = HuggingFaceTokenizer::new(hf_tokenizer).await.unwrap();
611        let special_tokens = tokenizer.special_tokens();
612
613        // 应该能找到一些特殊 tokens
614        assert!(special_tokens.bos_token.is_some() || special_tokens.eos_token.is_some());
615    }
616
617    #[tokio::test]
618    async fn test_tokenizer_token_id_lookup() {
619        use tokenizers::models::bpe::{Vocab, BPE};
620        use tokenizers::Tokenizer as HfTokenizer;
621
622        let vocab: Vocab = [("hello".to_string(), 0), ("world".to_string(), 1)]
623            .into_iter()
624            .collect();
625
626        let merges = vec![];
627        let bpe = BPE::builder()
628            .vocab_and_merges(vocab, merges)
629            .build()
630            .unwrap();
631
632        let hf_tokenizer = HfTokenizer::new(bpe);
633        let tokenizer = HuggingFaceTokenizer::new(hf_tokenizer).await.unwrap();
634
635        // 测试 token_id 查找
636        let token_id = tokenizer.token_id("hello");
637        assert!(token_id.is_some());
638        assert_eq!(token_id.unwrap().get(), 0);
639    }
640
641    #[tokio::test]
642    async fn test_tokenizer_info() {
643        use tokenizers::models::bpe::{Vocab, BPE};
644        use tokenizers::Tokenizer as HfTokenizer;
645
646        let vocab: Vocab = [("hello".to_string(), 0), ("world".to_string(), 1)]
647            .into_iter()
648            .collect();
649
650        let merges = vec![];
651        let bpe = BPE::builder()
652            .vocab_and_merges(vocab, merges)
653            .build()
654            .unwrap();
655
656        let hf_tokenizer = HfTokenizer::new(bpe);
657        let tokenizer = HuggingFaceTokenizer::new(hf_tokenizer).await.unwrap();
658
659        let info = tokenizer.info();
660        assert_eq!(info.vocab_size, 2);
661        assert!(info.supports_incremental);
662        assert_eq!(info.tokenizer_type, TokenizerType::BPE);
663    }
664
665    #[tokio::test]
666    async fn test_incremental_tokenizer_interface() {
667        use tokenizers::models::bpe::{Vocab, BPE};
668        use tokenizers::Tokenizer as HfTokenizer;
669
670        let vocab: Vocab = [("hello".to_string(), 0), ("world".to_string(), 1)]
671            .into_iter()
672            .collect();
673
674        let merges = vec![];
675        let bpe = BPE::builder()
676            .vocab_and_merges(vocab, merges)
677            .build()
678            .unwrap();
679
680        let hf_tokenizer = HfTokenizer::new(bpe);
681        let tokenizer = HuggingFaceTokenizer::new(hf_tokenizer).await.unwrap();
682
683        // 测试增量解码接口
684        let mut state = tokenizer.create_state();
685
686        // 添加一个 token
687        let result = tokenizer.decode_incremental_with_state(&mut state, TokenId::new(0));
688        assert!(result.is_ok());
689
690        // 重置状态
691        tokenizer.reset_state(&mut state);
692        let text = tokenizer.get_decoded_text(&state);
693        assert!(text.is_empty());
694    }
695}