use fancy_regex::Regex as FancyRegex;
use regex::Regex as FastRegex;
use rustc_hash::{FxHashMap as HashMap, FxHasher};
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
#[cfg(feature = "python")]
use pyo3::prelude::*;
pub type Rank = u32;
const MAX_NUM_THREADS: usize = 128;
const LARGE_PIECE_THRESHOLD: usize = 500;
thread_local! {
static THREAD_INDEX: usize = {
let mut h = FxHasher::default();
std::thread::current().id().hash(&mut h);
(h.finish() as usize) % MAX_NUM_THREADS
};
}
#[inline]
fn thread_index() -> usize {
THREAD_INDEX.with(|&i| i)
}
#[derive(Debug)]
pub enum BuildError {
InvalidRegex(String),
VocabularyMismatch,
}
impl std::fmt::Display for BuildError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BuildError::InvalidRegex(e) => write!(f, "invalid regex pattern: {e}"),
BuildError::VocabularyMismatch => write!(
f,
"vocabulary has duplicate entries (encoder/decoder size mismatch)"
),
}
}
}
impl std::error::Error for BuildError {}
impl From<fancy_regex::Error> for BuildError {
fn from(e: fancy_regex::Error) -> Self {
BuildError::InvalidRegex(e.to_string())
}
}
#[derive(Debug)]
pub enum DecodeError {
InvalidToken(Rank),
InvalidUtf8,
}
impl std::fmt::Display for DecodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DecodeError::InvalidToken(t) => write!(f, "invalid token id: {t}"),
DecodeError::InvalidUtf8 => write!(f, "decoded bytes are not valid UTF-8"),
}
}
}
impl std::error::Error for DecodeError {}
#[derive(Clone, Copy, PartialEq, Debug)]
enum ShrinkMode {
None,
PlainOnly,
Unified,
}
fn try_transform_for_fast_regex(pattern: &str) -> Option<(FastRegex, ShrinkMode)> {
let shrink_mode = if pattern.contains(r"\s+(?!\S)|\s+") {
ShrinkMode::PlainOnly
} else if pattern.contains(r"\s+(?!\S)|\s") {
ShrinkMode::Unified
} else {
ShrinkMode::None
};
let mut stripped = pattern.replace(r"\s+(?!\S)|\s+", r"\s+");
stripped = stripped.replace(r"\s+(?!\S)|\s", r"\s+");
if stripped.contains("(?=")
|| stripped.contains("(?!")
|| stripped.contains("(?<=")
|| stripped.contains("(?<!")
{
return None;
}
stripped = stripped
.replace("?+", "?")
.replace("++", "+")
.replace("*+", "*");
let range_possessive = FastRegex::new(r"(\{\d+(?:,\d*)?\})\+").ok()?;
let stripped = range_possessive.replace_all(&stripped, "$1").into_owned();
let regex = FastRegex::new(&stripped).ok()?;
Some((regex, shrink_mode))
}
#[inline]
fn is_plain_whitespace_run(s: &str) -> bool {
!s.is_empty()
&& s.chars()
.all(|c| c.is_whitespace() && c != '\n' && c != '\r')
}
#[inline]
fn is_whitespace_run(s: &str) -> bool {
!s.is_empty() && s.chars().all(|c| c.is_whitespace())
}
#[inline]
fn next_char_is_non_whitespace(text: &str, pos: usize) -> bool {
match text[pos..].chars().next() {
Some(c) => !c.is_whitespace(),
None => false,
}
}
enum SplitEngine {
Fast {
clones: Vec<FastRegex>,
shrink_mode: ShrinkMode,
},
Fancy(Vec<FancyRegex>),
}
impl SplitEngine {
fn new(pattern: &str) -> Result<Self, BuildError> {
if let Some((fast, shrink_mode)) = try_transform_for_fast_regex(pattern) {
let clones: Vec<FastRegex> = (0..MAX_NUM_THREADS).map(|_| fast.clone()).collect();
return Ok(SplitEngine::Fast {
clones,
shrink_mode,
});
}
let fancy = FancyRegex::new(pattern)?;
let clones: Vec<FancyRegex> = (0..MAX_NUM_THREADS).map(|_| fancy.clone()).collect();
Ok(SplitEngine::Fancy(clones))
}
#[cfg(test)]
fn is_fast(&self) -> bool {
matches!(self, SplitEngine::Fast { .. })
}
#[inline]
fn find_pieces<F: FnMut(&str)>(&self, text: &str, mut f: F) {
match self {
SplitEngine::Fast {
clones,
shrink_mode,
} => {
let regex = &clones[thread_index()];
let mut pos = 0;
while pos < text.len() {
let m = match regex.find_at(text, pos) {
Some(m) => m,
None => break,
};
if m.start() > pos {
pos = m.start();
}
let start = m.start();
let mut end = m.end();
let piece = &text[start..end];
let should_shrink = match shrink_mode {
ShrinkMode::None => false,
ShrinkMode::PlainOnly => is_plain_whitespace_run(piece),
ShrinkMode::Unified => is_whitespace_run(piece),
};
if should_shrink && end < text.len() && next_char_is_non_whitespace(text, end) {
if let Some((last_i, _)) = piece.char_indices().next_back() {
if last_i > 0 {
end = start + last_i;
} else {
}
}
}
f(&text[start..end]);
if end == pos {
pos += 1;
} else {
pos = end;
}
}
}
SplitEngine::Fancy(clones) => {
let regex = &clones[thread_index()];
for mat in regex.find_iter(text) {
match mat {
Ok(m) => f(m.as_str()),
Err(_) => continue,
}
}
}
}
}
}
#[cfg_attr(feature = "python", pyclass(module = "riptoken._riptoken"))]
pub struct CoreBPE {
encoder: HashMap<Vec<u8>, Rank>,
decoder: HashMap<Rank, Vec<u8>>,
special_tokens_encoder: HashMap<String, Rank>,
special_tokens_decoder: HashMap<Rank, Vec<u8>>,
split_engine: SplitEngine,
special_regex_tls: Vec<FancyRegex>,
sorted_token_bytes: Vec<Vec<u8>>,
}
#[inline(always)]
fn rank_of(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Rank {
ranks.get(piece).copied().unwrap_or(Rank::MAX)
}
#[inline]
fn byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
if piece.len() < 2 {
return vec![(0, Rank::MAX), (piece.len(), Rank::MAX)];
}
let mut parts: Vec<(usize, Rank)> = Vec::with_capacity(piece.len() + 1);
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
for i in 0..piece.len() - 1 {
let rank = rank_of(ranks, &piece[i..i + 2]);
if rank < min_rank.0 {
min_rank = (rank, i);
}
parts.push((i, rank));
}
parts.push((piece.len() - 1, Rank::MAX));
parts.push((piece.len(), Rank::MAX));
let get_rank = |parts: &[(usize, Rank)], i: usize| -> Rank {
if i + 3 < parts.len() {
rank_of(ranks, &piece[parts[i].0..parts[i + 3].0])
} else {
Rank::MAX
}
};
while min_rank.0 != Rank::MAX {
let i = min_rank.1;
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1);
}
parts[i].1 = get_rank(&parts, i);
parts.remove(i + 1);
min_rank = (Rank::MAX, usize::MAX);
for (j, &(_, rank)) in parts[..parts.len() - 2].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, j);
}
}
}
parts
}
#[inline]
fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
if piece.len() == 1 {
return vec![*ranks.get(piece).expect("byte fallback")];
}
if piece.len() < LARGE_PIECE_THRESHOLD {
let positions = byte_pair_merge(ranks, piece);
let mut out: Vec<Rank> = Vec::with_capacity(positions.len() - 1);
out.extend(
positions
.windows(2)
.map(|w| rank_of(ranks, &piece[w[0].0..w[1].0])),
);
out
} else {
byte_pair_merge_large(ranks, piece)
}
}
fn byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<Rank> {
use std::cmp::Reverse;
use std::collections::BinaryHeap;
#[derive(Clone)]
struct State {
prev: usize,
end: usize,
cur_rank: Rank,
}
let n = piece.len();
let mut state: Vec<State> = (0..n)
.map(|i| State {
prev: if i == 0 { usize::MAX } else { i - 1 },
end: i + 1,
cur_rank: 0,
})
.collect();
let mut heap: BinaryHeap<(Reverse<Rank>, usize)> = BinaryHeap::with_capacity(n);
for i in 0..n.saturating_sub(1) {
let rank = rank_of(ranks, &piece[i..state[i + 1].end]);
state[i].cur_rank = rank;
if rank != Rank::MAX {
heap.push((Reverse(rank), i));
}
}
while let Some((Reverse(rank), start)) = heap.pop() {
if state[start].cur_rank != rank || rank == Rank::MAX {
continue;
}
let right = state[start].end;
if right >= n {
continue;
}
let new_end = state[right].end;
state[start].end = new_end;
if new_end < n {
state[new_end].prev = start;
}
state[right].cur_rank = Rank::MAX;
let next_end = state[start].end;
if next_end < n {
let new_rank = rank_of(ranks, &piece[start..state[next_end].end]);
state[start].cur_rank = new_rank;
if new_rank != Rank::MAX {
heap.push((Reverse(new_rank), start));
}
} else {
state[start].cur_rank = Rank::MAX;
}
let prev = state[start].prev;
if prev != usize::MAX {
let prev_next_end = state[prev].end; debug_assert_eq!(prev_next_end, start);
let span_end = state[start].end;
let new_rank = rank_of(ranks, &piece[prev..span_end]);
state[prev].cur_rank = new_rank;
if new_rank != Rank::MAX {
heap.push((Reverse(new_rank), prev));
}
}
}
let mut tokens = Vec::new();
let mut i = 0;
while i < n {
let end = state[i].end;
tokens.push(rank_of(ranks, &piece[i..end]));
i = end;
}
tokens
}
fn build_special_regex(specials: &HashMap<String, Rank>) -> Result<Option<FancyRegex>, BuildError> {
if specials.is_empty() {
return Ok(None);
}
let parts: Vec<String> = specials
.keys()
.map(|s| fancy_regex::escape(s).into_owned())
.collect();
let pattern = parts.join("|");
Ok(Some(FancyRegex::new(&pattern)?))
}
impl CoreBPE {
pub fn new(
encoder: HashMap<Vec<u8>, Rank>,
special_tokens_encoder: HashMap<String, Rank>,
pattern: &str,
) -> Result<Self, BuildError> {
let split_engine = SplitEngine::new(pattern)?;
let decoder: HashMap<Rank, Vec<u8>> =
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
if decoder.len() != encoder.len() {
return Err(BuildError::VocabularyMismatch);
}
let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
.iter()
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
.collect();
let special_regex = build_special_regex(&special_tokens_encoder)?;
let special_regex_tls: Vec<FancyRegex> = match special_regex {
Some(r) => (0..MAX_NUM_THREADS).map(|_| r.clone()).collect(),
None => Vec::new(),
};
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort();
Ok(CoreBPE {
encoder,
decoder,
special_tokens_encoder,
special_tokens_decoder,
split_engine,
special_regex_tls,
sorted_token_bytes,
})
}
pub fn n_vocab(&self) -> usize {
let max_ordinary = self.encoder.values().copied().max().unwrap_or(0);
let max_special = self
.special_tokens_encoder
.values()
.copied()
.max()
.unwrap_or(0);
max_ordinary.max(max_special) as usize + 1
}
pub fn token_byte_values(&self) -> &[Vec<u8>] {
&self.sorted_token_bytes
}
#[inline]
fn tl_special_regex(&self) -> Option<&FancyRegex> {
self.special_regex_tls.get(thread_index())
}
#[inline]
fn emit_piece(&self, piece: &[u8], out: &mut Vec<Rank>) {
if let Some(&token) = self.encoder.get(piece) {
out.push(token);
return;
}
out.extend(byte_pair_encode(piece, &self.encoder));
}
pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
let mut ret = Vec::with_capacity(text.len() / 3 + 1);
self.split_engine.find_pieces(text, |piece| {
self.emit_piece(piece.as_bytes(), &mut ret);
});
ret
}
pub fn encode_ordinary_batch(&self, texts: &[&str]) -> Vec<Vec<Rank>> {
use rayon::prelude::*;
texts.par_iter().map(|t| self.encode_ordinary(t)).collect()
}
pub fn encode_batch(&self, texts: &[&str], allowed_special: &HashSet<&str>) -> Vec<Vec<Rank>> {
use rayon::prelude::*;
texts
.par_iter()
.map(|t| self.encode(t, allowed_special))
.collect()
}
pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> Vec<Rank> {
let special_regex = match self.tl_special_regex() {
Some(r) => r,
None => return self.encode_ordinary(text),
};
let mut ret = Vec::new();
let mut start = 0usize;
loop {
let mut next_special: Option<(usize, usize)> = None;
let mut search_from = start;
while search_from <= text.len() {
match special_regex.find_from_pos(text, search_from) {
Ok(Some(m)) => {
if allowed_special.contains(&text[m.start()..m.end()]) {
next_special = Some((m.start(), m.end()));
break;
}
search_from = m.start() + 1;
}
_ => break,
}
}
let end = next_special.map_or(text.len(), |(s, _)| s);
self.split_engine.find_pieces(&text[start..end], |piece| {
self.emit_piece(piece.as_bytes(), &mut ret);
});
match next_special {
Some((s, e)) => {
let piece = &text[s..e];
if let Some(&tok) = self.special_tokens_encoder.get(piece) {
ret.push(tok);
}
start = e;
}
None => break,
}
}
ret
}
pub fn encode_single_token(&self, piece: &[u8]) -> Option<Rank> {
if let Some(&r) = self.encoder.get(piece) {
return Some(r);
}
if let Ok(s) = std::str::from_utf8(piece) {
if let Some(&r) = self.special_tokens_encoder.get(s) {
return Some(r);
}
}
None
}
pub fn decode_bytes(&self, tokens: &[Rank]) -> Vec<u8> {
let mut ret = Vec::with_capacity(tokens.len() * 2);
for &token in tokens {
if let Some(bytes) = self.decoder.get(&token) {
ret.extend_from_slice(bytes);
} else if let Some(bytes) = self.special_tokens_decoder.get(&token) {
ret.extend_from_slice(bytes);
}
}
ret
}
pub fn decode(&self, tokens: &[Rank]) -> Result<String, DecodeError> {
String::from_utf8(self.decode_bytes(tokens)).map_err(|_| DecodeError::InvalidUtf8)
}
pub fn decode_single_token_bytes(&self, token: Rank) -> Result<Vec<u8>, DecodeError> {
if let Some(bytes) = self.decoder.get(&token) {
return Ok(bytes.clone());
}
if let Some(bytes) = self.special_tokens_decoder.get(&token) {
return Ok(bytes.clone());
}
Err(DecodeError::InvalidToken(token))
}
}
#[cfg(feature = "python")]
#[pymethods]
impl CoreBPE {
#[new]
#[pyo3(signature = (encoder, special_tokens_encoder, pattern))]
fn py_new(
encoder: HashMap<Vec<u8>, Rank>,
special_tokens_encoder: HashMap<String, Rank>,
pattern: &str,
) -> PyResult<Self> {
Self::new(encoder, special_tokens_encoder, pattern)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
}
#[pyo3(name = "encode_ordinary")]
fn py_encode_ordinary(&self, py: Python<'_>, text: &str) -> Vec<Rank> {
py.detach(|| self.encode_ordinary(text))
}
#[pyo3(name = "encode")]
fn py_encode(&self, py: Python<'_>, text: &str, allowed_special: HashSet<String>) -> Vec<Rank> {
py.detach(|| {
let allowed_refs: HashSet<&str> = allowed_special.iter().map(|s| s.as_str()).collect();
self.encode(text, &allowed_refs)
})
}
#[pyo3(name = "encode_ordinary_batch")]
fn py_encode_ordinary_batch(&self, py: Python<'_>, texts: Vec<String>) -> Vec<Vec<Rank>> {
py.detach(|| {
let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
self.encode_ordinary_batch(&refs)
})
}
#[pyo3(name = "encode_batch")]
fn py_encode_batch(
&self,
py: Python<'_>,
texts: Vec<String>,
allowed_special: HashSet<String>,
) -> Vec<Vec<Rank>> {
py.detach(|| {
let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
let allowed_refs: HashSet<&str> = allowed_special.iter().map(|s| s.as_str()).collect();
self.encode_batch(&refs, &allowed_refs)
})
}
#[pyo3(name = "encode_single_token")]
fn py_encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> {
self.encode_single_token(piece)
.ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("token not found"))
}
#[pyo3(name = "decode_bytes")]
fn py_decode_bytes<'py>(
&self,
py: Python<'py>,
tokens: Vec<Rank>,
) -> pyo3::Bound<'py, pyo3::types::PyBytes> {
let bytes = py.detach(|| self.decode_bytes(&tokens));
pyo3::types::PyBytes::new(py, &bytes)
}
#[pyo3(name = "decode_single_token_bytes")]
fn py_decode_single_token_bytes<'py>(
&self,
py: Python<'py>,
token: Rank,
) -> PyResult<pyo3::Bound<'py, pyo3::types::PyBytes>> {
let bytes = self
.decode_single_token_bytes(token)
.map_err(|e| pyo3::exceptions::PyKeyError::new_err(e.to_string()))?;
Ok(pyo3::types::PyBytes::new(py, &bytes))
}
#[pyo3(name = "n_vocab")]
fn py_n_vocab(&self) -> usize {
self.n_vocab()
}
#[pyo3(name = "token_byte_values")]
fn py_token_byte_values<'py>(
&self,
py: Python<'py>,
) -> Vec<pyo3::Bound<'py, pyo3::types::PyBytes>> {
self.sorted_token_bytes
.iter()
.map(|b| pyo3::types::PyBytes::new(py, b))
.collect()
}
}
#[cfg(feature = "python")]
#[pymodule]
fn _riptoken(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<CoreBPE>()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn toy_bpe() -> CoreBPE {
let mut encoder = HashMap::default();
for (i, b) in b"helo ".iter().enumerate() {
encoder.insert(vec![*b], i as Rank);
}
encoder.insert(b"he".to_vec(), 100);
encoder.insert(b"ll".to_vec(), 101);
CoreBPE::new(encoder, HashMap::default(), r"\w+| ").unwrap()
}
#[test]
fn merge_empty_piece() {
let ranks: HashMap<Vec<u8>, Rank> = HashMap::default();
let result = byte_pair_merge(&ranks, b"");
assert_eq!(result, vec![(0, Rank::MAX), (0, Rank::MAX)]);
}
#[test]
fn merge_single_byte() {
let ranks: HashMap<Vec<u8>, Rank> = HashMap::default();
let result = byte_pair_merge(&ranks, b"a");
assert_eq!(result, vec![(0, Rank::MAX), (1, Rank::MAX)]);
}
#[test]
fn merge_two_byte_exact_match() {
let mut ranks = HashMap::default();
ranks.insert(b"ab".to_vec(), 5);
let result = byte_pair_merge(&ranks, b"ab");
let positions: Vec<usize> = result.iter().map(|&(p, _)| p).collect();
assert_eq!(positions, vec![0, 2]);
}
#[test]
fn merge_no_vocab_matches() {
let ranks: HashMap<Vec<u8>, Rank> = HashMap::default();
let result = byte_pair_merge(&ranks, b"abcd");
let positions: Vec<usize> = result.iter().map(|&(p, _)| p).collect();
assert_eq!(positions, vec![0, 1, 2, 3, 4]);
}
#[test]
fn merge_cascade() {
let mut ranks = HashMap::default();
ranks.insert(b"ab".to_vec(), 0);
ranks.insert(b"cd".to_vec(), 1);
let result = byte_pair_merge(&ranks, b"abcd");
let positions: Vec<usize> = result.iter().map(|&(p, _)| p).collect();
assert_eq!(positions, vec![0, 2, 4]);
}
#[test]
fn encode_toy() {
let bpe = toy_bpe();
let tokens = bpe.encode_ordinary("hello");
assert_eq!(tokens, vec![100, 101, 3]);
}
#[test]
fn roundtrip_toy() {
let bpe = toy_bpe();
let text = "hello";
let tokens = bpe.encode_ordinary(text);
let decoded = bpe.decode_bytes(&tokens);
assert_eq!(decoded, text.as_bytes());
assert_eq!(bpe.decode(&tokens).unwrap(), text);
}
#[test]
fn encode_single_token_and_lookup() {
let bpe = toy_bpe();
assert_eq!(bpe.encode_single_token(b"he"), Some(100));
assert_eq!(bpe.encode_single_token(b"zz"), None);
assert_eq!(bpe.decode_single_token_bytes(100).unwrap(), b"he".to_vec());
assert!(bpe.decode_single_token_bytes(9999).is_err());
}
#[test]
fn n_vocab_counts_everything() {
let mut encoder = HashMap::default();
encoder.insert(b"a".to_vec(), 0);
encoder.insert(b"b".to_vec(), 1);
let mut specials = HashMap::default();
specials.insert("<|endoftext|>".to_string(), 2);
let bpe = CoreBPE::new(encoder, specials, r"\w+").unwrap();
assert_eq!(bpe.n_vocab(), 3);
}
#[test]
fn encode_with_allowed_special() {
let mut encoder = HashMap::default();
for b in b"abcdefghijklmnopqrstuvwxyz <>|" {
encoder.insert(vec![*b], *b as Rank);
}
let mut specials = HashMap::default();
specials.insert("<|eot|>".to_string(), 999);
let bpe = CoreBPE::new(encoder, specials, r"\w+|[<|>]").unwrap();
let allowed: HashSet<&str> = std::iter::once("<|eot|>").collect();
let tokens = bpe.encode("ab<|eot|>cd", &allowed);
assert!(tokens.contains(&999));
let empty: HashSet<&str> = HashSet::new();
let tokens = bpe.encode("ab<|eot|>cd", &empty);
assert!(!tokens.contains(&999));
}
#[test]
fn fast_engine_kicks_in_on_tiktoken_patterns() {
let o200k = r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+";
let engine = SplitEngine::new(o200k).unwrap();
assert!(engine.is_fast(), "o200k_base should use fast engine");
let simple = SplitEngine::new(r"\w+|\s+").unwrap();
assert!(simple.is_fast());
}
#[test]
fn whitespace_shrink_matches_tiktoken_behavior() {
let mut encoder: HashMap<Vec<u8>, Rank> = HashMap::default();
for b in 0u8..=255 {
encoder.insert(vec![b], b as Rank);
}
encoder.insert(b" hello".to_vec(), 1000);
encoder.insert(b"hello".to_vec(), 1001);
let pattern = r"[^\r\n\p{L}\p{N}]?\p{L}+|\s+(?!\S)|\s+";
let bpe = CoreBPE::new(encoder, HashMap::default(), pattern).unwrap();
assert!(bpe.split_engine.is_fast());
let tokens = bpe.encode_ordinary(" hello");
assert_eq!(
tokens,
vec![b' ' as Rank, 1000],
"fast path should replicate `\\s+(?!\\S)` whitespace-shrink behavior"
);
let tokens = bpe.encode_ordinary("hello ");
assert_eq!(tokens, vec![1001, b' ' as Rank]);
}
#[test]
fn whitespace_shrink_unified_mode_includes_newlines() {
let mut encoder: HashMap<Vec<u8>, Rank> = HashMap::default();
for b in 0u8..=255 {
encoder.insert(vec![b], b as Rank);
}
encoder.insert(b" hello".to_vec(), 1000);
encoder.insert(b"hello".to_vec(), 1001);
let pattern = r" ?\p{L}+|\s+$|\s+(?!\S)|\s";
let bpe = CoreBPE::new(encoder, HashMap::default(), pattern).unwrap();
assert!(bpe.split_engine.is_fast());
let tokens = bpe.encode_ordinary("\n hello");
assert_eq!(
tokens,
vec![b'\n' as Rank, b' ' as Rank, 1000],
"unified shrink mode must fire on whitespace runs that include newlines"
);
let tokens = bpe.encode_ordinary("hi\n");
assert_eq!(tokens, vec![b'h' as Rank, b'i' as Rank, b'\n' as Rank]);
}
#[test]
fn batch_encode_matches_sequential() {
let bpe = toy_bpe();
let texts = vec!["hello", "hello world", "the lazy fox"];
let batch = bpe.encode_ordinary_batch(&texts);
let seq: Vec<Vec<Rank>> = texts.iter().map(|t| bpe.encode_ordinary(t)).collect();
assert_eq!(batch, seq);
let empty: HashSet<&str> = HashSet::new();
let batch_sp = bpe.encode_batch(&texts, &empty);
assert_eq!(batch_sp, seq);
}
#[test]
fn large_piece_matches_small_piece() {
let mut ranks = HashMap::default();
for b in 0u8..=255 {
ranks.insert(vec![b], b as Rank);
}
ranks.insert(b"ab".to_vec(), 300);
ranks.insert(b"cd".to_vec(), 301);
ranks.insert(b"abcd".to_vec(), 302);
let piece = b"abcdabcdabcdabcd";
let small = {
let pos = byte_pair_merge(&ranks, piece);
pos.windows(2)
.map(|w| rank_of(&ranks, &piece[w[0].0..w[1].0]))
.collect::<Vec<_>>()
};
let large = byte_pair_merge_large(&ranks, piece);
assert_eq!(small, large, "heap and vec paths disagree");
}
}