use foldhash::HashMap as FoldHashMap;
use smallvec::SmallVec;
use crate::types::TokenId;
const MAX_CACHED_TOKEN_LEN: usize = 16;
#[inline(always)]
fn pack_pair(left: TokenId, right: TokenId) -> u64 {
((left as u64) << 32) | (right as u64)
}
#[derive(Clone)]
pub struct BytePairEncoder {
pair_lookup: FoldHashMap<u64, (TokenId, u32)>,
byte_lut: [TokenId; 256],
token_cache: FoldHashMap<Vec<u8>, TokenId>,
vocab_size: usize,
num_base_tokens: usize,
}
impl BytePairEncoder {
pub fn from_merges(
merges: &[(TokenId, TokenId)],
base_tokens: &[Vec<u8>],
) -> (Self, Vec<Vec<u8>>) {
Self::from_merges_with_added(merges, base_tokens, &[])
}
pub fn from_merges_with_added(
merges: &[(TokenId, TokenId)],
base_tokens: &[Vec<u8>],
added_tokens: &[(u32, Vec<u8>)],
) -> (Self, Vec<Vec<u8>>) {
let mut token_bytes: Vec<Vec<u8>> = base_tokens.to_vec();
let mut pair_lookup = FoldHashMap::default();
let mut byte_lut = [0u32; 256];
for (token_id, bytes) in base_tokens.iter().enumerate() {
if bytes.len() == 1 {
byte_lut[bytes[0] as usize] = token_id as TokenId;
}
}
for (i, token) in byte_lut.iter_mut().enumerate() {
if *token == 0 && i < base_tokens.len() {
if base_tokens.get(i).is_some_and(|b| b.len() == 1 && b[0] == i as u8) {
*token = i as TokenId;
}
}
}
let mut added_sorted: Vec<_> = added_tokens.to_vec();
added_sorted.sort_by_key(|(id, _)| *id);
let mut added_iter = added_sorted.into_iter().peekable();
for (merge_index, &(left, right)) in merges.iter().enumerate() {
let next_id = token_bytes.len() as TokenId;
while let Some(&(added_id, _)) = added_iter.peek() {
if added_id <= next_id {
let (_, bytes) = added_iter.next().unwrap();
token_bytes.push(bytes);
} else {
break;
}
}
let new_id = token_bytes.len() as TokenId;
pair_lookup.insert(pack_pair(left, right), (new_id, merge_index as u32));
let mut bytes = token_bytes[left as usize].clone();
bytes.extend_from_slice(&token_bytes[right as usize]);
token_bytes.push(bytes);
}
for (_, bytes) in added_iter {
token_bytes.push(bytes);
}
let mut token_cache = FoldHashMap::default();
for (token_id, bytes) in token_bytes.iter().enumerate() {
if bytes.len() <= MAX_CACHED_TOKEN_LEN {
token_cache.insert(bytes.clone(), token_id as TokenId);
}
}
let vocab_size = token_bytes.len();
let num_base_tokens = base_tokens.len();
let encoder = Self {
pair_lookup,
byte_lut,
token_cache,
vocab_size,
num_base_tokens,
};
(encoder, token_bytes)
}
pub fn from_vocab_and_merges(
vocab: &[(u32, Vec<u8>)],
merges: &[(TokenId, TokenId)],
num_base_tokens: usize,
) -> (Self, Vec<Vec<u8>>) {
let token_bytes: Vec<Vec<u8>> = vocab.iter().map(|(_, bytes)| bytes.clone()).collect();
let mut byte_lut = [u32::MAX; 256];
for (token_id, bytes) in token_bytes.iter().enumerate() {
if bytes.len() == 1 {
let byte_val = bytes[0] as usize;
if byte_lut[byte_val] == u32::MAX {
byte_lut[byte_val] = token_id as TokenId;
}
}
}
let all_bytes_to_id: FoldHashMap<Vec<u8>, TokenId> = vocab
.iter()
.map(|(id, bytes)| (bytes.clone(), *id))
.collect();
let mut pair_lookup = FoldHashMap::default();
for (merge_index, &(left, right)) in merges.iter().enumerate() {
let mut merged_bytes = token_bytes[left as usize].clone();
merged_bytes.extend_from_slice(&token_bytes[right as usize]);
if let Some(&merged_id) = all_bytes_to_id.get(&merged_bytes) {
pair_lookup
.entry(pack_pair(left, right))
.or_insert((merged_id, merge_index as u32));
}
}
let mut token_cache = FoldHashMap::default();
for (token_id, bytes) in token_bytes.iter().enumerate() {
if bytes.len() <= MAX_CACHED_TOKEN_LEN {
token_cache.insert(bytes.clone(), token_id as TokenId);
}
}
let encoder = Self {
pair_lookup,
byte_lut,
token_cache,
vocab_size: vocab.len(),
num_base_tokens,
};
(encoder, token_bytes)
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn num_base_tokens(&self) -> usize {
self.num_base_tokens
}
pub fn pair_lookup(&self) -> &FoldHashMap<u64, (TokenId, u32)> {
&self.pair_lookup
}
#[inline]
pub fn is_valid_pair(&self, token1: TokenId, token2: TokenId) -> bool {
!self.pair_lookup.contains_key(&pack_pair(token1, token2))
}
pub fn from_parts(
merges: &[(TokenId, TokenId, TokenId)], byte_lut: [TokenId; 256],
token_cache: FoldHashMap<Vec<u8>, TokenId>,
vocab_size: usize,
num_base_tokens: usize,
) -> Self {
let mut pair_lookup = FoldHashMap::default();
for (merge_index, &(left, right, merged_id)) in merges.iter().enumerate() {
pair_lookup.insert(pack_pair(left, right), (merged_id, merge_index as u32));
}
Self {
pair_lookup,
byte_lut,
token_cache,
vocab_size,
num_base_tokens,
}
}
#[inline]
pub fn encode(&self, text: &[u8]) -> Vec<TokenId> {
if text.is_empty() {
return Vec::new();
}
if text.len() == 1 {
return vec![self.byte_lut[text[0] as usize]];
}
if text.len() <= MAX_CACHED_TOKEN_LEN {
if let Some(&token_id) = self.token_cache.get(text) {
return vec![token_id];
}
}
let mut tokens: SmallVec<[TokenId; 16]> = text
.iter()
.map(|&b| self.byte_lut[b as usize])
.collect();
let mut len = tokens.len();
while len > 1 {
let mut best_rank = u32::MAX;
let mut best_pos = usize::MAX;
let mut best_merged = 0;
for i in 0..len - 1 {
if let Some(&(merged, rank)) = self.pair_lookup.get(&pack_pair(tokens[i], tokens[i + 1])) {
if rank < best_rank {
best_rank = rank;
best_pos = i;
best_merged = merged;
}
}
}
if best_pos == usize::MAX {
break; }
tokens[best_pos] = best_merged;
tokens.copy_within(best_pos + 2..len, best_pos + 1);
len -= 1;
}
tokens.truncate(len);
tokens.into_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decoder::VocabDecoder;
#[test]
fn test_encode_basic() {
let base_tokens = vec![vec![b'a'], vec![b'b'], vec![b'c']];
let merges = vec![(0, 1), (3, 2)];
let (encoder, token_bytes) = BytePairEncoder::from_merges(&merges, &base_tokens);
let decoder = VocabDecoder::new(token_bytes);
let encoded = encoder.encode(b"abc");
assert_eq!(encoded, vec![4]);
assert_eq!(decoder.decode(&encoded), b"abc");
}
#[test]
fn test_single_byte_fast_path() {
let base_tokens = vec![vec![b'a'], vec![b'b'], vec![b'c']];
let merges = vec![(0, 1)];
let (encoder, _) = BytePairEncoder::from_merges(&merges, &base_tokens);
assert_eq!(encoder.encode(b"a"), vec![0]);
assert_eq!(encoder.encode(b"b"), vec![1]);
assert_eq!(encoder.encode(b"c"), vec![2]);
}
#[test]
fn test_early_exit_multi_byte_token() {
let base_tokens = vec![vec![b'a'], vec![b'b'], vec![b'c']];
let merges = vec![(0, 1), (3, 2)];
let (encoder, _) = BytePairEncoder::from_merges(&merges, &base_tokens);
assert_eq!(encoder.encode(b"ab"), vec![3]);
assert_eq!(encoder.encode(b"abc"), vec![4]);
}
#[test]
fn test_encode_roundtrip() {
let base_tokens = vec![vec![b'a'], vec![b'b'], vec![b'c'], vec![b'd']];
let merges = vec![(0, 1), (2, 3), (4, 5)];
let (encoder, token_bytes) = BytePairEncoder::from_merges(&merges, &base_tokens);
let decoder = VocabDecoder::new(token_bytes);
for text in [b"abcd".as_slice(), b"ab", b"cd", b"abcdabcd", b"a", b""] {
let encoded = encoder.encode(text);
let decoded = decoder.decode(&encoded);
assert_eq!(decoded, text);
}
}
}