#![allow(clippy::borrow_deref_ref)]
use std::collections::HashSet;
use std::thread;
use anyhow::anyhow;
use anyhow::Result;
use fancy_regex::Regex;
use rustc_hash::FxHashMap as HashMap;
#[cfg(feature = "python")]
use pyo3::exceptions;
#[cfg(feature = "python")]
use pyo3::prelude::*;
#[cfg(feature = "python")]
use pyo3::types::{PyBytes, PyList, PyTuple};
#[cfg(feature = "python")]
use pyo3::PyResult;
fn _byte_pair_merge<T>(
piece: &[u8],
ranks: &HashMap<Vec<u8>, usize>,
f: impl Fn(std::ops::Range<usize>) -> T,
) -> Vec<T> {
let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect();
macro_rules! get_rank {
($start_idx:expr, $skip:expr) => {{
let start_idx: usize = $start_idx;
let skip: usize = $skip;
if (start_idx + skip + 2) < parts.len() {
ranks
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
.map(|r| *r)
} else {
None
}
}};
($idx:expr) => {{
get_rank!($idx, 0)
}};
}
for i in 0..parts.len() - 2 {
match get_rank!(i) {
Some(rank) => {
debug_assert!(rank != usize::MAX);
parts[i].1 = rank;
}
None => {
continue;
}
};
}
loop {
if parts.len() == 1 {
break;
}
let mut min_rank: (usize, usize) = (usize::MAX, 0);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
}
if min_rank.0 != usize::MAX {
let i = min_rank.1;
parts[i].1 = get_rank!(i, 1).unwrap_or(usize::MAX);
if i > 0 {
parts[i - 1].1 = get_rank!(i - 1, 1).unwrap_or(usize::MAX);
}
parts.remove(i + 1);
} else {
break;
}
}
let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1);
for i in 0..parts.len() - 1 {
out.push(f(parts[i].0..parts[i + 1].0));
}
out
}
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
if piece.len() == 1 {
return vec![ranks[piece]];
}
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
}
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
if piece.len() == 1 {
return vec![piece];
}
_byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end])
}
use std::num::NonZeroU64;
pub struct FakeThreadId(NonZeroU64);
fn hash_current_thread() -> usize {
const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x = unsafe {
std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
};
u64::from(x) as usize
}
const MAX_NUM_THREADS: usize = 128;
#[cfg(feature = "python")]
#[pyclass]
#[derive(Clone)]
pub struct CoreBPE {
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
decoder: HashMap<usize, Vec<u8>>,
special_tokens_decoder: HashMap<usize, Vec<u8>>,
regex_tls: Vec<Regex>,
special_regex_tls: Vec<Regex>,
sorted_token_bytes: Vec<Vec<u8>>,
}
#[cfg(not(feature = "python"))]
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct CoreBPE {
pub encoder: HashMap<Vec<u8>, usize>,
pub special_tokens_encoder: HashMap<String, usize>,
pub decoder: HashMap<usize, Vec<u8>>,
pub special_tokens_decoder: HashMap<usize, Vec<u8>>,
pub regex_tls: Vec<Regex>,
pub special_regex_tls: Vec<Regex>,
pub sorted_token_bytes: Vec<Vec<u8>>,
}
impl CoreBPE {
fn _get_tl_regex(&self) -> &Regex {
&self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}
fn _get_tl_special_regex(&self) -> &Regex {
&self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
}
fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
let mut ret = Vec::with_capacity(tokens.len() * 2);
for token in tokens {
let token_bytes = self
.decoder
.get(token)
.unwrap_or_else(|| &self.special_tokens_decoder[token]);
ret.extend(token_bytes);
}
ret
}
#[allow(clippy::needless_lifetimes)] fn _decode_native_and_split<'a>(
&'a self,
tokens: Vec<usize>,
) -> impl Iterator<Item = Vec<u8>> + '_ {
tokens.into_iter().map(move |token| {
let token_bytes = self
.decoder
.get(&token)
.unwrap_or_else(|| &self.special_tokens_decoder[&token]);
token_bytes.clone()
})
}
fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
let regex = self._get_tl_regex();
let mut ret = vec![];
for mat in regex.find_iter(text) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
ret.push(*token);
continue;
}
ret.extend(&byte_pair_encode(piece, &self.encoder));
}
ret
}
fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
let special_regex = self._get_tl_special_regex();
let regex = self._get_tl_regex();
let mut ret = vec![];
let mut start = 0;
let mut last_piece_token_len = 0;
loop {
let mut next_special;
let mut start_find = start;
loop {
next_special = special_regex.find_from_pos(text, start_find).unwrap();
match next_special {
Some(m) => {
if allowed_special.contains(&text[m.start()..m.end()]) {
break;
}
start_find = m.start() + 1;
}
None => break,
}
}
let end = next_special.map_or(text.len(), |m| m.start());
for mat in regex.find_iter(&text[start..end]) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
last_piece_token_len = 1;
ret.push(*token);
continue;
}
let tokens = byte_pair_encode(piece, &self.encoder);
last_piece_token_len = tokens.len();
ret.extend(&tokens);
}
match next_special {
Some(m) => {
let piece = m.as_str();
let token = self.special_tokens_encoder[piece];
ret.push(token);
start = m.end();
last_piece_token_len = 0;
}
None => break,
}
}
(ret, last_piece_token_len)
}
fn _increase_last_piece_token_len(
&self,
tokens: Vec<usize>,
mut last_piece_token_len: usize,
) -> (Vec<usize>, usize) {
{
let token_is_all_space = |token| {
self.decoder
.get(token)
.map(|token_bytes| {
token_bytes
.iter()
.rev()
.all(|&b| [b' ', b'\n', b'\t'].contains(&b))
})
.unwrap_or(false)
};
if last_piece_token_len > 0
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
{
while (last_piece_token_len < tokens.len())
&& token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
{
last_piece_token_len += 1;
}
}
}
debug_assert!(last_piece_token_len <= tokens.len());
(tokens, last_piece_token_len)
}
fn _encode_unstable_native(
&self,
text: &str,
allowed_special: &HashSet<&str>,
) -> (Vec<usize>, HashSet<Vec<usize>>) {
let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
if last_piece_token_len == 0 {
return (tokens, HashSet::new());
}
let (mut tokens, last_piece_token_len) =
self._increase_last_piece_token_len(tokens, last_piece_token_len);
let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
tokens.truncate(tokens.len() - last_piece_token_len);
let mut completions = HashSet::new();
if unstable_bytes.is_empty() {
return (tokens, completions);
}
let mut point = self
.sorted_token_bytes
.partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(&unstable_bytes)
{
completions.insert(vec![
self.encoder[self.sorted_token_bytes[point].as_slice()],
]);
point += 1;
}
for i in 1..unstable_bytes.len() {
let prefix = &unstable_bytes[..i];
let suffix = &unstable_bytes[i..];
let mut point = self
.sorted_token_bytes
.partition_point(|x| x.as_slice() < suffix);
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(suffix)
{
let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
let encoded = match std::str::from_utf8(&possibility) {
Ok(s) => self._encode_ordinary_native(s),
Err(_) => byte_pair_encode(&possibility, &self.encoder),
};
let mut seq = Vec::new();
let mut seq_len = 0;
for token in encoded {
seq.push(token);
seq_len += self.decoder[&token].len();
if seq_len >= unstable_bytes.len() {
break;
}
}
completions.insert(seq);
point += 1;
}
}
if unstable_bytes.len() > 1 {
let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
if unstable_bytes.len() - last_decoded.1 > 0
&& last_decoded.0.map_or(false, |c| c.is_whitespace())
{
let mut reencoded = byte_pair_encode(
&unstable_bytes[..unstable_bytes.len() - last_decoded.1],
&self.encoder,
);
reencoded.extend(byte_pair_encode(
&unstable_bytes[unstable_bytes.len() - last_decoded.1..],
&self.encoder,
));
completions.insert(reencoded);
}
}
(tokens, completions)
}
}
impl CoreBPE {
#[cfg(not(feature = "python"))]
pub fn new(
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
pattern: &str,
) -> Result<Self> {
let regex = Regex::new(pattern).map_err(|e| anyhow!(e.to_string()))?;
let special_regex = {
let _parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&_parts.join("|")).map_err(|e| anyhow!(e.to_string()))?
};
let decoder: HashMap<usize, Vec<u8>> =
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
assert!(encoder.len() == decoder.len());
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
.iter()
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
.collect();
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort();
Ok(CoreBPE {
encoder,
special_tokens_encoder,
decoder,
special_tokens_decoder,
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
special_regex_tls: (0..MAX_NUM_THREADS)
.map(|_| special_regex.clone())
.collect(),
sorted_token_bytes,
})
}
pub fn encode_ordinary(&self, text: &str) -> Vec<usize> {
self._encode_ordinary_native(text)
}
pub fn encode(&self, text: &str, allowed_special: HashSet<&str>) -> Vec<usize> {
self._encode_native(text, &allowed_special).0
}
pub fn encode_with_special_tokens(&self, text: &str) -> Vec<usize> {
let allowed_special = self
.special_tokens_encoder
.keys()
.map(|s| s.as_str())
.collect();
self._encode_native(text, &allowed_special).0
}
pub fn decode(&self, tokens: Vec<usize>) -> Result<String> {
match String::from_utf8(self._decode_native(&tokens)) {
Ok(text) => Ok(text),
Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)),
}
}
pub fn split_by_token<'a>(
&'a self,
text: &'a str,
use_special_tokens: bool,
) -> Result<Vec<String>> {
self.split_by_token_iter(text, use_special_tokens).collect()
}
pub fn split_by_token_iter<'a>(
&'a self,
text: &'a str,
use_special_tokens: bool,
) -> impl Iterator<Item = Result<String>> + 'a {
let encoded = match use_special_tokens {
true => self.encode_with_special_tokens(text),
false => self.encode_ordinary(text),
};
self._decode_native_and_split(encoded).map(|token| {
Ok(String::from_utf8_lossy(token.as_slice()).to_string())
})
}
pub fn split_by_token_ordinary<'a>(&'a self, text: &'a str) -> Result<Vec<String>> {
self.split_by_token(text, false)
}
pub fn split_by_token_ordinary_iter<'a>(
&'a self,
text: &'a str,
) -> impl Iterator<Item = Result<String>> + 'a {
self.split_by_token_iter(text, false)
}
}
#[cfg(feature = "python")]
#[pymethods]
impl CoreBPE {
#[new]
pub fn new(
encoder: HashMap<Vec<u8>, usize>,
special_tokens_encoder: HashMap<String, usize>,
pattern: &str,
) -> PyResult<Self> {
let regex = Regex::new(pattern)
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
let special_regex = {
let _parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&_parts.join("|"))
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?
};
let decoder: HashMap<usize, Vec<u8>> =
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
assert!(encoder.len() == decoder.len());
let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
.iter()
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
.collect();
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort();
Ok(CoreBPE {
encoder,
special_tokens_encoder,
decoder,
special_tokens_decoder,
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
special_regex_tls: (0..MAX_NUM_THREADS)
.map(|_| special_regex.clone())
.collect(),
sorted_token_bytes,
})
}
fn encode_with_unstable(
&self,
py: Python,
text: &str,
allowed_special: HashSet<&str>,
) -> Py<PyTuple> {
let (tokens, completions) =
py.allow_threads(|| self._encode_unstable_native(text, &allowed_special));
let py_completions =
PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..])));
(tokens, py_completions).into_py(py)
}
fn encode_single_token(&self, piece: &[u8]) -> PyResult<usize> {
if let Some(token) = self.encoder.get(piece).copied() {
return Ok(token);
}
if let Ok(piece_str) = std::str::from_utf8(piece) {
if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() {
return Ok(token);
}
}
Err(PyErr::new::<exceptions::PyKeyError, _>(piece.to_owned()))
}
pub fn encode_single_piece(&self, piece: &[u8]) -> Vec<usize> {
if let Some(token) = self.encoder.get(piece) {
return vec![*token];
}
byte_pair_encode(piece, &self.encoder)
}
pub fn decode_bytes(&self, py: Python, tokens: Vec<usize>) -> Py<PyBytes> {
let bytes = py.allow_threads(|| self._decode_native(&tokens));
PyBytes::new(py, &bytes).into()
}
pub fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult<Py<PyBytes>> {
if let Some(bytes) = self.decoder.get(&token) {
return Ok(PyBytes::new(py, bytes).into());
}
if let Some(bytes) = self.special_tokens_decoder.get(&token) {
return Ok(PyBytes::new(py, bytes).into());
}
Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
}
pub fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
self.sorted_token_bytes
.iter()
.map(|x| PyBytes::new(py, x).into())
.collect()
}
}
#[cfg(feature = "python")]
#[pymodule]
fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<CoreBPE>()?;
Ok(())
}
#[cfg(test)]
mod tests {
use rustc_hash::FxHashMap as HashMap;
use super::byte_pair_split;
#[test]
fn very_simple_test() {
let mut ranks = HashMap::default();
ranks.insert(b"ab".to_vec(), 1);
ranks.insert(b"cd".to_vec(), 2);
let res = byte_pair_split(b"abcd", &ranks);
assert_eq!(res, vec![b"ab", b"cd"]);
}
}