use crate::errors::{Result, TiktokenError};
use regex::Regex;
use std::collections::HashMap;
pub type Token = u32;
pub type Rank = u32;
#[derive(Clone)]
pub struct CoreBPE {
encoder: HashMap<Vec<u8>, Token>,
decoder: HashMap<Token, Vec<u8>>,
special_tokens_encoder: HashMap<String, Token>,
special_tokens_decoder: HashMap<Token, Vec<u8>>,
regex: Regex,
special_regex: Option<Regex>,
}
impl CoreBPE {
pub fn new(
mergeable_ranks: HashMap<Vec<u8>, Rank>,
special_tokens: HashMap<String, Token>,
pattern: &str,
) -> Result<Self> {
let regex = Regex::new(pattern)?;
let encoder: HashMap<Vec<u8>, Token> = mergeable_ranks;
let decoder: HashMap<Token, Vec<u8>> =
encoder.iter().map(|(bytes, &token)| (token, bytes.clone())).collect();
let special_tokens_decoder: HashMap<Token, Vec<u8>> = special_tokens
.iter()
.map(|(text, &token)| (token, text.as_bytes().to_vec()))
.collect();
let special_regex = if special_tokens.is_empty() {
None
} else {
let escaped_tokens: Vec<String> =
special_tokens.keys().map(|s| regex::escape(s)).collect();
Some(Regex::new(&escaped_tokens.join("|"))?)
};
Ok(CoreBPE {
encoder,
decoder,
special_tokens_encoder: special_tokens,
special_tokens_decoder,
regex,
special_regex,
})
}
pub fn encode_ordinary(&self, text: &str) -> Vec<Token> {
let mut tokens = Vec::new();
for mat in self.regex.find_iter(text) {
let piece = mat.as_str().as_bytes();
if let Some(&token) = self.encoder.get(piece) {
tokens.push(token);
} else {
tokens.extend(self.byte_pair_encode(piece));
}
}
tokens
}
pub fn encode(
&self,
text: &str,
allowed_special: &[&str],
disallowed_special: &[&str],
) -> Result<Vec<Token>> {
if !disallowed_special.is_empty() {
if let Some(ref special_regex) = self.special_regex {
for mat in special_regex.find_iter(text) {
let token_text = mat.as_str();
if disallowed_special.contains(&token_text)
&& !allowed_special.contains(&token_text)
{
return Err(TiktokenError::EncodingError(format!(
"Disallowed special token: {token_text}"
)));
}
}
}
}
let mut tokens = Vec::new();
let mut start = 0;
while start < text.len() {
let mut next_special_start = text.len();
let mut next_special_end = text.len();
let mut found_special = None;
if let Some(ref special_regex) = self.special_regex {
for mat in special_regex.find_iter(&text[start..]) {
let token_text = &text[start + mat.start()..start + mat.end()];
if allowed_special.contains(&token_text) {
next_special_start = start + mat.start();
next_special_end = start + mat.end();
found_special = Some(token_text);
break;
}
}
}
if next_special_start > start {
let ordinary_text = &text[start..next_special_start];
tokens.extend(self.encode_ordinary(ordinary_text));
}
if let Some(special_token) = found_special {
if let Some(&token) = self.special_tokens_encoder.get(special_token) {
tokens.push(token);
}
start = next_special_end;
} else {
break;
}
}
Ok(tokens)
}
pub fn decode_bytes(&self, tokens: &[Token]) -> Result<Vec<u8>> {
let mut result = Vec::new();
for &token in tokens {
if let Some(bytes) = self.decoder.get(&token) {
result.extend_from_slice(bytes);
} else if let Some(bytes) = self.special_tokens_decoder.get(&token) {
result.extend_from_slice(bytes);
} else {
return Err(TiktokenError::InvalidToken(token));
}
}
Ok(result)
}
pub fn decode(&self, tokens: &[Token]) -> Result<String> {
let bytes = self.decode_bytes(tokens)?;
String::from_utf8(bytes).map_err(TiktokenError::from)
}
pub fn decode_single_token_bytes(&self, token: Token) -> Result<&[u8]> {
if let Some(bytes) = self.decoder.get(&token) {
Ok(bytes)
} else if let Some(bytes) = self.special_tokens_decoder.get(&token) {
Ok(bytes)
} else {
Err(TiktokenError::InvalidToken(token))
}
}
fn byte_pair_encode(&self, piece: &[u8]) -> Vec<Token> {
if piece.len() == 1 {
return vec![self.encoder[piece]];
}
let parts = self.byte_pair_merge(piece);
parts
.windows(2)
.map(|window| {
let start = window[0].0;
let end = window[1].0;
self.encoder[&piece[start..end]]
})
.collect()
}
fn byte_pair_merge(&self, piece: &[u8]) -> Vec<(usize, Rank)> {
let mut parts = Vec::with_capacity(piece.len() + 1);
let mut min_rank = (Rank::MAX, usize::MAX);
for i in 0..piece.len() - 1 {
let pair = &piece[i..i + 2];
let rank = self.encoder.get(pair).copied().unwrap_or(Rank::MAX);
if rank < min_rank.0 {
min_rank = (rank, i);
}
parts.push((i, rank));
}
parts.push((piece.len() - 1, Rank::MAX));
parts.push((piece.len(), Rank::MAX));
while min_rank.0 != Rank::MAX {
let i = min_rank.1;
if i > 0 {
parts[i - 1].1 = self.get_pair_rank(piece, &parts, i - 1);
}
parts[i].1 = self.get_pair_rank(piece, &parts, i);
parts.remove(i + 1);
min_rank = (Rank::MAX, usize::MAX);
for (idx, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, idx);
}
}
}
parts
}
fn get_pair_rank(&self, piece: &[u8], parts: &[(usize, Rank)], i: usize) -> Rank {
if i + 3 < parts.len() {
let start = parts[i].0;
let end = parts[i + 3].0;
self.encoder.get(&piece[start..end]).copied().unwrap_or(Rank::MAX)
} else {
Rank::MAX
}
}
pub fn special_tokens(&self) -> &HashMap<String, Token> {
&self.special_tokens_encoder
}
pub fn is_special_token(&self, token: Token) -> bool {
self.special_tokens_decoder.contains_key(&token)
}
pub fn max_token_value(&self) -> Token {
let max_regular = self.decoder.keys().max().copied().unwrap_or(0);
let max_special = self.special_tokens_decoder.keys().max().copied().unwrap_or(0);
max_regular.max(max_special)
}
pub fn vocab_size(&self) -> usize {
self.decoder.len() + self.special_tokens_decoder.len()
}
}
#[derive(Clone)]
pub struct Encoding {
pub name: String,
core: CoreBPE,
}
impl Encoding {
pub fn new(
name: String,
mergeable_ranks: HashMap<Vec<u8>, Rank>,
special_tokens: HashMap<String, Token>,
pattern: &str,
) -> Result<Self> {
let core = CoreBPE::new(mergeable_ranks, special_tokens, pattern)?;
Ok(Encoding { name, core })
}
pub fn encode_ordinary(&self, text: &str) -> Vec<Token> {
self.core.encode_ordinary(text)
}
pub fn encode(
&self,
text: &str,
allowed_special: &[&str],
disallowed_special: &[&str],
) -> Result<Vec<Token>> {
self.core.encode(text, allowed_special, disallowed_special)
}
pub fn decode(&self, tokens: &[Token]) -> Result<String> {
self.core.decode(tokens)
}
pub fn decode_bytes(&self, tokens: &[Token]) -> Result<Vec<u8>> {
self.core.decode_bytes(tokens)
}
pub fn decode_single_token_bytes(&self, token: Token) -> Result<&[u8]> {
self.core.decode_single_token_bytes(token)
}
pub fn special_tokens(&self) -> &HashMap<String, Token> {
self.core.special_tokens()
}
pub fn is_special_token(&self, token: Token) -> bool {
self.core.is_special_token(token)
}
pub fn max_token_value(&self) -> Token {
self.core.max_token_value()
}
pub fn vocab_size(&self) -> usize {
self.core.vocab_size()
}
pub fn count_tokens(&self, text: &str) -> usize {
self.encode_ordinary(text).len()
}
pub fn encode_batch(
&self,
texts: &[&str],
allowed_special: &[&str],
disallowed_special: &[&str],
) -> Result<Vec<Vec<Token>>> {
texts
.iter()
.map(|&text| self.encode(text, allowed_special, disallowed_special))
.collect()
}
pub fn encode_ordinary_batch(&self, texts: &[&str]) -> Vec<Vec<Token>> {
texts.iter().map(|&text| self.encode_ordinary(text)).collect()
}
pub fn decode_batch(&self, token_sequences: &[&[Token]]) -> Result<Vec<String>> {
token_sequences.iter().map(|&tokens| self.decode(tokens)).collect()
}
pub fn decode_bytes_batch(&self, token_sequences: &[&[Token]]) -> Result<Vec<Vec<u8>>> {
token_sequences.iter().map(|&tokens| self.decode_bytes(tokens)).collect()
}
pub fn encode_single_token(&self, text: &str) -> Result<Token> {
let bytes = text.as_bytes();
if let Some(&token) = self.special_tokens().get(text) {
return Ok(token);
}
if let Some(&token) = self.core.encoder.get(bytes) {
return Ok(token);
}
Err(TiktokenError::EncodingError(format!(
"Text '{text}' does not correspond to a single token"
)))
}
pub fn token_byte_values(&self) -> Vec<Vec<u8>> {
let mut values: Vec<Vec<u8>> = self.core.decoder.values().cloned().collect();
values.extend(self.core.special_tokens_decoder.values().cloned());
values.sort();
values
}
pub fn eot_token(&self) -> Option<Token> {
self.special_tokens().get("<|endoftext|>").copied()
}
}