#[rustfmt::skip]
use std::collections::HashSet;
use std::num::NonZeroU64;
use std::thread;
use fancy_regex::Regex;
use rustc_hash::FxHashMap as HashMap;
pub type Rank = u32;
use std::collections::BinaryHeap;
#[derive(Eq, PartialEq, Clone, Copy)]
struct Merge {
start: usize,
rank: Rank,
}
impl Ord for Merge {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.rank
.cmp(&self.rank)
.then_with(|| other.start.cmp(&self.start))
}
}
impl PartialOrd for Merge {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
struct State {
prev: usize,
end: usize,
next_end: usize,
next_rank: Rank,
cur_rank: Rank,
}
fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<Rank> {
let mut state = Vec::with_capacity(piece.len());
state.push(State {
prev: usize::MAX,
end: 1,
next_end: 2,
next_rank: Rank::MAX,
cur_rank: Rank::MAX,
});
let mut heap = BinaryHeap::with_capacity(piece.len());
for i in 0..piece.len() - 1 {
if let Some(&rank) = ranks.get(&piece[i..i + 2]) {
heap.push(Merge { start: i, rank });
state[i].next_rank = rank;
}
state.push(State {
prev: i,
end: i + 2,
next_end: i + 3,
next_rank: Rank::MAX,
cur_rank: Rank::MAX,
});
}
let potential_merge = {
#[inline(always)]
|state: &mut Vec<State>,
heap: &mut BinaryHeap<Merge>,
start: usize,
next_end_item: usize| {
state[start].next_end = next_end_item;
state[start].next_rank = Rank::MAX; if next_end_item <= piece.len()
&& let Some(&rank) = ranks.get(&piece[start..next_end_item])
{
heap.push(Merge { start, rank });
state[start].next_rank = rank;
}
}
};
while let Some(left) = heap.pop() {
if left.rank == Rank::MAX {
break;
}
if left.rank != state[left.start].next_rank {
continue; }
let left_start = left.start;
let right_start = state[left_start].end;
let right_end = state[left_start].next_end;
debug_assert!(right_end == state[right_start].end);
let right_next_end = state[right_start].next_end;
state[left_start].cur_rank = state[left_start].next_rank;
state[left_start].end = right_end;
potential_merge(&mut state, &mut heap, left_start, right_next_end);
if right_end < state.len() {
state[right_end].prev = left_start;
}
if left_start > 0 {
let prev_start = state[left_start].prev;
potential_merge(&mut state, &mut heap, prev_start, right_end);
}
state[right_start].next_rank = Rank::MAX;
}
let mut result = Vec::new();
let mut i = 0;
while i < state.len() {
if state[i].cur_rank != Rank::MAX {
result.push(state[i].cur_rank);
} else {
result.push(ranks[&piece[i..state[i].end]]);
}
i = state[i].end;
}
result
}
fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
let mut parts = Vec::with_capacity(piece.len() + 1);
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
for i in 0..piece.len() - 1 {
let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
if rank < min_rank.0 {
min_rank = (rank, i);
}
parts.push((i, rank));
}
parts.push((piece.len() - 1, Rank::MAX));
parts.push((piece.len(), Rank::MAX));
let get_rank = {
#[inline(always)]
|parts: &Vec<(usize, Rank)>, i: usize| {
if (i + 3) < parts.len() {
*ranks
.get(&piece[parts[i].0..parts[i + 3].0])
.unwrap_or(&Rank::MAX)
} else {
Rank::MAX
}
}
};
while min_rank.0 != Rank::MAX {
let i = min_rank.1;
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1);
}
parts[i].1 = get_rank(&parts, i);
parts.remove(i + 1);
min_rank = (Rank::MAX, usize::MAX);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
}
}
parts
}
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
let piece_len = piece.len();
if piece_len == 1 {
return vec![ranks[piece]];
}
if piece_len < 100 {
return _byte_pair_merge(ranks, piece)
.windows(2)
.map(|part| ranks[&piece[part[0].0..part[1].0]])
.collect();
}
_byte_pair_merge_large(ranks, piece)
}
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
assert!(piece.len() > 1);
_byte_pair_merge(ranks, piece)
.windows(2)
.map(|part| &piece[part[0].0..part[1].0])
.collect()
}
pub struct FakeThreadId(NonZeroU64);
fn hash_current_thread() -> usize {
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x = unsafe {
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
};
u64::from(x) as usize
}
#[derive(Debug, Clone)]
pub struct DecodeKeyError {
pub token: Rank,
}
impl std::fmt::Display for DecodeKeyError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Invalid token for decoding: {}", self.token)
}
}
impl std::error::Error for DecodeKeyError {}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct DecodeError {
pub message: String,
}
impl std::fmt::Display for DecodeError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Could not decode tokens: {}", self.message)
}
}
impl std::error::Error for DecodeError {}
#[derive(Debug, Clone)]
pub struct EncodeError {
pub message: String,
}
impl std::fmt::Display for EncodeError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Could not encode string: {}", self.message)
}
}
impl std::error::Error for EncodeError {}
pub const MAX_NUM_THREADS: usize = 128;
#[derive(Clone)]
pub struct CoreBPE {
pub(crate) encoder: HashMap<Vec<u8>, Rank>,
pub(crate) special_tokens_encoder: HashMap<String, Rank>,
pub(crate) decoder: HashMap<Rank, Vec<u8>>,
pub(crate) special_tokens_decoder: HashMap<Rank, Vec<u8>>,
pub(crate) regex_tls: Vec<Regex>,
pub(crate) special_regex_tls: Vec<Regex>,
#[allow(dead_code)]
pub(crate) sorted_token_bytes: Vec<Vec<u8>>,
}
impl CoreBPE {
fn _get_tl_regex(&self) -> &Regex {
&self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}
fn _get_tl_special_regex(&self) -> &Regex {
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}
pub fn decode_bytes(&self, tokens: &[Rank]) -> Result<Vec<u8>, DecodeKeyError> {
let mut ret = Vec::with_capacity(tokens.len() * 2);
for &token in tokens {
let token_bytes = match self.decoder.get(&token) {
Some(bytes) => bytes,
None => self
.special_tokens_decoder
.get(&token)
.ok_or(DecodeKeyError { token })?,
};
ret.extend(token_bytes);
}
Ok(ret)
}
pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
let regex = self._get_tl_regex();
let mut ret = vec![];
for mat in regex.find_iter(text) {
let piece = mat.unwrap().as_str().as_bytes();
match self.encoder.get(piece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
}
}
ret
}
pub fn encode(
&self,
text: &str,
allowed_special: &HashSet<&str>,
) -> Result<(Vec<Rank>, usize), EncodeError> {
let special_regex = self._get_tl_special_regex();
let regex = self._get_tl_regex();
let mut ret = vec![];
let mut start = 0;
let mut last_piece_token_len = 0;
loop {
let mut next_special;
let mut start_find = start;
loop {
next_special = special_regex.find_from_pos(text, start_find).unwrap();
match next_special {
Some(m) => {
if allowed_special.contains(&text[m.start()..m.end()]) {
break;
}
start_find = m.start() + 1;
}
None => break,
}
}
let end = next_special.map_or(text.len(), |m| m.start());
for mat_res in regex.find_iter(&text[start..end]) {
let mat = match mat_res {
Ok(m) => m,
Err(e) => {
return Err(EncodeError {
message: format!("Regex error while tokenizing: {e}"),
});
}
};
let piece = mat.as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
last_piece_token_len = 1;
ret.push(*token);
continue;
}
let tokens = byte_pair_encode(piece, &self.encoder);
last_piece_token_len = tokens.len();
ret.extend(&tokens);
}
match next_special {
Some(m) => {
let piece = m.as_str();
let token = self.special_tokens_encoder[piece];
ret.push(token);
start = m.end();
last_piece_token_len = 0;
}
None => break,
}
}
Ok((ret, last_piece_token_len))
}
fn _increase_last_piece_token_len(
&self,
tokens: Vec<Rank>,
mut last_piece_token_len: usize,
) -> (Vec<Rank>, usize) {
{
let token_is_all_space = |token| {
self.decoder
.get(token)
.map(|token_bytes| token_bytes.iter().rev().all(|&b| b" \n\t".contains(&b)))
.unwrap_or(false)
};
if last_piece_token_len > 0
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
{
while (last_piece_token_len < tokens.len())
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
{
last_piece_token_len += 1;
}
}
}
debug_assert!(last_piece_token_len <= tokens.len());
(tokens, last_piece_token_len)
}
pub fn _encode_unstable_native(
&self,
text: &str,
allowed_special: &HashSet<&str>,
) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
let (tokens, last_piece_token_len) = self.encode(text, allowed_special).unwrap();
if last_piece_token_len == 0 {
return (tokens, HashSet::new());
}
let (mut tokens, last_piece_token_len) =
self._increase_last_piece_token_len(tokens, last_piece_token_len);
let unstable_bytes = self
.decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
.unwrap();
tokens.truncate(tokens.len() - last_piece_token_len);
let mut completions = HashSet::new();
if unstable_bytes.is_empty() {
return (tokens, completions);
}
let mut point = self
.sorted_token_bytes
.partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(&unstable_bytes)
{
completions.insert(vec![
self.encoder[self.sorted_token_bytes[point].as_slice()],
]);
point += 1;
}
for i in 1..unstable_bytes.len() {
let prefix = &unstable_bytes[..i];
let suffix = &unstable_bytes[i..];
let mut point = self
.sorted_token_bytes
.partition_point(|x| x.as_slice() < suffix);
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(suffix)
{
let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
let encoded = match std::str::from_utf8(&possibility) {
Ok(s) => self.encode_ordinary(s),
Err(_) => byte_pair_encode(&possibility, &self.encoder),
};
let mut seq = Vec::new();
let mut seq_len = 0;
for token in encoded {
seq.push(token);
seq_len += self.decoder[&token].len();
if seq_len >= unstable_bytes.len() {
break;
}
}
completions.insert(seq);
point += 1;
}
}
if unstable_bytes.len() > 1 {
let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
if unstable_bytes.len() - last_decoded.1 > 0
&& last_decoded.0.is_some_and(|c| c.is_whitespace())
{
let mut reencoded = byte_pair_encode(
&unstable_bytes[..unstable_bytes.len() - last_decoded.1],
&self.encoder,
);
reencoded.extend(byte_pair_encode(
&unstable_bytes[unstable_bytes.len() - last_decoded.1..],
&self.encoder,
));
completions.insert(reencoded);
}
}
(tokens, completions)
}
pub fn special_tokens(&self) -> HashSet<&str> {
self.special_tokens_encoder
.keys()
.map(|s| s.as_str())
.collect()
}
pub fn encode_with_special_tokens(&self, text: &str) -> Vec<Rank> {
let allowed_special = self.special_tokens();
self.encode(text, &allowed_special).unwrap().0
}
}
#[cfg(test)]
mod tests {
use rustc_hash::FxHashMap as HashMap;
use crate::{Rank, byte_pair_split};
fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
HashMap::from_iter([(b"ab".to_vec(), 0), (b"cd".to_vec(), 1)])
}
#[test]
fn test_simple_characters() {
let ranks = setup_ranks();
let res = byte_pair_split(b"abcd", &ranks);
assert_eq!(res, vec![b"ab", b"cd"]);
}
#[test]
fn test_repeated_characters() {
let ranks = setup_ranks();
let res = byte_pair_split(b"abab", &ranks);
assert_eq!(res, vec![b"ab", b"ab"]);
}
}