use super::vendor_tiktoken::*;
use anyhow::anyhow;
use anyhow::Result;
use fancy_regex::Regex;
use rustc_hash::FxHashMap as HashMap;
use std::collections::HashSet;
mod sealed {
pub trait Sealed {}
impl Sealed for super::Rank {}
impl Sealed for usize {}
impl Sealed for u64 {}
impl Sealed for i64 {}
}
pub trait FromRank: sealed::Sealed {
fn from_rank(rank: Rank) -> Self;
}
impl FromRank for Rank {
#[inline]
fn from_rank(rank: Rank) -> Self {
rank
}
}
impl FromRank for usize {
#[inline]
fn from_rank(rank: Rank) -> Self {
rank as usize
}
}
impl FromRank for u64 {
#[inline]
fn from_rank(rank: Rank) -> Self {
u64::from(rank)
}
}
impl FromRank for i64 {
#[inline]
fn from_rank(rank: Rank) -> Self {
i64::from(rank)
}
}
impl CoreBPE {
pub fn new(
encoder: HashMap<Vec<u8>, Rank>,
special_tokens_encoder: HashMap<String, Rank>,
pattern: &str,
) -> Result<Self> {
let regex = Regex::new(pattern)?;
let special_regex = {
let parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&parts.join("|"))?
};
let decoder: HashMap<Rank, Vec<u8>> =
encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
assert!(
encoder.len() == decoder.len(),
"Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
);
let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
.iter()
.map(|(k, v)| (*v, k.as_bytes().to_vec()))
.collect();
let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
sorted_token_bytes.sort();
Ok(Self {
encoder,
special_tokens_encoder,
decoder,
special_tokens_decoder,
regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
special_regex_tls: (0..MAX_NUM_THREADS)
.map(|_| special_regex.clone())
.collect(),
sorted_token_bytes,
})
}
pub fn encode_ordinary_as<T: FromRank>(&self, text: &str) -> Vec<T> {
self.encode_ordinary(text)
.into_iter()
.map(T::from_rank)
.collect()
}
pub fn encode_with_special_tokens_as<T: FromRank>(&self, text: &str) -> Vec<T> {
self.encode_with_special_tokens(text)
.into_iter()
.map(T::from_rank)
.collect()
}
pub fn encode_as<T: FromRank>(
&self,
text: &str,
allowed_special: &HashSet<&str>,
) -> (Vec<T>, usize) {
let (tokens, last_piece_token_len) = self.encode(text, allowed_special);
(
tokens.into_iter().map(T::from_rank).collect(),
last_piece_token_len,
)
}
pub fn count_ordinary(&self, text: &str) -> usize {
self.encode_ordinary(text).len()
}
pub fn count(&self, text: &str, allowed_special: &HashSet<&str>) -> usize {
self.encode(text, allowed_special).0.len()
}
pub fn count_with_special_tokens(&self, text: &str) -> usize {
self.encode_with_special_tokens(text).len()
}
pub fn decode(&self, tokens: &[Rank]) -> Result<String> {
match String::from_utf8(self.decode_bytes(tokens)?) {
Ok(text) => Ok(text),
Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)),
}
}
pub fn _decode_native_and_split(
&self,
tokens: Vec<Rank>,
) -> impl Iterator<Item = Vec<u8>> + '_ {
tokens.into_iter().map(|token| {
let token_bytes = self
.decoder
.get(&token)
.unwrap_or_else(|| &self.special_tokens_decoder[&token]);
token_bytes.clone()
})
}
pub fn split_by_token<'a>(
&'a self,
text: &'a str,
use_special_tokens: bool,
) -> Result<Vec<String>> {
self.split_by_token_iter(text, use_special_tokens).collect()
}
pub fn split_by_token_iter<'a>(
&'a self,
text: &'a str,
use_special_tokens: bool,
) -> impl Iterator<Item = Result<String>> + 'a {
let encoded = match use_special_tokens {
true => self.encode_with_special_tokens(text),
false => self.encode_ordinary(text),
};
self._decode_native_and_split(encoded).map(|token| {
Ok(String::from_utf8_lossy(token.as_slice()).to_string())
})
}
pub fn split_by_token_ordinary<'a>(&'a self, text: &'a str) -> Result<Vec<String>> {
self.split_by_token(text, false)
}
pub fn split_by_token_ordinary_iter<'a>(
&'a self,
text: &'a str,
) -> impl Iterator<Item = Result<String>> + 'a {
self.split_by_token_iter(text, false)
}
}