use foldhash::HashMap as FoldHashMap;
use chunk::chunk;
use crate::types::TokenId;
const NONE: u32 = u32::MAX;
const METASPACE: [u8; 3] = [0xE2, 0x96, 0x81];
#[inline(always)]
fn pack_pair(left: TokenId, right: TokenId) -> u64 {
((left as u64) << 32) | (right as u64)
}
#[inline]
fn utf8_char_len(b: u8) -> usize {
if b < 0x80 {
1
} else if b < 0xE0 {
2
} else if b < 0xF0 {
3
} else {
4
}
}
#[derive(Clone, Copy)]
struct Symbol {
token: TokenId,
prev: u32,
next: u32,
len: u16,
}
#[derive(Clone, Copy)]
struct HeapEntry {
key: u64,
right: u32,
size: u32,
}
impl HeapEntry {
#[inline(always)]
fn new(rank: u32, left: u32, right: u32, size: u32) -> Self {
Self {
key: ((rank as u64) << 32) | (left as u64),
right,
size,
}
}
#[inline(always)]
fn left(&self) -> u32 {
self.key as u32
}
#[cfg(test)]
#[inline(always)]
fn rank(&self) -> u32 {
(self.key >> 32) as u32
}
}
struct RadixHeap {
buckets: [Vec<HeapEntry>; 65],
last_min: u64,
len: usize,
overflow: Vec<HeapEntry>,
}
impl RadixHeap {
fn new() -> Self {
Self {
buckets: std::array::from_fn(|_| Vec::new()),
last_min: 0,
len: 0,
overflow: Vec::new(),
}
}
#[inline]
fn bucket_index(&self, key: u64) -> usize {
if key == self.last_min {
0
} else {
let diff = key ^ self.last_min;
(64 - diff.leading_zeros()) as usize
}
}
#[inline]
fn push(&mut self, entry: HeapEntry) {
if entry.key < self.last_min {
self.overflow.push(entry);
} else {
let idx = self.bucket_index(entry.key);
self.buckets[idx].push(entry);
}
self.len += 1;
}
fn pop(&mut self) -> Option<HeapEntry> {
if self.len == 0 {
return None;
}
if !self.overflow.is_empty() {
let mut ov_min_idx = 0;
let mut ov_min_key = self.overflow[0].key;
for (i, entry) in self.overflow.iter().enumerate().skip(1) {
if entry.key < ov_min_key {
ov_min_key = entry.key;
ov_min_idx = i;
}
}
let mut normal_bucket_idx = 0;
while normal_bucket_idx < 65 && self.buckets[normal_bucket_idx].is_empty() {
normal_bucket_idx += 1;
}
let normal_min_key = if normal_bucket_idx < 65 {
if normal_bucket_idx == 0 {
Some(self.last_min)
} else {
self.buckets[normal_bucket_idx].iter().map(|e| e.key).min()
}
} else {
None
};
if normal_min_key.is_none() || ov_min_key <= normal_min_key.unwrap() {
let entry = self.overflow.swap_remove(ov_min_idx);
self.len -= 1;
return Some(entry);
}
}
let mut bucket_idx = 0;
while bucket_idx < 65 && self.buckets[bucket_idx].is_empty() {
bucket_idx += 1;
}
if bucket_idx >= 65 {
return None;
}
if bucket_idx == 0 {
self.len -= 1;
return self.buckets[0].pop();
}
let bucket = &mut self.buckets[bucket_idx];
let mut min_idx = 0;
let mut min_key = bucket[0].key;
for (i, entry) in bucket.iter().enumerate().skip(1) {
if entry.key < min_key {
min_key = entry.key;
min_idx = i;
}
}
self.last_min = min_key;
let min_entry = bucket.swap_remove(min_idx);
let entries: Vec<HeapEntry> = std::mem::take(bucket);
for entry in entries {
let new_idx = self.bucket_index(entry.key);
self.buckets[new_idx].push(entry);
}
self.len -= 1;
Some(min_entry)
}
fn clear(&mut self) {
for bucket in &mut self.buckets {
bucket.clear();
}
self.last_min = 0;
self.len = 0;
self.overflow.clear();
}
}
pub struct EncodeState {
symbols: Vec<Symbol>,
heap: RadixHeap,
result: Vec<TokenId>,
}
impl EncodeState {
pub fn new() -> Self {
Self {
symbols: Vec::new(),
heap: RadixHeap::new(),
result: Vec::new(),
}
}
pub fn with_capacity(text_len: usize) -> Self {
Self {
symbols: Vec::with_capacity(text_len),
heap: RadixHeap::new(),
result: Vec::with_capacity(text_len / 4),
}
}
fn clear(&mut self) {
self.symbols.clear();
self.heap.clear();
self.result.clear();
}
}
impl Default for EncodeState {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct SentencePieceBPE {
pair_lookup: FoldHashMap<u64, (TokenId, u32)>,
max_rank: u32,
num_base_tokens: usize,
vocab_size: usize,
token_cache: FoldHashMap<Vec<u8>, TokenId>,
byte_lut: [TokenId; 256],
token_lengths: Vec<u16>,
}
impl std::fmt::Debug for SentencePieceBPE {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SentencePieceBPE")
.field("vocab_size", &self.vocab_size)
.field("num_base_tokens", &self.num_base_tokens)
.field("merges", &self.pair_lookup.len())
.field("max_rank", &self.max_rank)
.finish()
}
}
impl SentencePieceBPE {
pub fn from_vocab_and_merges(
vocab: &[(u32, Vec<u8>)],
merges: &[(TokenId, TokenId)],
num_base_tokens: usize,
byte_fallback_ids: &foldhash::HashSet<u32>,
) -> (Self, Vec<Vec<u8>>) {
let token_bytes: Vec<Vec<u8>> = vocab.iter().map(|(_, bytes)| bytes.clone()).collect();
let bytes_to_id: FoldHashMap<Vec<u8>, TokenId> = vocab
.iter()
.map(|(id, bytes)| (bytes.clone(), *id))
.collect();
let mut pair_lookup: FoldHashMap<u64, (TokenId, u32)> = FoldHashMap::default();
let mut max_rank = 0u32;
for (merge_rank, &(left, right)) in merges.iter().enumerate() {
let mut merged_bytes = token_bytes[left as usize].clone();
merged_bytes.extend_from_slice(&token_bytes[right as usize]);
if let Some(&merged_id) = bytes_to_id.get(&merged_bytes) {
pair_lookup
.entry(pack_pair(left, right))
.or_insert((merged_id, merge_rank as u32));
max_rank = max_rank.max(merge_rank as u32);
}
}
let mut byte_lut = [u32::MAX; 256];
for (id, bytes) in vocab {
if bytes.len() == 1 {
let byte_val = bytes[0] as usize;
let is_fallback = byte_fallback_ids.contains(id);
if byte_lut[byte_val] == u32::MAX
|| (!is_fallback && byte_fallback_ids.contains(&byte_lut[byte_val]))
{
byte_lut[byte_val] = *id;
}
}
}
let token_lengths: Vec<u16> = token_bytes.iter().map(|b| b.len() as u16).collect();
let mut token_cache: FoldHashMap<Vec<u8>, TokenId> = vocab
.iter()
.filter(|(_, bytes)| bytes.len() > 1)
.map(|(id, bytes)| (bytes.clone(), *id))
.collect();
for (byte_val, &token_id) in byte_lut.iter().enumerate() {
if token_id != u32::MAX {
token_cache.insert(vec![byte_val as u8], token_id);
}
}
let encoder = Self {
pair_lookup,
max_rank,
num_base_tokens,
vocab_size: vocab.len(),
token_cache,
byte_lut,
token_lengths,
};
(encoder, token_bytes)
}
#[inline]
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
#[inline]
pub fn num_base_tokens(&self) -> usize {
self.num_base_tokens
}
#[inline]
pub fn pair_lookup(&self) -> &FoldHashMap<u64, (TokenId, u32)> {
&self.pair_lookup
}
pub fn from_parts(
merges: &[(TokenId, TokenId, TokenId)], byte_lut: [TokenId; 256],
token_cache: FoldHashMap<Vec<u8>, TokenId>,
token_lengths: Vec<u16>,
vocab_size: usize,
num_base_tokens: usize,
) -> Self {
let mut pair_lookup: FoldHashMap<u64, (TokenId, u32)> = FoldHashMap::default();
let mut max_rank = 0u32;
for (merge_rank, &(left, right, merged_id)) in merges.iter().enumerate() {
pair_lookup
.entry(pack_pair(left, right))
.or_insert((merged_id, merge_rank as u32));
max_rank = max_rank.max(merge_rank as u32);
}
Self {
pair_lookup,
max_rank,
num_base_tokens,
vocab_size,
token_cache,
byte_lut,
token_lengths,
}
}
#[inline]
pub fn is_valid_pair(&self, _token1: TokenId, _token2: TokenId) -> bool {
true
}
#[inline]
fn get_merge(&self, left: TokenId, right: TokenId) -> Option<(TokenId, u32)> {
self.pair_lookup.get(&pack_pair(left, right)).copied()
}
pub fn encode(&self, text: &[u8]) -> Vec<TokenId> {
if text.is_empty() {
return Vec::new();
}
if let Some(&token_id) = self.token_cache.get(text) {
return vec![token_id];
}
let mut symbols = self.init_symbols(text);
if symbols.is_empty() {
return Vec::new();
}
let mut heap = RadixHeap::new();
self.init_heap(&symbols, &mut heap);
self.merge_loop(&mut symbols, &mut heap);
self.collect_results(&symbols)
}
pub fn encode_with_state<'a>(&self, text: &[u8], state: &'a mut EncodeState) -> &'a [TokenId] {
state.clear();
if text.is_empty() {
return &state.result;
}
if let Some(&token_id) = self.token_cache.get(text) {
state.result.push(token_id);
return &state.result;
}
self.init_symbols_into(text, &mut state.symbols);
if state.symbols.is_empty() {
return &state.result;
}
self.init_heap(&state.symbols, &mut state.heap);
self.merge_loop(&mut state.symbols, &mut state.heap);
self.collect_results_into(&state.symbols, &mut state.result);
&state.result
}
pub fn encode_chunked(
&self,
text: &[u8],
state: &mut EncodeState,
chunk_size: usize,
) -> Vec<TokenId> {
if text.len() <= chunk_size {
return self.encode_with_state(text, state).to_vec();
}
let mut result = Vec::with_capacity(text.len() / 4);
for chunk_bytes in chunk(text)
.size(chunk_size)
.pattern(&METASPACE)
.prefix()
.consecutive()
.forward_fallback()
{
let chunk_tokens = self.encode_with_state(chunk_bytes, state);
result.extend_from_slice(chunk_tokens);
}
result
}
fn init_symbols(&self, text: &[u8]) -> Vec<Symbol> {
let mut symbols = Vec::with_capacity(text.len());
self.init_symbols_into(text, &mut symbols);
symbols
}
fn init_symbols_into(&self, text: &[u8], symbols: &mut Vec<Symbol>) {
let mut pos = 0;
while pos < text.len() {
let char_len = utf8_char_len(text[pos]);
let end = (pos + char_len).min(text.len());
let char_bytes = &text[pos..end];
let (token, len) = if let Some(&token_id) = self.token_cache.get(char_bytes) {
(token_id, char_bytes.len())
} else {
let byte_token = self.byte_lut[text[pos] as usize];
(byte_token, 1)
};
if token != u32::MAX {
let idx = symbols.len() as u32;
symbols.push(Symbol {
token,
prev: if idx == 0 { NONE } else { idx - 1 },
next: NONE,
len: self.token_lengths.get(token as usize).copied().unwrap_or(len as u16),
});
if idx > 0 {
symbols[(idx - 1) as usize].next = idx;
}
}
pos += len;
}
}
fn init_heap(&self, symbols: &[Symbol], heap: &mut RadixHeap) {
for i in 0..symbols.len().saturating_sub(1) {
let left_sym = &symbols[i];
let right_sym = &symbols[i + 1];
if let Some((_, rank)) = self.get_merge(left_sym.token, right_sym.token) {
heap.push(HeapEntry::new(
rank,
i as u32,
(i + 1) as u32,
left_sym.len as u32 + right_sym.len as u32,
));
}
}
}
fn merge_loop(&self, symbols: &mut [Symbol], heap: &mut RadixHeap) {
while let Some(entry) = heap.pop() {
let left_idx = entry.left() as usize;
let right_idx = entry.right as usize;
let left = &symbols[left_idx];
let right = &symbols[right_idx];
if left.len == 0 || right.len == 0 {
continue; }
if left.next != entry.right {
continue; }
if (left.len as u32 + right.len as u32) != entry.size {
continue; }
let (merged_token, _) = self.get_merge(left.token, right.token).unwrap();
let new_len = left.len + right.len;
let right_next = right.next;
symbols[left_idx].token = merged_token;
symbols[left_idx].len = new_len;
symbols[left_idx].next = right_next;
symbols[right_idx].len = 0;
if right_next != NONE {
symbols[right_next as usize].prev = entry.left();
}
let left_prev = symbols[left_idx].prev;
if left_prev != NONE {
let prev = &symbols[left_prev as usize];
if prev.len > 0 {
if let Some((_, rank)) = self.get_merge(prev.token, merged_token) {
heap.push(HeapEntry::new(
rank,
left_prev,
entry.left(),
prev.len as u32 + new_len as u32,
));
}
}
}
if right_next != NONE {
let next = &symbols[right_next as usize];
if next.len > 0 {
if let Some((_, rank)) = self.get_merge(merged_token, next.token) {
heap.push(HeapEntry::new(
rank,
entry.left(),
right_next,
new_len as u32 + next.len as u32,
));
}
}
}
}
}
fn collect_results(&self, symbols: &[Symbol]) -> Vec<TokenId> {
let mut result = Vec::new();
self.collect_results_into(symbols, &mut result);
result
}
fn collect_results_into(&self, symbols: &[Symbol], result: &mut Vec<TokenId>) {
let mut idx = 0u32;
while idx != NONE && (idx as usize) < symbols.len() {
let sym = &symbols[idx as usize];
if sym.len > 0 {
result.push(sym.token);
}
idx = sym.next;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_merge() {
let vocab = vec![
(0, b"a".to_vec()),
(1, b"b".to_vec()),
(2, b"ab".to_vec()),
];
let merges = vec![(0, 1)];
let (encoder, _) = SentencePieceBPE::from_vocab_and_merges(&vocab, &merges, 2, &Default::default());
assert_eq!(encoder.encode(b"ab"), vec![2]);
assert_eq!(encoder.encode(b"a"), vec![0]);
assert_eq!(encoder.encode(b"ba"), vec![1, 0]);
}
#[test]
fn test_merge_rank_priority() {
let vocab = vec![
(0, b"a".to_vec()),
(1, b"b".to_vec()),
(2, b"c".to_vec()),
(3, b"ab".to_vec()),
(4, b"bc".to_vec()),
(5, b"abc".to_vec()),
];
let merges = vec![
(0, 1), (3, 2), (1, 2), ];
let (encoder, _) = SentencePieceBPE::from_vocab_and_merges(&vocab, &merges, 3, &Default::default());
assert_eq!(encoder.encode(b"abc"), vec![5]);
}
#[test]
fn test_unicode_char() {
let vocab = vec![
(0, "▁".as_bytes().to_vec()),
(1, "H".as_bytes().to_vec()),
(2, "▁H".as_bytes().to_vec()),
];
let merges = vec![(0, 1)];
let (encoder, _) = SentencePieceBPE::from_vocab_and_merges(&vocab, &merges, 2, &Default::default());
assert_eq!(encoder.encode("▁H".as_bytes()), vec![2]);
assert_eq!(encoder.encode("▁".as_bytes()), vec![0]);
}
#[test]
fn test_radix_heap() {
let mut heap = RadixHeap::new();
heap.push(HeapEntry::new(5, 0, 1, 2));
heap.push(HeapEntry::new(3, 1, 2, 2));
heap.push(HeapEntry::new(7, 2, 3, 2));
heap.push(HeapEntry::new(1, 3, 4, 2));
assert_eq!(heap.pop().unwrap().rank(), 1);
assert_eq!(heap.pop().unwrap().rank(), 3);
assert_eq!(heap.pop().unwrap().rank(), 5);
assert_eq!(heap.pop().unwrap().rank(), 7);
assert!(heap.pop().is_none());
}
#[test]
fn test_radix_heap_tie_breaking() {
let mut heap = RadixHeap::new();
heap.push(HeapEntry::new(5, 3, 4, 2));
heap.push(HeapEntry::new(5, 1, 2, 2));
heap.push(HeapEntry::new(5, 2, 3, 2));
assert_eq!(heap.pop().unwrap().left(), 1);
assert_eq!(heap.pop().unwrap().left(), 2);
assert_eq!(heap.pop().unwrap().left(), 3);
assert!(heap.pop().is_none());
}
#[test]
fn test_encode_with_state() {
let vocab = vec![
(0, b"a".to_vec()),
(1, b"b".to_vec()),
(2, b"ab".to_vec()),
];
let merges = vec![(0, 1)];
let (encoder, _) = SentencePieceBPE::from_vocab_and_merges(&vocab, &merges, 2, &Default::default());
let mut state = EncodeState::new();
assert_eq!(encoder.encode_with_state(b"ab", &mut state), &[2]);
assert_eq!(encoder.encode_with_state(b"a", &mut state), &[0]);
assert_eq!(encoder.encode_with_state(b"ba", &mut state), &[1, 0]);
for _ in 0..5 {
assert_eq!(encoder.encode_with_state(b"ab", &mut state), &[2]);
}
}
#[test]
fn test_encode_chunked() {
let vocab = vec![
(0, "▁".as_bytes().to_vec()),
(1, "a".as_bytes().to_vec()),
(2, "b".as_bytes().to_vec()),
(3, "c".as_bytes().to_vec()),
(4, "▁a".as_bytes().to_vec()),
(5, "▁ab".as_bytes().to_vec()),
];
let merges = vec![
(0, 1), (4, 2), ];
let (encoder, _) = SentencePieceBPE::from_vocab_and_merges(&vocab, &merges, 4, &Default::default());
let mut state = EncodeState::new();
let text = "▁ab▁ab▁ab".as_bytes();
let regular = encoder.encode(text);
let chunked = encoder.encode_chunked(text, &mut state, 6);
assert_eq!(regular, chunked);
}
}