use daggrs::{DoubleArrayAhoCorasick, MatchKind, Trie};
use foldhash::HashMap as FoldHashMap;
use chunk::chunk;
use smallvec::SmallVec;
use std::collections::VecDeque;
use std::thread;
use crate::types::{Split, TokenId};
const PARALLEL_THRESHOLD: usize = 10_000;
const MAX_CACHED_TOKEN_LEN: usize = 16;
const ENCODE_ITER_BUFFER_SIZE: usize = 8;
#[inline(always)]
fn pack_pair(left: TokenId, right: TokenId) -> u64 {
((left as u64) << 32) | (right as u64)
}
#[inline]
fn split_at_boundaries(text: &[u8]) -> Vec<&[u8]> {
let num_cpus = thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(1);
let target_size = text.len() / num_cpus;
chunk(text)
.size(target_size)
.delimiters(b" \n")
.prefix()
.collect()
}
pub struct EncodeIter<'a> {
encoder: &'a BacktrackingBytePairEncoder,
text: &'a [u8],
pos: usize,
buffer: VecDeque<TokenId>,
bitfield: Bitfield,
next_token: Option<TokenId>,
done: bool,
}
impl<'a> EncodeIter<'a> {
pub(crate) fn new(encoder: &'a BacktrackingBytePairEncoder, text: &'a [u8]) -> Self {
let n = text.len();
let next_token = if text.is_empty() {
None
} else {
encoder.next_match(text)
};
Self {
encoder,
text,
pos: 0,
buffer: VecDeque::with_capacity(ENCODE_ITER_BUFFER_SIZE + 1),
bitfield: Bitfield::new(n + 1),
next_token,
done: text.is_empty(),
}
}
fn encode_one_token(&mut self) -> bool {
let Some(mut token) = self.next_token else {
return false;
};
let last = self.buffer.back().copied();
loop {
let token_len = self.encoder.token_len(token);
let end_pos = self.pos + token_len;
let is_reachable = self.bitfield.is_set(end_pos);
let is_compatible = last
.map(|last_token| self.encoder.is_valid_pair(last_token, token))
.unwrap_or(true);
if is_reachable && is_compatible {
self.buffer.push_back(token);
self.pos = end_pos;
self.next_token = self.encoder.next_match(&self.text[self.pos..]);
return true;
} else if let Some(shorter) = self.encoder.next_prefix(token) {
token = shorter;
} else {
self.bitfield.clear(self.pos);
if let Some(last_token) = self.buffer.pop_back() {
self.pos -= self.encoder.token_len(last_token);
self.next_token = Some(last_token);
return false;
} else {
self.next_token = None;
return false;
}
}
}
}
}
impl Iterator for EncodeIter<'_> {
type Item = TokenId;
fn next(&mut self) -> Option<TokenId> {
if self.done {
return self.buffer.pop_front();
}
while self.buffer.len() < ENCODE_ITER_BUFFER_SIZE {
if !self.encode_one_token() {
if self.next_token.is_none() {
self.done = true;
break;
}
}
}
self.buffer.pop_front()
}
}
impl std::iter::FusedIterator for EncodeIter<'_> {}
#[derive(Clone)]
pub struct BacktrackingBytePairEncoder {
split_table: Vec<Split>,
pair_lookup: FoldHashMap<u64, TokenId>,
token_lengths: Vec<u8>,
num_base_tokens: usize,
matcher: DoubleArrayAhoCorasick,
next_prefix_match: Vec<TokenId>,
token_cache: FoldHashMap<Vec<u8>, TokenId>,
}
impl BacktrackingBytePairEncoder {
pub fn from_merges(
merges: &[(TokenId, TokenId)],
base_tokens: &[Vec<u8>],
) -> (Self, Vec<Vec<u8>>) {
Self::from_merges_with_added(merges, base_tokens, &[])
}
pub fn from_vocab_and_merges(
vocab: &[(u32, Vec<u8>)],
merges: &[(TokenId, TokenId)],
num_base_tokens: usize,
) -> (Self, Vec<Vec<u8>>) {
let token_bytes: Vec<Vec<u8>> = vocab.iter().map(|(_, bytes)| bytes.clone()).collect();
let bytes_to_id: FoldHashMap<Vec<u8>, TokenId> = vocab
.iter()
.map(|(id, bytes)| (bytes.clone(), *id))
.collect();
let mut pair_lookup = FoldHashMap::default();
let mut merge_creates: FoldHashMap<TokenId, (TokenId, TokenId)> = FoldHashMap::default();
for &(left, right) in merges.iter() {
let mut merged_bytes = token_bytes[left as usize].clone();
merged_bytes.extend_from_slice(&token_bytes[right as usize]);
if let Some(&merged_id) = bytes_to_id.get(&merged_bytes) {
pair_lookup.insert(pack_pair(left, right), merged_id);
merge_creates.entry(merged_id).or_insert((left, right));
}
}
let mut split_table: Vec<Split> = Vec::with_capacity(vocab.len());
for (id, _) in vocab.iter() {
let id = *id as TokenId;
if let Some(&(left, right)) = merge_creates.get(&id) {
split_table.push(Split::merge(left, right));
} else {
split_table.push(Split::base(id));
}
}
let (matcher, next_prefix_match) = Self::build_matcher_and_prefixes(&token_bytes);
let token_lengths = Self::build_token_lengths(&token_bytes);
let mut token_cache = FoldHashMap::default();
for (token_id, bytes) in token_bytes.iter().enumerate() {
if bytes.len() <= MAX_CACHED_TOKEN_LEN {
token_cache.insert(bytes.clone(), token_id as TokenId);
}
}
let encoder = Self {
split_table,
pair_lookup,
token_lengths,
num_base_tokens,
matcher,
next_prefix_match,
token_cache,
};
(encoder, token_bytes)
}
pub fn from_merges_with_added(
merges: &[(TokenId, TokenId)],
base_tokens: &[Vec<u8>],
added_tokens: &[(u32, Vec<u8>)],
) -> (Self, Vec<Vec<u8>>) {
let num_base_tokens = base_tokens.len();
let mut split_table: Vec<Split> = (0..num_base_tokens as TokenId)
.map(Split::base)
.collect();
let mut token_bytes: Vec<Vec<u8>> = base_tokens.to_vec();
let mut pair_lookup = FoldHashMap::default();
let mut added_sorted: Vec<_> = added_tokens.to_vec();
added_sorted.sort_by_key(|(id, _)| *id);
let mut added_iter = added_sorted.into_iter().peekable();
for &(left, right) in merges.iter() {
let next_id = split_table.len() as TokenId;
while let Some(&(added_id, _)) = added_iter.peek() {
if added_id <= next_id {
let (_, bytes) = added_iter.next().unwrap();
split_table.push(Split::base(split_table.len() as TokenId));
token_bytes.push(bytes);
} else {
break;
}
}
let new_id = split_table.len() as TokenId;
split_table.push(Split::merge(left, right));
pair_lookup.insert(pack_pair(left, right), new_id);
let mut bytes = token_bytes[left as usize].clone();
bytes.extend_from_slice(&token_bytes[right as usize]);
token_bytes.push(bytes);
}
for (_, bytes) in added_iter {
split_table.push(Split::base(split_table.len() as TokenId));
token_bytes.push(bytes);
}
let (matcher, next_prefix_match) = Self::build_matcher_and_prefixes(&token_bytes);
let token_lengths = Self::build_token_lengths(&token_bytes);
let mut token_cache = FoldHashMap::default();
for (token_id, bytes) in token_bytes.iter().enumerate() {
if bytes.len() <= MAX_CACHED_TOKEN_LEN {
token_cache.insert(bytes.clone(), token_id as TokenId);
}
}
let encoder = Self {
split_table,
pair_lookup,
token_lengths,
num_base_tokens,
matcher,
next_prefix_match,
token_cache,
};
(encoder, token_bytes)
}
pub fn from_parts(
split_table: Vec<Split>,
pair_lookup: FoldHashMap<u64, TokenId>,
token_lengths: Vec<u8>,
num_base_tokens: usize,
matcher: DoubleArrayAhoCorasick,
next_prefix_match: Vec<TokenId>,
token_bytes: &[Vec<u8>],
) -> Self {
let mut token_cache = FoldHashMap::default();
for (token_id, bytes) in token_bytes.iter().enumerate() {
if bytes.len() <= MAX_CACHED_TOKEN_LEN {
token_cache.insert(bytes.clone(), token_id as TokenId);
}
}
Self {
split_table,
pair_lookup,
token_lengths,
num_base_tokens,
matcher,
next_prefix_match,
token_cache,
}
}
fn build_matcher_and_prefixes(token_bytes: &[Vec<u8>]) -> (DoubleArrayAhoCorasick, Vec<TokenId>) {
let mut trie = Trie::new();
for (id, bytes) in token_bytes.iter().enumerate() {
trie.add(bytes, id as TokenId);
}
trie.build(MatchKind::LeftmostLongest);
let matcher = trie.compile();
let next_prefix_match: Vec<TokenId> = token_bytes
.iter()
.map(|token| {
if token.len() <= 1 {
u32::MAX
} else {
let prefix = &token[..token.len() - 1];
matcher
.find_iter(prefix)
.next()
.map(|m| m.pattern_id)
.unwrap_or(u32::MAX)
}
})
.collect();
(matcher, next_prefix_match)
}
fn build_token_lengths(token_bytes: &[Vec<u8>]) -> Vec<u8> {
token_bytes
.iter()
.map(|t| t.len().min(255) as u8)
.collect()
}
pub fn split_table(&self) -> &[Split] {
&self.split_table
}
pub fn matcher(&self) -> &DoubleArrayAhoCorasick {
&self.matcher
}
pub fn next_prefix_match_table(&self) -> &[TokenId] {
&self.next_prefix_match
}
#[inline]
pub fn is_valid_pair(&self, mut token1: TokenId, mut token2: TokenId) -> bool {
let mut limit = u32::MAX;
loop {
if let Some(&combined) = self.pair_lookup.get(&pack_pair(token1, token2)) {
if combined < limit {
return false;
}
}
if token1 > token2 {
limit = token1;
let right = self.split_table[token1 as usize].right;
if right == token1 {
limit = token2 + 1;
let left = self.split_table[token2 as usize].left;
if left + 1 == limit {
return true;
}
token2 = left;
} else {
token1 = right;
}
} else {
limit = token2 + 1;
let left = self.split_table[token2 as usize].left;
if left + 1 == limit {
limit = token1;
let right = self.split_table[token1 as usize].right;
if right == limit {
return true;
}
token1 = right;
} else {
token2 = left;
}
}
}
}
#[inline]
pub fn token_len(&self, token: TokenId) -> usize {
self.token_lengths[token as usize] as usize
}
pub fn vocab_size(&self) -> usize {
self.token_lengths.len()
}
pub fn num_base_tokens(&self) -> usize {
self.num_base_tokens
}
pub fn encode(&self, text: &[u8]) -> Vec<TokenId> {
if text.is_empty() {
return Vec::new();
}
if text.len() <= MAX_CACHED_TOKEN_LEN {
if let Some(&token_id) = self.token_cache.get(text) {
return vec![token_id];
}
}
if text.len() < PARALLEL_THRESHOLD {
return self.encode_sequential(text);
}
let chunks = split_at_boundaries(text);
if chunks.len() == 1 {
return self.encode_sequential(chunks[0]);
}
let results: Vec<Vec<TokenId>> = thread::scope(|s| {
let handles: Vec<_> = chunks
.iter()
.map(|chunk| s.spawn(|| self.encode_sequential(chunk)))
.collect();
handles.into_iter().map(|h| h.join().unwrap()).collect()
});
let total: usize = results.iter().map(|v| v.len()).sum();
let mut output = Vec::with_capacity(total);
for chunk in results {
output.extend(chunk);
}
output
}
pub fn encode_iter<'a>(&'a self, text: &'a [u8]) -> EncodeIter<'a> {
EncodeIter::new(self, text)
}
pub fn encode_batch(&self, texts: &[&[u8]]) -> Vec<Vec<TokenId>> {
if texts.is_empty() {
return Vec::new();
}
let num_cpus = thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(1);
if texts.len() <= num_cpus || num_cpus == 1 {
if num_cpus == 1 {
return texts.iter().map(|t| self.encode_sequential(t)).collect();
}
return thread::scope(|s| {
let handles: Vec<_> = texts
.iter()
.map(|text| s.spawn(|| self.encode_sequential(text)))
.collect();
handles.into_iter().map(|h| h.join().unwrap()).collect()
});
}
let chunk_size = (texts.len() + num_cpus - 1) / num_cpus;
thread::scope(|s| {
let handles: Vec<_> = texts
.chunks(chunk_size)
.map(|chunk| {
s.spawn(|| {
chunk
.iter()
.map(|t| self.encode_sequential(t))
.collect::<Vec<_>>()
})
})
.collect();
handles
.into_iter()
.flat_map(|h| h.join().unwrap())
.collect()
})
}
fn encode_sequential(&self, text: &[u8]) -> Vec<TokenId> {
if text.is_empty() {
return Vec::new();
}
if text.len() <= MAX_CACHED_TOKEN_LEN {
if let Some(&token_id) = self.token_cache.get(text) {
return vec![token_id];
}
}
let n = text.len();
let mut tokens: SmallVec<[TokenId; 16]> = SmallVec::new();
let mut bitfield = Bitfield::new(n + 1);
let mut pos = 0;
let mut next_token = self.next_match(&text[pos..]);
while let Some(mut token) = next_token {
let last = tokens.last().copied();
loop {
let token_len = self.token_len(token);
let end_pos = pos + token_len;
let is_reachable = bitfield.is_set(end_pos);
let is_compatible = last
.map(|last_token| self.is_valid_pair(last_token, token))
.unwrap_or(true);
if is_reachable && is_compatible {
tokens.push(token);
pos = end_pos;
next_token = self.next_match(&text[pos..]);
break;
} else if let Some(shorter) = self.next_prefix(token) {
token = shorter;
} else {
bitfield.clear(pos);
if let Some(last_token) = tokens.pop() {
pos -= self.token_len(last_token);
}
next_token = last;
break;
}
}
}
tokens.into_vec()
}
#[inline]
fn next_match(&self, text: &[u8]) -> Option<TokenId> {
self.matcher.find_iter(text).next().map(|m| m.pattern_id)
}
#[inline]
fn next_prefix(&self, token: TokenId) -> Option<TokenId> {
let prefix = self.next_prefix_match[token as usize];
if prefix == u32::MAX {
None
} else {
Some(prefix)
}
}
}
struct Bitfield {
bits: Vec<u64>,
}
impl Bitfield {
fn new(size: usize) -> Self {
let num_words = (size + 63) / 64;
Self {
bits: vec![u64::MAX; num_words],
}
}
#[inline]
fn clear(&mut self, pos: usize) {
let word = pos / 64;
let bit = pos % 64;
self.bits[word] &= !(1 << bit);
}
#[inline]
fn is_set(&self, pos: usize) -> bool {
let word = pos / 64;
let bit = pos % 64;
(self.bits[word] >> bit) & 1 != 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decoder::VocabDecoder;
#[test]
fn test_from_merges() {
let base_tokens = vec![vec![b'a'], vec![b'b'], vec![b'c']];
let merges = vec![(0, 1), (3, 2)];
let (encoder, token_bytes) = BacktrackingBytePairEncoder::from_merges(&merges, &base_tokens);
let decoder = VocabDecoder::new(token_bytes);
assert_eq!(encoder.vocab_size(), 5);
assert_eq!(encoder.num_base_tokens(), 3);
assert_eq!(decoder.token_to_bytes(0), b"a");
assert_eq!(decoder.token_to_bytes(3), b"ab");
assert_eq!(decoder.token_to_bytes(4), b"abc");
}
#[test]
fn test_is_valid_pair() {
let base_tokens = vec![vec![b'a'], vec![b'b'], vec![b'c']];
let merges = vec![(0, 1)];
let (encoder, _) = BacktrackingBytePairEncoder::from_merges(&merges, &base_tokens);
assert!(!encoder.is_valid_pair(0, 1));
assert!(encoder.is_valid_pair(3, 2));
assert!(encoder.is_valid_pair(1, 2));
}
#[test]
fn test_encode_merged_token() {
let base_tokens = vec![vec![b'a'], vec![b'b'], vec![b'c']];
let merges = vec![(0, 1)];
let (encoder, _) = BacktrackingBytePairEncoder::from_merges(&merges, &base_tokens);
assert_eq!(encoder.encode(b"ab"), vec![3]);
assert_eq!(encoder.encode(b"abc"), vec![3, 2]);
}
#[test]
fn test_early_exit() {
let base_tokens = vec![vec![b'a'], vec![b'b'], vec![b'c']];
let merges = vec![(0, 1), (3, 2)];
let (encoder, _) = BacktrackingBytePairEncoder::from_merges(&merges, &base_tokens);
assert_eq!(encoder.encode(b"a"), vec![0]);
assert_eq!(encoder.encode(b"ab"), vec![3]);
assert_eq!(encoder.encode(b"abc"), vec![4]);
}
#[test]
fn test_encode_decode_roundtrip() {
let base_tokens = vec![vec![b'a'], vec![b'b'], vec![b'c'], vec![b'd']];
let merges = vec![(0, 1), (2, 3), (4, 5)];
let (encoder, token_bytes) = BacktrackingBytePairEncoder::from_merges(&merges, &base_tokens);
let decoder = VocabDecoder::new(token_bytes);
for text in [b"abcd".as_slice(), b"ab", b"cd", b"abcdabcd", b"a", b""] {
let encoded = encoder.encode(text);
let decoded = decoder.decode(&encoded);
assert_eq!(decoded, text);
}
}
#[test]
fn test_encode_iter_matches_encode() {
let base_tokens = vec![vec![b'a'], vec![b'b'], vec![b'c'], vec![b'd']];
let merges = vec![(0, 1), (2, 3), (4, 5)];
let (encoder, _) = BacktrackingBytePairEncoder::from_merges(&merges, &base_tokens);
for text in [b"".as_slice(), b"a", b"ab", b"abcd", b"abcdabcdabcdabcdabcd"] {
let encoded = encoder.encode(text);
let iter_encoded: Vec<_> = encoder.encode_iter(text).collect();
assert_eq!(encoded, iter_encoded);
}
}
}