use ahash::AHashMap;
use anyhow::{anyhow, Result};
use base64::{engine::general_purpose::STANDARD, Engine};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use tokenizers::{
decoders::byte_level::ByteLevel as ByteLevelDecoder,
models::bpe::{BpeBuilder, Merges, Vocab},
pre_tokenizers::{
byte_level::ByteLevel,
sequence::Sequence,
split::{Split, SplitPattern},
PreTokenizerWrapper,
},
tokenizer::{normalizer::SplitDelimiterBehavior, Tokenizer},
};
#[allow(dead_code)]
pub fn convert_tiktoken_to_tokenizers<P: AsRef<Path>>(
tokenizer_model_path: P,
) -> Result<Tokenizer> {
let model_bytes = fs::read(&tokenizer_model_path)?;
let (vocab, merges) = extract_vocab_merges_from_model(&model_bytes)?;
let bpe_model = BpeBuilder::new()
.vocab_and_merges(vocab, merges)
.build()
.map_err(|e| anyhow!("Failed to build BPE model: {}", e))?;
let mut tokenizer = Tokenizer::new(bpe_model);
let pattern = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
let split = Split::new(
SplitPattern::Regex(pattern.to_string()),
SplitDelimiterBehavior::Isolated,
false,
)
.map_err(|e| anyhow!("Failed to create split pre-tokenizer: {}", e))?;
let byte_level = ByteLevel::new(
false, true, true, );
let pre_tokenizer = Sequence::new(vec![
PreTokenizerWrapper::Split(split),
PreTokenizerWrapper::ByteLevel(byte_level),
]);
tokenizer.with_pre_tokenizer(Some(pre_tokenizer));
let decoder = ByteLevelDecoder::new(true, false, false);
tokenizer.with_decoder(Some(decoder));
Ok(tokenizer)
}
fn extract_vocab_merges_from_model(model_bytes: &[u8]) -> Result<(Vocab, Merges)> {
let mut bpe_ranks = HashMap::new();
let lines = String::from_utf8_lossy(model_bytes);
for line in lines.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() != 2 {
continue;
}
let token_base64 = parts[0];
let rank_str = parts[1];
let token_bytes = STANDARD.decode(token_base64)?;
let rank: u32 = rank_str.parse()?;
bpe_ranks.insert(token_bytes, rank);
}
let mut vocab = AHashMap::new();
let mut merges = Vec::new();
for (token, rank) in &bpe_ranks {
let token_str = token_bytes_to_string(token);
vocab.insert(token_str, *rank);
if token.len() == 1 {
continue;
}
let mut local = Vec::new();
for index in 1..token.len() {
let piece_l = &token[..index];
let piece_r = &token[index..];
if let (Some(&rank_l), Some(&rank_r)) = (bpe_ranks.get(piece_l), bpe_ranks.get(piece_r))
{
let mut concat = piece_l.to_vec();
concat.extend_from_slice(piece_r);
if bpe_ranks.contains_key(&concat) {
local.push((piece_l.to_vec(), piece_r.to_vec(), *rank, rank_l, rank_r));
}
}
}
local.sort_by_key(|(_, _, _, rank_l, rank_r)| (*rank_l, *rank_r));
for (piece_l, piece_r, rank, _, _) in local {
merges.push((piece_l, piece_r, rank));
}
}
merges.sort_by_key(|(_, _, rank)| *rank);
let merges: Vec<(String, String)> = merges
.into_iter()
.map(|(l, r, _)| (token_bytes_to_string(&l), token_bytes_to_string(&r)))
.collect();
Ok((vocab, merges))
}
pub(super) fn bytes_to_unicode() -> AHashMap<u8, char> {
let mut bs: Vec<u8> = vec![];
bs.extend((b'!'..=b'~').collect::<Vec<_>>());
bs.extend((0xA1u8..=0xACu8).collect::<Vec<_>>());
bs.extend((0xAEu8..=0xFFu8).collect::<Vec<_>>());
let mut cs: Vec<u32> = bs.iter().map(|&b| b as u32).collect();
let mut n: u32 = 0;
for b in 0u8..=255 {
if !bs.contains(&b) {
bs.push(b);
cs.push(256 + n);
n += 1;
}
}
let mut byte_encoder = AHashMap::new();
for (b, c) in bs.iter().zip(cs.iter()) {
byte_encoder.insert(*b, char::from_u32(*c).unwrap());
}
byte_encoder
}
pub(super) fn token_bytes_to_string(bytes: &[u8]) -> String {
let byte_encoder = bytes_to_unicode();
bytes.iter().map(|&b| byte_encoder[&b]).collect()
}
#[cfg(test)]
mod tests {
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use super::*;
#[test]
fn test_bytes_to_unicode() {
let byte_encoder = bytes_to_unicode();
assert_eq!(byte_encoder.len(), 256);
assert_eq!(byte_encoder[&b'h'], 'h');
assert_eq!(byte_encoder[&b'e'], 'e');
assert_eq!(byte_encoder[&b'l'], 'l');
assert_eq!(byte_encoder[&b'o'], 'o');
for b in 0u8..=255 {
assert!(byte_encoder.contains_key(&b));
}
}
#[test]
fn test_token_bytes_to_string() {
let test_bytes = b"hello";
let result = token_bytes_to_string(test_bytes);
assert_eq!(result, "hello");
}
#[test]
fn test_tiktoken_conversion() -> anyhow::Result<()> {
let api = ApiBuilder::new().with_progress(true).build().unwrap();
let api = api.repo(Repo::with_revision(
"EricB/mistralrs_tests".to_string(),
RepoType::Model,
"main".to_string(),
));
let converted_tokenizer = {
let tokenizer_filename = api.get("tokenizer_llama3.model").unwrap();
convert_tiktoken_to_tokenizers(tokenizer_filename).unwrap()
};
let truth_tokenizer = {
let tokenizer_filename = api.get("tokenizer_llama3.json").unwrap();
Tokenizer::from_file(tokenizer_filename).unwrap()
};
let test_cases = vec![
"The quick brown fox",
"Hello, world!",
"123456",
"🦀 Rust",
"Hello, world! \n🚀 (normal) 😶🌫️ (compound emoji, zwj sequence) ✅ (emoji as single token)\n你好世界!\nNǐ hǎo shìjiè!",
];
for test_case in test_cases {
let converted_enc = converted_tokenizer
.encode(test_case, false)
.map_err(|e| anyhow!("Failed to encode '{}': {}", test_case, e))?;
let truth_enc = truth_tokenizer
.encode(test_case, false)
.map_err(|e| anyhow!("Failed to encode '{}': {}", test_case, e))?;
assert!(
!converted_enc.get_ids().is_empty(),
"Converted tokenizer produced empty output for '{test_case}'"
);
assert!(
!truth_enc.get_ids().is_empty(),
"Truth tokenizer produced empty output for '{test_case}'"
);
}
Ok(())
}
}