use crate::vocab::Vocab;
use once_cell::sync::Lazy;
use regex::Regex;
use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::traits::{TokenizedInput, Tokenizer};
#[derive(Debug)]
pub struct TiktokenTokenizer {
vocab: Vocab,
encoder: HashMap<Vec<u8>, usize>,
decoder: HashMap<usize, Vec<u8>>,
special_tokens: HashMap<String, usize>,
pattern: Regex,
cache: RwLock<HashMap<String, Vec<usize>>>,
}
impl Clone for TiktokenTokenizer {
fn clone(&self) -> Self {
Self {
vocab: self.vocab.clone(),
encoder: self.encoder.clone(),
decoder: self.decoder.clone(),
special_tokens: self.special_tokens.clone(),
pattern: self.pattern.clone(),
cache: RwLock::new(HashMap::new()), }
}
}
static TIKTOKEN_PATTERN: Lazy<Regex> = Lazy::new(|| {
Regex::new(
r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+"
).expect("TIKTOKEN_PATTERN regex must be valid")
});
impl TiktokenTokenizer {
pub fn new(
encoder: HashMap<Vec<u8>, usize>,
special_tokens: HashMap<String, usize>,
pattern: Option<Regex>,
) -> Self {
let decoder: HashMap<usize, Vec<u8>> =
encoder.iter().map(|(k, &v)| (v, k.clone())).collect();
let vocab_map: HashMap<String, u32> = encoder
.iter()
.map(|(bytes, &rank)| {
let token = String::from_utf8_lossy(bytes).to_string();
(token, rank as u32)
})
.collect();
Self {
vocab: Vocab::from_map(vocab_map),
encoder,
decoder,
special_tokens,
pattern: pattern.unwrap_or_else(|| TIKTOKEN_PATTERN.clone()),
cache: RwLock::new(HashMap::new()),
}
}
pub fn cl100k_base() -> Self {
let mut encoder = HashMap::new();
let mut special_tokens = HashMap::new();
for i in 0..256 {
encoder.insert(vec![i as u8], i);
}
special_tokens.insert("<|endoftext|>".to_string(), 100257);
special_tokens.insert("<|fim_prefix|>".to_string(), 100258);
special_tokens.insert("<|fim_middle|>".to_string(), 100259);
special_tokens.insert("<|fim_suffix|>".to_string(), 100260);
special_tokens.insert("<|endofprompt|>".to_string(), 100276);
Self::new(encoder, special_tokens, None)
}
pub fn r50k_base() -> Self {
let mut encoder = HashMap::new();
let mut special_tokens = HashMap::new();
for i in 0..256 {
encoder.insert(vec![i as u8], i);
}
special_tokens.insert("<|endoftext|>".to_string(), 50256);
Self::new(encoder, special_tokens, None)
}
pub fn from_tiktoken_file(
encoder_path: &str,
special_tokens_path: Option<&str>,
) -> Result<Self> {
use std::fs::File;
use std::io::{BufRead, BufReader};
let mut encoder = HashMap::new();
let file = File::open(encoder_path).map_err(|e| {
TrustformersError::other(format!(
"Failed to open encoder file {}: {}",
encoder_path, e
))
})?;
let reader = BufReader::new(file);
for (line_num, line) in reader.lines().enumerate() {
let line = line.map_err(|e| {
TrustformersError::other(format!("Failed to read line {}: {}", line_num, e))
})?;
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue; }
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() != 2 {
return Err(TrustformersError::other(format!(
"Invalid encoder format at line {}: expected 'token rank', got '{}'",
line_num + 1,
line
)));
}
let token_bytes = Self::decode_tiktoken_token(parts[0]).map_err(|e| {
TrustformersError::other(format!(
"Failed to decode token at line {}: {}",
line_num + 1,
e
))
})?;
let rank: usize = parts[1].parse().map_err(|e| {
TrustformersError::other(format!("Invalid rank at line {}: {}", line_num + 1, e))
})?;
encoder.insert(token_bytes, rank);
}
let mut special_tokens = HashMap::new();
if let Some(special_path) = special_tokens_path {
let special_file = File::open(special_path).map_err(|e| {
TrustformersError::other(format!(
"Failed to open special tokens file {}: {}",
special_path, e
))
})?;
let special_reader = BufReader::new(special_file);
for (line_num, line) in special_reader.lines().enumerate() {
let line = line.map_err(|e| {
TrustformersError::other(format!(
"Failed to read special tokens line {}: {}",
line_num, e
))
})?;
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() != 2 {
return Err(TrustformersError::other(format!(
"Invalid special token format at line {}: expected 'token rank'",
line_num + 1
)));
}
let token = parts[0].to_string();
let rank: usize = parts[1].parse().map_err(|e| {
TrustformersError::other(format!(
"Invalid special token rank at line {}: {}",
line_num + 1,
e
))
})?;
special_tokens.insert(token, rank);
}
}
let pattern = TIKTOKEN_PATTERN.clone();
Ok(Self::new(encoder, special_tokens, Some(pattern)))
}
fn decode_tiktoken_token(encoded: &str) -> std::result::Result<Vec<u8>, String> {
let decoded = if encoded.starts_with("b'") && encoded.ends_with("'") {
let inner = &encoded[2..encoded.len() - 1];
Self::decode_python_bytes_literal(inner)?
} else {
use base64::{engine::general_purpose::STANDARD, Engine as _};
STANDARD.decode(encoded).map_err(|e| format!("Base64 decode error: {}", e))?
};
Ok(decoded)
}
fn decode_python_bytes_literal(literal: &str) -> std::result::Result<Vec<u8>, String> {
let mut bytes = Vec::new();
let mut chars = literal.chars().peekable();
while let Some(ch) = chars.next() {
match ch {
'\\' => {
match chars.next().ok_or("Unexpected end of escape sequence")? {
'n' => bytes.push(b'\n'),
'r' => bytes.push(b'\r'),
't' => bytes.push(b'\t'),
'\\' => bytes.push(b'\\'),
'\'' => bytes.push(b'\''),
'"' => bytes.push(b'"'),
'x' => {
let hex1 = chars.next().ok_or("Incomplete hex escape")?;
let hex2 = chars.next().ok_or("Incomplete hex escape")?;
let hex_str = format!("{}{}", hex1, hex2);
let byte_val = u8::from_str_radix(&hex_str, 16)
.map_err(|_| format!("Invalid hex escape: \\x{}", hex_str))?;
bytes.push(byte_val);
},
'0'..='7' => {
let mut octal = String::new();
octal.push(chars.peek().copied().unwrap_or('0')); chars.next();
for _ in 0..2 {
if let Some(&next_char) = chars.peek() {
if next_char.is_ascii_digit() && next_char <= '7' {
octal.push(next_char);
chars.next();
} else {
break;
}
}
}
let byte_val = u8::from_str_radix(&octal, 8)
.map_err(|_| format!("Invalid octal escape: \\{}", octal))?;
bytes.push(byte_val);
},
other => return Err(format!("Unknown escape sequence: \\{}", other)),
}
},
_ => {
if ch.is_ascii() {
bytes.push(ch as u8);
} else {
let mut utf8_buf = [0u8; 4];
let utf8_bytes = ch.encode_utf8(&mut utf8_buf).as_bytes();
bytes.extend_from_slice(utf8_bytes);
}
},
}
}
Ok(bytes)
}
fn encode_bytes(&self, piece: &[u8]) -> Vec<usize> {
if piece.len() == 1 {
return vec![self.encoder.get(piece).copied().unwrap_or(0)];
}
let mut word: Vec<Vec<u8>> = piece.iter().map(|&b| vec![b]).collect();
while word.len() > 1 {
let mut min_rank = usize::MAX;
let mut merge_idx = None;
for i in 0..word.len() - 1 {
let mut merged = word[i].clone();
merged.extend_from_slice(&word[i + 1]);
if let Some(&rank) = self.encoder.get(&merged) {
if rank < min_rank {
min_rank = rank;
merge_idx = Some(i);
}
}
}
if let Some(idx) = merge_idx {
let mut merged = word[idx].clone();
merged.extend_from_slice(&word[idx + 1]);
word[idx] = merged;
word.remove(idx + 1);
} else {
break;
}
}
word.into_iter()
.map(|bytes| self.encoder.get(&bytes).copied().unwrap_or(0))
.collect()
}
pub fn encode_text(&self, text: &str) -> Vec<usize> {
if let Ok(cache) = self.cache.read() {
if let Some(cached) = cache.get(text) {
return cached.clone();
}
}
let mut tokens = vec![];
let _remaining_text = text;
for (special_token, &token_id) in &self.special_tokens {
if text.contains(special_token) {
let parts: Vec<&str> = text.split(special_token).collect();
for (i, part) in parts.iter().enumerate() {
if i > 0 {
tokens.push(token_id);
}
if !part.is_empty() {
tokens.extend(self.encode_text_without_special(part));
}
}
if let Ok(mut cache) = self.cache.write() {
cache.insert(text.to_string(), tokens.clone());
}
return tokens;
}
}
let result = self.encode_text_without_special(text);
if let Ok(mut cache) = self.cache.write() {
cache.insert(text.to_string(), result.clone());
}
result
}
fn encode_text_without_special(&self, text: &str) -> Vec<usize> {
let mut tokens = vec![];
for mat in self.pattern.find_iter(text) {
let piece = mat.as_str().as_bytes();
tokens.extend(self.encode_bytes(piece));
}
tokens
}
pub fn decode_tokens(&self, tokens: &[usize]) -> Result<String> {
let mut bytes = vec![];
for &token_id in tokens {
let is_special = self.special_tokens.values().any(|&id| id == token_id);
if is_special {
if let Some((special_str, _)) =
self.special_tokens.iter().find(|(_, &id)| id == token_id)
{
bytes.extend_from_slice(special_str.as_bytes());
}
} else if let Some(token_bytes) = self.decoder.get(&token_id) {
bytes.extend_from_slice(token_bytes);
}
}
String::from_utf8(bytes)
.map_err(|e| TrustformersError::other(format!("Failed to decode UTF-8: {}", e)))
}
pub fn vocab_size(&self) -> usize {
self.encoder.len() + self.special_tokens.len()
}
pub fn special_tokens(&self) -> &HashMap<String, usize> {
&self.special_tokens
}
pub fn is_special_token(&self, token_id: usize) -> bool {
self.special_tokens.values().any(|&id| id == token_id)
}
pub fn encode_with_special_tokens(
&self,
text: &str,
allowed_special: &HashSet<String>,
) -> Vec<usize> {
let mut result = vec![];
let mut start = 0;
while start < text.len() {
let mut found_special = false;
for special_token in allowed_special {
if text[start..].starts_with(special_token) {
if let Some(&token_id) = self.special_tokens.get(special_token) {
result.push(token_id);
start += special_token.len();
found_special = true;
break;
}
}
}
if !found_special {
let mut end = text.len();
for special_token in allowed_special {
if let Some(pos) = text[start..].find(special_token) {
end = end.min(start + pos);
}
}
if start < end {
result.extend(self.encode_text_without_special(&text[start..end]));
start = end;
}
}
}
result
}
}
impl Tokenizer for TiktokenTokenizer {
fn encode(&self, text: &str) -> Result<TokenizedInput> {
let tokens = self.encode_text(text);
let input_ids: Vec<u32> = tokens.into_iter().map(|t| t as u32).collect();
let attention_mask = vec![1u8; input_ids.len()];
Ok(TokenizedInput {
input_ids,
attention_mask,
token_type_ids: None,
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
})
}
fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
let combined = format!("{} {}", text, text2);
self.encode(&combined)
}
fn decode(&self, ids: &[u32]) -> Result<String> {
let tokens: Vec<usize> = ids.iter().map(|&id| id as usize).collect();
self.decode_tokens(&tokens)
}
fn vocab_size(&self) -> usize {
self.vocab.len()
}
fn get_vocab(&self) -> HashMap<String, u32> {
self.vocab.iter().map(|(k, &v)| (k.clone(), v)).collect()
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab.get_token(id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tiktoken_cl100k_base() {
let tokenizer = TiktokenTokenizer::cl100k_base();
assert!(tokenizer.vocab_size() > 250); assert!(tokenizer.special_tokens().contains_key("<|endoftext|>"));
assert!(tokenizer.special_tokens().contains_key("<|fim_prefix|>"));
}
#[test]
fn test_tiktoken_r50k_base() {
let tokenizer = TiktokenTokenizer::r50k_base();
assert!(tokenizer.vocab_size() > 250); assert!(tokenizer.special_tokens().contains_key("<|endoftext|>"));
}
#[test]
fn test_tiktoken_encode_decode() {
let tokenizer = TiktokenTokenizer::cl100k_base();
let text = "Hello, world!";
let tokens = tokenizer.encode_text(text);
assert!(!tokens.is_empty());
let decoded = tokenizer.decode_tokens(&tokens).expect("Operation failed in test");
assert!(decoded.contains("Hello") || decoded.contains("world"));
}
#[test]
fn test_special_tokens() {
let tokenizer = TiktokenTokenizer::cl100k_base();
let text = "Hello <|endoftext|> world";
let tokens = tokenizer.encode_text(text);
assert!(tokens.contains(&100257)); }
#[test]
fn test_tokenizer_trait() {
let tokenizer = TiktokenTokenizer::cl100k_base();
let result = tokenizer.encode("Hello, world!").expect("Encoding failed");
assert!(!result.input_ids.is_empty());
assert_eq!(result.input_ids.len(), result.attention_mask.len());
let decoded = tokenizer.decode(&result.input_ids).expect("Decoding failed");
assert!(!decoded.is_empty());
}
#[test]
fn test_encode_with_special_tokens() {
let tokenizer = TiktokenTokenizer::cl100k_base();
let mut allowed_special = HashSet::new();
allowed_special.insert("<|endoftext|>".to_string());
let text = "Hello <|endoftext|> world";
let tokens = tokenizer.encode_with_special_tokens(text, &allowed_special);
assert!(!tokens.is_empty());
assert!(tokens.contains(&100257)); }
#[test]
fn test_is_special_token() {
let tokenizer = TiktokenTokenizer::cl100k_base();
assert!(tokenizer.is_special_token(100257)); assert!(!tokenizer.is_special_token(0)); }
#[test]
fn test_cache_functionality() {
let tokenizer = TiktokenTokenizer::cl100k_base();
let text = "Hello, world!";
let tokens1 = tokenizer.encode_text(text);
let tokens2 = tokenizer.encode_text(text);
assert_eq!(tokens1, tokens2);
}
}