use base64::Engine;
use rustc_hash::FxHashMap;
use crate::bpe::CoreBpe;
const CL100K_BASE_DATA: &[u8] = include_bytes!("encodings/cl100k_base.tiktoken.zst");
const O200K_BASE_DATA: &[u8] = include_bytes!("encodings/o200k_base.tiktoken.zst");
const P50K_BASE_DATA: &[u8] = include_bytes!("encodings/p50k_base.tiktoken.zst");
const R50K_BASE_DATA: &[u8] = include_bytes!("encodings/r50k_base.tiktoken.zst");
const LLAMA3_DATA: &[u8] = include_bytes!("encodings/llama3.tiktoken.zst");
const DEEPSEEK_V3_DATA: &[u8] = include_bytes!("encodings/deepseek_v3.tiktoken.zst");
const QWEN2_DATA: &[u8] = include_bytes!("encodings/qwen2.tiktoken.zst");
const MISTRAL_V3_DATA: &[u8] = include_bytes!("encodings/mistral_v3.tiktoken.zst");
const CL100K_PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+";
const O200K_PATTERN: &str = concat!(
r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+",
r"(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
r"|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*",
r"(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
r"|\p{N}{1,3}",
r"| ?[^\s\p{L}\p{N}]+[\r\n]*",
r"|\s*[\r\n]+",
r"|\s+",
);
const P50K_PATTERN: &str = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+";
const LLAMA3_PATTERN: &str = CL100K_PATTERN;
const DEEPSEEK_V3_PATTERN: &str = concat!(
r"\p{N}{1,3}",
r"|[一-龥\x{3040}-\x{309F}\x{30A0}-\x{30FF}]+",
r"|[!-/:-@\[-`{-~][A-Za-z]+",
r"|[^\r\n\p{L}\p{P}\p{S}]?[\p{L}\p{M}]+",
r"| ?[\p{P}\p{S}]+[\r\n]*",
r"|\s*[\r\n]+",
r"|\s+",
r"|[\s\S]",
);
const QWEN2_PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+";
const MISTRAL_V3_PATTERN: &str = CL100K_PATTERN;
pub(crate) fn parse_tiktoken_data(compressed: &[u8]) -> FxHashMap<Vec<u8>, u32> {
let mut decoder =
ruzstd::decoding::StreamingDecoder::new(compressed).expect("zstd decompression failed");
let mut data = Vec::new();
std::io::Read::read_to_end(&mut decoder, &mut data).expect("zstd decompression failed");
parse_tiktoken_lines(&data)
}
fn parse_tiktoken_lines(data: &[u8]) -> FxHashMap<Vec<u8>, u32> {
let engine = base64::engine::general_purpose::STANDARD;
let content = std::str::from_utf8(data).expect("tiktoken data must be valid UTF-8");
let mut ranks = FxHashMap::default();
ranks.reserve(data.len() / 20);
for line in content.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let mut parts = line.splitn(2, ' ');
let token_b64 = parts.next().expect("missing token");
let rank_str = parts.next().expect("missing rank");
let token_bytes = engine
.decode(token_b64)
.expect("invalid base64 in tiktoken data");
let rank: u32 = rank_str.parse().expect("invalid rank in tiktoken data");
ranks.insert(token_bytes, rank);
}
ranks
}
fn special_tokens(pairs: &[(&str, u32)]) -> FxHashMap<Vec<u8>, u32> {
pairs
.iter()
.map(|&(s, v)| (s.as_bytes().to_vec(), v))
.collect()
}
pub fn cl100k_base() -> CoreBpe {
let encoder = parse_tiktoken_data(CL100K_BASE_DATA);
let special = special_tokens(&[
("<|endoftext|>", 100257),
("<|fim_prefix|>", 100258),
("<|fim_middle|>", 100259),
("<|fim_suffix|>", 100260),
("<|endofprompt|>", 100276),
]);
CoreBpe::new(encoder, special, CL100K_PATTERN)
}
pub fn p50k_base() -> CoreBpe {
let encoder = parse_tiktoken_data(P50K_BASE_DATA);
let special = special_tokens(&[("<|endoftext|>", 50256)]);
CoreBpe::new(encoder, special, P50K_PATTERN)
}
pub fn p50k_edit() -> CoreBpe {
let encoder = parse_tiktoken_data(P50K_BASE_DATA);
let special = special_tokens(&[
("<|endoftext|>", 50256),
("<|fim_prefix|>", 50281),
("<|fim_middle|>", 50282),
("<|fim_suffix|>", 50283),
]);
CoreBpe::new(encoder, special, P50K_PATTERN)
}
pub fn o200k_base() -> CoreBpe {
let encoder = parse_tiktoken_data(O200K_BASE_DATA);
let special = special_tokens(&[("<|endoftext|>", 199999), ("<|endofprompt|>", 200018)]);
CoreBpe::new(encoder, special, O200K_PATTERN)
}
pub fn r50k_base() -> CoreBpe {
let encoder = parse_tiktoken_data(R50K_BASE_DATA);
let special = special_tokens(&[("<|endoftext|>", 50256)]);
CoreBpe::new(encoder, special, P50K_PATTERN)
}
pub fn llama3() -> CoreBpe {
let encoder = parse_tiktoken_data(LLAMA3_DATA);
let special = special_tokens(&[
("<|begin_of_text|>", 128000),
("<|end_of_text|>", 128001),
("<|finetune_right_pad_id|>", 128004),
("<|start_header_id|>", 128006),
("<|end_header_id|>", 128007),
("<|eom_id|>", 128008),
("<|eot_id|>", 128009),
("<|python_tag|>", 128010),
]);
CoreBpe::new(encoder, special, LLAMA3_PATTERN)
}
pub fn deepseek_v3() -> CoreBpe {
let encoder = parse_tiktoken_data(DEEPSEEK_V3_DATA);
let special = special_tokens(&[
("<|begin▁of▁sentence|>", 0),
("<|end▁of▁sentence|>", 1),
("<|▁pad▁|>", 2),
("<|EOT|>", 128805),
]);
CoreBpe::new(encoder, special, DEEPSEEK_V3_PATTERN)
}
pub fn qwen2() -> CoreBpe {
let encoder = parse_tiktoken_data(QWEN2_DATA);
let special = special_tokens(&[
("<|endoftext|>", 151643),
("<|im_start|>", 151644),
("<|im_end|>", 151645),
("<|object_ref_start|>", 151646),
("<|object_ref_end|>", 151647),
("<|box_start|>", 151648),
("<|box_end|>", 151649),
("<|quad_start|>", 151650),
("<|quad_end|>", 151651),
("<|vision_start|>", 151652),
("<|vision_end|>", 151653),
("<|vision_pad|>", 151654),
("<|image_pad|>", 151655),
("<|video_pad|>", 151656),
]);
CoreBpe::new(encoder, special, QWEN2_PATTERN)
}
pub fn mistral_v3() -> CoreBpe {
let encoder = parse_tiktoken_data(MISTRAL_V3_DATA);
let special = special_tokens(&[
("<unk>", 0),
("<s>", 1),
("</s>", 2),
("[INST]", 3),
("[/INST]", 4),
("[AVAILABLE_TOOLS]", 5),
("[/AVAILABLE_TOOLS]", 6),
("[TOOL_RESULTS]", 7),
("[/TOOL_RESULTS]", 8),
("[TOOL_CALLS]", 9),
("[IMG]", 10),
("[IMG_BREAK]", 12),
("[IMG_END]", 13),
("[PREFIX]", 14),
("[MIDDLE]", 15),
("[SUFFIX]", 16),
]);
CoreBpe::new(encoder, special, MISTRAL_V3_PATTERN)
}
#[cfg(test)]
pub(crate) fn parse_tiktoken_data_for_test() -> FxHashMap<Vec<u8>, u32> {
parse_tiktoken_data(CL100K_BASE_DATA)
}