use std::collections::HashSet;
use std::path::Path;
use base64::Engine as _;
use rayon::prelude::*;
use rustc_hash::FxHashMap;
use tiktoken_rs::CoreBPE;
use super::{
Encoding, Error, Result, TokenIdType,
traits::{Decoder, Encoder, Tokenizer},
};
const DEFAULT_NUM_RESERVED_SPECIAL_TOKENS: u32 = 256;
const KIMI_PATTERN: &str = r#"[\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"#;
pub struct TikTokenTokenizer {
bpe: CoreBPE,
special_token_ids: HashSet<u32>,
}
impl TikTokenTokenizer {
pub fn from_file(
path: &str,
pattern: &str,
special_tokens: FxHashMap<String, u32>,
) -> Result<Self> {
let encoder = parse_tiktoken_file(path)?;
let special_token_ids: HashSet<u32> = special_tokens.values().copied().collect();
let bpe = CoreBPE::new(encoder, special_tokens, pattern)
.map_err(|err| Error::msg(format!("Error creating tiktoken BPE: {err}")))?;
Ok(Self {
bpe,
special_token_ids,
})
}
pub fn from_file_auto(path: &str) -> Result<Self> {
let file_path = Path::new(path);
let directory = file_path
.parent()
.ok_or_else(|| Error::msg("Cannot determine parent directory of tiktoken file"))?;
let pattern = detect_bpe_pattern(directory)?;
let encoder = parse_tiktoken_file(path)?;
let num_base_tokens = encoder.values().max().map_or(0, |&m| m + 1) as usize;
let special_tokens = load_special_tokens(directory, num_base_tokens)?;
let special_token_ids: HashSet<u32> = special_tokens.values().copied().collect();
let bpe = CoreBPE::new(encoder, special_tokens, pattern)
.map_err(|err| Error::msg(format!("Error creating tiktoken BPE: {err}")))?;
Ok(Self {
bpe,
special_token_ids,
})
}
}
impl Encoder for TikTokenTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
let token_ids: Vec<u32> = self.bpe.encode_with_special_tokens(input);
Ok(Encoding::Sp(token_ids))
}
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
inputs.par_iter().map(|input| self.encode(input)).collect()
}
}
impl Decoder for TikTokenTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
let ids: Vec<u32> = if skip_special_tokens {
token_ids
.iter()
.filter(|&&id| !self.special_token_ids.contains(&id))
.copied()
.collect()
} else {
token_ids.to_vec()
};
let bytes: Vec<u8> = self.bpe._decode_native_and_split(ids).flatten().collect();
Ok(String::from_utf8_lossy(&bytes).into_owned())
}
}
impl Tokenizer for TikTokenTokenizer {}
fn parse_tiktoken_file(path: &str) -> Result<FxHashMap<Vec<u8>, u32>> {
let contents = std::fs::read_to_string(path)
.map_err(|err| Error::msg(format!("Failed to read tiktoken file '{path}': {err}")))?;
let engine = base64::engine::general_purpose::STANDARD;
let mut encoder = FxHashMap::default();
for line in contents.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let mut parts = line.split_whitespace();
let token_b64 = parts
.next()
.ok_or_else(|| Error::msg(format!("Invalid tiktoken line (no token): {line}")))?;
let rank_str = parts
.next()
.ok_or_else(|| Error::msg(format!("Invalid tiktoken line (no rank): {line}")))?;
let token_bytes = engine
.decode(token_b64)
.map_err(|err| Error::msg(format!("Invalid base64 in tiktoken file: {err}")))?;
let rank: u32 = rank_str
.parse()
.map_err(|err| Error::msg(format!("Invalid rank in tiktoken file: {err}")))?;
encoder.insert(token_bytes, rank);
}
Ok(encoder)
}
fn detect_bpe_pattern(directory: &Path) -> Result<&'static str> {
let model_type: String = crate::file_json_field(&directory.join("config.json"), "model_type")
.map_err(|err| {
Error::msg(format!("Failed to read model_type from config.json: {err}"))
})?;
match model_type.as_str() {
"kimi" | "kimi_k2" | "kimi_k25" | "deepseek_v3" => Ok(KIMI_PATTERN),
_ => Err(Error::msg(format!(
"Unsupported tiktoken model_type '{model_type}'. \
Currently supported: kimi, kimi_k2, kimi_k25, deepseek_v3. \
To add a new model type, extend detect_bpe_pattern() in tokenizers/tiktoken.rs \
with the appropriate BPE regex pattern. \
Alternatively, provide a tokenizer.json (HuggingFace format) instead."
))),
}
}
fn load_special_tokens(directory: &Path, num_base_tokens: usize) -> Result<FxHashMap<String, u32>> {
let config_path = directory.join("tokenizer_config.json");
let mut special_tokens = FxHashMap::default();
if !config_path.exists() {
for i in 0..DEFAULT_NUM_RESERVED_SPECIAL_TOKENS {
let id = num_base_tokens as u32 + i;
special_tokens.insert(format!("<|reserved_token_{id}|>"), id);
}
return Ok(special_tokens);
}
let contents = std::fs::read_to_string(&config_path)
.map_err(|err| Error::msg(format!("Failed to read tokenizer_config.json: {err}")))?;
let config: serde_json::Value = serde_json::from_str(&contents)
.map_err(|err| Error::msg(format!("Failed to parse tokenizer_config.json: {err}")))?;
if let Some(added_tokens) = config
.get("added_tokens_decoder")
.and_then(|v| v.as_object())
{
for (id_str, token_def) in added_tokens {
let id: u32 = id_str.parse().map_err(|err| {
Error::msg(format!(
"Invalid token ID '{id_str}' in added_tokens_decoder: {err}"
))
})?;
let content = token_def
.get("content")
.and_then(|v| v.as_str())
.unwrap_or_else(|| {
tracing::warn!("Missing 'content' field for token ID {id}");
""
});
if !content.is_empty() {
special_tokens.insert(content.to_string(), id);
}
}
let used_ids: HashSet<u32> = special_tokens.values().copied().collect();
for i in 0..DEFAULT_NUM_RESERVED_SPECIAL_TOKENS {
let id = num_base_tokens as u32 + i;
if !used_ids.contains(&id) {
special_tokens.insert(format!("<|reserved_token_{id}|>"), id);
}
}
} else {
for i in 0..DEFAULT_NUM_RESERVED_SPECIAL_TOKENS {
let id = num_base_tokens as u32 + i;
special_tokens.insert(format!("<|reserved_token_{id}|>"), id);
}
}
Ok(special_tokens)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizers::DecodeStream;
use std::io::Write;
use std::sync::Arc;
fn create_test_tiktoken_file(dir: &Path) -> String {
let engine = base64::engine::general_purpose::STANDARD;
let mut content = String::new();
let tokens: Vec<(&[u8], u32)> = vec![
(b"h", 0),
(b"e", 1),
(b"l", 2),
(b"o", 3),
(b" ", 4),
(b"w", 5),
(b"r", 6),
(b"d", 7),
(b"he", 8),
(b"ll", 9),
(b"lo", 10),
(b"wo", 11),
(b"rl", 12),
(b"hel", 13),
(b"llo", 14),
(b"wor", 15),
(b"hell", 16),
(b"ello", 17),
(b"worl", 18),
(b"hello", 19),
(b"world", 20),
];
for (token, rank) in tokens {
let encoded = engine.encode(token);
content.push_str(&format!("{encoded} {rank}\n"));
}
let file_path = dir.join("tiktoken.model");
let mut file = std::fs::File::create(&file_path).unwrap();
file.write_all(content.as_bytes()).unwrap();
file_path.to_str().unwrap().to_string()
}
fn create_test_config(dir: &Path, model_type: &str) {
let config = serde_json::json!({
"model_type": model_type,
"max_position_embeddings": 32768,
"eos_token_id": [21]
});
let file_path = dir.join("config.json");
std::fs::write(file_path, serde_json::to_string_pretty(&config).unwrap()).unwrap();
}
fn create_test_tokenizer_config(dir: &Path, num_base_tokens: usize) {
let mut added_tokens = serde_json::Map::new();
let bos_id = num_base_tokens;
let eos_id = num_base_tokens + 1;
added_tokens.insert(
bos_id.to_string(),
serde_json::json!({"content": "[BOS]", "special": true}),
);
added_tokens.insert(
eos_id.to_string(),
serde_json::json!({"content": "[EOS]", "special": true}),
);
let config = serde_json::json!({
"added_tokens_decoder": added_tokens
});
let file_path = dir.join("tokenizer_config.json");
std::fs::write(file_path, serde_json::to_string_pretty(&config).unwrap()).unwrap();
}
#[test]
fn test_parse_tiktoken_file() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
let encoder = parse_tiktoken_file(&file_path).unwrap();
assert_eq!(encoder.len(), 21);
assert_eq!(encoder[b"hello".as_slice()], 19);
assert_eq!(encoder[b"world".as_slice()], 20);
}
#[test]
fn test_parse_tiktoken_file_missing() {
let result = parse_tiktoken_file("/nonexistent/path/tiktoken.model");
assert!(result.is_err());
}
#[test]
fn test_tiktoken_from_file() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
let mut special_tokens = FxHashMap::default();
special_tokens.insert("[BOS]".to_string(), 21_u32);
special_tokens.insert("[EOS]".to_string(), 22_u32);
let pattern = r"[\w]+|[^\w\s]+|\s+";
let tokenizer = TikTokenTokenizer::from_file(&file_path, pattern, special_tokens).unwrap();
let encoding = tokenizer.encode("hello world").unwrap();
let ids = encoding.token_ids();
assert!(!ids.is_empty());
let decoded = tokenizer.decode(ids, false).unwrap();
assert_eq!(decoded, "hello world");
}
#[test]
fn test_tiktoken_encoding_variant() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
let special_tokens = FxHashMap::default();
let pattern = r"[\w]+|[^\w\s]+|\s+";
let tokenizer = TikTokenTokenizer::from_file(&file_path, pattern, special_tokens).unwrap();
let encoding = tokenizer.encode("hello").unwrap();
match &encoding {
Encoding::Sp(_) => {}
other => panic!("Expected Encoding::Sp, got {:?}", other),
}
}
#[test]
fn test_tiktoken_skip_special_tokens() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
let mut special_tokens = FxHashMap::default();
special_tokens.insert("[BOS]".to_string(), 21_u32);
special_tokens.insert("[EOS]".to_string(), 22_u32);
let pattern = r"[\w]+|[^\w\s]+|\s+";
let tokenizer = TikTokenTokenizer::from_file(&file_path, pattern, special_tokens).unwrap();
let encoding = tokenizer.encode("hello").unwrap();
let mut ids = vec![21u32]; ids.extend(encoding.token_ids());
ids.push(22);
let decoded_skip = tokenizer.decode(&ids, true).unwrap();
assert_eq!(decoded_skip, "hello");
let decoded_all = tokenizer.decode(&ids, false).unwrap();
assert!(decoded_all.contains("hello"));
}
#[test]
fn test_tiktoken_from_file_auto() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
create_test_config(dir.path(), "kimi");
create_test_tokenizer_config(dir.path(), 21);
let tokenizer = TikTokenTokenizer::from_file_auto(&file_path).unwrap();
let encoding = tokenizer.encode("hello world").unwrap();
let ids = encoding.token_ids();
assert!(!ids.is_empty());
let decoded = tokenizer.decode(ids, false).unwrap();
assert_eq!(decoded, "hello world");
}
#[test]
fn test_detect_bpe_pattern_unknown() {
let dir = tempfile::tempdir().unwrap();
create_test_config(dir.path(), "unknown_model");
let result = detect_bpe_pattern(dir.path());
assert!(result.is_err());
}
#[test]
fn test_load_special_tokens_no_config() {
let dir = tempfile::tempdir().unwrap();
let tokens = load_special_tokens(dir.path(), 100).unwrap();
assert_eq!(tokens.len(), 256);
assert_eq!(tokens["<|reserved_token_100|>"], 100);
assert_eq!(tokens["<|reserved_token_355|>"], 355);
}
#[test]
fn test_load_special_tokens_with_config() {
let dir = tempfile::tempdir().unwrap();
create_test_tokenizer_config(dir.path(), 100);
let tokens = load_special_tokens(dir.path(), 100).unwrap();
assert_eq!(tokens["[BOS]"], 100);
assert_eq!(tokens["[EOS]"], 101);
assert!(tokens.len() > 2);
}
fn create_test_tiktoken_file_with_byte_tokens(dir: &Path) -> String {
let engine = base64::engine::general_purpose::STANDARD;
let mut content = String::new();
let tokens: Vec<(&[u8], u32)> = vec![
(b"h", 0),
(b"e", 1),
(b"l", 2),
(b"o", 3),
(b" ", 4),
(b"hello", 5),
];
for (token, rank) in &tokens {
let encoded = engine.encode(token);
content.push_str(&format!("{encoded} {rank}\n"));
}
let byte_tokens: Vec<(Vec<u8>, u32)> =
vec![(vec![0xE4], 100), (vec![0xBD], 101), (vec![0xA0], 102)];
for (token, rank) in &byte_tokens {
let encoded = engine.encode(token);
content.push_str(&format!("{encoded} {rank}\n"));
}
let emoji_tokens: Vec<(Vec<u8>, u32)> = vec![
(vec![0xF0], 200),
(vec![0x9F], 201),
(vec![0x98], 202),
(vec![0x80], 203),
];
for (token, rank) in &emoji_tokens {
let encoded = engine.encode(token);
content.push_str(&format!("{encoded} {rank}\n"));
}
let file_path = dir.join("tiktoken.model");
let mut file = std::fs::File::create(&file_path).unwrap();
file.write_all(content.as_bytes()).unwrap();
file_path.to_str().unwrap().to_string()
}
fn create_byte_token_tokenizer(dir: &Path) -> TikTokenTokenizer {
let file_path = create_test_tiktoken_file_with_byte_tokens(dir);
let special_tokens = FxHashMap::default();
let pattern = r"[\w]+|[^\w\s]+|\s+";
TikTokenTokenizer::from_file(&file_path, pattern, special_tokens).unwrap()
}
#[test]
fn test_decode_single_incomplete_utf8_byte_does_not_error() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[100], false);
assert!(
result.is_ok(),
"decode() should not error on incomplete UTF-8 bytes"
);
let text = result.unwrap();
assert!(
text.contains('\u{FFFD}'),
"incomplete UTF-8 byte should produce replacement character, got: {:?}",
text
);
}
#[test]
fn test_decode_two_of_three_utf8_bytes_does_not_error() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[100, 101], false);
assert!(result.is_ok());
let text = result.unwrap();
assert!(
text.contains('\u{FFFD}'),
"incomplete 2-of-3 UTF-8 bytes should produce replacement character, got: {:?}",
text
);
}
#[test]
fn test_decode_complete_multibyte_utf8_produces_correct_char() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[100, 101, 102], false);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "ä½ ");
}
#[test]
fn test_decode_complete_4byte_emoji_from_byte_tokens() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[200, 201, 202, 203], false);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "😀");
}
#[test]
fn test_decode_partial_emoji_does_not_error() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[200], false);
assert!(result.is_ok());
assert!(result.unwrap().contains('\u{FFFD}'));
}
#[test]
fn test_decode_mixed_ascii_and_incomplete_bytes() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[5, 100], false);
assert!(result.is_ok());
let text = result.unwrap();
assert!(
text.starts_with("hello"),
"should start with 'hello', got: {:?}",
text
);
assert!(
text.contains('\u{FFFD}'),
"trailing incomplete byte should produce U+FFFD"
);
}
#[test]
fn test_decode_stream_incremental_multibyte_reassembly() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let tokenizer_arc: Arc<dyn crate::tokenizers::traits::Tokenizer> = Arc::new(tokenizer);
let mut stream = DecodeStream::new(tokenizer_arc, &[5], false);
let r1 = stream.step(100).unwrap();
assert_eq!(r1, None, "first byte of 3-byte char should be buffered");
let r2 = stream.step(101).unwrap();
assert_eq!(r2, None, "second byte of 3-byte char should be buffered");
let r3 = stream.step(102).unwrap();
assert!(r3.is_some(), "third byte should complete the character");
assert_eq!(r3.unwrap(), "ä½ ");
}
#[test]
fn test_decode_stream_incremental_emoji_reassembly() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let tokenizer_arc: Arc<dyn crate::tokenizers::traits::Tokenizer> = Arc::new(tokenizer);
let mut stream = DecodeStream::new(tokenizer_arc, &[5], false);
let r1 = stream.step(200).unwrap();
assert_eq!(r1, None, "byte 1/4 of emoji should be buffered");
let r2 = stream.step(201).unwrap();
assert_eq!(r2, None, "byte 2/4 of emoji should be buffered");
let r3 = stream.step(202).unwrap();
assert_eq!(r3, None, "byte 3/4 of emoji should be buffered");
let r4 = stream.step(203).unwrap();
assert!(r4.is_some(), "byte 4/4 should complete the emoji");
assert_eq!(r4.unwrap(), "😀");
}
#[test]
fn test_tiktoken_encode_batch() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_test_tiktoken_file(dir.path());
let special_tokens = FxHashMap::default();
let pattern = r"[\w]+|[^\w\s]+|\s+";
let tokenizer = TikTokenTokenizer::from_file(&file_path, pattern, special_tokens).unwrap();
let inputs = &["hello", "world"];
let encodings = tokenizer.encode_batch(inputs).unwrap();
assert_eq!(encodings.len(), 2);
for (encoding, input) in encodings.iter().zip(inputs.iter()) {
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
assert_eq!(decoded, *input);
}
}
fn create_byte_level_tiktoken_file(dir: &Path) -> String {
let engine = base64::engine::general_purpose::STANDARD;
let mut content = String::new();
for byte_val in 0u16..256 {
let encoded = engine.encode([byte_val as u8]);
content.push_str(&format!("{encoded} {byte_val}\n"));
}
let file_path = dir.join("tiktoken.model");
std::fs::write(&file_path, &content).unwrap();
file_path.to_str().unwrap().to_string()
}
#[test]
fn test_reserved_token_absolute_id_naming_kimi_k25_regression() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_byte_level_tiktoken_file(dir.path());
create_test_config(dir.path(), "kimi");
create_test_tokenizer_config(dir.path(), 256);
let tokenizer = TikTokenTokenizer::from_file_auto(&file_path).unwrap();
let single = "<|reserved_token_258|>";
let enc = tokenizer.encode(single).unwrap();
assert_eq!(
enc.token_ids().len(),
1,
"'{single}' should be 1 special token, got {} tokens: {:?}. \
This means fallback naming still uses relative offsets instead of absolute IDs.",
enc.token_ids().len(),
enc.token_ids()
);
assert_eq!(enc.token_ids()[0], 258);
let multi: String = (258u32..268)
.map(|id| format!("<|reserved_token_{id}|>"))
.collect();
let enc_multi = tokenizer.encode(&multi).unwrap();
assert_eq!(
enc_multi.token_ids().len(),
10,
"10 reserved token strings should produce exactly 10 tokens, got {}: {:?}",
enc_multi.token_ids().len(),
enc_multi.token_ids()
);
let expected_ids: Vec<u32> = (258..268).collect();
assert_eq!(enc_multi.token_ids(), &expected_ids);
}
#[test]
fn test_relative_offset_naming_causes_inflation() {
let dir = tempfile::tempdir().unwrap();
let file_path = create_byte_level_tiktoken_file(dir.path());
let _encoder = parse_tiktoken_file(&file_path).unwrap();
let num_base_tokens = 256usize;
let mut bad_special_tokens: FxHashMap<String, u32> = FxHashMap::default();
bad_special_tokens.insert("[BOS]".to_string(), 256);
bad_special_tokens.insert("[EOS]".to_string(), 257);
for i in 0..DEFAULT_NUM_RESERVED_SPECIAL_TOKENS {
let id = num_base_tokens as u32 + i;
if id != 256 && id != 257 {
bad_special_tokens.insert(format!("<|reserved_token_{i}|>"), id);
}
}
let bad_tokenizer =
TikTokenTokenizer::from_file(&file_path, KIMI_PATTERN, bad_special_tokens).unwrap();
let input = "<|reserved_token_258|>";
let enc = bad_tokenizer.encode(input).unwrap();
assert!(
enc.token_ids().len() > 1,
"With buggy relative-offset naming, '{}' should NOT be recognized as a \
single special token. Got {} token(s): {:?}",
input,
enc.token_ids().len(),
enc.token_ids()
);
let multi: String = (258u32..268)
.map(|id| format!("<|reserved_token_{id}|>"))
.collect();
let enc_multi = bad_tokenizer.encode(&multi).unwrap();
assert!(
enc_multi.token_ids().len() > 10,
"With buggy naming, 10 reserved token strings should inflate beyond 10 tokens. \
Got {}",
enc_multi.token_ids().len(),
);
}
}