1use std::collections::HashMap;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum SpecialToken {
17 Cls,
19 Sep,
21 Pad,
23 Unk,
25 Mask,
27}
28
29impl SpecialToken {
30 pub fn as_str(&self) -> &'static str {
32 match self {
33 SpecialToken::Cls => "[CLS]",
34 SpecialToken::Sep => "[SEP]",
35 SpecialToken::Pad => "[PAD]",
36 SpecialToken::Unk => "[UNK]",
37 SpecialToken::Mask => "[MASK]",
38 }
39 }
40
41 pub fn all() -> &'static [SpecialToken] {
43 &[
44 SpecialToken::Cls,
45 SpecialToken::Sep,
46 SpecialToken::Pad,
47 SpecialToken::Unk,
48 SpecialToken::Mask,
49 ]
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
59pub enum TokenizerMode {
60 Bpe,
62 WordPiece,
64}
65
66#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct MergeRule {
73 pub left: String,
74 pub right: String,
75 pub merged: String,
76}
77
78#[derive(Debug, Clone)]
84pub struct EncodeResult {
85 pub tokens: Vec<String>,
87 pub ids: Vec<u32>,
89}
90
91#[derive(Debug, Clone)]
97pub struct TokenizerConfig {
98 pub mode: TokenizerMode,
100 pub max_length: usize,
102 pub lowercase: bool,
104}
105
106impl Default for TokenizerConfig {
107 fn default() -> Self {
108 Self {
109 mode: TokenizerMode::Bpe,
110 max_length: 512,
111 lowercase: true,
112 }
113 }
114}
115
116pub struct Tokenizer {
118 config: TokenizerConfig,
119 token_to_id: HashMap<String, u32>,
121 id_to_token: HashMap<u32, String>,
123 next_id: u32,
125 merge_rules: Vec<MergeRule>,
127}
128
129impl Tokenizer {
130 pub fn new(config: TokenizerConfig) -> Self {
137 let mut tok = Self {
138 config,
139 token_to_id: HashMap::new(),
140 id_to_token: HashMap::new(),
141 next_id: 0,
142 merge_rules: Vec::new(),
143 };
144 for st in SpecialToken::all() {
146 tok.add_token(st.as_str());
147 }
148 tok
149 }
150
151 pub fn bpe() -> Self {
153 Self::new(TokenizerConfig {
154 mode: TokenizerMode::Bpe,
155 ..TokenizerConfig::default()
156 })
157 }
158
159 pub fn wordpiece() -> Self {
161 Self::new(TokenizerConfig {
162 mode: TokenizerMode::WordPiece,
163 ..TokenizerConfig::default()
164 })
165 }
166
167 pub fn add_token(&mut self, token: &str) -> u32 {
173 if let Some(&id) = self.token_to_id.get(token) {
174 return id;
175 }
176 let id = self.next_id;
177 self.next_id += 1;
178 self.token_to_id.insert(token.to_string(), id);
179 self.id_to_token.insert(id, token.to_string());
180 id
181 }
182
183 pub fn remove_token(&mut self, token: &str) -> bool {
187 for st in SpecialToken::all() {
189 if st.as_str() == token {
190 return false;
191 }
192 }
193 if let Some(id) = self.token_to_id.remove(token) {
194 self.id_to_token.remove(&id);
195 return true;
196 }
197 false
198 }
199
200 pub fn vocab_size(&self) -> usize {
202 self.token_to_id.len()
203 }
204
205 pub fn contains_token(&self, token: &str) -> bool {
207 self.token_to_id.contains_key(token)
208 }
209
210 pub fn add_merge_rule(&mut self, left: &str, right: &str) {
214 let merged = format!("{left}{right}");
215 self.add_token(&merged);
216 self.merge_rules.push(MergeRule {
217 left: left.to_string(),
218 right: right.to_string(),
219 merged,
220 });
221 }
222
223 pub fn merge_rule_count(&self) -> usize {
225 self.merge_rules.len()
226 }
227
228 pub fn token_to_id(&self, token: &str) -> Option<u32> {
232 self.token_to_id.get(token).copied()
233 }
234
235 pub fn id_to_token(&self, id: u32) -> Option<&str> {
237 self.id_to_token.get(&id).map(String::as_str)
238 }
239
240 pub fn unk_id(&self) -> u32 {
242 self.token_to_id
243 .get(SpecialToken::Unk.as_str())
244 .copied()
245 .unwrap_or(0)
246 }
247
248 pub fn cls_id(&self) -> u32 {
250 self.token_to_id
251 .get(SpecialToken::Cls.as_str())
252 .copied()
253 .unwrap_or(0)
254 }
255
256 pub fn sep_id(&self) -> u32 {
258 self.token_to_id
259 .get(SpecialToken::Sep.as_str())
260 .copied()
261 .unwrap_or(0)
262 }
263
264 pub fn pad_id(&self) -> u32 {
266 self.token_to_id
267 .get(SpecialToken::Pad.as_str())
268 .copied()
269 .unwrap_or(0)
270 }
271
272 pub fn encode(&self, text: &str) -> EncodeResult {
278 let text = if self.config.lowercase {
279 text.to_lowercase()
280 } else {
281 text.to_string()
282 };
283
284 let sub_tokens = match &self.config.mode {
285 TokenizerMode::Bpe => self.bpe_tokenize(&text),
286 TokenizerMode::WordPiece => self.wordpiece_tokenize(&text),
287 };
288
289 let max = self.config.max_length;
290 let truncated: Vec<String> = sub_tokens.into_iter().take(max).collect();
291 let ids: Vec<u32> = truncated
292 .iter()
293 .map(|t| {
294 self.token_to_id
295 .get(t.as_str())
296 .copied()
297 .unwrap_or_else(|| self.unk_id())
298 })
299 .collect();
300
301 EncodeResult {
302 tokens: truncated,
303 ids,
304 }
305 }
306
307 pub fn decode(&self, ids: &[u32]) -> String {
311 let mut parts: Vec<String> = Vec::with_capacity(ids.len());
312 for &id in ids {
313 if let Some(tok) = self.id_to_token.get(&id) {
314 let is_special = SpecialToken::all().iter().any(|st| st.as_str() == tok);
316 if is_special {
317 continue;
318 }
319 parts.push(tok.clone());
320 }
321 }
322 self.merge_subwords(&parts)
323 }
324
325 pub fn encode_batch(&self, texts: &[&str]) -> Vec<EncodeResult> {
327 texts.iter().map(|t| self.encode(t)).collect()
328 }
329
330 fn merge_subwords(&self, tokens: &[String]) -> String {
337 if tokens.is_empty() {
338 return String::new();
339 }
340
341 match &self.config.mode {
342 TokenizerMode::WordPiece => {
343 let mut result = String::new();
344 for tok in tokens {
345 if let Some(suffix) = tok.strip_prefix("##") {
346 result.push_str(suffix);
347 } else {
348 if !result.is_empty() {
349 result.push(' ');
350 }
351 result.push_str(tok);
352 }
353 }
354 result
355 }
356 TokenizerMode::Bpe => tokens.join(" "),
357 }
358 }
359
360 fn bpe_tokenize(&self, text: &str) -> Vec<String> {
364 let words: Vec<&str> = text.split_whitespace().collect();
365 let mut all_tokens: Vec<String> = Vec::new();
366
367 for word in words {
368 let mut symbols: Vec<String> = word.chars().map(|c| c.to_string()).collect();
370
371 for rule in &self.merge_rules {
373 symbols = Self::apply_merge(&symbols, &rule.left, &rule.right, &rule.merged);
374 }
375
376 for sym in symbols {
378 if self.token_to_id.contains_key(&sym) {
379 all_tokens.push(sym);
380 } else {
381 all_tokens.push(SpecialToken::Unk.as_str().to_string());
382 }
383 }
384 }
385
386 all_tokens
387 }
388
389 fn apply_merge(symbols: &[String], left: &str, right: &str, merged: &str) -> Vec<String> {
391 let mut result: Vec<String> = Vec::with_capacity(symbols.len());
392 let mut i = 0;
393 while i < symbols.len() {
394 if i + 1 < symbols.len() && symbols[i] == left && symbols[i + 1] == right {
395 result.push(merged.to_string());
396 i += 2;
397 } else {
398 result.push(symbols[i].clone());
399 i += 1;
400 }
401 }
402 result
403 }
404
405 fn wordpiece_tokenize(&self, text: &str) -> Vec<String> {
409 let words: Vec<&str> = text.split_whitespace().collect();
410 let mut all_tokens: Vec<String> = Vec::new();
411
412 for word in words {
413 let chars: Vec<char> = word.chars().collect();
414 let n = chars.len();
415 let mut start = 0;
416
417 while start < n {
418 let mut end = n;
419 let mut found = false;
420
421 while start < end {
422 let sub: String = chars[start..end].iter().collect();
423 let candidate = if start == 0 {
424 sub.clone()
425 } else {
426 format!("##{sub}")
427 };
428
429 if self.token_to_id.contains_key(&candidate) {
430 all_tokens.push(candidate);
431 start = end;
432 found = true;
433 break;
434 }
435 end -= 1;
436 }
437
438 if !found {
439 all_tokens.push(SpecialToken::Unk.as_str().to_string());
441 start += 1;
442 }
443 }
444 }
445
446 all_tokens
447 }
448
449 pub fn max_length(&self) -> usize {
453 self.config.max_length
454 }
455
456 pub fn mode(&self) -> &TokenizerMode {
458 &self.config.mode
459 }
460
461 pub fn is_lowercase(&self) -> bool {
463 self.config.lowercase
464 }
465}
466
467#[cfg(test)]
472mod tests {
473 use super::*;
474
475 fn bpe_tokenizer() -> Tokenizer {
476 Tokenizer::bpe()
477 }
478
479 fn wp_tokenizer() -> Tokenizer {
480 Tokenizer::wordpiece()
481 }
482
483 #[test]
486 fn test_special_token_cls_str() {
487 assert_eq!(SpecialToken::Cls.as_str(), "[CLS]");
488 }
489
490 #[test]
491 fn test_special_token_sep_str() {
492 assert_eq!(SpecialToken::Sep.as_str(), "[SEP]");
493 }
494
495 #[test]
496 fn test_special_token_pad_str() {
497 assert_eq!(SpecialToken::Pad.as_str(), "[PAD]");
498 }
499
500 #[test]
501 fn test_special_token_unk_str() {
502 assert_eq!(SpecialToken::Unk.as_str(), "[UNK]");
503 }
504
505 #[test]
506 fn test_special_token_mask_str() {
507 assert_eq!(SpecialToken::Mask.as_str(), "[MASK]");
508 }
509
510 #[test]
511 fn test_special_token_all_count() {
512 assert_eq!(SpecialToken::all().len(), 5);
513 }
514
515 #[test]
518 fn test_new_bpe_has_special_tokens() {
519 let tok = bpe_tokenizer();
520 assert!(tok.contains_token("[CLS]"));
521 assert!(tok.contains_token("[SEP]"));
522 assert!(tok.contains_token("[PAD]"));
523 assert!(tok.contains_token("[UNK]"));
524 assert!(tok.contains_token("[MASK]"));
525 }
526
527 #[test]
528 fn test_new_bpe_vocab_size() {
529 let tok = bpe_tokenizer();
530 assert_eq!(tok.vocab_size(), 5); }
532
533 #[test]
534 fn test_new_wordpiece_mode() {
535 let tok = wp_tokenizer();
536 assert_eq!(*tok.mode(), TokenizerMode::WordPiece);
537 }
538
539 #[test]
540 fn test_bpe_mode() {
541 let tok = bpe_tokenizer();
542 assert_eq!(*tok.mode(), TokenizerMode::Bpe);
543 }
544
545 #[test]
548 fn test_add_token_returns_new_id() {
549 let mut tok = bpe_tokenizer();
550 let id1 = tok.add_token("hello");
551 let id2 = tok.add_token("world");
552 assert_ne!(id1, id2);
553 }
554
555 #[test]
556 fn test_add_token_idempotent() {
557 let mut tok = bpe_tokenizer();
558 let id1 = tok.add_token("hello");
559 let id2 = tok.add_token("hello");
560 assert_eq!(id1, id2);
561 assert_eq!(tok.vocab_size(), 6); }
564
565 #[test]
566 fn test_remove_token_normal() {
567 let mut tok = bpe_tokenizer();
568 tok.add_token("temp");
569 assert!(tok.contains_token("temp"));
570 assert!(tok.remove_token("temp"));
571 assert!(!tok.contains_token("temp"));
572 }
573
574 #[test]
575 fn test_remove_special_token_prevented() {
576 let mut tok = bpe_tokenizer();
577 assert!(!tok.remove_token("[CLS]"));
578 assert!(tok.contains_token("[CLS]"));
579 }
580
581 #[test]
582 fn test_remove_nonexistent_returns_false() {
583 let mut tok = bpe_tokenizer();
584 assert!(!tok.remove_token("nonexistent"));
585 }
586
587 #[test]
588 fn test_vocab_size_grows() {
589 let mut tok = bpe_tokenizer();
590 assert_eq!(tok.vocab_size(), 5);
591 tok.add_token("a");
592 tok.add_token("b");
593 assert_eq!(tok.vocab_size(), 7);
594 }
595
596 #[test]
599 fn test_token_to_id_roundtrip() {
600 let mut tok = bpe_tokenizer();
601 let id = tok.add_token("cat");
602 assert_eq!(tok.token_to_id("cat"), Some(id));
603 assert_eq!(tok.id_to_token(id), Some("cat"));
604 }
605
606 #[test]
607 fn test_token_to_id_missing() {
608 let tok = bpe_tokenizer();
609 assert_eq!(tok.token_to_id("missing"), None);
610 }
611
612 #[test]
613 fn test_id_to_token_missing() {
614 let tok = bpe_tokenizer();
615 assert_eq!(tok.id_to_token(9999), None);
616 }
617
618 #[test]
619 fn test_unk_id() {
620 let tok = bpe_tokenizer();
621 let unk = tok.unk_id();
622 assert_eq!(tok.id_to_token(unk), Some("[UNK]"));
623 }
624
625 #[test]
626 fn test_cls_id() {
627 let tok = bpe_tokenizer();
628 let cls = tok.cls_id();
629 assert_eq!(tok.id_to_token(cls), Some("[CLS]"));
630 }
631
632 #[test]
633 fn test_sep_id() {
634 let tok = bpe_tokenizer();
635 let sep = tok.sep_id();
636 assert_eq!(tok.id_to_token(sep), Some("[SEP]"));
637 }
638
639 #[test]
640 fn test_pad_id() {
641 let tok = bpe_tokenizer();
642 let pad = tok.pad_id();
643 assert_eq!(tok.id_to_token(pad), Some("[PAD]"));
644 }
645
646 #[test]
649 fn test_add_merge_rule_creates_merged_token() {
650 let mut tok = bpe_tokenizer();
651 tok.add_token("h");
652 tok.add_token("e");
653 tok.add_merge_rule("h", "e");
654 assert!(tok.contains_token("he"));
655 assert_eq!(tok.merge_rule_count(), 1);
656 }
657
658 #[test]
659 fn test_bpe_merge_rules_applied_in_order() {
660 let mut tok = bpe_tokenizer();
661 tok.add_token("h");
663 tok.add_token("e");
664 tok.add_token("l");
665 tok.add_token("o");
666 tok.add_merge_rule("h", "e"); tok.add_merge_rule("l", "o"); tok.add_merge_rule("he", "l"); tok.add_merge_rule("hel", "lo"); let result = tok.encode("hello");
672 assert!(result.tokens.contains(&"hello".to_string()));
673 }
674
675 #[test]
678 fn test_bpe_encode_unknown_chars() {
679 let tok = bpe_tokenizer();
680 let result = tok.encode("xyz");
682 assert!(result.ids.iter().all(|&id| id == tok.unk_id()));
683 }
684
685 #[test]
686 fn test_bpe_encode_single_char_tokens() {
687 let mut tok = bpe_tokenizer();
688 tok.add_token("a");
689 tok.add_token("b");
690 let result = tok.encode("ab");
691 assert_eq!(result.tokens, vec!["a", "b"]);
692 }
693
694 #[test]
695 fn test_bpe_encode_multiple_words() {
696 let mut tok = bpe_tokenizer();
697 tok.add_token("h");
698 tok.add_token("i");
699 let result = tok.encode("hi hi");
700 assert_eq!(result.tokens.len(), 4); }
702
703 #[test]
706 fn test_wordpiece_full_word_match() {
707 let mut tok = wp_tokenizer();
708 tok.add_token("hello");
709 let result = tok.encode("hello");
710 assert_eq!(result.tokens, vec!["hello"]);
711 }
712
713 #[test]
714 fn test_wordpiece_continuation_tokens() {
715 let mut tok = wp_tokenizer();
716 tok.add_token("un");
717 tok.add_token("##believ");
718 tok.add_token("##able");
719 let result = tok.encode("unbelievable");
720 assert_eq!(result.tokens, vec!["un", "##believ", "##able"]);
721 }
722
723 #[test]
724 fn test_wordpiece_unknown_fallback() {
725 let tok = wp_tokenizer();
726 let result = tok.encode("xyz");
727 assert!(result.ids.iter().all(|&id| id == tok.unk_id()));
729 }
730
731 #[test]
732 fn test_wordpiece_multiple_words() {
733 let mut tok = wp_tokenizer();
734 tok.add_token("hello");
735 tok.add_token("world");
736 let result = tok.encode("hello world");
737 assert_eq!(result.tokens, vec!["hello", "world"]);
738 }
739
740 #[test]
743 fn test_bpe_decode_simple() {
744 let mut tok = bpe_tokenizer();
745 let id_a = tok.add_token("hello");
746 let id_b = tok.add_token("world");
747 let decoded = tok.decode(&[id_a, id_b]);
748 assert_eq!(decoded, "hello world");
749 }
750
751 #[test]
752 fn test_wordpiece_decode_merges_continuations() {
753 let mut tok = wp_tokenizer();
754 let id_un = tok.add_token("un");
755 let id_do = tok.add_token("##do");
756 let decoded = tok.decode(&[id_un, id_do]);
757 assert_eq!(decoded, "undo");
758 }
759
760 #[test]
761 fn test_decode_skips_special_tokens() {
762 let tok = bpe_tokenizer();
763 let cls = tok.cls_id();
764 let sep = tok.sep_id();
765 let decoded = tok.decode(&[cls, sep]);
766 assert_eq!(decoded, "");
767 }
768
769 #[test]
770 fn test_decode_empty() {
771 let tok = bpe_tokenizer();
772 assert_eq!(tok.decode(&[]), "");
773 }
774
775 #[test]
778 fn test_truncation_at_max_length() {
779 let mut tok = Tokenizer::new(TokenizerConfig {
780 mode: TokenizerMode::Bpe,
781 max_length: 3,
782 lowercase: true,
783 });
784 tok.add_token("a");
785 tok.add_token("b");
786 tok.add_token("c");
787 tok.add_token("d");
788 let result = tok.encode("a b c d");
789 assert_eq!(result.tokens.len(), 3);
790 assert_eq!(result.ids.len(), 3);
791 }
792
793 #[test]
794 fn test_truncation_shorter_text_unaffected() {
795 let mut tok = Tokenizer::new(TokenizerConfig {
796 mode: TokenizerMode::Bpe,
797 max_length: 100,
798 lowercase: true,
799 });
800 tok.add_token("x");
801 let result = tok.encode("x");
802 assert_eq!(result.tokens.len(), 1);
803 }
804
805 #[test]
808 fn test_encode_batch_count() {
809 let mut tok = bpe_tokenizer();
810 tok.add_token("a");
811 let results = tok.encode_batch(&["a", "a a", "a a a"]);
812 assert_eq!(results.len(), 3);
813 }
814
815 #[test]
816 fn test_encode_batch_independent() {
817 let mut tok = bpe_tokenizer();
818 tok.add_token("x");
819 tok.add_token("y");
820 let results = tok.encode_batch(&["x", "y"]);
821 assert_ne!(results[0].ids, results[1].ids);
822 }
823
824 #[test]
825 fn test_encode_batch_empty() {
826 let tok = bpe_tokenizer();
827 let results = tok.encode_batch(&[]);
828 assert!(results.is_empty());
829 }
830
831 #[test]
834 fn test_lowercase_enabled() {
835 let mut tok = Tokenizer::new(TokenizerConfig {
836 mode: TokenizerMode::Bpe,
837 max_length: 512,
838 lowercase: true,
839 });
840 tok.add_token("hello");
841 let r1 = tok.encode("HELLO");
844 let r2 = tok.encode("hello");
845 assert_eq!(r1.ids, r2.ids);
846 }
847
848 #[test]
849 fn test_lowercase_disabled() {
850 let mut tok = Tokenizer::new(TokenizerConfig {
851 mode: TokenizerMode::Bpe,
852 max_length: 512,
853 lowercase: false,
854 });
855 tok.add_token("A");
856 tok.add_token("a");
857 let r1 = tok.encode("A");
858 let r2 = tok.encode("a");
859 assert_ne!(r1.ids, r2.ids);
860 }
861
862 #[test]
865 fn test_max_length_accessor() {
866 let tok = bpe_tokenizer();
867 assert_eq!(tok.max_length(), 512);
868 }
869
870 #[test]
871 fn test_is_lowercase_accessor() {
872 let tok = bpe_tokenizer();
873 assert!(tok.is_lowercase());
874 }
875
876 #[test]
879 fn test_encode_empty_string() {
880 let tok = bpe_tokenizer();
881 let result = tok.encode("");
882 assert!(result.tokens.is_empty());
883 assert!(result.ids.is_empty());
884 }
885
886 #[test]
887 fn test_encode_whitespace_only() {
888 let tok = bpe_tokenizer();
889 let result = tok.encode(" ");
890 assert!(result.tokens.is_empty());
891 }
892
893 #[test]
894 fn test_wordpiece_greedy_longest_match() {
895 let mut tok = wp_tokenizer();
896 tok.add_token("play");
897 tok.add_token("##ing");
898 tok.add_token("##i");
899 tok.add_token("##n");
900 tok.add_token("##g");
901 let result = tok.encode("playing");
902 assert_eq!(result.tokens, vec!["play", "##ing"]);
904 }
905
906 #[test]
907 fn test_merge_rule_struct_fields() {
908 let rule = MergeRule {
909 left: "a".to_string(),
910 right: "b".to_string(),
911 merged: "ab".to_string(),
912 };
913 assert_eq!(rule.left, "a");
914 assert_eq!(rule.right, "b");
915 assert_eq!(rule.merged, "ab");
916 }
917
918 #[test]
919 fn test_encode_result_tokens_and_ids_same_length() {
920 let mut tok = bpe_tokenizer();
921 tok.add_token("t");
922 tok.add_token("e");
923 tok.add_token("s");
924 let result = tok.encode("test");
925 assert_eq!(result.tokens.len(), result.ids.len());
926 }
927
928 #[test]
929 fn test_tokenizer_config_default() {
930 let cfg = TokenizerConfig::default();
931 assert_eq!(cfg.mode, TokenizerMode::Bpe);
932 assert_eq!(cfg.max_length, 512);
933 assert!(cfg.lowercase);
934 }
935}