1use fancy_regex::Regex as FancyRegex;
37use regex::Regex as FastRegex;
38use regex::RegexBuilder as FastRegexBuilder;
39use regex_automata::{
40 Input,
41 dfa::{dense, regex::Regex as DfaRegex},
42 nfa::thompson,
43 util::syntax,
44};
45use rustc_hash::{FxHashMap as HashMap, FxHasher};
46use std::collections::HashSet;
47use std::hash::{Hash, Hasher};
48
49#[cfg(feature = "python")]
50use pyo3::prelude::*;
51
52#[cfg(feature = "precompiled-dfa")]
55mod prebuilt {
56 use super::*;
57
58 const GPT2_RAW: &str =
61 r"'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s";
62 const CL100K_RAW: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s";
63 const O200K_RAW: &str = concat!(
64 r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
65 r"|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
66 r"|\p{N}{1,3}",
67 r"| ?[^\s\p{L}\p{N}]+[\r\n/]*",
68 r"|\s*[\r\n]+",
69 r"|\s+(?!\S)|\s+",
70 );
71
72 #[repr(C)]
75 struct AlignAs<Align, Bytes: ?Sized> {
76 _align: [Align; 0],
77 bytes: Bytes,
78 }
79
80 macro_rules! include_dfa {
81 ($path:expr) => {{
82 const ALIGNED: &AlignAs<u32, [u8]> = &AlignAs {
83 _align: [],
84 bytes: *include_bytes!($path),
85 };
86 &ALIGNED.bytes
87 }};
88 }
89
90 static GPT2_FWD: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/gpt2_fwd.dfa"));
91 static GPT2_REV: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/gpt2_rev.dfa"));
92 static CL100K_FWD: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/cl100k_fwd.dfa"));
93 static CL100K_REV: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/cl100k_rev.dfa"));
94 static O200K_FWD: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/o200k_fwd.dfa"));
95 static O200K_REV: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/o200k_rev.dfa"));
96
97 pub(crate) fn try_load(pattern: &str) -> Option<(DfaRegex, ShrinkMode)> {
102 let (fwd_bytes, rev_bytes, shrink_mode) = if pattern == GPT2_RAW {
103 (GPT2_FWD, GPT2_REV, ShrinkMode::Unified)
104 } else if pattern == CL100K_RAW {
105 (CL100K_FWD, CL100K_REV, ShrinkMode::PlainOnly)
106 } else if pattern == O200K_RAW {
107 (O200K_FWD, O200K_REV, ShrinkMode::PlainOnly)
108 } else {
109 return None;
110 };
111 let (fwd, _) = dense::DFA::from_bytes(fwd_bytes).ok()?;
112 let (rev, _) = dense::DFA::from_bytes(rev_bytes).ok()?;
113 Some((
114 DfaRegex::builder().build_from_dfas(fwd.to_owned(), rev.to_owned()),
115 shrink_mode,
116 ))
117 }
118}
119
120pub type Rank = u32;
125
126const MAX_NUM_THREADS: usize = 128;
129
130const LARGE_PIECE_THRESHOLD: usize = 500;
136
137thread_local! {
138 static THREAD_INDEX: usize = {
139 let mut h = FxHasher::default();
140 std::thread::current().id().hash(&mut h);
141 (h.finish() as usize) % MAX_NUM_THREADS
142 };
143}
144
145#[inline]
146fn thread_index() -> usize {
147 THREAD_INDEX.with(|&i| i)
148}
149
150#[derive(Debug)]
152pub enum BuildError {
153 InvalidRegex(String),
156 VocabularyMismatch,
159}
160
161impl std::fmt::Display for BuildError {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 match self {
164 BuildError::InvalidRegex(e) => write!(f, "invalid regex pattern: {e}"),
165 BuildError::VocabularyMismatch => write!(
166 f,
167 "vocabulary has duplicate entries (encoder/decoder size mismatch)"
168 ),
169 }
170 }
171}
172
173impl std::error::Error for BuildError {}
174
175impl From<fancy_regex::Error> for BuildError {
176 fn from(e: fancy_regex::Error) -> Self {
177 BuildError::InvalidRegex(e.to_string())
178 }
179}
180
181#[derive(Debug)]
183pub enum DecodeError {
184 InvalidToken(Rank),
186 InvalidUtf8,
188}
189
190impl std::fmt::Display for DecodeError {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 match self {
193 DecodeError::InvalidToken(t) => write!(f, "invalid token id: {t}"),
194 DecodeError::InvalidUtf8 => write!(f, "decoded bytes are not valid UTF-8"),
195 }
196 }
197}
198
199impl std::error::Error for DecodeError {}
200
201#[derive(Clone, Copy, PartialEq, Debug)]
210enum ShrinkMode {
211 None,
214 PlainOnly,
219 Unified,
223}
224
225fn transform_pattern(pattern: &str) -> Option<(String, ShrinkMode)> {
233 let has_lookahead_ws = pattern.contains(r"\s+(?!\S)");
253 let has_newline_alt = pattern.contains(r"\s*[\r\n]");
254 let shrink_mode = if has_lookahead_ws {
255 if has_newline_alt {
256 ShrinkMode::PlainOnly
257 } else {
258 ShrinkMode::Unified
259 }
260 } else {
261 ShrinkMode::None
262 };
263 let mut stripped = pattern.replace(r"\s+(?!\S)|\s+", r"\s+");
264 stripped = stripped.replace(r"\s+(?!\S)|\s", r"\s+");
265 if stripped.contains("(?=")
269 || stripped.contains("(?!")
270 || stripped.contains("(?<=")
271 || stripped.contains("(?<!")
272 {
273 return None;
274 }
275 stripped = stripped
296 .replace("?+", "?")
297 .replace("++", "+")
298 .replace("*+", "*");
299 let range_possessive = FastRegex::new(r"(\{\d+(?:,\d*)?\})\+").ok()?;
300 let stripped = range_possessive.replace_all(&stripped, "$1").into_owned();
301 Some((stripped, shrink_mode))
302}
303
304fn try_transform_for_fast_regex(pattern: &str) -> Option<(FastRegex, ShrinkMode)> {
312 let (transformed, shrink_mode) = transform_pattern(pattern)?;
313 let regex = FastRegexBuilder::new(&transformed)
314 .dfa_size_limit(32 * (1 << 20))
315 .build()
316 .ok()?;
317 Some((regex, shrink_mode))
318}
319
320fn try_build_precompiled_dfa(pattern: &str) -> Option<(DfaRegex, ShrinkMode)> {
330 let (transformed, shrink_mode) = transform_pattern(pattern)?;
331 let dfa = DfaRegex::builder()
332 .syntax(syntax::Config::new().unicode(true).utf8(true))
333 .thompson(thompson::Config::new())
334 .dense(dense::Config::new().start_kind(regex_automata::dfa::StartKind::Unanchored))
335 .build(&transformed)
336 .ok()?;
337 Some((dfa, shrink_mode))
338}
339
340#[inline]
349fn is_plain_whitespace_run(s: &str) -> bool {
350 !s.is_empty()
351 && s.chars()
352 .all(|c| c.is_whitespace() && c != '\n' && c != '\r')
353}
354
355#[inline]
362fn is_whitespace_run(s: &str) -> bool {
363 !s.is_empty() && s.chars().all(|c| c.is_whitespace())
364}
365
366#[inline]
369fn next_char_is_non_whitespace(text: &str, pos: usize) -> bool {
370 match text[pos..].chars().next() {
371 Some(c) => !c.is_whitespace(),
372 None => false,
373 }
374}
375
376#[inline]
382fn apply_shrink(text: &str, start: usize, end: usize, shrink_mode: ShrinkMode) -> usize {
383 let piece = &text[start..end];
384 let should_shrink = match shrink_mode {
385 ShrinkMode::None => false,
386 ShrinkMode::PlainOnly => is_plain_whitespace_run(piece),
387 ShrinkMode::Unified => is_whitespace_run(piece),
388 };
389 if should_shrink && end < text.len() && next_char_is_non_whitespace(text, end) {
390 if let Some((last_i, _)) = piece.char_indices().next_back() {
391 if last_i > 0 {
392 return start + last_i;
393 }
394 }
395 }
396 end
397}
398
399enum SplitEngine {
401 PrecompiledDfa {
405 dfa_regex: DfaRegex,
406 shrink_mode: ShrinkMode,
407 },
408 Fast {
411 clones: Vec<FastRegex>,
412 shrink_mode: ShrinkMode,
413 },
414 Fancy(Vec<FancyRegex>),
416}
417
418impl SplitEngine {
419 fn new(pattern: &str) -> Result<Self, BuildError> {
424 #[cfg(feature = "precompiled-dfa")]
426 if let Some((dfa_regex, shrink_mode)) = prebuilt::try_load(pattern) {
427 return Ok(SplitEngine::PrecompiledDfa {
428 dfa_regex,
429 shrink_mode,
430 });
431 }
432 if let Some((dfa_regex, shrink_mode)) = try_build_precompiled_dfa(pattern) {
434 return Ok(SplitEngine::PrecompiledDfa {
435 dfa_regex,
436 shrink_mode,
437 });
438 }
439 if let Some((fast, shrink_mode)) = try_transform_for_fast_regex(pattern) {
441 let clones: Vec<FastRegex> = (0..MAX_NUM_THREADS).map(|_| fast.clone()).collect();
442 return Ok(SplitEngine::Fast {
443 clones,
444 shrink_mode,
445 });
446 }
447 let fancy = FancyRegex::new(pattern)?;
449 let clones: Vec<FancyRegex> = (0..MAX_NUM_THREADS).map(|_| fancy.clone()).collect();
450 Ok(SplitEngine::Fancy(clones))
451 }
452
453 #[cfg(test)]
455 fn is_fast(&self) -> bool {
456 matches!(
457 self,
458 SplitEngine::PrecompiledDfa { .. } | SplitEngine::Fast { .. }
459 )
460 }
461
462 #[cfg(all(test, feature = "precompiled-dfa"))]
464 fn is_precompiled(&self) -> bool {
465 matches!(self, SplitEngine::PrecompiledDfa { .. })
466 }
467
468 #[inline]
473 fn find_pieces<F: FnMut(&str)>(&self, text: &str, mut f: F) {
474 match self {
475 SplitEngine::PrecompiledDfa {
476 dfa_regex,
477 shrink_mode,
478 } => {
479 let haystack = text.as_bytes();
480 let mut pos = 0;
481 while pos < haystack.len() {
482 let input = Input::new(haystack).range(pos..);
483 let m = match dfa_regex.find(input) {
484 Some(m) => m,
485 None => break,
486 };
487 if m.start() > pos {
488 pos = m.start();
489 }
490 let start = m.start();
491 let end = apply_shrink(text, start, m.end(), *shrink_mode);
492 f(&text[start..end]);
493 if end == pos {
494 pos += 1;
495 } else {
496 pos = end;
497 }
498 }
499 }
500 SplitEngine::Fast {
501 clones,
502 shrink_mode,
503 } => {
504 let regex = &clones[thread_index()];
505 let mut pos = 0;
506 while pos < text.len() {
507 let m = match regex.find_at(text, pos) {
508 Some(m) => m,
509 None => break,
510 };
511 if m.start() > pos {
512 pos = m.start();
513 }
514 let start = m.start();
515 let end = apply_shrink(text, start, m.end(), *shrink_mode);
516 f(&text[start..end]);
517 if end == pos {
518 pos += 1;
519 } else {
520 pos = end;
521 }
522 }
523 }
524 SplitEngine::Fancy(clones) => {
525 let regex = &clones[thread_index()];
526 for mat in regex.find_iter(text) {
527 match mat {
528 Ok(m) => f(m.as_str()),
529 Err(_) => continue,
530 }
531 }
532 }
533 }
534 }
535}
536
537#[cfg_attr(feature = "python", pyclass(module = "riptoken._riptoken"))]
543pub struct CoreBPE {
544 encoder: HashMap<Vec<u8>, Rank>,
546 decoder: HashMap<Rank, Vec<u8>>,
548 special_tokens_encoder: HashMap<String, Rank>,
550 special_tokens_decoder: HashMap<Rank, Vec<u8>>,
552 split_engine: SplitEngine,
554 special_regex_tls: Vec<FancyRegex>,
559 sorted_token_bytes: Vec<Vec<u8>>,
561}
562
563#[inline(always)]
571fn rank_of(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Rank {
572 ranks.get(piece).copied().unwrap_or(Rank::MAX)
573}
574
575#[inline]
585fn byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
586 if piece.len() < 2 {
588 return vec![(0, Rank::MAX), (piece.len(), Rank::MAX)];
589 }
590
591 let mut parts: Vec<(usize, Rank)> = Vec::with_capacity(piece.len() + 1);
592
593 let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
595 for i in 0..piece.len() - 1 {
596 let rank = rank_of(ranks, &piece[i..i + 2]);
597 if rank < min_rank.0 {
598 min_rank = (rank, i);
599 }
600 parts.push((i, rank));
601 }
602 parts.push((piece.len() - 1, Rank::MAX));
603 parts.push((piece.len(), Rank::MAX));
604
605 let get_rank = |parts: &[(usize, Rank)], i: usize| -> Rank {
609 if i + 3 < parts.len() {
610 rank_of(ranks, &piece[parts[i].0..parts[i + 3].0])
611 } else {
612 Rank::MAX
613 }
614 };
615
616 while min_rank.0 != Rank::MAX {
617 let i = min_rank.1;
618
619 if i > 0 {
623 parts[i - 1].1 = get_rank(&parts, i - 1);
624 }
625 parts[i].1 = get_rank(&parts, i);
626 parts.remove(i + 1);
627
628 min_rank = (Rank::MAX, usize::MAX);
630 for (j, &(_, rank)) in parts[..parts.len() - 2].iter().enumerate() {
631 if rank < min_rank.0 {
632 min_rank = (rank, j);
633 }
634 }
635 }
636
637 parts
638}
639
640#[inline]
642fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
643 if piece.len() == 1 {
645 return vec![*ranks.get(piece).expect("byte fallback")];
647 }
648
649 if piece.len() < LARGE_PIECE_THRESHOLD {
650 let positions = byte_pair_merge(ranks, piece);
651 let mut out: Vec<Rank> = Vec::with_capacity(positions.len() - 1);
653 out.extend(
654 positions
655 .windows(2)
656 .map(|w| rank_of(ranks, &piece[w[0].0..w[1].0])),
657 );
658 out
659 } else {
660 byte_pair_merge_large(ranks, piece)
661 }
662}
663
664fn byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<Rank> {
674 use std::cmp::Reverse;
675 use std::collections::BinaryHeap;
676
677 #[derive(Clone)]
682 struct State {
683 prev: usize,
684 end: usize,
685 cur_rank: Rank,
686 }
687
688 let n = piece.len();
689 let mut state: Vec<State> = (0..n)
690 .map(|i| State {
691 prev: if i == 0 { usize::MAX } else { i - 1 },
692 end: i + 1,
693 cur_rank: 0,
694 })
695 .collect();
696
697 let mut heap: BinaryHeap<(Reverse<Rank>, usize)> = BinaryHeap::with_capacity(n);
699
700 for i in 0..n.saturating_sub(1) {
702 let rank = rank_of(ranks, &piece[i..state[i + 1].end]);
703 state[i].cur_rank = rank;
704 if rank != Rank::MAX {
705 heap.push((Reverse(rank), i));
706 }
707 }
708
709 while let Some((Reverse(rank), start)) = heap.pop() {
710 if state[start].cur_rank != rank || rank == Rank::MAX {
712 continue;
713 }
714
715 let right = state[start].end;
717 if right >= n {
718 continue;
719 }
720 let new_end = state[right].end;
721 state[start].end = new_end;
722
723 if new_end < n {
725 state[new_end].prev = start;
726 }
727
728 state[right].cur_rank = Rank::MAX;
730
731 let next_end = state[start].end;
733 if next_end < n {
734 let new_rank = rank_of(ranks, &piece[start..state[next_end].end]);
735 state[start].cur_rank = new_rank;
736 if new_rank != Rank::MAX {
737 heap.push((Reverse(new_rank), start));
738 }
739 } else {
740 state[start].cur_rank = Rank::MAX;
741 }
742
743 let prev = state[start].prev;
745 if prev != usize::MAX {
746 let prev_next_end = state[prev].end; debug_assert_eq!(prev_next_end, start);
748 let span_end = state[start].end;
749 let new_rank = rank_of(ranks, &piece[prev..span_end]);
750 state[prev].cur_rank = new_rank;
751 if new_rank != Rank::MAX {
752 heap.push((Reverse(new_rank), prev));
753 }
754 }
755 }
756
757 let mut tokens = Vec::new();
759 let mut i = 0;
760 while i < n {
761 let end = state[i].end;
762 tokens.push(rank_of(ranks, &piece[i..end]));
763 i = end;
764 }
765 tokens
766}
767
768fn build_special_regex(specials: &HashMap<String, Rank>) -> Result<Option<FancyRegex>, BuildError> {
775 if specials.is_empty() {
776 return Ok(None);
777 }
778 let parts: Vec<String> = specials
780 .keys()
781 .map(|s| fancy_regex::escape(s).into_owned())
782 .collect();
783 let pattern = parts.join("|");
784 Ok(Some(FancyRegex::new(&pattern)?))
785}
786
787impl CoreBPE {
790 pub fn new(
799 encoder: HashMap<Vec<u8>, Rank>,
800 special_tokens_encoder: HashMap<String, Rank>,
801 pattern: &str,
802 ) -> Result<Self, BuildError> {
803 let split_engine = SplitEngine::new(pattern)?;
804 let decoder: HashMap<Rank, Vec<u8>> =
805 encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
806 if decoder.len() != encoder.len() {
807 return Err(BuildError::VocabularyMismatch);
808 }
809 let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
810 .iter()
811 .map(|(k, v)| (*v, k.as_bytes().to_vec()))
812 .collect();
813
814 let special_regex = build_special_regex(&special_tokens_encoder)?;
815 let special_regex_tls: Vec<FancyRegex> = match special_regex {
816 Some(r) => (0..MAX_NUM_THREADS).map(|_| r.clone()).collect(),
817 None => Vec::new(),
818 };
819
820 let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
821 sorted_token_bytes.sort();
822
823 Ok(CoreBPE {
824 encoder,
825 decoder,
826 special_tokens_encoder,
827 special_tokens_decoder,
828 split_engine,
829 special_regex_tls,
830 sorted_token_bytes,
831 })
832 }
833
834 pub fn n_vocab(&self) -> usize {
839 let max_ordinary = self.encoder.values().copied().max().unwrap_or(0);
840 let max_special = self
841 .special_tokens_encoder
842 .values()
843 .copied()
844 .max()
845 .unwrap_or(0);
846 max_ordinary.max(max_special) as usize + 1
847 }
848
849 pub fn token_byte_values(&self) -> &[Vec<u8>] {
851 &self.sorted_token_bytes
852 }
853
854 #[inline]
855 fn tl_special_regex(&self) -> Option<&FancyRegex> {
856 self.special_regex_tls.get(thread_index())
857 }
858
859 #[inline]
861 fn emit_piece(&self, piece: &[u8], out: &mut Vec<Rank>) {
862 if let Some(&token) = self.encoder.get(piece) {
864 out.push(token);
865 return;
866 }
867 out.extend(byte_pair_encode(piece, &self.encoder));
868 }
869
870 pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
875 let mut ret = Vec::with_capacity(text.len() / 3 + 1);
880 self.split_engine.find_pieces(text, |piece| {
881 self.emit_piece(piece.as_bytes(), &mut ret);
882 });
883 ret
884 }
885
886 pub fn encode_ordinary_batch(&self, texts: &[&str]) -> Vec<Vec<Rank>> {
893 use rayon::prelude::*;
894 texts.par_iter().map(|t| self.encode_ordinary(t)).collect()
895 }
896
897 pub fn encode_batch(&self, texts: &[&str], allowed_special: &HashSet<&str>) -> Vec<Vec<Rank>> {
901 use rayon::prelude::*;
902 texts
903 .par_iter()
904 .map(|t| self.encode(t, allowed_special))
905 .collect()
906 }
907
908 pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> Vec<Rank> {
913 let special_regex = match self.tl_special_regex() {
914 Some(r) => r,
915 None => return self.encode_ordinary(text),
917 };
918
919 let mut ret = Vec::new();
920 let mut start = 0usize;
921 loop {
922 let mut next_special: Option<(usize, usize)> = None;
924 let mut search_from = start;
925 while search_from <= text.len() {
926 match special_regex.find_from_pos(text, search_from) {
927 Ok(Some(m)) => {
928 if allowed_special.contains(&text[m.start()..m.end()]) {
929 next_special = Some((m.start(), m.end()));
930 break;
931 }
932 search_from = m.start() + 1;
934 }
935 _ => break,
936 }
937 }
938
939 let end = next_special.map_or(text.len(), |(s, _)| s);
940
941 self.split_engine.find_pieces(&text[start..end], |piece| {
943 self.emit_piece(piece.as_bytes(), &mut ret);
944 });
945
946 match next_special {
948 Some((s, e)) => {
949 let piece = &text[s..e];
950 if let Some(&tok) = self.special_tokens_encoder.get(piece) {
951 ret.push(tok);
952 }
953 start = e;
954 }
955 None => break,
956 }
957 }
958 ret
959 }
960
961 pub fn encode_single_token(&self, piece: &[u8]) -> Option<Rank> {
963 if let Some(&r) = self.encoder.get(piece) {
964 return Some(r);
965 }
966 if let Ok(s) = std::str::from_utf8(piece) {
967 if let Some(&r) = self.special_tokens_encoder.get(s) {
968 return Some(r);
969 }
970 }
971 None
972 }
973
974 pub fn decode_bytes(&self, tokens: &[Rank]) -> Vec<u8> {
980 let mut ret = Vec::with_capacity(tokens.len() * 2);
981 for &token in tokens {
982 if let Some(bytes) = self.decoder.get(&token) {
983 ret.extend_from_slice(bytes);
984 } else if let Some(bytes) = self.special_tokens_decoder.get(&token) {
985 ret.extend_from_slice(bytes);
986 }
987 }
988 ret
989 }
990
991 pub fn decode(&self, tokens: &[Rank]) -> Result<String, DecodeError> {
997 String::from_utf8(self.decode_bytes(tokens)).map_err(|_| DecodeError::InvalidUtf8)
998 }
999
1000 pub fn decode_single_token_bytes(&self, token: Rank) -> Result<Vec<u8>, DecodeError> {
1003 if let Some(bytes) = self.decoder.get(&token) {
1004 return Ok(bytes.clone());
1005 }
1006 if let Some(bytes) = self.special_tokens_decoder.get(&token) {
1007 return Ok(bytes.clone());
1008 }
1009 Err(DecodeError::InvalidToken(token))
1010 }
1011}
1012
1013#[cfg(feature = "python")]
1016#[pymethods]
1017impl CoreBPE {
1018 #[new]
1019 #[pyo3(signature = (encoder, special_tokens_encoder, pattern))]
1020 fn py_new(
1021 encoder: HashMap<Vec<u8>, Rank>,
1022 special_tokens_encoder: HashMap<String, Rank>,
1023 pattern: &str,
1024 ) -> PyResult<Self> {
1025 Self::new(encoder, special_tokens_encoder, pattern)
1026 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
1027 }
1028
1029 #[pyo3(name = "encode_ordinary")]
1030 fn py_encode_ordinary(&self, py: Python<'_>, text: &str) -> Vec<Rank> {
1031 py.detach(|| self.encode_ordinary(text))
1032 }
1033
1034 #[pyo3(name = "encode", signature = (text, allowed_special = None))]
1035 fn py_encode(
1036 &self,
1037 py: Python<'_>,
1038 text: &str,
1039 allowed_special: Option<HashSet<String>>,
1040 ) -> Vec<Rank> {
1041 py.detach(|| {
1042 let allowed = allowed_special.unwrap_or_default();
1043 let allowed_refs: HashSet<&str> = allowed.iter().map(|s| s.as_str()).collect();
1044 self.encode(text, &allowed_refs)
1045 })
1046 }
1047
1048 #[pyo3(name = "encode_ordinary_batch")]
1049 fn py_encode_ordinary_batch(&self, py: Python<'_>, texts: Vec<String>) -> Vec<Vec<Rank>> {
1050 py.detach(|| {
1051 let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
1052 self.encode_ordinary_batch(&refs)
1053 })
1054 }
1055
1056 #[pyo3(name = "encode_batch", signature = (texts, allowed_special = None))]
1057 fn py_encode_batch(
1058 &self,
1059 py: Python<'_>,
1060 texts: Vec<String>,
1061 allowed_special: Option<HashSet<String>>,
1062 ) -> Vec<Vec<Rank>> {
1063 py.detach(|| {
1064 let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
1065 let allowed = allowed_special.unwrap_or_default();
1066 let allowed_refs: HashSet<&str> = allowed.iter().map(|s| s.as_str()).collect();
1067 self.encode_batch(&refs, &allowed_refs)
1068 })
1069 }
1070
1071 #[pyo3(name = "encode_single_token")]
1072 fn py_encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> {
1073 self.encode_single_token(piece)
1074 .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("token not found"))
1075 }
1076
1077 #[pyo3(name = "decode_bytes")]
1078 fn py_decode_bytes<'py>(
1079 &self,
1080 py: Python<'py>,
1081 tokens: Vec<Rank>,
1082 ) -> pyo3::Bound<'py, pyo3::types::PyBytes> {
1083 let bytes = py.detach(|| self.decode_bytes(&tokens));
1084 pyo3::types::PyBytes::new(py, &bytes)
1085 }
1086
1087 #[pyo3(name = "decode")]
1094 fn py_decode(&self, py: Python<'_>, tokens: Vec<Rank>) -> String {
1095 py.detach(|| {
1096 let bytes = self.decode_bytes(&tokens);
1097 String::from_utf8_lossy(&bytes).into_owned()
1098 })
1099 }
1100
1101 #[pyo3(name = "decode_single_token_bytes")]
1102 fn py_decode_single_token_bytes<'py>(
1103 &self,
1104 py: Python<'py>,
1105 token: Rank,
1106 ) -> PyResult<pyo3::Bound<'py, pyo3::types::PyBytes>> {
1107 let bytes = self
1108 .decode_single_token_bytes(token)
1109 .map_err(|e| pyo3::exceptions::PyKeyError::new_err(e.to_string()))?;
1110 Ok(pyo3::types::PyBytes::new(py, &bytes))
1111 }
1112
1113 #[pyo3(name = "n_vocab")]
1114 fn py_n_vocab(&self) -> usize {
1115 self.n_vocab()
1116 }
1117
1118 #[pyo3(name = "token_byte_values")]
1119 fn py_token_byte_values<'py>(
1120 &self,
1121 py: Python<'py>,
1122 ) -> Vec<pyo3::Bound<'py, pyo3::types::PyBytes>> {
1123 self.sorted_token_bytes
1124 .iter()
1125 .map(|b| pyo3::types::PyBytes::new(py, b))
1126 .collect()
1127 }
1128}
1129
1130#[cfg(feature = "python")]
1131#[pymodule]
1132fn _riptoken(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
1133 m.add_class::<CoreBPE>()?;
1134 Ok(())
1135}
1136
1137#[cfg(test)]
1140mod tests {
1141 use super::*;
1142
1143 fn toy_bpe() -> CoreBPE {
1144 let mut encoder = HashMap::default();
1145 for (i, b) in b"helo ".iter().enumerate() {
1146 encoder.insert(vec![*b], i as Rank);
1147 }
1148 encoder.insert(b"he".to_vec(), 100);
1149 encoder.insert(b"ll".to_vec(), 101);
1150 CoreBPE::new(encoder, HashMap::default(), r"\w+| ").unwrap()
1151 }
1152
1153 #[test]
1154 fn merge_empty_piece() {
1155 let ranks: HashMap<Vec<u8>, Rank> = HashMap::default();
1156 let result = byte_pair_merge(&ranks, b"");
1157 assert_eq!(result, vec![(0, Rank::MAX), (0, Rank::MAX)]);
1158 }
1159
1160 #[test]
1161 fn merge_single_byte() {
1162 let ranks: HashMap<Vec<u8>, Rank> = HashMap::default();
1163 let result = byte_pair_merge(&ranks, b"a");
1164 assert_eq!(result, vec![(0, Rank::MAX), (1, Rank::MAX)]);
1165 }
1166
1167 #[test]
1168 fn merge_two_byte_exact_match() {
1169 let mut ranks = HashMap::default();
1170 ranks.insert(b"ab".to_vec(), 5);
1171 let result = byte_pair_merge(&ranks, b"ab");
1172 let positions: Vec<usize> = result.iter().map(|&(p, _)| p).collect();
1173 assert_eq!(positions, vec![0, 2]);
1174 }
1175
1176 #[test]
1177 fn merge_no_vocab_matches() {
1178 let ranks: HashMap<Vec<u8>, Rank> = HashMap::default();
1179 let result = byte_pair_merge(&ranks, b"abcd");
1180 let positions: Vec<usize> = result.iter().map(|&(p, _)| p).collect();
1181 assert_eq!(positions, vec![0, 1, 2, 3, 4]);
1183 }
1184
1185 #[test]
1186 fn merge_cascade() {
1187 let mut ranks = HashMap::default();
1188 ranks.insert(b"ab".to_vec(), 0);
1189 ranks.insert(b"cd".to_vec(), 1);
1190 let result = byte_pair_merge(&ranks, b"abcd");
1191 let positions: Vec<usize> = result.iter().map(|&(p, _)| p).collect();
1192 assert_eq!(positions, vec![0, 2, 4]);
1193 }
1194
1195 #[test]
1196 fn encode_toy() {
1197 let bpe = toy_bpe();
1198 let tokens = bpe.encode_ordinary("hello");
1199 assert_eq!(tokens, vec![100, 101, 3]);
1201 }
1202
1203 #[test]
1204 fn roundtrip_toy() {
1205 let bpe = toy_bpe();
1206 let text = "hello";
1207 let tokens = bpe.encode_ordinary(text);
1208 let decoded = bpe.decode_bytes(&tokens);
1209 assert_eq!(decoded, text.as_bytes());
1210 assert_eq!(bpe.decode(&tokens).unwrap(), text);
1211 }
1212
1213 #[test]
1214 fn encode_single_token_and_lookup() {
1215 let bpe = toy_bpe();
1216 assert_eq!(bpe.encode_single_token(b"he"), Some(100));
1217 assert_eq!(bpe.encode_single_token(b"zz"), None);
1218 assert_eq!(bpe.decode_single_token_bytes(100).unwrap(), b"he".to_vec());
1219 assert!(bpe.decode_single_token_bytes(9999).is_err());
1220 }
1221
1222 #[test]
1223 fn n_vocab_counts_everything() {
1224 let mut encoder = HashMap::default();
1225 encoder.insert(b"a".to_vec(), 0);
1226 encoder.insert(b"b".to_vec(), 1);
1227 let mut specials = HashMap::default();
1228 specials.insert("<|endoftext|>".to_string(), 2);
1229 let bpe = CoreBPE::new(encoder, specials, r"\w+").unwrap();
1230 assert_eq!(bpe.n_vocab(), 3);
1231 }
1232
1233 #[test]
1234 fn encode_with_allowed_special() {
1235 let mut encoder = HashMap::default();
1236 for b in b"abcdefghijklmnopqrstuvwxyz <>|" {
1237 encoder.insert(vec![*b], *b as Rank);
1238 }
1239 let mut specials = HashMap::default();
1240 specials.insert("<|eot|>".to_string(), 999);
1241 let bpe = CoreBPE::new(encoder, specials, r"\w+|[<|>]").unwrap();
1242
1243 let allowed: HashSet<&str> = std::iter::once("<|eot|>").collect();
1244 let tokens = bpe.encode("ab<|eot|>cd", &allowed);
1245 assert!(tokens.contains(&999));
1246
1247 let empty: HashSet<&str> = HashSet::new();
1250 let tokens = bpe.encode("ab<|eot|>cd", &empty);
1251 assert!(!tokens.contains(&999));
1252 }
1253
1254 #[test]
1255 fn fast_engine_kicks_in_on_tiktoken_patterns() {
1256 let o200k = r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+";
1260 let engine = SplitEngine::new(o200k).unwrap();
1261 assert!(engine.is_fast(), "o200k_base should use fast engine");
1262
1263 let simple = SplitEngine::new(r"\w+|\s+").unwrap();
1265 assert!(simple.is_fast());
1266 }
1267
1268 #[test]
1269 #[cfg(feature = "precompiled-dfa")]
1270 fn prebuilt_dfa_used_for_stock_patterns() {
1271 let o200k_raw = concat!(
1273 r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
1274 r"|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
1275 r"|\p{N}{1,3}",
1276 r"| ?[^\s\p{L}\p{N}]+[\r\n/]*",
1277 r"|\s*[\r\n]+",
1278 r"|\s+(?!\S)|\s+",
1279 );
1280 let engine = SplitEngine::new(o200k_raw).unwrap();
1281 assert!(
1282 engine.is_precompiled(),
1283 "o200k_base stock pattern should use prebuilt DFA"
1284 );
1285
1286 let gpt2_raw =
1287 r"'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s";
1288 let engine = SplitEngine::new(gpt2_raw).unwrap();
1289 assert!(
1290 engine.is_precompiled(),
1291 "gpt2 stock pattern should use prebuilt DFA"
1292 );
1293
1294 let cl100k_raw = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s";
1295 let engine = SplitEngine::new(cl100k_raw).unwrap();
1296 assert!(
1297 engine.is_precompiled(),
1298 "cl100k_base stock pattern should use prebuilt DFA"
1299 );
1300
1301 let custom = SplitEngine::new(r"\w+|\s+").unwrap();
1303 assert!(custom.is_fast(), "custom pattern should use fast engine");
1304 }
1305
1306 #[test]
1307 fn whitespace_shrink_matches_tiktoken_behavior() {
1308 let mut encoder: HashMap<Vec<u8>, Rank> = HashMap::default();
1311 for b in 0u8..=255 {
1312 encoder.insert(vec![b], b as Rank);
1313 }
1314 encoder.insert(b" hello".to_vec(), 1000);
1317 encoder.insert(b"hello".to_vec(), 1001);
1318
1319 let pattern = r"[^\r\n\p{L}\p{N}]?\p{L}+|\s+(?!\S)|\s+";
1323 let bpe = CoreBPE::new(encoder, HashMap::default(), pattern).unwrap();
1324 assert!(bpe.split_engine.is_fast());
1325
1326 let tokens = bpe.encode_ordinary(" hello");
1330 assert_eq!(
1331 tokens,
1332 vec![b' ' as Rank, 1000],
1333 "fast path should replicate `\\s+(?!\\S)` whitespace-shrink behavior"
1334 );
1335
1336 let tokens = bpe.encode_ordinary("hello ");
1339 assert_eq!(tokens, vec![1001, b' ' as Rank]);
1340 }
1341
1342 #[test]
1343 fn whitespace_shrink_unified_mode_includes_newlines() {
1344 let mut encoder: HashMap<Vec<u8>, Rank> = HashMap::default();
1349 for b in 0u8..=255 {
1350 encoder.insert(vec![b], b as Rank);
1351 }
1352 encoder.insert(b" hello".to_vec(), 1000);
1353 encoder.insert(b"hello".to_vec(), 1001);
1354
1355 let pattern = r" ?\p{L}+|\s+$|\s+(?!\S)|\s";
1357 let bpe = CoreBPE::new(encoder, HashMap::default(), pattern).unwrap();
1358 assert!(bpe.split_engine.is_fast());
1359
1360 let tokens = bpe.encode_ordinary("\n hello");
1364 assert_eq!(
1365 tokens,
1366 vec![b'\n' as Rank, b' ' as Rank, 1000],
1367 "unified shrink mode must fire on whitespace runs that include newlines"
1368 );
1369
1370 let tokens = bpe.encode_ordinary("hi\n");
1373 assert_eq!(tokens, vec![b'h' as Rank, b'i' as Rank, b'\n' as Rank]);
1374 }
1375
1376 #[test]
1377 fn batch_encode_matches_sequential() {
1378 let bpe = toy_bpe();
1379 let texts = vec!["hello", "hello world", "the lazy fox"];
1380 let batch = bpe.encode_ordinary_batch(&texts);
1381 let seq: Vec<Vec<Rank>> = texts.iter().map(|t| bpe.encode_ordinary(t)).collect();
1382 assert_eq!(batch, seq);
1383
1384 let empty: HashSet<&str> = HashSet::new();
1386 let batch_sp = bpe.encode_batch(&texts, &empty);
1387 assert_eq!(batch_sp, seq);
1388 }
1389
1390 #[test]
1391 fn large_piece_matches_small_piece() {
1392 let mut ranks = HashMap::default();
1395 for b in 0u8..=255 {
1397 ranks.insert(vec![b], b as Rank);
1398 }
1399 ranks.insert(b"ab".to_vec(), 300);
1401 ranks.insert(b"cd".to_vec(), 301);
1402 ranks.insert(b"abcd".to_vec(), 302);
1403
1404 let piece = b"abcdabcdabcdabcd";
1405 let small = {
1406 let pos = byte_pair_merge(&ranks, piece);
1407 pos.windows(2)
1408 .map(|w| rank_of(&ranks, &piece[w[0].0..w[1].0]))
1409 .collect::<Vec<_>>()
1410 };
1411 let large = byte_pair_merge_large(&ranks, piece);
1412 assert_eq!(small, large, "heap and vec paths disagree");
1413 }
1414}