use rustc_hash::FxHashMap;
use super::tokenizer::{
Tokenizer, TokenizerError, CL100K_BASE_PATTERN, MISTRAL_V3_PATTERN, O200K_BASE_PATTERN,
SENTENCEPIECE_PATTERN,
};
pub const CL100K_BASE_VOCAB: &[u8] =
include_bytes!("../../python/splintr/vocabs/cl100k_base.tiktoken");
pub const O200K_BASE_VOCAB: &[u8] =
include_bytes!("../../python/splintr/vocabs/o200k_base.tiktoken");
pub const LLAMA3_VOCAB: &[u8] = include_bytes!("../../python/splintr/vocabs/llama3.tiktoken");
pub const DEEPSEEK_V3_VOCAB: &[u8] =
include_bytes!("../../python/splintr/vocabs/deepseek_v3.tiktoken");
pub const MISTRAL_VOCAB: &[u8] = include_bytes!("../../python/splintr/vocabs/mistral.tiktoken");
pub const MISTRAL_V2_VOCAB: &[u8] =
include_bytes!("../../python/splintr/vocabs/mistral_v2.tiktoken");
pub const MISTRAL_V3_VOCAB: &[u8] =
include_bytes!("../../python/splintr/vocabs/mistral_v3_tekken.tiktoken");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PretrainedVocab {
Cl100kBase,
O200kBase,
Llama3,
DeepseekV3,
MistralV1,
MistralV2,
MistralV3,
}
impl PretrainedVocab {
pub fn from_name(name: &str) -> Option<Self> {
match name {
"cl100k_base" => Some(Self::Cl100kBase),
"o200k_base" => Some(Self::O200kBase),
"llama3" | "llama3.1" | "llama3.2" | "llama3.3" => Some(Self::Llama3),
"deepseek_v3" | "deepseek-v3" => Some(Self::DeepseekV3),
"mistral" | "mistral_v1" => Some(Self::MistralV1),
"mistral_v2" => Some(Self::MistralV2),
"mistral_v3" => Some(Self::MistralV3),
_ => None,
}
}
pub fn supported_names() -> &'static [&'static str] {
&[
"cl100k_base",
"o200k_base",
"llama3",
"llama3.1",
"llama3.2",
"llama3.3",
"deepseek_v3",
"deepseek-v3",
"mistral",
"mistral_v1",
"mistral_v2",
"mistral_v3",
]
}
}
pub fn from_pretrained(name: &str) -> Result<Tokenizer, TokenizerError> {
let vocab = PretrainedVocab::from_name(name).ok_or_else(|| {
TokenizerError::UnknownPretrained(format!(
"{}. Supported: {}",
name,
PretrainedVocab::supported_names().join(", ")
))
})?;
from_vocab(vocab)
}
pub fn from_vocab(vocab: PretrainedVocab) -> Result<Tokenizer, TokenizerError> {
let special = special_tokens(vocab);
let pat = pattern(vocab);
match vocab {
PretrainedVocab::Cl100kBase => Tokenizer::from_bytes(CL100K_BASE_VOCAB, pat, special),
PretrainedVocab::O200kBase => Tokenizer::from_bytes(O200K_BASE_VOCAB, pat, special),
PretrainedVocab::Llama3 => Tokenizer::from_bytes(LLAMA3_VOCAB, pat, special),
PretrainedVocab::DeepseekV3 => {
Tokenizer::from_bytes_byte_level(DEEPSEEK_V3_VOCAB, pat, special)
}
PretrainedVocab::MistralV1 => {
Tokenizer::from_bytes_sentencepiece(MISTRAL_VOCAB, pat, special)
}
PretrainedVocab::MistralV2 => {
Tokenizer::from_bytes_sentencepiece_with_decoder(MISTRAL_V2_VOCAB, pat, special)
}
PretrainedVocab::MistralV3 => {
Tokenizer::from_bytes_byte_level(MISTRAL_V3_VOCAB, pat, special)
}
}
}
pub fn pattern(vocab: PretrainedVocab) -> &'static str {
match vocab {
PretrainedVocab::Cl100kBase => CL100K_BASE_PATTERN,
PretrainedVocab::O200kBase => O200K_BASE_PATTERN,
PretrainedVocab::Llama3 => O200K_BASE_PATTERN, PretrainedVocab::DeepseekV3 => O200K_BASE_PATTERN, PretrainedVocab::MistralV1 | PretrainedVocab::MistralV2 => SENTENCEPIECE_PATTERN, PretrainedVocab::MistralV3 => MISTRAL_V3_PATTERN, }
}
pub fn uses_byte_level(vocab: PretrainedVocab) -> bool {
matches!(vocab, PretrainedVocab::DeepseekV3)
}
pub fn eos_token_id(vocab: PretrainedVocab) -> u32 {
match vocab {
PretrainedVocab::Cl100kBase => 100257, PretrainedVocab::O200kBase => 199999, PretrainedVocab::Llama3 => 128001, PretrainedVocab::DeepseekV3 => 1, PretrainedVocab::MistralV1 | PretrainedVocab::MistralV2 | PretrainedVocab::MistralV3 => 2, }
}
pub fn eos_token_id_by_name(name: &str) -> u32 {
PretrainedVocab::from_name(name)
.map(eos_token_id)
.unwrap_or(0)
}
pub fn bos_token_id(vocab: PretrainedVocab) -> Option<u32> {
match vocab {
PretrainedVocab::Cl100kBase => None, PretrainedVocab::O200kBase => None, PretrainedVocab::Llama3 => Some(128000), PretrainedVocab::DeepseekV3 => Some(0), PretrainedVocab::MistralV1 | PretrainedVocab::MistralV2 | PretrainedVocab::MistralV3 => {
Some(1)
} }
}
pub fn bos_token_id_by_name(name: &str) -> Option<u32> {
PretrainedVocab::from_name(name).and_then(bos_token_id)
}
pub fn pad_token_id(vocab: PretrainedVocab) -> Option<u32> {
match vocab {
PretrainedVocab::Cl100kBase => Some(100316), PretrainedVocab::O200kBase => Some(200058), PretrainedVocab::Llama3 => Some(128339), PretrainedVocab::DeepseekV3 => Some(2), PretrainedVocab::MistralV1 => Some(32039), PretrainedVocab::MistralV2 => Some(32807), PretrainedVocab::MistralV3 => Some(131111), }
}
pub fn special_tokens(vocab: PretrainedVocab) -> FxHashMap<String, u32> {
match vocab {
PretrainedVocab::Cl100kBase => cl100k_base_special_tokens(),
PretrainedVocab::O200kBase => o200k_base_special_tokens(),
PretrainedVocab::Llama3 => llama3_special_tokens(),
PretrainedVocab::DeepseekV3 => deepseek_v3_special_tokens(),
PretrainedVocab::MistralV1 => mistral_v1_special_tokens(),
PretrainedVocab::MistralV2 => mistral_v2_special_tokens(),
PretrainedVocab::MistralV3 => mistral_v3_special_tokens(),
}
}
pub fn cl100k_base_special_tokens() -> FxHashMap<String, u32> {
let mut special = FxHashMap::default();
special.insert("<|endoftext|>".to_string(), 100257);
special.insert("<|fim_prefix|>".to_string(), 100258);
special.insert("<|fim_middle|>".to_string(), 100259);
special.insert("<|fim_suffix|>".to_string(), 100260);
special.insert("<|endofprompt|>".to_string(), 100276);
insert_agent_tokens(&mut special, 100277);
special
}
pub fn o200k_base_special_tokens() -> FxHashMap<String, u32> {
let mut special = FxHashMap::default();
special.insert("<|endoftext|>".to_string(), 199999);
special.insert("<|endofprompt|>".to_string(), 200018);
insert_agent_tokens(&mut special, 200019);
special
}
pub fn llama3_special_tokens() -> FxHashMap<String, u32> {
let mut special = FxHashMap::default();
special.insert("<|begin_of_text|>".to_string(), 128000);
special.insert("<|end_of_text|>".to_string(), 128001);
special.insert("<|reserved_special_token_0|>".to_string(), 128002);
special.insert("<|reserved_special_token_1|>".to_string(), 128003);
special.insert("<|finetune_right_pad_id|>".to_string(), 128004);
special.insert("<|step_id|>".to_string(), 128005);
special.insert("<|start_header_id|>".to_string(), 128006);
special.insert("<|end_header_id|>".to_string(), 128007);
special.insert("<|eom_id|>".to_string(), 128008);
special.insert("<|eot_id|>".to_string(), 128009);
special.insert("<|python_tag|>".to_string(), 128010);
special.insert("<|image|>".to_string(), 128256);
special.insert("<|/image|>".to_string(), 128257);
special.insert("<|audio|>".to_string(), 128258);
special.insert("<|/audio|>".to_string(), 128259);
special.insert("<|video|>".to_string(), 128260);
special.insert("<|/video|>".to_string(), 128261);
insert_agent_tokens_llama3(&mut special, 128300);
special
}
pub fn deepseek_v3_special_tokens() -> FxHashMap<String, u32> {
let mut special = FxHashMap::default();
special.insert("<|begin▁of▁sentence|>".to_string(), 0);
special.insert("<|end▁of▁sentence|>".to_string(), 1);
special.insert("<|▁pad▁|>".to_string(), 2);
special.insert("<think>".to_string(), 128798);
special.insert("</think>".to_string(), 128799);
special.insert("<|fim▁hole|>".to_string(), 128800);
special.insert("<|fim▁begin|>".to_string(), 128801);
special.insert("<|fim▁end|>".to_string(), 128802);
special.insert("<|User|>".to_string(), 128803);
special.insert("<|Assistant|>".to_string(), 128804);
special.insert("<|EOT|>".to_string(), 128805);
special.insert("<|tool▁calls▁begin|>".to_string(), 128806);
special.insert("<|tool▁calls▁end|>".to_string(), 128807);
special.insert("<|tool▁call▁begin|>".to_string(), 128808);
special.insert("<|tool▁call▁end|>".to_string(), 128809);
special.insert("<|tool▁outputs▁begin|>".to_string(), 128810);
special.insert("<|tool▁outputs▁end|>".to_string(), 128811);
special.insert("<|tool▁output▁begin|>".to_string(), 128812);
special.insert("<|tool▁output▁end|>".to_string(), 128813);
special.insert("<|tool▁sep|>".to_string(), 128814);
insert_agent_tokens(&mut special, 128900);
special
}
pub fn mistral_v1_special_tokens() -> FxHashMap<String, u32> {
let mut special = FxHashMap::default();
special.insert("<unk>".to_string(), 0);
special.insert("<s>".to_string(), 1);
special.insert("</s>".to_string(), 2);
insert_agent_tokens(&mut special, 32000);
special
}
pub fn mistral_v2_special_tokens() -> FxHashMap<String, u32> {
let mut special = FxHashMap::default();
special.insert("[INST]".to_string(), 3);
special.insert("[/INST]".to_string(), 4);
special.insert("[TOOL_CALLS]".to_string(), 5);
special.insert("[AVAILABLE_TOOLS]".to_string(), 6);
special.insert("[/AVAILABLE_TOOLS]".to_string(), 7);
special.insert("[TOOL_RESULTS]".to_string(), 8);
special.insert("[/TOOL_RESULTS]".to_string(), 9);
insert_agent_tokens(&mut special, 32768);
special
}
pub fn mistral_v3_special_tokens() -> FxHashMap<String, u32> {
let mut special = FxHashMap::default();
special.insert("<unk>".to_string(), 0);
special.insert("<s>".to_string(), 1);
special.insert("</s>".to_string(), 2);
special.insert("[INST]".to_string(), 3);
special.insert("[/INST]".to_string(), 4);
special.insert("[AVAILABLE_TOOLS]".to_string(), 5);
special.insert("[/AVAILABLE_TOOLS]".to_string(), 6);
special.insert("[TOOL_RESULTS]".to_string(), 7);
special.insert("[/TOOL_RESULTS]".to_string(), 8);
special.insert("[TOOL_CALLS]".to_string(), 9);
insert_agent_tokens(&mut special, 131072);
special
}
fn insert_agent_tokens(special: &mut FxHashMap<String, u32>, base: u32) {
special.insert("<|system|>".to_string(), base);
special.insert("<|user|>".to_string(), base + 1);
special.insert("<|assistant|>".to_string(), base + 2);
special.insert("<|im_start|>".to_string(), base + 3);
special.insert("<|im_end|>".to_string(), base + 4);
special.insert("<|think|>".to_string(), base + 5);
special.insert("<|/think|>".to_string(), base + 6);
special.insert("<|plan|>".to_string(), base + 7);
special.insert("<|/plan|>".to_string(), base + 8);
special.insert("<|step|>".to_string(), base + 9);
special.insert("<|/step|>".to_string(), base + 10);
special.insert("<|act|>".to_string(), base + 11);
special.insert("<|/act|>".to_string(), base + 12);
special.insert("<|observe|>".to_string(), base + 13);
special.insert("<|/observe|>".to_string(), base + 14);
special.insert("<|function|>".to_string(), base + 15);
special.insert("<|/function|>".to_string(), base + 16);
special.insert("<|result|>".to_string(), base + 17);
special.insert("<|/result|>".to_string(), base + 18);
special.insert("<|error|>".to_string(), base + 19);
special.insert("<|/error|>".to_string(), base + 20);
special.insert("<|code|>".to_string(), base + 21);
special.insert("<|/code|>".to_string(), base + 22);
special.insert("<|output|>".to_string(), base + 23);
special.insert("<|/output|>".to_string(), base + 24);
special.insert("<|lang|>".to_string(), base + 25);
special.insert("<|/lang|>".to_string(), base + 26);
special.insert("<|context|>".to_string(), base + 27);
special.insert("<|/context|>".to_string(), base + 28);
special.insert("<|quote|>".to_string(), base + 29);
special.insert("<|/quote|>".to_string(), base + 30);
special.insert("<|cite|>".to_string(), base + 31);
special.insert("<|/cite|>".to_string(), base + 32);
special.insert("<|source|>".to_string(), base + 33);
special.insert("<|/source|>".to_string(), base + 34);
special.insert("<|memory|>".to_string(), base + 35);
special.insert("<|/memory|>".to_string(), base + 36);
special.insert("<|recall|>".to_string(), base + 37);
special.insert("<|/recall|>".to_string(), base + 38);
special.insert("<|pad|>".to_string(), base + 39);
special.insert("<|stop|>".to_string(), base + 40);
special.insert("<|sep|>".to_string(), base + 41);
special.insert("<|image|>".to_string(), base + 42);
special.insert("<|/image|>".to_string(), base + 43);
special.insert("<|audio|>".to_string(), base + 44);
special.insert("<|/audio|>".to_string(), base + 45);
special.insert("<|video|>".to_string(), base + 46);
special.insert("<|/video|>".to_string(), base + 47);
special.insert("<|title|>".to_string(), base + 48);
special.insert("<|/title|>".to_string(), base + 49);
special.insert("<|section|>".to_string(), base + 50);
special.insert("<|/section|>".to_string(), base + 51);
special.insert("<|summary|>".to_string(), base + 52);
special.insert("<|/summary|>".to_string(), base + 53);
}
fn insert_agent_tokens_llama3(special: &mut FxHashMap<String, u32>, base: u32) {
special.insert("<|system|>".to_string(), base);
special.insert("<|user|>".to_string(), base + 1);
special.insert("<|assistant|>".to_string(), base + 2);
special.insert("<|im_start|>".to_string(), base + 3);
special.insert("<|im_end|>".to_string(), base + 4);
special.insert("<|think|>".to_string(), base + 5);
special.insert("<|/think|>".to_string(), base + 6);
special.insert("<|plan|>".to_string(), base + 7);
special.insert("<|/plan|>".to_string(), base + 8);
special.insert("<|step|>".to_string(), base + 9);
special.insert("<|/step|>".to_string(), base + 10);
special.insert("<|act|>".to_string(), base + 11);
special.insert("<|/act|>".to_string(), base + 12);
special.insert("<|observe|>".to_string(), base + 13);
special.insert("<|/observe|>".to_string(), base + 14);
special.insert("<|function|>".to_string(), base + 15);
special.insert("<|/function|>".to_string(), base + 16);
special.insert("<|result|>".to_string(), base + 17);
special.insert("<|/result|>".to_string(), base + 18);
special.insert("<|error|>".to_string(), base + 19);
special.insert("<|/error|>".to_string(), base + 20);
special.insert("<|code|>".to_string(), base + 21);
special.insert("<|/code|>".to_string(), base + 22);
special.insert("<|output|>".to_string(), base + 23);
special.insert("<|/output|>".to_string(), base + 24);
special.insert("<|lang|>".to_string(), base + 25);
special.insert("<|/lang|>".to_string(), base + 26);
special.insert("<|context|>".to_string(), base + 27);
special.insert("<|/context|>".to_string(), base + 28);
special.insert("<|quote|>".to_string(), base + 29);
special.insert("<|/quote|>".to_string(), base + 30);
special.insert("<|cite|>".to_string(), base + 31);
special.insert("<|/cite|>".to_string(), base + 32);
special.insert("<|source|>".to_string(), base + 33);
special.insert("<|/source|>".to_string(), base + 34);
special.insert("<|memory|>".to_string(), base + 35);
special.insert("<|/memory|>".to_string(), base + 36);
special.insert("<|recall|>".to_string(), base + 37);
special.insert("<|/recall|>".to_string(), base + 38);
special.insert("<|pad|>".to_string(), base + 39);
special.insert("<|stop|>".to_string(), base + 40);
special.insert("<|sep|>".to_string(), base + 41);
special.insert("<|title|>".to_string(), base + 48);
special.insert("<|/title|>".to_string(), base + 49);
special.insert("<|section|>".to_string(), base + 50);
special.insert("<|/section|>".to_string(), base + 51);
special.insert("<|summary|>".to_string(), base + 52);
special.insert("<|/summary|>".to_string(), base + 53);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_pretrained_llama3() {
let tokenizer = from_pretrained("llama3").unwrap();
assert!(tokenizer.vocab_size() > 100000);
}
#[test]
fn test_from_pretrained_cl100k() {
let tokenizer = from_pretrained("cl100k_base").unwrap();
assert!(tokenizer.vocab_size() > 90000);
}
#[test]
fn test_eos_token_ids() {
assert_eq!(eos_token_id(PretrainedVocab::Cl100kBase), 100257);
assert_eq!(eos_token_id(PretrainedVocab::O200kBase), 199999);
assert_eq!(eos_token_id(PretrainedVocab::Llama3), 128001);
assert_eq!(eos_token_id(PretrainedVocab::DeepseekV3), 1);
assert_eq!(eos_token_id(PretrainedVocab::MistralV1), 2);
}
#[test]
fn test_vocab_from_name() {
assert_eq!(
PretrainedVocab::from_name("llama3"),
Some(PretrainedVocab::Llama3)
);
assert_eq!(
PretrainedVocab::from_name("llama3.1"),
Some(PretrainedVocab::Llama3)
);
assert_eq!(
PretrainedVocab::from_name("deepseek_v3"),
Some(PretrainedVocab::DeepseekV3)
);
assert_eq!(
PretrainedVocab::from_name("mistral"),
Some(PretrainedVocab::MistralV1)
);
assert_eq!(PretrainedVocab::from_name("unknown"), None);
}
#[test]
fn test_from_pretrained_mistral() {
let tokenizer = from_pretrained("mistral").unwrap();
assert!(tokenizer.vocab_size() >= 31000);
}
#[test]
fn test_mistral_encode_decode() {
let tokenizer = from_pretrained("mistral").unwrap();
let text = "Hello, world!";
let tokens = tokenizer.encode(text);
assert!(!tokens.is_empty());
let decoded = tokenizer.decode(&tokens).unwrap();
assert_eq!(decoded, text, "Encoding should be reversible");
}
}