1use std::collections::HashMap;
7
8use unicode_normalization::UnicodeNormalization;
9
10use crate::gguf::{GgufFile, MetadataValue};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum TokenizerType {
15 BPE,
17 SentencePiece,
19 WordPiece,
21 Unknown,
23}
24
25impl TokenizerType {
26 pub fn from_gguf_str(s: &str) -> Self {
28 match s.to_lowercase().as_str() {
29 "llama" | "bpe" => Self::BPE,
30 "gpt2" => Self::BPE,
31 "sentencepiece" | "spm" => Self::SentencePiece,
32 "wordpiece" | "bert" => Self::WordPiece,
33 _ => Self::Unknown,
34 }
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub enum TokenType {
41 #[default]
43 Normal,
44 Control,
46 Byte,
48 Unknown,
50}
51
52#[derive(Debug, Clone)]
54pub enum Normalizer {
55 NFC,
56 NFKC,
57 NFD,
58 NFKD,
59 Lowercase,
60 Strip { left: bool, right: bool },
61 Prepend(String),
62 Replace { pattern: String, content: String },
63 StripAccents,
64 Sequence(Vec<Normalizer>),
65}
66
67impl Normalizer {
68 fn apply(&self, text: &str) -> String {
69 match self {
70 Self::NFC => text.nfc().collect(),
71 Self::NFKC => text.nfkc().collect(),
72 Self::NFD => text.nfd().collect(),
73 Self::NFKD => text.nfkd().collect(),
74 Self::Lowercase => text.to_lowercase(),
75 Self::Strip { left, right } => {
76 let s = if *left { text.trim_start() } else { text };
77 if *right { s.trim_end().to_string() } else { s.to_string() }
78 }
79 Self::Prepend(prefix) => format!("{}{}", prefix, text),
80 Self::Replace { pattern, content } => text.replace(pattern.as_str(), content.as_str()),
81 Self::StripAccents => {
82 text.nfkd()
83 .filter(|c| !unicode_normalization::char::is_combining_mark(*c))
84 .collect()
85 }
86 Self::Sequence(normalizers) => {
87 let mut result = text.to_string();
88 for n in normalizers {
89 result = n.apply(&result);
90 }
91 result
92 }
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
99pub enum PreTokenizer {
100 ByteLevel { add_prefix_space: bool },
102 Whitespace,
104 Metaspace { replacement: char, add_prefix_space: bool },
106 Punctuation,
108 Digits { individual_digits: bool },
110 Sequence(Vec<PreTokenizer>),
112}
113
114impl PreTokenizer {
115 fn apply(&self, text: &str) -> Vec<String> {
116 match self {
117 Self::ByteLevel { add_prefix_space } => {
118 let text = if *add_prefix_space && !text.starts_with(' ') {
119 format!(" {}", text)
120 } else {
121 text.to_string()
122 };
123 let mut tokens = Vec::new();
124 let mut current = String::new();
125 for ch in text.chars() {
126 if ch == ' ' && !current.is_empty() {
127 tokens.push(std::mem::take(&mut current));
128 }
129 current.push(ch);
130 }
131 if !current.is_empty() {
132 tokens.push(current);
133 }
134 tokens
135 }
136 Self::Whitespace => {
137 text.split_whitespace().map(|s| s.to_string()).collect()
138 }
139 Self::Metaspace { replacement, add_prefix_space } => {
140 let text = if *add_prefix_space && !text.starts_with(' ') {
141 format!(" {}", text)
142 } else {
143 text.to_string()
144 };
145 text.split(' ')
146 .enumerate()
147 .map(|(i, s)| {
148 if i == 0 && s.is_empty() {
149 replacement.to_string()
150 } else if i > 0 {
151 format!("{}{}", replacement, s)
152 } else {
153 s.to_string()
154 }
155 })
156 .filter(|s| !s.is_empty())
157 .collect()
158 }
159 Self::Punctuation => {
160 let mut result = Vec::new();
161 let mut current = String::new();
162 for ch in text.chars() {
163 if ch.is_ascii_punctuation() {
164 if !current.is_empty() {
165 result.push(std::mem::take(&mut current));
166 }
167 result.push(ch.to_string());
168 } else {
169 current.push(ch);
170 }
171 }
172 if !current.is_empty() {
173 result.push(current);
174 }
175 result
176 }
177 Self::Digits { individual_digits } => {
178 if !*individual_digits {
179 return vec![text.to_string()];
180 }
181 let mut result = Vec::new();
182 let mut current = String::new();
183 for ch in text.chars() {
184 if ch.is_ascii_digit() {
185 if !current.is_empty() {
186 result.push(std::mem::take(&mut current));
187 }
188 result.push(ch.to_string());
189 } else {
190 current.push(ch);
191 }
192 }
193 if !current.is_empty() {
194 result.push(current);
195 }
196 result
197 }
198 Self::Sequence(pre_tokenizers) => {
199 let mut segments = vec![text.to_string()];
200 for pt in pre_tokenizers {
201 let mut next = Vec::new();
202 for seg in &segments {
203 next.extend(pt.apply(seg));
204 }
205 segments = next;
206 }
207 segments
208 }
209 }
210 }
211}
212
213#[derive(Debug, Clone)]
215pub enum TemplateElement {
216 SpecialToken { id: String, token_id: u32 },
217 Sequence { type_id: u32 },
218}
219
220#[derive(Debug, Clone)]
222pub enum PostProcessor {
223 TemplateProcessing {
224 single: Vec<TemplateElement>,
225 pair: Vec<TemplateElement>,
226 },
227 ByteLevel { trim_offsets: bool },
228}
229
230#[derive(Debug, Clone)]
232pub struct SpecialTokens {
233 pub bos_token_id: u32,
235 pub eos_token_id: u32,
237 pub pad_token_id: Option<u32>,
239 pub unk_token_id: Option<u32>,
241}
242
243impl Default for SpecialTokens {
244 fn default() -> Self {
245 Self {
246 bos_token_id: 1,
247 eos_token_id: 2,
248 pad_token_id: None,
249 unk_token_id: Some(0),
250 }
251 }
252}
253
254#[derive(thiserror::Error, Debug)]
256pub enum TokenizerError {
257 #[error("Missing tokenizer data in GGUF: {0}")]
258 MissingData(String),
259
260 #[error("Invalid token: {0}")]
261 InvalidToken(String),
262
263 #[error("Encoding error: {0}")]
264 EncodingError(String),
265}
266
267pub type TokenizerResult<T> = Result<T, TokenizerError>;
268
269fn flush_valid_utf8(buf: &mut Vec<u8>) -> String {
273 if buf.is_empty() {
274 return String::new();
275 }
276
277 let valid_up_to = match std::str::from_utf8(buf) {
280 Ok(_) => {
281 let s = String::from_utf8(std::mem::take(buf)).unwrap();
282 return s;
283 }
284 Err(e) => e.valid_up_to(),
285 };
286
287 if valid_up_to == 0 {
288 if buf.len() <= 3 && buf[0] >= 0x80 {
291 return String::new();
292 }
293 buf.remove(0);
295 return String::from("\u{FFFD}");
296 }
297
298 let text = String::from_utf8(buf[..valid_up_to].to_vec()).unwrap();
299 *buf = buf[valid_up_to..].to_vec();
300 text
301}
302
303fn build_gpt2_mappings() -> (HashMap<char, u8>, [char; 256]) {
309 let mut byte_to_unicode = ['\0'; 256];
310
311 let mut direct: Vec<u8> = Vec::new();
312 direct.extend(33u8..=126);
313 direct.extend(161u8..=172);
314 direct.extend(174u8..=255);
315
316 for &b in &direct {
317 byte_to_unicode[b as usize] = char::from(b);
318 }
319
320 let mut n: u32 = 0;
321 for b in 0u16..=255 {
322 if !direct.contains(&(b as u8)) {
323 byte_to_unicode[b as usize] = char::from_u32(256 + n).unwrap();
324 n += 1;
325 }
326 }
327
328 let unicode_to_byte: HashMap<char, u8> = byte_to_unicode
329 .iter()
330 .enumerate()
331 .map(|(b, &c)| (c, b as u8))
332 .collect();
333
334 (unicode_to_byte, byte_to_unicode)
335}
336
337#[derive(Debug, Clone)]
339enum TextSegment {
340 Text(String),
342 SpecialToken(u32),
344}
345
346#[derive(Debug)]
348pub struct Tokenizer {
349 token_to_id: HashMap<String, u32>,
351 id_to_token: Vec<String>,
353 scores: Vec<f32>,
355 merges: HashMap<(u32, u32), (u32, usize)>,
358 pub special_tokens: SpecialTokens,
360 pub tokenizer_type: TokenizerType,
362 pub vocab_size: usize,
364 token_types: Vec<TokenType>,
366 gpt2_unicode_to_byte: Option<HashMap<char, u8>>,
368 gpt2_byte_to_unicode: Option<[char; 256]>,
370 normalizer: Option<Normalizer>,
372 pre_tokenizer: Option<PreTokenizer>,
374 post_processor: Option<PostProcessor>,
376 wordpiece_prefix: String,
378 control_token_strings: Vec<(String, u32)>,
380 pub has_explicit_bos: bool,
382 pub add_space_prefix: bool,
384}
385
386impl Tokenizer {
387 pub fn from_gguf(gguf: &GgufFile) -> TokenizerResult<Self> {
389 let model_str = gguf
391 .data
392 .get_string("tokenizer.ggml.model")
393 .unwrap_or("bpe");
394 let tokenizer_type = TokenizerType::from_gguf_str(model_str);
395
396 let uses_gpt2_bytes = model_str == "gpt2"
398 || gguf
399 .data
400 .get_string("tokenizer.ggml.pre")
401 .is_some_and(|p| {
402 matches!(
403 p,
404 "qwen2" | "gpt-2" | "gpt2" | "starcoder" | "deepseek-llm" | "deepseek-coder"
405 )
406 });
407
408 let tokens = Self::load_tokens(gguf)?;
410 let vocab_size = tokens.len();
411
412 let mut token_to_id = HashMap::with_capacity(vocab_size);
414 let mut id_to_token = Vec::with_capacity(vocab_size);
415
416 for (id, token) in tokens.into_iter().enumerate() {
417 token_to_id.insert(token.clone(), id as u32);
418 id_to_token.push(token);
419 }
420
421 let scores = Self::load_scores(gguf, vocab_size);
423
424 let token_types = Self::load_token_types(gguf, vocab_size);
426
427 let merges = Self::load_merges(gguf, &token_to_id);
429
430 let special_tokens = Self::load_special_tokens(gguf);
432
433 let (gpt2_unicode_to_byte, gpt2_byte_to_unicode) = if uses_gpt2_bytes {
434 let (u2b, b2u) = build_gpt2_mappings();
435 (Some(u2b), Some(b2u))
436 } else {
437 (None, None)
438 };
439
440 let has_explicit_bos = gguf.data.get_u32("tokenizer.ggml.bos_token_id").is_some();
441 let add_space_prefix = gguf
442 .data
443 .get_bool("tokenizer.ggml.add_space_prefix")
444 .unwrap_or(true);
445
446 let mut control_token_strings: Vec<(String, u32)> = token_types
447 .iter()
448 .enumerate()
449 .filter(|(_, tt)| **tt == TokenType::Control)
450 .filter_map(|(id, _)| {
451 let s = &id_to_token[id];
452 if !s.is_empty() {
453 Some((s.clone(), id as u32))
454 } else {
455 None
456 }
457 })
458 .collect();
459 control_token_strings.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
460
461 Ok(Self {
462 token_to_id,
463 id_to_token,
464 scores,
465 merges,
466 special_tokens,
467 tokenizer_type,
468 vocab_size,
469 token_types,
470 gpt2_unicode_to_byte,
471 gpt2_byte_to_unicode,
472 normalizer: None,
473 pre_tokenizer: None,
474 post_processor: None,
475 wordpiece_prefix: "##".to_string(),
476 control_token_strings,
477 has_explicit_bos,
478 add_space_prefix,
479 })
480 }
481
482 fn load_tokens(gguf: &GgufFile) -> TokenizerResult<Vec<String>> {
484 let tokens_value = gguf
485 .data
486 .metadata
487 .get("tokenizer.ggml.tokens")
488 .ok_or_else(|| TokenizerError::MissingData("tokenizer.ggml.tokens".into()))?;
489
490 match tokens_value {
491 MetadataValue::Array(arr) => {
492 let mut tokens = Vec::with_capacity(arr.values.len());
493 for value in &arr.values {
494 match value {
495 MetadataValue::String(s) => tokens.push(s.clone()),
496 _ => {
497 return Err(TokenizerError::MissingData(
498 "Expected string tokens".into(),
499 ));
500 }
501 }
502 }
503 Ok(tokens)
504 }
505 _ => Err(TokenizerError::MissingData("Expected token array".into())),
506 }
507 }
508
509 fn load_scores(gguf: &GgufFile, vocab_size: usize) -> Vec<f32> {
511 let scores_value = gguf.data.metadata.get("tokenizer.ggml.scores");
512
513 match scores_value {
514 Some(MetadataValue::Array(arr)) => {
515 let mut scores = Vec::with_capacity(arr.values.len());
516 for value in &arr.values {
517 match value {
518 MetadataValue::Float32(f) => scores.push(*f),
519 _ => scores.push(0.0),
520 }
521 }
522 scores
523 }
524 _ => vec![0.0; vocab_size],
525 }
526 }
527
528 fn load_token_types(gguf: &GgufFile, vocab_size: usize) -> Vec<TokenType> {
530 let types_value = gguf.data.metadata.get("tokenizer.ggml.token_type");
531
532 match types_value {
533 Some(MetadataValue::Array(arr)) => {
534 let mut types = Vec::with_capacity(arr.values.len());
535 for value in &arr.values {
536 let token_type = match value {
537 MetadataValue::Int32(t) => match *t {
538 1 => TokenType::Normal,
539 2 => TokenType::Unknown,
540 3 => TokenType::Control,
541 6 => TokenType::Byte,
542 _ => TokenType::Normal,
543 },
544 _ => TokenType::Normal,
545 };
546 types.push(token_type);
547 }
548 types
549 }
550 _ => vec![TokenType::Normal; vocab_size],
551 }
552 }
553
554 fn load_merges(
556 gguf: &GgufFile,
557 token_to_id: &HashMap<String, u32>,
558 ) -> HashMap<(u32, u32), (u32, usize)> {
559 let mut merges = HashMap::new();
560
561 let merges_value = gguf.data.metadata.get("tokenizer.ggml.merges");
562
563 if let Some(MetadataValue::Array(arr)) = merges_value {
564 for (priority, value) in arr.values.iter().enumerate() {
565 if let MetadataValue::String(merge_str) = value {
566 let parts: Vec<&str> = merge_str.split(' ').collect();
568 if parts.len() == 2
569 && let (Some(&id1), Some(&id2)) =
570 (token_to_id.get(parts[0]), token_to_id.get(parts[1]))
571 {
572 let merged = format!("{}{}", parts[0], parts[1]);
574 if let Some(&merged_id) = token_to_id.get(&merged) {
575 merges.insert((id1, id2), (merged_id, priority));
576 }
577 }
578 }
579 }
580 }
581
582 merges
583 }
584
585 fn load_special_tokens(gguf: &GgufFile) -> SpecialTokens {
587 SpecialTokens {
588 bos_token_id: gguf
589 .data
590 .get_u32("tokenizer.ggml.bos_token_id")
591 .unwrap_or(1),
592 eos_token_id: gguf
593 .data
594 .get_u32("tokenizer.ggml.eos_token_id")
595 .unwrap_or(2),
596 pad_token_id: gguf.data.get_u32("tokenizer.ggml.padding_token_id"),
597 unk_token_id: gguf.data.get_u32("tokenizer.ggml.unknown_token_id"),
598 }
599 }
600
601 fn split_with_special_tokens(&self, text: &str) -> Vec<TextSegment> {
607 if self.control_token_strings.is_empty() {
608 return vec![TextSegment::Text(text.to_string())];
609 }
610
611 let mut segments = Vec::new();
612 let mut remaining = text;
613
614 while !remaining.is_empty() {
615 let mut earliest_pos = remaining.len();
616 let mut matched_len = 0;
617 let mut matched_id = 0u32;
618
619 for (tok_str, tok_id) in &self.control_token_strings {
620 if let Some(pos) = remaining.find(tok_str.as_str()) {
621 if pos < earliest_pos
622 || (pos == earliest_pos && tok_str.len() > matched_len)
623 {
624 earliest_pos = pos;
625 matched_len = tok_str.len();
626 matched_id = *tok_id;
627 }
628 }
629 }
630
631 if matched_len == 0 {
632 segments.push(TextSegment::Text(remaining.to_string()));
633 break;
634 }
635
636 if earliest_pos > 0 {
637 segments.push(TextSegment::Text(remaining[..earliest_pos].to_string()));
638 }
639 segments.push(TextSegment::SpecialToken(matched_id));
640 remaining = &remaining[earliest_pos + matched_len..];
641 }
642
643 segments
644 }
645
646 fn encode_text_segment(&self, text: &str) -> TokenizerResult<Vec<u32>> {
648 if text.is_empty() {
649 return Ok(vec![]);
650 }
651 if self.normalizer.is_some() || self.pre_tokenizer.is_some() {
652 let normalized = match &self.normalizer {
653 Some(n) => n.apply(text),
654 None => text.to_string(),
655 };
656 let pre_tokens = match &self.pre_tokenizer {
657 Some(pt) => pt.apply(&normalized),
658 None => vec![normalized],
659 };
660 let mut tokens = Vec::new();
661 for pre_token in &pre_tokens {
662 if pre_token.is_empty() {
663 continue;
664 }
665 match self.tokenizer_type {
666 TokenizerType::SentencePiece => {
667 tokens.extend(self.encode_unigram(pre_token)?);
668 }
669 TokenizerType::WordPiece => {
670 tokens.extend(self.encode_wordpiece(pre_token)?);
671 }
672 _ => {
673 tokens.extend(self.encode_bpe_pretokenized(pre_token)?);
674 }
675 }
676 }
677 Ok(tokens)
678 } else if !self.merges.is_empty() {
679 self.encode_bpe(text)
680 } else {
681 self.encode_sentencepiece(text)
682 }
683 }
684
685 pub fn encode(&self, text: &str, add_bos: bool) -> TokenizerResult<Vec<u32>> {
687 let mut tokens = Vec::new();
688
689 if add_bos {
690 tokens.push(self.special_tokens.bos_token_id);
691 }
692
693 let segments = self.split_with_special_tokens(text);
694 for segment in segments {
695 match segment {
696 TextSegment::Text(t) => {
697 tokens.extend(self.encode_text_segment(&t)?);
698 }
699 TextSegment::SpecialToken(id) => {
700 tokens.push(id);
701 }
702 }
703 }
704
705 if !add_bos {
706 if let Some(PostProcessor::TemplateProcessing { ref single, .. }) = self.post_processor {
707 let mut processed = Vec::new();
708 for elem in single {
709 match elem {
710 TemplateElement::SpecialToken { token_id, .. } => {
711 processed.push(*token_id);
712 }
713 TemplateElement::Sequence { .. } => {
714 processed.extend(&tokens);
715 }
716 }
717 }
718 return Ok(processed);
719 }
720 }
721
722 Ok(tokens)
723 }
724
725 fn encode_sentencepiece(&self, text: &str) -> TokenizerResult<Vec<u32>> {
727 let mut result = Vec::new();
728
729 let text_with_prefix = if self.add_space_prefix {
731 format!(" {}", text)
732 } else {
733 text.to_string()
734 };
735 let chars: Vec<char> = text_with_prefix.chars().collect();
736 let mut pos = 0;
737
738 while pos < chars.len() {
739 let mut best_len = 0;
740 let mut best_id = None;
741
742 for end in (pos + 1..=chars.len()).rev() {
745 let substr: String = chars[pos..end].iter().collect();
746
747 let spm_str = substr.replace(' ', "▁");
749 if let Some(&id) = self.token_to_id.get(&spm_str) {
750 best_len = end - pos;
751 best_id = Some(id);
752 break; }
754
755 if let Some(&id) = self.token_to_id.get(&substr) {
757 best_len = end - pos;
758 best_id = Some(id);
759 break; }
761 }
762
763 if let Some(id) = best_id {
764 result.push(id);
765 pos += best_len;
766 } else {
767 let ch = chars[pos];
769 let ch_str = ch.to_string();
770
771 if ch == ' '
773 && let Some(&id) = self.token_to_id.get("▁")
774 {
775 result.push(id);
776 pos += 1;
777 continue;
778 }
779
780 if let Some(&id) = self.token_to_id.get(&ch_str) {
782 result.push(id);
783 pos += 1;
784 continue;
785 }
786
787 for byte in ch_str.as_bytes() {
789 let byte_token = format!("<0x{:02X}>", byte);
790 if let Some(&id) = self.token_to_id.get(&byte_token) {
791 result.push(id);
792 } else if let Some(unk_id) = self.special_tokens.unk_token_id {
793 result.push(unk_id);
794 }
795 }
796 pos += 1;
797 }
798 }
799
800 Ok(result)
801 }
802
803 fn encode_bpe(&self, text: &str) -> TokenizerResult<Vec<u32>> {
805 if self.gpt2_byte_to_unicode.is_some() {
806 return self.encode_bpe_gpt2(text);
807 }
808
809 let mut result = Vec::new();
810
811 let text_with_prefix = if self.add_space_prefix && !text.starts_with(' ') && !text.is_empty() {
812 format!(" {}", text)
813 } else {
814 text.to_string()
815 };
816
817 for segment in self.split_into_segments(&text_with_prefix) {
818 if segment.is_empty() {
819 continue;
820 }
821
822 if let Some(&id) = self.token_to_id.get(&segment) {
823 result.push(id);
824 continue;
825 }
826
827 let mut tokens = self.text_to_initial_tokens(&segment)?;
828 self.apply_bpe_merges(&mut tokens);
829 result.extend(tokens);
830 }
831
832 Ok(result)
833 }
834
835 fn encode_bpe_gpt2(&self, text: &str) -> TokenizerResult<Vec<u32>> {
840 let b2u = self.gpt2_byte_to_unicode.as_ref().unwrap();
841 let mut result = Vec::new();
842
843 for segment in Self::gpt2_pretokenize(text) {
844 if segment.is_empty() {
845 continue;
846 }
847
848 let mapped: String = segment.as_bytes().iter().map(|&b| b2u[b as usize]).collect();
849
850 if let Some(&id) = self.token_to_id.get(&mapped) {
851 result.push(id);
852 continue;
853 }
854
855 let mut tokens: Vec<u32> = Vec::with_capacity(mapped.len());
856 for ch in mapped.chars() {
857 let ch_str = ch.to_string();
858 if let Some(&id) = self.token_to_id.get(&ch_str) {
859 tokens.push(id);
860 } else if let Some(unk_id) = self.special_tokens.unk_token_id {
861 tokens.push(unk_id);
862 }
863 }
864
865 self.apply_bpe_merges(&mut tokens);
866 result.extend(tokens);
867 }
868
869 Ok(result)
870 }
871
872 fn gpt2_pretokenize(text: &str) -> Vec<String> {
878 let mut chunks = Vec::new();
879 let chars: Vec<char> = text.chars().collect();
880 let mut i = 0;
881
882 while i < chars.len() {
883 let ch = chars[i];
884
885 if ch == ' ' {
886 let mut chunk = String::new();
887 chunk.push(ch);
888 i += 1;
889 if i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
890 while i < chars.len()
891 && !chars[i].is_whitespace()
892 && (chars[i].is_alphanumeric() || chars[i] == '_')
893 {
894 chunk.push(chars[i]);
895 i += 1;
896 }
897 }
898 chunks.push(chunk);
899 } else if ch == '\n' || ch == '\r' || ch == '\t' {
900 let mut chunk = String::new();
901 while i < chars.len()
902 && (chars[i] == '\n' || chars[i] == '\r' || chars[i] == '\t')
903 {
904 chunk.push(chars[i]);
905 i += 1;
906 }
907 chunks.push(chunk);
908 } else if ch.is_alphabetic() || ch == '_' {
909 let mut chunk = String::new();
910 while i < chars.len() && (chars[i].is_alphabetic() || chars[i] == '_') {
911 chunk.push(chars[i]);
912 i += 1;
913 }
914 chunks.push(chunk);
915 } else if ch.is_ascii_digit() {
916 let mut chunk = String::new();
917 let mut count = 0;
918 while i < chars.len() && chars[i].is_ascii_digit() && count < 3 {
919 chunk.push(chars[i]);
920 i += 1;
921 count += 1;
922 }
923 chunks.push(chunk);
924 } else {
925 chunks.push(ch.to_string());
926 i += 1;
927 }
928 }
929
930 chunks
931 }
932
933 fn apply_bpe_merges(&self, tokens: &mut Vec<u32>) {
935 loop {
936 if tokens.len() < 2 {
937 break;
938 }
939
940 let mut best_merge: Option<(usize, u32, usize)> = None;
941
942 for i in 0..tokens.len() - 1 {
943 let pair = (tokens[i], tokens[i + 1]);
944 if let Some(&(merged_id, priority)) = self.merges.get(&pair)
945 && (best_merge.is_none() || priority < best_merge.unwrap().2)
946 {
947 best_merge = Some((i, merged_id, priority));
948 }
949 }
950
951 match best_merge {
952 Some((pos, merged_id, _)) => {
953 tokens[pos] = merged_id;
954 tokens.remove(pos + 1);
955 }
956 None => break,
957 }
958 }
959 }
960
961 fn split_into_segments(&self, text: &str) -> Vec<String> {
968 let mut segments = Vec::new();
969 let mut current = String::new();
970
971 for ch in text.chars() {
972 if ch.is_whitespace() {
975 if !current.is_empty() {
976 segments.push(current.clone());
977 current.clear();
978 }
979 current.push(ch);
980 } else if ch.is_ascii_punctuation() {
981 if !current.is_empty() {
983 segments.push(current.clone());
984 current.clear();
985 }
986 segments.push(ch.to_string());
987 } else {
988 current.push(ch);
989 }
990 }
991
992 if !current.is_empty() {
993 segments.push(current);
994 }
995
996 segments
997 }
998
999 fn text_to_initial_tokens(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1001 let mut tokens = Vec::new();
1002
1003 for ch in text.chars() {
1004 let ch_str = ch.to_string();
1005
1006 if let Some(&id) = self.token_to_id.get(&ch_str) {
1007 tokens.push(id);
1008 continue;
1009 }
1010
1011 if ch == ' '
1012 && let Some(&id) = self.token_to_id.get("▁")
1013 {
1014 tokens.push(id);
1015 continue;
1016 }
1017
1018 for byte in ch_str.as_bytes() {
1019 let byte_token = format!("<0x{:02X}>", byte);
1020 if let Some(&id) = self.token_to_id.get(&byte_token) {
1021 tokens.push(id);
1022 } else if let Some(unk_id) = self.special_tokens.unk_token_id {
1023 tokens.push(unk_id);
1024 }
1025 }
1026 }
1027
1028 Ok(tokens)
1029 }
1030
1031 #[allow(dead_code)]
1033 fn encode_fallback(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1034 let mut tokens = Vec::new();
1035
1036 for ch in text.chars() {
1037 let ch_str = ch.to_string();
1038 if let Some(&id) = self.token_to_id.get(&ch_str) {
1039 tokens.push(id);
1040 } else {
1041 for byte in ch_str.as_bytes() {
1043 let byte_token = format!("<0x{:02X}>", byte);
1044 if let Some(&id) = self.token_to_id.get(&byte_token) {
1045 tokens.push(id);
1046 } else if let Some(unk_id) = self.special_tokens.unk_token_id {
1047 tokens.push(unk_id);
1048 }
1049 }
1050 }
1051 }
1052
1053 Ok(tokens)
1054 }
1055
1056 fn encode_unigram(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1058 if text.is_empty() {
1059 return Ok(vec![]);
1060 }
1061
1062 let char_boundaries: Vec<usize> = text
1063 .char_indices()
1064 .map(|(i, _)| i)
1065 .chain(std::iter::once(text.len()))
1066 .collect();
1067 let n = char_boundaries.len() - 1;
1068
1069 const NEG_INF: f64 = -1e18;
1070 let mut best_score = vec![NEG_INF; n + 1];
1071 let mut best_path: Vec<Option<(u32, usize)>> = vec![None; n + 1];
1072 best_score[0] = 0.0;
1073
1074 let max_token_chars = 128;
1075
1076 for end in 1..=n {
1077 let end_byte = char_boundaries[end];
1078 let min_start = end.saturating_sub(max_token_chars);
1079
1080 for start in (min_start..end).rev() {
1081 if best_score[start] <= NEG_INF {
1082 continue;
1083 }
1084 let start_byte = char_boundaries[start];
1085 let substr = &text[start_byte..end_byte];
1086
1087 if let Some(&id) = self.token_to_id.get(substr) {
1088 let score = *self.scores.get(id as usize).unwrap_or(&0.0) as f64;
1089 let candidate = best_score[start] + score;
1090 if candidate > best_score[end] {
1091 best_score[end] = candidate;
1092 best_path[end] = Some((id, start));
1093 }
1094 }
1095 }
1096
1097 if best_path[end].is_none() && best_score[end - 1] > NEG_INF {
1099 let start_byte = char_boundaries[end - 1];
1100 let end_byte_val = char_boundaries[end];
1101 let ch_str = &text[start_byte..end_byte_val];
1102
1103 if let Some(&id) = self.token_to_id.get(ch_str) {
1104 let score = *self.scores.get(id as usize).unwrap_or(&-10.0) as f64;
1105 best_score[end] = best_score[end - 1] + score;
1106 best_path[end] = Some((id, end - 1));
1107 } else {
1108 for byte in ch_str.as_bytes() {
1110 let byte_token = format!("<0x{:02X}>", byte);
1111 if let Some(&id) = self.token_to_id.get(&byte_token) {
1112 let score = *self.scores.get(id as usize).unwrap_or(&-10.0) as f64;
1113 let candidate = best_score[end - 1] + score;
1114 if candidate > best_score[end] {
1115 best_score[end] = candidate;
1116 best_path[end] = Some((id, end - 1));
1117 }
1118 }
1119 }
1120 }
1121 }
1122 }
1123
1124 if best_score[n] <= NEG_INF {
1125 return self.encode_unigram_fallback(text);
1126 }
1127
1128 let mut result = Vec::new();
1129 let mut pos = n;
1130 while pos > 0 {
1131 if let Some((token_id, start)) = best_path[pos] {
1132 result.push(token_id);
1133 pos = start;
1134 } else {
1135 break;
1136 }
1137 }
1138 result.reverse();
1139 Ok(result)
1140 }
1141
1142 fn encode_unigram_fallback(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1144 let mut result = Vec::new();
1145 for ch in text.chars() {
1146 let ch_str = ch.to_string();
1147 if let Some(&id) = self.token_to_id.get(&ch_str) {
1148 result.push(id);
1149 } else {
1150 for byte in ch_str.as_bytes() {
1151 let byte_token = format!("<0x{:02X}>", byte);
1152 if let Some(&id) = self.token_to_id.get(&byte_token) {
1153 result.push(id);
1154 } else if let Some(unk_id) = self.special_tokens.unk_token_id {
1155 result.push(unk_id);
1156 }
1157 }
1158 }
1159 }
1160 Ok(result)
1161 }
1162
1163 fn encode_wordpiece(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1165 if text.is_empty() {
1166 return Ok(vec![]);
1167 }
1168
1169 let mut result = Vec::new();
1170 let chars: Vec<char> = text.chars().collect();
1171
1172 let words: Vec<String> = text.split_whitespace().map(|s| s.to_string()).collect();
1174 let words = if words.is_empty() {
1175 vec![text.to_string()]
1176 } else {
1177 words
1178 };
1179
1180 for word in &words {
1181 let word_chars: Vec<char> = word.chars().collect();
1182 if word_chars.len() > 200 {
1183 if let Some(unk_id) = self.special_tokens.unk_token_id {
1184 result.push(unk_id);
1185 }
1186 continue;
1187 }
1188
1189 let mut start = 0;
1190 let mut is_first_subword = true;
1191
1192 while start < word_chars.len() {
1193 let mut end = word_chars.len();
1194 let mut found = false;
1195
1196 while start < end {
1197 let substr: String = word_chars[start..end].iter().collect();
1198 let candidate = if is_first_subword {
1199 substr.clone()
1200 } else {
1201 format!("{}{}", self.wordpiece_prefix, substr)
1202 };
1203
1204 if let Some(&id) = self.token_to_id.get(&candidate) {
1205 result.push(id);
1206 found = true;
1207 break;
1208 }
1209 end -= 1;
1210 }
1211
1212 if !found {
1213 if let Some(unk_id) = self.special_tokens.unk_token_id {
1214 result.push(unk_id);
1215 }
1216 break;
1217 }
1218
1219 start = end;
1220 is_first_subword = false;
1221 }
1222 }
1223
1224 let _ = chars; Ok(result)
1226 }
1227
1228 fn encode_bpe_pretokenized(&self, text: &str) -> TokenizerResult<Vec<u32>> {
1230 if let Some(&id) = self.token_to_id.get(text) {
1231 return Ok(vec![id]);
1232 }
1233
1234 let mut tokens = self.text_to_initial_tokens(text)?;
1235 self.apply_bpe_merges(&mut tokens);
1236 Ok(tokens)
1237 }
1238
1239 pub fn decode(&self, tokens: &[u32]) -> TokenizerResult<String> {
1241 if let Some(ref u2b) = self.gpt2_unicode_to_byte {
1242 return self.decode_gpt2(tokens, u2b);
1243 }
1244 self.decode_sentencepiece(tokens)
1245 }
1246
1247 fn decode_gpt2(
1253 &self,
1254 tokens: &[u32],
1255 unicode_to_byte: &HashMap<char, u8>,
1256 ) -> TokenizerResult<String> {
1257 let mut raw_bytes: Vec<u8> = Vec::new();
1258
1259 for &token_id in tokens {
1260 if self.is_special_token(token_id) {
1261 continue;
1262 }
1263
1264 let token_str = self.id_to_token.get(token_id as usize).ok_or_else(|| {
1265 TokenizerError::InvalidToken(format!("Unknown token ID: {}", token_id))
1266 })?;
1267
1268 if self.get_token_type(token_id) == TokenType::Control {
1270 continue;
1271 }
1272
1273 if token_str.starts_with("<0x")
1275 && token_str.ends_with('>')
1276 && token_str.len() == 6
1277 && let Ok(byte) = u8::from_str_radix(&token_str[3..5], 16)
1278 {
1279 raw_bytes.push(byte);
1280 continue;
1281 }
1282
1283 for ch in token_str.chars() {
1285 if let Some(&b) = unicode_to_byte.get(&ch) {
1286 raw_bytes.push(b);
1287 } else {
1288 let mut buf = [0u8; 4];
1290 let encoded = ch.encode_utf8(&mut buf);
1291 raw_bytes.extend_from_slice(encoded.as_bytes());
1292 }
1293 }
1294 }
1295
1296 Ok(String::from_utf8_lossy(&raw_bytes).into_owned())
1297 }
1298
1299 fn decode_sentencepiece(&self, tokens: &[u32]) -> TokenizerResult<String> {
1301 let mut text = String::new();
1302 let mut byte_buffer: Vec<u8> = Vec::new();
1303
1304 for &token_id in tokens {
1305 if self.is_special_token(token_id) {
1306 continue;
1307 }
1308
1309 if self.get_token_type(token_id) == TokenType::Control {
1310 continue;
1311 }
1312
1313 let token_str = self.id_to_token.get(token_id as usize).ok_or_else(|| {
1314 TokenizerError::InvalidToken(format!("Unknown token ID: {}", token_id))
1315 })?;
1316
1317 if token_str.starts_with("<0x")
1319 && token_str.ends_with('>')
1320 && token_str.len() == 6
1321 && let Ok(byte) = u8::from_str_radix(&token_str[3..5], 16)
1322 {
1323 byte_buffer.push(byte);
1324 continue;
1325 }
1326
1327 if !byte_buffer.is_empty() {
1329 text.push_str(&String::from_utf8_lossy(&byte_buffer));
1330 byte_buffer.clear();
1331 }
1332
1333 text.push_str(&token_str.replace('▁', " "));
1335 }
1336
1337 if !byte_buffer.is_empty() {
1339 text.push_str(&String::from_utf8_lossy(&byte_buffer));
1340 }
1341
1342 Ok(text)
1343 }
1344
1345 pub fn decode_token(&self, token_id: u32) -> TokenizerResult<String> {
1347 self.decode(&[token_id])
1348 }
1349
1350 pub fn decode_token_streaming(
1356 &self,
1357 token_id: u32,
1358 pending: &mut Vec<u8>,
1359 ) -> TokenizerResult<String> {
1360 if self.is_special_token(token_id) || self.get_token_type(token_id) == TokenType::Control {
1361 let flushed = flush_valid_utf8(pending);
1363 return Ok(flushed);
1364 }
1365
1366 let token_str = self.id_to_token.get(token_id as usize).ok_or_else(|| {
1367 TokenizerError::InvalidToken(format!("Unknown token ID: {}", token_id))
1368 })?;
1369
1370 if token_str.starts_with("<0x")
1372 && token_str.ends_with('>')
1373 && token_str.len() == 6
1374 && let Ok(byte) = u8::from_str_radix(&token_str[3..5], 16)
1375 {
1376 pending.push(byte);
1377 return Ok(flush_valid_utf8(pending));
1378 }
1379
1380 if let Some(ref u2b) = self.gpt2_unicode_to_byte {
1381 for ch in token_str.chars() {
1383 if let Some(&b) = u2b.get(&ch) {
1384 pending.push(b);
1385 } else {
1386 let mut buf = [0u8; 4];
1387 let encoded = ch.encode_utf8(&mut buf);
1388 pending.extend_from_slice(encoded.as_bytes());
1389 }
1390 }
1391 Ok(flush_valid_utf8(pending))
1392 } else {
1393 let mut result = flush_valid_utf8(pending);
1395 result.push_str(&token_str.replace('▁', " "));
1396 Ok(result)
1397 }
1398 }
1399
1400 pub fn get_token(&self, id: u32) -> Option<&str> {
1402 self.id_to_token.get(id as usize).map(|s| s.as_str())
1403 }
1404
1405 pub fn get_token_id(&self, token: &str) -> Option<u32> {
1407 self.token_to_id.get(token).copied()
1408 }
1409
1410 pub fn get_token_type(&self, id: u32) -> TokenType {
1412 self.token_types
1413 .get(id as usize)
1414 .copied()
1415 .unwrap_or(TokenType::Normal)
1416 }
1417
1418 pub fn is_special_token(&self, id: u32) -> bool {
1420 id == self.special_tokens.bos_token_id
1421 || id == self.special_tokens.eos_token_id
1422 || self.special_tokens.pad_token_id == Some(id)
1423 || self.special_tokens.unk_token_id == Some(id)
1424 }
1425
1426 pub fn from_hf_json(path: impl AsRef<std::path::Path>) -> TokenizerResult<Self> {
1431 let path = path.as_ref();
1432 let data = std::fs::read_to_string(path)
1433 .map_err(|e| TokenizerError::MissingData(format!("{}: {}", path.display(), e)))?;
1434
1435 Self::from_hf_json_str(&data)
1436 }
1437
1438 pub fn from_hf_json_str(json: &str) -> TokenizerResult<Self> {
1440 let root: serde_json::Value = serde_json::from_str(json)
1441 .map_err(|e| TokenizerError::EncodingError(format!("Invalid tokenizer.json: {}", e)))?;
1442
1443 let model = root
1444 .get("model")
1445 .ok_or_else(|| TokenizerError::MissingData("model section in tokenizer.json".into()))?;
1446
1447 let model_type = model.get("type").and_then(|v| v.as_str()).unwrap_or("BPE");
1448 let tokenizer_type = match model_type {
1449 "BPE" => TokenizerType::BPE,
1450 "Unigram" => TokenizerType::SentencePiece,
1451 "WordPiece" => TokenizerType::WordPiece,
1452 _ => TokenizerType::Unknown,
1453 };
1454
1455 let mut token_to_id = HashMap::new();
1456 let mut id_to_token: Vec<String>;
1457 let mut scores: Vec<f32>;
1458 let mut merges = HashMap::new();
1459 let mut wordpiece_prefix = "##".to_string();
1460 let mut model_unk_token: Option<String> = None;
1461
1462 match tokenizer_type {
1463 TokenizerType::SentencePiece => {
1464 let vocab_arr = model
1466 .get("vocab")
1467 .and_then(|v| v.as_array())
1468 .ok_or_else(|| {
1469 TokenizerError::MissingData("Unigram vocab array".into())
1470 })?;
1471
1472 id_to_token = Vec::with_capacity(vocab_arr.len());
1473 scores = Vec::with_capacity(vocab_arr.len());
1474
1475 for (id, entry) in vocab_arr.iter().enumerate() {
1476 let arr = entry.as_array().ok_or_else(|| {
1477 TokenizerError::MissingData(format!(
1478 "Unigram vocab entry {} not an array",
1479 id
1480 ))
1481 })?;
1482 let token = arr
1483 .first()
1484 .and_then(|v| v.as_str())
1485 .ok_or_else(|| {
1486 TokenizerError::MissingData(format!(
1487 "Unigram vocab entry {} missing token",
1488 id
1489 ))
1490 })?;
1491 let score = arr
1492 .get(1)
1493 .and_then(|v| v.as_f64())
1494 .unwrap_or(0.0) as f32;
1495
1496 token_to_id.insert(token.to_string(), id as u32);
1497 id_to_token.push(token.to_string());
1498 scores.push(score);
1499 }
1500
1501 if let Some(unk_id) = model.get("unk_id").and_then(|v| v.as_u64()) {
1502 model_unk_token = id_to_token.get(unk_id as usize).cloned();
1503 }
1504 }
1505 TokenizerType::WordPiece => {
1506 let vocab_obj = model
1508 .get("vocab")
1509 .and_then(|v| v.as_object())
1510 .ok_or_else(|| {
1511 TokenizerError::MissingData("WordPiece vocab object".into())
1512 })?;
1513
1514 let vocab_size = vocab_obj.len();
1515 id_to_token = vec![String::new(); vocab_size];
1516
1517 for (token, id_val) in vocab_obj {
1518 let id = id_val.as_u64().ok_or_else(|| {
1519 TokenizerError::MissingData(format!("Invalid vocab ID for '{}'", token))
1520 })? as u32;
1521 token_to_id.insert(token.clone(), id);
1522 if (id as usize) < id_to_token.len() {
1523 id_to_token[id as usize] = token.clone();
1524 }
1525 }
1526
1527 if let Some(prefix) = model
1528 .get("continuing_subword_prefix")
1529 .and_then(|v| v.as_str())
1530 {
1531 wordpiece_prefix = prefix.to_string();
1532 }
1533 if let Some(unk) = model.get("unk_token").and_then(|v| v.as_str()) {
1534 model_unk_token = Some(unk.to_string());
1535 }
1536
1537 scores = vec![0.0; id_to_token.len()];
1538 }
1539 _ => {
1540 let vocab_obj = model
1542 .get("vocab")
1543 .and_then(|v| v.as_object())
1544 .ok_or_else(|| {
1545 TokenizerError::MissingData("BPE vocab object".into())
1546 })?;
1547
1548 let vocab_size = vocab_obj.len();
1549 id_to_token = vec![String::new(); vocab_size];
1550
1551 for (token, id_val) in vocab_obj {
1552 let id = id_val.as_u64().ok_or_else(|| {
1553 TokenizerError::MissingData(format!("Invalid vocab ID for '{}'", token))
1554 })? as u32;
1555 token_to_id.insert(token.clone(), id);
1556 if (id as usize) < id_to_token.len() {
1557 id_to_token[id as usize] = token.clone();
1558 }
1559 }
1560
1561 if let Some(merges_arr) = model.get("merges").and_then(|v| v.as_array()) {
1562 for (priority, merge_val) in merges_arr.iter().enumerate() {
1563 let (part0, part1) = if let Some(merge_str) = merge_val.as_str() {
1567 let parts: Vec<&str> = merge_str.split(' ').collect();
1568 if parts.len() == 2 {
1569 (parts[0].to_string(), parts[1].to_string())
1570 } else {
1571 continue;
1572 }
1573 } else if let Some(arr) = merge_val.as_array() {
1574 if arr.len() == 2 {
1575 if let (Some(a), Some(b)) = (arr[0].as_str(), arr[1].as_str()) {
1576 (a.to_string(), b.to_string())
1577 } else {
1578 continue;
1579 }
1580 } else {
1581 continue;
1582 }
1583 } else {
1584 continue;
1585 };
1586
1587 if let (Some(&id1), Some(&id2)) =
1588 (token_to_id.get(&part0), token_to_id.get(&part1))
1589 {
1590 let merged = format!("{}{}", part0, part1);
1591 if let Some(&merged_id) = token_to_id.get(&merged) {
1592 merges.insert((id1, id2), (merged_id, priority));
1593 }
1594 }
1595 }
1596 }
1597
1598 scores = vec![0.0; id_to_token.len()];
1599 }
1600 }
1601
1602 let vocab_size = id_to_token.len();
1603
1604 let mut bos_token_id: Option<u32> = None;
1606 let mut eos_token_id: Option<u32> = None;
1607 let mut pad_token_id: Option<u32> = None;
1608 let mut unk_token_id: Option<u32> = None;
1609
1610 if let Some(added_tokens) = root.get("added_tokens").and_then(|v| v.as_array()) {
1611 for token_obj in added_tokens {
1612 let content = token_obj
1613 .get("content")
1614 .and_then(|v| v.as_str())
1615 .unwrap_or("");
1616 let id = token_obj
1617 .get("id")
1618 .and_then(|v| v.as_u64())
1619 .map(|v| v as u32);
1620 let special = token_obj
1621 .get("special")
1622 .and_then(|v| v.as_bool())
1623 .unwrap_or(false);
1624
1625 if let Some(id) = id {
1626 token_to_id.insert(content.to_string(), id);
1627 if (id as usize) < id_to_token.len() {
1628 id_to_token[id as usize] = content.to_string();
1629 }
1630
1631 if special {
1632 let content_lower = content.to_lowercase();
1633 if content_lower.contains("bos")
1634 || content == "<s>"
1635 || content == "<|begin_of_text|>"
1636 || content == "<|startoftext|>"
1637 {
1638 bos_token_id = Some(id);
1639 }
1640 if content_lower.contains("eos")
1641 || content == "</s>"
1642 || content == "<|end_of_text|>"
1643 || content == "<|endoftext|>"
1644 || content == "<|eot_id|>"
1645 {
1646 eos_token_id = Some(id);
1647 }
1648 if content_lower.contains("pad") || content == "<pad>" {
1649 pad_token_id = Some(id);
1650 }
1651 if content_lower.contains("unk") || content == "<unk>" {
1652 unk_token_id = Some(id);
1653 }
1654 }
1655 }
1656 }
1657 }
1658
1659 if unk_token_id.is_none() {
1661 if let Some(ref unk_str) = model_unk_token {
1662 unk_token_id = token_to_id.get(unk_str).copied();
1663 }
1664 }
1665
1666 if let Some(post_proc) = root.get("post_processor") {
1668 if let Some(special_tokens_map) = post_proc.get("special_tokens") {
1669 if let Some(bos_obj) = special_tokens_map
1670 .get("<s>")
1671 .or_else(|| special_tokens_map.get("<|begin_of_text|>"))
1672 && let Some(ids) = bos_obj.get("ids").and_then(|v| v.as_array())
1673 && let Some(id) = ids.first().and_then(|v| v.as_u64())
1674 {
1675 bos_token_id = bos_token_id.or(Some(id as u32));
1676 }
1677 if let Some(eos_obj) = special_tokens_map
1678 .get("</s>")
1679 .or_else(|| special_tokens_map.get("<|end_of_text|>"))
1680 && let Some(ids) = eos_obj.get("ids").and_then(|v| v.as_array())
1681 && let Some(id) = ids.first().and_then(|v| v.as_u64())
1682 {
1683 eos_token_id = eos_token_id.or(Some(id as u32));
1684 }
1685 }
1686 }
1687
1688 let special_tokens = SpecialTokens {
1689 bos_token_id: bos_token_id.unwrap_or(1),
1690 eos_token_id: eos_token_id.unwrap_or(2),
1691 pad_token_id,
1692 unk_token_id,
1693 };
1694
1695 let mut token_types = vec![TokenType::Normal; vocab_size];
1697 for &id in [special_tokens.bos_token_id, special_tokens.eos_token_id].iter() {
1698 if (id as usize) < token_types.len() {
1699 token_types[id as usize] = TokenType::Control;
1700 }
1701 }
1702 if let Some(pad_id) = special_tokens.pad_token_id
1703 && (pad_id as usize) < token_types.len()
1704 {
1705 token_types[pad_id as usize] = TokenType::Control;
1706 }
1707 if let Some(unk_id) = special_tokens.unk_token_id
1708 && (unk_id as usize) < token_types.len()
1709 {
1710 token_types[unk_id as usize] = TokenType::Control;
1711 }
1712 for (token, &id) in &token_to_id {
1713 if token.starts_with("<0x")
1714 && token.ends_with('>')
1715 && token.len() == 6
1716 && (id as usize) < token_types.len()
1717 {
1718 token_types[id as usize] = TokenType::Byte;
1719 }
1720 }
1721
1722 let uses_byte_level = root
1724 .get("pre_tokenizer")
1725 .and_then(|v| v.get("type").or_else(|| {
1726 v.get("pretokenizers").and_then(|arr| {
1728 arr.as_array().and_then(|a| {
1729 a.iter().find_map(|pt| {
1730 pt.get("type").filter(|t| t.as_str() == Some("ByteLevel"))
1731 })
1732 })
1733 })
1734 }))
1735 .and_then(|v| v.as_str())
1736 .is_some_and(|t| t == "ByteLevel");
1737
1738 let (gpt2_unicode_to_byte, gpt2_byte_to_unicode) = if tokenizer_type == TokenizerType::BPE && uses_byte_level {
1739 let (u2b, b2u) = build_gpt2_mappings();
1740 (Some(u2b), Some(b2u))
1741 } else {
1742 (None, None)
1743 };
1744
1745 let normalizer = root.get("normalizer")
1747 .and_then(|v| if v.is_null() { None } else { Self::parse_normalizer(v) });
1748 let pre_tokenizer = root.get("pre_tokenizer")
1749 .and_then(|v| if v.is_null() { None } else { Self::parse_pre_tokenizer(v) });
1750 let post_processor = root.get("post_processor")
1751 .and_then(|v| if v.is_null() { None } else { Self::parse_post_processor(v, &token_to_id) });
1752
1753 let mut control_token_strings: Vec<(String, u32)> = token_types
1754 .iter()
1755 .enumerate()
1756 .filter(|(_, tt)| **tt == TokenType::Control)
1757 .filter_map(|(id, _)| {
1758 let s = &id_to_token[id];
1759 if !s.is_empty() {
1760 Some((s.clone(), id as u32))
1761 } else {
1762 None
1763 }
1764 })
1765 .collect();
1766 control_token_strings.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
1767
1768 Ok(Self {
1769 token_to_id,
1770 id_to_token,
1771 scores,
1772 merges,
1773 special_tokens,
1774 tokenizer_type,
1775 vocab_size,
1776 token_types,
1777 gpt2_unicode_to_byte,
1778 gpt2_byte_to_unicode,
1779 normalizer,
1780 pre_tokenizer,
1781 post_processor,
1782 wordpiece_prefix,
1783 control_token_strings,
1784 has_explicit_bos: bos_token_id.is_some(),
1785 add_space_prefix: true, })
1787 }
1788
1789 fn parse_normalizer(value: &serde_json::Value) -> Option<Normalizer> {
1790 let type_str = value.get("type")?.as_str()?;
1791 match type_str {
1792 "NFC" => Some(Normalizer::NFC),
1793 "NFKC" => Some(Normalizer::NFKC),
1794 "NFD" => Some(Normalizer::NFD),
1795 "NFKD" => Some(Normalizer::NFKD),
1796 "Lowercase" => Some(Normalizer::Lowercase),
1797 "Strip" => {
1798 let left = value.get("strip_left").and_then(|v| v.as_bool()).unwrap_or(true);
1799 let right = value.get("strip_right").and_then(|v| v.as_bool()).unwrap_or(true);
1800 Some(Normalizer::Strip { left, right })
1801 }
1802 "Prepend" => {
1803 let prepend = value.get("prepend").and_then(|v| v.as_str()).unwrap_or("▁");
1804 Some(Normalizer::Prepend(prepend.to_string()))
1805 }
1806 "Replace" => {
1807 let pattern = value
1808 .get("pattern")
1809 .and_then(|v| v.get("String").and_then(|s| s.as_str()))
1810 .unwrap_or("");
1811 let content = value.get("content").and_then(|v| v.as_str()).unwrap_or("");
1812 Some(Normalizer::Replace {
1813 pattern: pattern.to_string(),
1814 content: content.to_string(),
1815 })
1816 }
1817 "StripAccents" => Some(Normalizer::StripAccents),
1818 "Sequence" => {
1819 let normalizers = value.get("normalizers")?.as_array()?;
1820 let parsed: Vec<Normalizer> = normalizers
1821 .iter()
1822 .filter_map(|v| Self::parse_normalizer(v))
1823 .collect();
1824 if parsed.is_empty() {
1825 None
1826 } else {
1827 Some(Normalizer::Sequence(parsed))
1828 }
1829 }
1830 "BertNormalizer" => {
1831 let mut seq = Vec::new();
1832 if value
1833 .get("lowercase")
1834 .and_then(|v| v.as_bool())
1835 .unwrap_or(true)
1836 {
1837 seq.push(Normalizer::Lowercase);
1838 }
1839 if value
1840 .get("strip_accents")
1841 .and_then(|v| v.as_bool())
1842 .unwrap_or(false)
1843 {
1844 seq.push(Normalizer::StripAccents);
1845 }
1846 match seq.len() {
1847 0 => None,
1848 1 => Some(seq.remove(0)),
1849 _ => Some(Normalizer::Sequence(seq)),
1850 }
1851 }
1852 "Precompiled" => Some(Normalizer::NFC),
1853 _ => None,
1854 }
1855 }
1856
1857 fn parse_pre_tokenizer(value: &serde_json::Value) -> Option<PreTokenizer> {
1858 let type_str = value.get("type")?.as_str()?;
1859 match type_str {
1860 "ByteLevel" => {
1861 let add_prefix_space = value
1862 .get("add_prefix_space")
1863 .and_then(|v| v.as_bool())
1864 .unwrap_or(true);
1865 Some(PreTokenizer::ByteLevel { add_prefix_space })
1866 }
1867 "Whitespace" | "WhitespaceSplit" => Some(PreTokenizer::Whitespace),
1868 "Metaspace" => {
1869 let replacement = value
1870 .get("replacement")
1871 .and_then(|v| v.as_str())
1872 .and_then(|s| s.chars().next())
1873 .unwrap_or('▁');
1874 let add_prefix_space = value
1875 .get("add_prefix_space")
1876 .and_then(|v| v.as_bool())
1877 .unwrap_or(true);
1878 Some(PreTokenizer::Metaspace {
1879 replacement,
1880 add_prefix_space,
1881 })
1882 }
1883 "Punctuation" | "BertPreTokenizer" => Some(PreTokenizer::Punctuation),
1884 "Digits" => {
1885 let individual_digits = value
1886 .get("individual_digits")
1887 .and_then(|v| v.as_bool())
1888 .unwrap_or(false);
1889 Some(PreTokenizer::Digits { individual_digits })
1890 }
1891 "Sequence" => {
1892 let pretokenizers = value.get("pretokenizers")?.as_array()?;
1893 let parsed: Vec<PreTokenizer> = pretokenizers
1894 .iter()
1895 .filter_map(|v| Self::parse_pre_tokenizer(v))
1896 .collect();
1897 if parsed.is_empty() {
1898 None
1899 } else {
1900 Some(PreTokenizer::Sequence(parsed))
1901 }
1902 }
1903 _ => None,
1904 }
1905 }
1906
1907 fn parse_post_processor(
1908 value: &serde_json::Value,
1909 token_to_id: &HashMap<String, u32>,
1910 ) -> Option<PostProcessor> {
1911 let type_str = value.get("type")?.as_str()?;
1912 match type_str {
1913 "TemplateProcessing" => {
1914 let parse_template = |arr: &[serde_json::Value]| -> Vec<TemplateElement> {
1915 arr.iter()
1916 .filter_map(|item| {
1917 if let Some(special) = item.get("SpecialToken") {
1918 let id_str = special.get("id")?.as_str()?;
1919 let token_id = token_to_id.get(id_str).copied()?;
1920 Some(TemplateElement::SpecialToken {
1921 id: id_str.to_string(),
1922 token_id,
1923 })
1924 } else if item.get("Sequence").is_some() {
1925 let type_id = item
1926 .get("Sequence")
1927 .and_then(|s| s.get("id"))
1928 .and_then(|v| v.as_u64())
1929 .unwrap_or(0) as u32;
1930 Some(TemplateElement::Sequence { type_id })
1931 } else {
1932 None
1933 }
1934 })
1935 .collect()
1936 };
1937
1938 let single = value
1939 .get("single")
1940 .and_then(|v| v.as_array())
1941 .map(|a| parse_template(a))
1942 .unwrap_or_default();
1943 let pair = value
1944 .get("pair")
1945 .and_then(|v| v.as_array())
1946 .map(|a| parse_template(a))
1947 .unwrap_or_default();
1948
1949 Some(PostProcessor::TemplateProcessing { single, pair })
1950 }
1951 "ByteLevel" => {
1952 let trim_offsets = value
1953 .get("trim_offsets")
1954 .and_then(|v| v.as_bool())
1955 .unwrap_or(true);
1956 Some(PostProcessor::ByteLevel { trim_offsets })
1957 }
1958 "BertProcessing" => {
1959 let mut single = Vec::new();
1960 let mut pair = Vec::new();
1961
1962 if let Some(cls) = value.get("cls").and_then(|v| v.as_array()) {
1963 if let (Some(token), Some(id)) = (
1964 cls.first().and_then(|v| v.as_str()),
1965 cls.get(1).and_then(|v| v.as_u64()),
1966 ) {
1967 let elem = TemplateElement::SpecialToken {
1968 id: token.to_string(),
1969 token_id: id as u32,
1970 };
1971 single.push(elem.clone());
1972 pair.push(elem);
1973 }
1974 }
1975
1976 single.push(TemplateElement::Sequence { type_id: 0 });
1977 pair.push(TemplateElement::Sequence { type_id: 0 });
1978
1979 if let Some(sep) = value.get("sep").and_then(|v| v.as_array()) {
1980 if let (Some(token), Some(id)) = (
1981 sep.first().and_then(|v| v.as_str()),
1982 sep.get(1).and_then(|v| v.as_u64()),
1983 ) {
1984 let elem = TemplateElement::SpecialToken {
1985 id: token.to_string(),
1986 token_id: id as u32,
1987 };
1988 single.push(elem.clone());
1989 pair.push(elem.clone());
1990 pair.push(TemplateElement::Sequence { type_id: 1 });
1991 pair.push(elem);
1992 }
1993 }
1994
1995 Some(PostProcessor::TemplateProcessing { single, pair })
1996 }
1997 _ => None,
1998 }
1999 }
2000}
2001
2002#[cfg(test)]
2003mod tests {
2004 use super::*;
2005
2006 #[test]
2007 fn test_tokenizer_type_parsing() {
2008 assert_eq!(TokenizerType::from_gguf_str("llama"), TokenizerType::BPE);
2009 assert_eq!(TokenizerType::from_gguf_str("bpe"), TokenizerType::BPE);
2010 assert_eq!(
2011 TokenizerType::from_gguf_str("sentencepiece"),
2012 TokenizerType::SentencePiece
2013 );
2014 }
2015
2016 #[test]
2017 fn test_special_tokens_default() {
2018 let special = SpecialTokens::default();
2019 assert_eq!(special.bos_token_id, 1);
2020 assert_eq!(special.eos_token_id, 2);
2021 }
2022
2023 #[test]
2024 fn test_gpt2_unicode_to_byte_table() {
2025 let (table, _) = build_gpt2_mappings();
2026 assert_eq!(table.len(), 256);
2027
2028 assert_eq!(table[&'!'], b'!');
2030 assert_eq!(table[&'A'], b'A');
2031 assert_eq!(table[&'~'], b'~');
2032
2033 assert_eq!(table[&'Ġ'], b' '); assert_eq!(table[&'Ċ'], b'\n'); assert_eq!(table[&'ĉ'], b'\t'); assert_eq!(table[&'¡'], 0xA1);
2040 assert_eq!(table[&'®'], 0xAE);
2041 assert_eq!(table[&'ÿ'], 0xFF);
2042 }
2043
2044 #[test]
2045 fn test_gpt2_decode_space_and_emoji() {
2046 let (table, _) = build_gpt2_mappings();
2047
2048 let bytes: Vec<u8> = "ĠHello".chars().map(|c| table[&c]).collect();
2050 assert_eq!(String::from_utf8(bytes).unwrap(), " Hello");
2051
2052 let bytes: Vec<u8> = "ðŁĺĬ".chars().map(|c| table[&c]).collect();
2054 let decoded = String::from_utf8(bytes).unwrap();
2055 assert_eq!(decoded, "😊");
2056 }
2057
2058 #[test]
2059 fn test_normalizer_nfc() {
2060 let norm = Normalizer::NFC;
2061 let decomposed = "e\u{0301}";
2063 let result = norm.apply(decomposed);
2064 assert_eq!(result, "\u{00E9}");
2065 }
2066
2067 #[test]
2068 fn test_normalizer_lowercase() {
2069 let norm = Normalizer::Lowercase;
2070 assert_eq!(norm.apply("HELLO World"), "hello world");
2071 }
2072
2073 #[test]
2074 fn test_normalizer_strip_accents() {
2075 let norm = Normalizer::StripAccents;
2076 assert_eq!(norm.apply("café"), "cafe");
2077 assert_eq!(norm.apply("naïve"), "naive");
2078 }
2079
2080 #[test]
2081 fn test_normalizer_sequence() {
2082 let norm = Normalizer::Sequence(vec![
2083 Normalizer::NFKC,
2084 Normalizer::Lowercase,
2085 ]);
2086 assert_eq!(norm.apply("HÉLLO"), "héllo");
2087 }
2088
2089 #[test]
2090 fn test_normalizer_replace() {
2091 let norm = Normalizer::Replace {
2092 pattern: " ".to_string(),
2093 content: "▁".to_string(),
2094 };
2095 assert_eq!(norm.apply("hello world"), "hello▁world");
2096 }
2097
2098 #[test]
2099 fn test_pre_tokenizer_whitespace() {
2100 let pt = PreTokenizer::Whitespace;
2101 assert_eq!(pt.apply("Hello world test"), vec!["Hello", "world", "test"]);
2102 }
2103
2104 #[test]
2105 fn test_pre_tokenizer_byte_level() {
2106 let pt = PreTokenizer::ByteLevel { add_prefix_space: true };
2107 let result = pt.apply("Hello world");
2108 assert_eq!(result, vec![" Hello", " world"]);
2109
2110 let pt_no_space = PreTokenizer::ByteLevel { add_prefix_space: false };
2111 let result = pt_no_space.apply("Hello world");
2112 assert_eq!(result, vec!["Hello", " world"]);
2113 }
2114
2115 #[test]
2116 fn test_pre_tokenizer_punctuation() {
2117 let pt = PreTokenizer::Punctuation;
2118 let result = pt.apply("Hello, world!");
2119 assert_eq!(result, vec!["Hello", ",", " world", "!"]);
2120 }
2121
2122 #[test]
2123 fn test_pre_tokenizer_digits() {
2124 let pt = PreTokenizer::Digits { individual_digits: true };
2125 let result = pt.apply("abc123def");
2126 assert_eq!(result, vec!["abc", "1", "2", "3", "def"]);
2127 }
2128
2129 #[test]
2130 fn test_pre_tokenizer_sequence() {
2131 let pt = PreTokenizer::Sequence(vec![
2132 PreTokenizer::Whitespace,
2133 PreTokenizer::Punctuation,
2134 ]);
2135 let result = pt.apply("Hello, world!");
2136 assert_eq!(result, vec!["Hello", ",", "world", "!"]);
2137 }
2138
2139 #[test]
2140 fn test_unigram_from_hf_json() {
2141 let json = r#"{
2142 "model": {
2143 "type": "Unigram",
2144 "unk_id": 0,
2145 "vocab": [
2146 ["<unk>", 0.0],
2147 ["▁", -1.0],
2148 ["▁the", -2.0],
2149 ["▁a", -2.5],
2150 ["h", -3.0],
2151 ["e", -3.0],
2152 ["l", -3.0],
2153 ["o", -3.0],
2154 ["he", -2.0],
2155 ["llo", -2.5]
2156 ]
2157 },
2158 "pre_tokenizer": {
2159 "type": "Metaspace",
2160 "replacement": "▁",
2161 "add_prefix_space": true
2162 },
2163 "added_tokens": [
2164 {"id": 0, "content": "<unk>", "special": true}
2165 ]
2166 }"#;
2167
2168 let tok = Tokenizer::from_hf_json_str(json).unwrap();
2169 assert_eq!(tok.tokenizer_type, TokenizerType::SentencePiece);
2170 assert_eq!(tok.vocab_size, 10);
2171 assert!(tok.scores.iter().any(|&s| s != 0.0));
2172 }
2173
2174 #[test]
2175 fn test_wordpiece_from_hf_json() {
2176 let json = r###"{
2177 "model": {
2178 "type": "WordPiece",
2179 "unk_token": "[UNK]",
2180 "continuing_subword_prefix": "##",
2181 "vocab": {
2182 "[UNK]": 0,
2183 "[CLS]": 1,
2184 "[SEP]": 2,
2185 "hello": 3,
2186 "world": 4,
2187 "he": 5,
2188 "##llo": 6,
2189 "wo": 7,
2190 "##rld": 8
2191 }
2192 },
2193 "normalizer": {
2194 "type": "BertNormalizer",
2195 "lowercase": true,
2196 "strip_accents": false
2197 },
2198 "pre_tokenizer": {
2199 "type": "BertPreTokenizer"
2200 },
2201 "added_tokens": [
2202 {"id": 0, "content": "[UNK]", "special": true},
2203 {"id": 1, "content": "[CLS]", "special": true},
2204 {"id": 2, "content": "[SEP]", "special": true}
2205 ]
2206 }"###;
2207
2208 let tok = Tokenizer::from_hf_json_str(json).unwrap();
2209 assert_eq!(tok.tokenizer_type, TokenizerType::WordPiece);
2210 assert_eq!(tok.wordpiece_prefix, "##");
2211
2212 let tokens = tok.encode("hello", false).unwrap();
2214 assert_eq!(tokens, vec![3]);
2215
2216 let tokens = tok.encode("hello world", false).unwrap();
2218 assert_eq!(tokens, vec![3, 4]);
2219 }
2220
2221 #[test]
2222 fn test_wordpiece_subword_splitting() {
2223 let json = r###"{
2224 "model": {
2225 "type": "WordPiece",
2226 "unk_token": "[UNK]",
2227 "continuing_subword_prefix": "##",
2228 "vocab": {
2229 "[UNK]": 0,
2230 "[BOS]": 1,
2231 "[EOS]": 2,
2232 "un": 3,
2233 "##know": 4,
2234 "##n": 5,
2235 "unknown": 6,
2236 "the": 7,
2237 "##s": 8
2238 }
2239 },
2240 "pre_tokenizer": { "type": "Whitespace" },
2241 "added_tokens": [
2242 {"id": 0, "content": "[UNK]", "special": true},
2243 {"id": 1, "content": "[BOS]", "special": true},
2244 {"id": 2, "content": "[EOS]", "special": true}
2245 ]
2246 }"###;
2247
2248 let tok = Tokenizer::from_hf_json_str(json).unwrap();
2249
2250 let tokens = tok.encode("unknown", false).unwrap();
2252 assert_eq!(tokens, vec![6]);
2253
2254 let tokens = tok.encode("the", false).unwrap();
2256 assert_eq!(tokens, vec![7]);
2257
2258 let tokens = tok.encode("thes", false).unwrap();
2260 assert_eq!(tokens, vec![7, 8]);
2261 }
2262
2263 #[test]
2264 fn test_unigram_viterbi_encoding() {
2265 let json = r#"{
2266 "model": {
2267 "type": "Unigram",
2268 "unk_id": 0,
2269 "vocab": [
2270 ["<unk>", 0.0],
2271 ["<s>", 0.0],
2272 ["</s>", 0.0],
2273 ["a", -1.0],
2274 ["b", -1.0],
2275 ["c", -1.0],
2276 ["ab", -0.5],
2277 ["bc", -0.5],
2278 ["abc", -0.1]
2279 ]
2280 },
2281 "pre_tokenizer": { "type": "Whitespace" },
2282 "added_tokens": [
2283 {"id": 0, "content": "<unk>", "special": true},
2284 {"id": 1, "content": "<s>", "special": true},
2285 {"id": 2, "content": "</s>", "special": true}
2286 ]
2287 }"#;
2288
2289 let tok = Tokenizer::from_hf_json_str(json).unwrap();
2290
2291 let tokens = tok.encode("abc", false).unwrap();
2294 assert_eq!(tokens, vec![8]); }
2296
2297 #[test]
2298 fn test_bpe_with_pipeline() {
2299 let json = r#"{
2300 "model": {
2301 "type": "BPE",
2302 "vocab": {
2303 "<s>": 0,
2304 "</s>": 1,
2305 "h": 2,
2306 "e": 3,
2307 "l": 4,
2308 "o": 5,
2309 "he": 6,
2310 "ll": 7,
2311 "hell": 8,
2312 "hello": 9,
2313 " ": 10
2314 },
2315 "merges": [
2316 "h e",
2317 "l l",
2318 "he ll",
2319 "hell o"
2320 ]
2321 },
2322 "pre_tokenizer": {
2323 "type": "ByteLevel",
2324 "add_prefix_space": false
2325 },
2326 "added_tokens": [
2327 {"id": 0, "content": "<s>", "special": true},
2328 {"id": 1, "content": "</s>", "special": true}
2329 ]
2330 }"#;
2331
2332 let tok = Tokenizer::from_hf_json_str(json).unwrap();
2333 assert_eq!(tok.tokenizer_type, TokenizerType::BPE);
2334 assert!(tok.pre_tokenizer.is_some());
2335
2336 let tokens = tok.encode("hello", false).unwrap();
2338 assert_eq!(tokens, vec![9]);
2339 }
2340
2341 #[test]
2342 fn test_parse_normalizer_types() {
2343 let nfc: serde_json::Value = serde_json::from_str(r#"{"type": "NFC"}"#).unwrap();
2344 let result = Tokenizer::parse_normalizer(&nfc);
2345 assert!(matches!(result, Some(Normalizer::NFC)));
2346
2347 let bert: serde_json::Value = serde_json::from_str(
2348 r#"{"type": "BertNormalizer", "lowercase": true, "strip_accents": true}"#,
2349 )
2350 .unwrap();
2351 let result = Tokenizer::parse_normalizer(&bert);
2352 assert!(matches!(result, Some(Normalizer::Sequence(_))));
2353
2354 let seq: serde_json::Value = serde_json::from_str(
2355 r#"{"type": "Sequence", "normalizers": [{"type": "NFC"}, {"type": "Lowercase"}]}"#,
2356 )
2357 .unwrap();
2358 let result = Tokenizer::parse_normalizer(&seq);
2359 assert!(matches!(result, Some(Normalizer::Sequence(_))));
2360 }
2361
2362 #[test]
2363 fn test_parse_pre_tokenizer_types() {
2364 let bl: serde_json::Value =
2365 serde_json::from_str(r#"{"type": "ByteLevel", "add_prefix_space": false}"#).unwrap();
2366 let result = Tokenizer::parse_pre_tokenizer(&bl);
2367 assert!(matches!(
2368 result,
2369 Some(PreTokenizer::ByteLevel { add_prefix_space: false })
2370 ));
2371
2372 let meta: serde_json::Value = serde_json::from_str(
2373 r#"{"type": "Metaspace", "replacement": "▁", "add_prefix_space": true}"#,
2374 )
2375 .unwrap();
2376 let result = Tokenizer::parse_pre_tokenizer(&meta);
2377 assert!(matches!(
2378 result,
2379 Some(PreTokenizer::Metaspace { add_prefix_space: true, .. })
2380 ));
2381
2382 let seq: serde_json::Value = serde_json::from_str(
2383 r#"{"type": "Sequence", "pretokenizers": [{"type": "Whitespace"}, {"type": "Punctuation"}]}"#,
2384 )
2385 .unwrap();
2386 let result = Tokenizer::parse_pre_tokenizer(&seq);
2387 assert!(matches!(result, Some(PreTokenizer::Sequence(_))));
2388 }
2389}