1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::error::Error;
4use std::fmt;
5use std::fmt::{Debug, Display};
6
7use super::{DecodeError, EncodeError, Model};
8use crate::tokenizer::TokenId;
9use rustc_hash::{FxBuildHasher, FxHashMap};
10
11#[derive(Debug)]
14pub enum BpeError {
15 InvalidMergeEntry(String),
20
21 InvalidVocabEntry(String),
24
25 MissingVocabEntry(String),
27}
28
29impl Display for BpeError {
30 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 BpeError::InvalidMergeEntry(entry) => write!(fmt, "invalid merge entry: {}", entry),
33 BpeError::InvalidVocabEntry(entry) => write!(fmt, "invalid vocab entry: {}", entry),
34 BpeError::MissingVocabEntry(entry) => write!(fmt, "missing vocab entry: {}", entry),
35 }
36 }
37}
38
39impl Error for BpeError {}
40
41#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
45struct Rank(u32);
46
47pub type EncodedByteSlice<'a> = &'a str;
54
55pub type EncodedBytes = String;
57
58pub type EncodedBytesCow<'a> = Cow<'a, str>;
60
61fn is_printable(c: char) -> bool {
66 !c.is_control() && !c.is_whitespace() && c != '\u{ad}' }
68
69fn byte_to_char() -> [char; 256] {
75 let mut chars = ['\x00'; 256];
76
77 for b in 0..=255u8 {
78 let ch = char::from(b);
79 if is_printable(ch) {
80 chars[b as usize] = ch;
81 }
82 }
83
84 let mut non_printable_count = 0;
85 for b in 0..=255u8 {
86 if !is_printable(char::from(b)) {
87 chars[b as usize] = char::from_u32(256 + non_printable_count).unwrap();
88 non_printable_count += 1;
89 }
90 }
91
92 chars
93}
94
95pub fn char_to_byte() -> HashMap<char, u8> {
98 byte_to_char()
99 .iter()
100 .copied()
101 .enumerate()
102 .map(|(byte, ch)| (ch, byte as u8))
103 .collect()
104}
105
106fn bpe_merge(
111 tokens: &mut Vec<TokenId>,
112 merges: &FxHashMap<(TokenId, TokenId), (Rank, TokenId)>,
113) -> usize {
114 loop {
115 let min_pair: Option<((TokenId, TokenId), (Rank, TokenId))> = tokens
118 .windows(2)
119 .filter_map(|pair| {
120 let [first, second] = pair.try_into().unwrap();
121 merges
122 .get(&(first, second))
123 .map(|&rank_id| ((first, second), rank_id))
124 })
125 .min_by_key(|((_first, _second), (rank, _merged_id))| *rank);
126
127 let Some(((first, second), (_rank, merged_id))) = min_pair else {
128 break;
129 };
130
131 let mut i = 0;
132 while i < tokens.len() - 1 {
133 if tokens[i] == first && tokens[i + 1] == second {
134 tokens[i] = merged_id;
135 tokens.remove(i + 1);
136 }
137 i += 1;
138 }
139 }
140 tokens.len()
141}
142
143type MergeMap = FxHashMap<(TokenId, TokenId), (Rank, TokenId)>;
145
146fn build_merge_map(
149 vocab: &FxHashMap<EncodedBytes, TokenId>,
150 merges: &[(EncodedBytesCow, EncodedBytesCow)],
151) -> Result<MergeMap, BpeError> {
152 let mut merged_str = String::new();
153 let mut merge_map = HashMap::with_capacity_and_hasher(merges.len(), FxBuildHasher);
154
155 for (i, (a, b)) in merges.iter().enumerate() {
156 let a_id = *vocab.get(a.as_ref()).ok_or_else(|| {
157 BpeError::InvalidMergeEntry(format!(
158 "first entry in merge pair \"{a} {b}\" not found in vocab"
159 ))
160 })?;
161 let b_id = *vocab.get(b.as_ref()).ok_or_else(|| {
162 BpeError::InvalidMergeEntry(format!(
163 "second entry in merge pair \"{a} {b}\" not found in vocab"
164 ))
165 })?;
166
167 merged_str.clear();
168 merged_str.push_str(a);
169 merged_str.push_str(b);
170
171 let merged_id = *vocab.get(&merged_str).ok_or_else(|| {
172 BpeError::InvalidMergeEntry(format!("merged pair \"{a} {b}\" not found in vocab"))
173 })?;
174 let rank = Rank(i as u32);
175 merge_map.insert((a_id, b_id), (rank, merged_id));
176 }
177
178 Ok(merge_map)
179}
180
181pub fn merge_pairs_from_lines(
185 lines: &[impl AsRef<str>],
186) -> Vec<(EncodedBytesCow<'static>, EncodedBytesCow<'static>)> {
187 lines
188 .iter()
189 .filter_map(|line| {
190 let line = line.as_ref();
191 if line.starts_with("#version") || line.trim().is_empty() {
192 None
193 } else {
194 line.split_once(' ')
197 .map(|(a, b)| (a.to_string().into(), b.to_string().into()))
198 }
199 })
200 .collect()
201}
202
203fn build_vocab(
208 merges: &[(EncodedBytesCow, EncodedBytesCow)],
209 end_of_word_suffix: Option<EncodedByteSlice>,
210) -> FxHashMap<EncodedBytes, TokenId> {
211 let mut vocab = FxHashMap::default();
212
213 fn byte_to_rank() -> [Rank; 256] {
214 let mut ranks = [Rank(0); 256];
215
216 let mut rank = 0;
217 for byte in 0..=255u8 {
218 if is_printable(char::from(byte)) {
219 ranks[byte as usize] = Rank(rank);
220 rank += 1;
221 }
222 }
223
224 for byte in 0..=255u8 {
225 if !is_printable(char::from(byte)) {
226 ranks[byte as usize] = Rank(rank);
227 rank += 1;
228 }
229 }
230
231 ranks
232 }
233
234 for (ch, rank) in byte_to_char().into_iter().zip(byte_to_rank()) {
236 vocab.insert(ch.into(), rank.0);
237 }
238
239 if let Some(eow_suffix) = end_of_word_suffix {
242 let start_id = vocab.len() as u32;
243 for (ch, rank) in byte_to_char().into_iter().zip(byte_to_rank()) {
244 let mut bytes: EncodedBytes = ch.into();
245 bytes.push_str(eow_suffix);
246 vocab.insert(bytes, start_id + rank.0);
247 }
248 }
249
250 let start_id = vocab.len() as u32;
252 vocab.extend(
253 merges
254 .iter()
255 .enumerate()
256 .map(|(i, (a, b))| ([a.as_ref(), b.as_ref()].concat(), start_id + i as u32)),
257 );
258
259 vocab
260}
261
262#[derive(Default)]
264pub struct BpeOptions<'a> {
265 pub merges: &'a [(EncodedBytesCow<'a>, EncodedBytesCow<'a>)],
270
271 pub vocab: Option<FxHashMap<EncodedBytes, TokenId>>,
277
278 pub added_tokens: FxHashMap<TokenId, String>,
282
283 pub end_of_word_suffix: Option<String>,
286
287 pub ignore_merges: bool,
290}
291
292pub struct Bpe {
307 merges: MergeMap,
308
309 byte_to_token_id: [TokenId; 256],
311
312 byte_to_char: [char; 256],
315
316 token_id_to_encoded_bytes: FxHashMap<TokenId, EncodedBytes>,
317
318 vocab: Option<FxHashMap<EncodedBytes, TokenId>>,
319
320 added_tokens: FxHashMap<TokenId, String>,
322
323 end_of_word_suffix: Option<String>,
329
330 ignore_merges: bool,
333}
334
335impl Bpe {
336 pub fn new(config: BpeOptions) -> Result<Bpe, BpeError> {
338 let BpeOptions {
339 merges,
340 vocab,
341 added_tokens,
342 mut end_of_word_suffix,
343 ignore_merges,
344 } = config;
345
346 end_of_word_suffix.take_if(|suffix| suffix.is_empty());
348
349 let vocab = vocab.unwrap_or_else(|| build_vocab(merges, end_of_word_suffix.as_deref()));
350
351 let merges = build_merge_map(&vocab, merges)?;
352
353 let mut byte_to_token_id = [0; 256];
355 for (i, ch) in byte_to_char().into_iter().enumerate() {
356 let mut ch_buf = [0u8; 4];
357 let ch_str = ch.encode_utf8(&mut ch_buf);
358 if let Some(id) = vocab.get(ch_str).copied() {
359 byte_to_token_id[i] = id;
360 } else {
361 return Err(BpeError::MissingVocabEntry(ch_str.to_string()));
362 }
363 }
364
365 let (vocab, token_id_to_encoded_bytes) = if ignore_merges {
371 let token_id_to_encoded_bytes = vocab
372 .iter()
373 .map(|(token, id)| (*id, token.clone()))
374 .collect();
375 (Some(vocab), token_id_to_encoded_bytes)
376 } else {
377 let token_id_to_encoded_bytes =
378 vocab.into_iter().map(|(token, id)| (id, token)).collect();
379 (None, token_id_to_encoded_bytes)
380 };
381
382 Ok(Bpe {
383 added_tokens,
384 byte_to_char: byte_to_char(),
385 byte_to_token_id,
386 end_of_word_suffix,
387 ignore_merges,
388 merges,
389 token_id_to_encoded_bytes,
390 vocab,
391 })
392 }
393
394 fn encode_piece(&self, piece: &str, end_of_word: bool) -> Vec<TokenId> {
399 if self.ignore_merges
402 && let Some(vocab) = self.vocab.as_ref()
403 {
404 let encoded: EncodedBytes = piece
405 .as_bytes()
406 .iter()
407 .map(|&b| self.byte_to_char[b as usize])
408 .collect();
409 if let Some(&id) = vocab.get(&encoded) {
410 return [id].into();
411 }
412 }
413
414 let mut tokens: Vec<TokenId> = piece
416 .as_bytes()
417 .iter()
418 .map(|&b| self.byte_to_token_id[b as usize])
419 .collect();
420
421 if self.end_of_word_suffix.is_some()
424 && end_of_word
425 && let Some(last) = tokens.pop()
426 {
427 tokens.push(last + 256);
428 }
429
430 bpe_merge(&mut tokens, &self.merges);
432
433 tokens
434 }
435}
436
437impl Model for Bpe {
438 fn get_token_str(&self, id: TokenId) -> Option<String> {
439 if let Some(tok_str) = self.added_tokens.get(&id) {
440 return Some(tok_str.to_string());
441 }
442 self.token_id_to_encoded_bytes.get(&id).cloned()
443 }
444
445 fn get_token_id(&self, mut text: &str) -> Option<TokenId> {
446 if let Some((&id, _str)) = self.added_tokens.iter().find(|(_id, str)| *str == text) {
447 return Some(id);
448 }
449
450 let mut end_of_word = false;
454 if let Some(suffix) = self.end_of_word_suffix.as_deref()
455 && text.ends_with(suffix)
456 {
457 text = &text[..text.len() - suffix.len()];
458 end_of_word = true;
459 }
460
461 let tokens = self.encode_piece(text, end_of_word);
462 if tokens.len() == 1 {
463 Some(tokens[0])
464 } else {
465 None
466 }
467 }
468
469 fn encode_with_offsets(
470 &self,
471 piece: &str,
472 on_token: &mut dyn FnMut(usize, TokenId),
473 ) -> Result<(), EncodeError> {
474 if piece.is_empty() {
475 return Ok(());
476 }
477 for token in self.encode_piece(piece, true ) {
478 on_token(0, token)
479 }
480 Ok(())
481 }
482
483 fn decode(&self, ids: &[TokenId]) -> Result<String, DecodeError> {
484 let char_to_byte = char_to_byte();
485
486 let mut bytes = Vec::new();
487 for &id in ids {
488 if let Some(tok_str) = self.added_tokens.get(&id) {
489 bytes.extend(tok_str.as_bytes());
490 } else if let Some(encoded_bytes) = self.token_id_to_encoded_bytes.get(&id) {
491 bytes.extend(
492 encoded_bytes
493 .chars()
494 .map(|ch| char_to_byte.get(&ch).copied().unwrap()),
495 );
496 } else {
497 return Err(DecodeError::InvalidTokenId(id));
498 }
499 }
500
501 String::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8)
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use rten_testing::TestCases;
508 use rustc_hash::FxHashMap;
509
510 use super::{Bpe, BpeOptions, EncodedBytes, merge_pairs_from_lines};
511 use crate::pre_tokenizers::Split;
512 use crate::tokenizer::{TokenId, Tokenizer};
513
514 const MINI_GPT2: &str = "
516#version: 0.2
517Ġ t
518Ġ a
519h e
520i n
521r e
522o n
523Ġt he
524e r
525Ġ s
526a t
527Ġ w
528Ġ o
529e n
530Ġ c
531i t
532i s
533a n
534o r
535e s
536Ġ b
537e d
538Ġ f
539in g";
540
541 fn added_tokens() -> FxHashMap<TokenId, String> {
542 [(50256, "<|endoftext|>")]
543 .into_iter()
544 .map(|(id, str)| (id, str.to_string()))
545 .collect()
546 }
547
548 fn gen_vocab() -> FxHashMap<EncodedBytes, TokenId> {
554 let mut next_token_id = 1000;
555 let mut vocab = minimal_vocab(next_token_id);
556 next_token_id += vocab.len() as u32;
557
558 for line in MINI_GPT2.lines().map(|l| l.trim()) {
559 if line.starts_with("#version") || line.is_empty() {
560 continue;
561 }
562 let token_str: EncodedBytes = line.chars().filter(|ch| *ch != ' ').collect();
563 vocab.insert(token_str, next_token_id);
564 next_token_id += 1;
565 }
566
567 vocab
568 }
569
570 fn minimal_vocab(start_token_id: u32) -> FxHashMap<EncodedBytes, TokenId> {
572 let mut vocab = FxHashMap::default();
573 let mut next_token_id = start_token_id;
574 for ch in super::char_to_byte().keys() {
575 vocab.insert(ch.to_string(), next_token_id);
576 next_token_id += 1;
577 }
578 vocab
579 }
580
581 #[test]
582 fn test_encode() {
583 #[derive(Debug)]
584 struct Case<'a> {
585 text: &'a str,
586 expected_tokens: &'a [&'a str],
587 merges: &'a str,
588 vocab: Option<FxHashMap<EncodedBytes, TokenId>>,
589 end_of_word_suffix: Option<String>,
590 ignore_merges: bool,
591 }
592
593 impl<'a> Default for Case<'a> {
594 fn default() -> Self {
595 Self {
596 text: "",
597 expected_tokens: &[],
598 merges: "",
599 vocab: None,
600 end_of_word_suffix: None,
601 ignore_merges: false,
602 }
603 }
604 }
605
606 let cases = [
607 Case {
609 text: "the cat is in the bed",
610 expected_tokens: &[
611 "t", "he", "Ġc", "at", "Ġ", "is", "Ġ", "in", "Ġthe", "Ġb", "ed",
612 ],
613 merges: MINI_GPT2,
614 ..Default::default()
615 },
616 Case {
618 text: "--------",
619 expected_tokens: &["--------"],
620 merges: "
621- -
622-- --
623---- ----
624-------- --------
625",
626 ..Default::default()
627 },
628 Case {
630 text: "barbar",
631 expected_tokens: &["bar", "bar</w>"],
632 merges: "
633b a
634ba r
635ba r</w>
636",
637 end_of_word_suffix: Some("</w>".to_string()),
638 ..Default::default()
639 },
640 Case {
644 text: "barbar",
645 expected_tokens: &["bar", "bar"],
646 merges: "
647b a
648ba r",
649 end_of_word_suffix: Some("".to_string()),
650 ..Default::default()
651 },
652 Case {
654 text: "foobar",
655 expected_tokens: &["foobar"],
656 ignore_merges: true,
657 vocab: {
658 let mut vocab = minimal_vocab(0);
659 vocab.insert("foobar".to_string(), vocab.len() as u32);
660 Some(vocab)
661 },
662 ..Default::default()
663 },
664 ];
665
666 cases.test_each(|case| {
667 let Case {
668 text,
669 expected_tokens: tokens,
670 merges,
671 vocab,
672 end_of_word_suffix,
673 ignore_merges,
674 } = case;
675
676 let merges: Vec<&str> = merges.lines().collect();
677 let merge_pairs = merge_pairs_from_lines(&merges);
678 let bpe_opts = BpeOptions {
679 merges: &merge_pairs,
680 vocab: vocab.clone(),
681 end_of_word_suffix: end_of_word_suffix.clone(),
682 ignore_merges: *ignore_merges,
683 added_tokens: Default::default(),
684 };
685 let model = Bpe::new(bpe_opts).unwrap();
686 let tokenizer = Tokenizer::new(model, Default::default())
687 .with_pre_tokenizer(Box::new(Split::gpt2()));
688 let encoded = tokenizer.encode(*text, None).unwrap();
689 assert_eq!(
690 tokenizer.model().get_tokens(encoded.token_ids()).unwrap(),
691 *tokens
692 );
693 })
694 }
695
696 #[test]
697 fn test_get_token_str() {
698 #[derive(Debug)]
699 struct Case<'a> {
700 input: &'a str,
701 encoded_str: &'a str,
702 }
703
704 let cases = [
705 Case {
707 input: "a",
708 encoded_str: "a",
709 },
710 Case {
713 input: " ",
714 encoded_str: "Ġ",
715 },
716 Case {
718 input: "<|endoftext|>",
719 encoded_str: "<|endoftext|>",
720 },
721 ];
722
723 let merges: Vec<&str> = MINI_GPT2.lines().collect();
724 let merge_pairs = merge_pairs_from_lines(&merges);
725
726 cases.test_each(|case| {
727 let bpe_opts = BpeOptions {
728 merges: &merge_pairs,
729 added_tokens: added_tokens(),
730 ..Default::default()
731 };
732 let model = Bpe::new(bpe_opts).unwrap();
733 let tokenizer = Tokenizer::new(model, Default::default())
734 .with_pre_tokenizer(Box::new(Split::gpt2()));
735
736 let tok_id = tokenizer.model().get_token_id(case.input).unwrap();
737 let token_str = tokenizer.model().get_token_str(tok_id).unwrap();
738 assert_eq!(token_str, case.encoded_str);
739 })
740 }
741
742 #[test]
743 fn test_decode() {
744 #[derive(Debug)]
745 struct Case<'a> {
746 text: &'a str,
747 add_eos: bool,
748 expected: &'a str,
749 vocab: Option<FxHashMap<EncodedBytes, TokenId>>,
750 }
751
752 let vocab = gen_vocab();
753
754 let cases = [
755 Case {
756 text: "foo bar",
757 add_eos: false,
758 expected: "foo bar",
759 vocab: None,
760 },
761 Case {
762 text: "foo bar",
763 add_eos: true,
764 expected: "foo bar<|endoftext|>",
765 vocab: None,
766 },
767 Case {
768 text: "the cat is in the bed",
769 add_eos: false,
770 expected: "the cat is in the bed",
771 vocab: None,
772 },
773 Case {
774 text: "the cat is in the bed",
775 add_eos: false,
776 expected: "the cat is in the bed",
777 vocab: Some(vocab),
778 },
779 ];
780
781 cases.test_each(|case| {
782 let Case {
783 text,
784 add_eos,
785 expected,
786 vocab,
787 } = case;
788
789 let merges: Vec<&str> = MINI_GPT2.lines().collect();
790 let merge_pairs = merge_pairs_from_lines(&merges);
791 let bpe_opts = BpeOptions {
792 merges: &merge_pairs,
793 vocab: vocab.clone(),
794 added_tokens: added_tokens(),
795 ..Default::default()
796 };
797 let model = Bpe::new(bpe_opts).unwrap();
798 let tokenizer = Tokenizer::new(model, Default::default())
799 .with_pre_tokenizer(Box::new(Split::gpt2()));
800
801 let encoded = tokenizer.encode(*text, None).unwrap();
802 let mut token_ids = encoded.token_ids().to_vec();
803 if *add_eos {
804 token_ids.push(50256);
806 }
807 let decoded = tokenizer.decode(&token_ids).unwrap();
808 assert_eq!(decoded, *expected);
809 })
810 }
811}