use core::mem::size_of;
use std::io::{Read, Write};
use crate::encoder::{BacktrackingBytePairEncoder, BytePairEncoder, Encoder, EncoderType, SentencePieceBPE, UnigramEncoder, WordPieceEncoder};
use crate::decoder::{Decoder, DecoderType, VocabDecoder};
use crate::normalizer::Normalizer;
use crate::postprocessor::PostProcessor;
use crate::pretok::PretokType;
use crate::tokenizer::Tokenizer;
use crate::types::{Split, TokenId};
use daggrs::DoubleArrayAhoCorasick;
use foldhash::HashMap as FoldHashMap;
const MAGIC: &[u8; 4] = b"TOKI";
const VERSION: u32 = 11; const HEADER_SIZE: usize = 88;
impl PretokType {
fn from_u32(v: u32) -> Option<Self> {
match v {
0 => Some(Self::None),
1 => Some(Self::Gpt2),
2 => Some(Self::Cl100k),
3 => Some(Self::O200k),
4 => Some(Self::Bert),
5 => Some(Self::Voyage),
6 => Some(Self::DeepSeek),
7 => Some(Self::SmolLM),
8 => Some(Self::Qwen35),
_ => None,
}
}
}
impl Normalizer {
fn from_u32(v: u32) -> Option<Self> {
match v {
0 => Some(Self::None),
1 => Some(Self::BertUncased),
2 => Some(Self::BertCased),
3 => Some(Self::Nfc),
4 => Some(Self::Metaspace),
5 => Some(Self::SentencePiece),
6 => Some(Self::SentencePieceLowercase),
7 => Some(Self::MetaspaceReplace),
_ => None,
}
}
fn to_u32(&self) -> u32 {
match self {
Self::None => 0,
Self::BertUncased => 1,
Self::BertCased => 2,
Self::Nfc => 3,
Self::Metaspace => 4,
Self::SentencePiece => 5,
Self::SentencePieceLowercase => 6,
Self::MetaspaceReplace => 7,
}
}
}
impl PostProcessor {
fn type_id(&self) -> u32 {
match self {
Self::None => 0,
Self::Bert { .. } => 1,
Self::Prefix { .. } => 2,
Self::Template { .. } => 3,
}
}
fn serialize(&self) -> Vec<u8> {
match self {
Self::None => Vec::new(),
Self::Bert { cls_token, sep_token } => {
let mut buf = Vec::with_capacity(8);
buf.extend_from_slice(&cls_token.to_le_bytes());
buf.extend_from_slice(&sep_token.to_le_bytes());
buf
}
Self::Prefix { bos_token } => {
bos_token.to_le_bytes().to_vec()
}
Self::Template {
single_prefix,
single_suffix,
pair_a_prefix,
pair_a_suffix,
pair_b_prefix,
pair_b_suffix,
} => {
let mut buf = Vec::new();
for tokens in [
single_prefix,
single_suffix,
pair_a_prefix,
pair_a_suffix,
pair_b_prefix,
pair_b_suffix,
] {
buf.extend_from_slice(&(tokens.len() as u32).to_le_bytes());
for &token in tokens {
buf.extend_from_slice(&token.to_le_bytes());
}
}
buf
}
}
}
fn deserialize(type_id: u32, data: &[u8]) -> Option<Self> {
match type_id {
0 => Some(Self::None),
1 => {
if data.len() < 8 {
return None;
}
let cls_token = u32::from_le_bytes(data[0..4].try_into().ok()?);
let sep_token = u32::from_le_bytes(data[4..8].try_into().ok()?);
Some(Self::Bert { cls_token, sep_token })
}
2 => {
if data.len() < 4 {
return None;
}
let bos_token = u32::from_le_bytes(data[0..4].try_into().ok()?);
Some(Self::Prefix { bos_token })
}
3 => {
let mut offset = 0;
let mut arrays = Vec::new();
for _ in 0..6 {
if offset + 4 > data.len() {
return None;
}
let len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?) as usize;
offset += 4;
let mut tokens = Vec::with_capacity(len);
for _ in 0..len {
if offset + 4 > data.len() {
return None;
}
tokens.push(u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?));
offset += 4;
}
arrays.push(tokens);
}
Some(Self::Template {
single_prefix: arrays.remove(0),
single_suffix: arrays.remove(0),
pair_a_prefix: arrays.remove(0),
pair_a_suffix: arrays.remove(0),
pair_b_prefix: arrays.remove(0),
pair_b_suffix: arrays.remove(0),
})
}
_ => None,
}
}
}
fn crc32(data: &[u8]) -> u32 {
crc32fast::hash(data)
}
#[derive(Debug)]
pub enum SerdeError {
Io(std::io::Error),
InvalidMagic,
UnsupportedVersion(u32),
InvalidEncoderType(u32),
InvalidPretokenizer(u32),
InvalidNormalizer(u32),
InvalidPostProcessor(u32),
ChecksumMismatch { section: &'static str },
InvalidData(&'static str),
}
impl std::fmt::Display for SerdeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "IO error: {}", e),
Self::InvalidMagic => write!(f, "Invalid magic bytes (not a TOKI file)"),
Self::UnsupportedVersion(v) => write!(f, "Unsupported version: {}", v),
Self::InvalidEncoderType(v) => write!(f, "Invalid encoder type: {}", v),
Self::InvalidPretokenizer(v) => write!(f, "Invalid pretokenizer type: {}", v),
Self::InvalidNormalizer(v) => write!(f, "Invalid normalizer type: {}", v),
Self::InvalidPostProcessor(v) => write!(f, "Invalid post-processor type: {}", v),
Self::ChecksumMismatch { section } => write!(f, "Checksum mismatch in {}", section),
Self::InvalidData(msg) => write!(f, "Invalid data: {}", msg),
}
}
}
impl std::error::Error for SerdeError {}
impl From<std::io::Error> for SerdeError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl Tokenizer {
pub fn to_file(&self, path: impl AsRef<std::path::Path>) -> Result<(), SerdeError> {
let file = std::fs::File::create(path)?;
let mut writer = std::io::BufWriter::new(file);
self.save(&mut writer)
}
pub fn save<W: Write>(&self, writer: &mut W) -> Result<(), SerdeError> {
let encoder_type = self.encoder_type();
let pretokenizer_type = self.pretokenizer_type();
let normalizer = self.normalizer();
let post_processor = self.post_processor();
let encoder = self.encoder();
let decoder = self.decoder();
let token_data = serialize_vocab_decoder(decoder.vocab());
let (merge_data, daac_data, prefix_data) = match encoder {
Encoder::Backtracking(enc) => {
let merge = serialize_splits(enc.split_table());
let daac = enc.matcher().serialize();
let prefix = serialize_prefix_match(enc.next_prefix_match_table());
(merge, daac, prefix)
}
Encoder::Simple(enc) => {
let merge = serialize_pair_lookup(enc);
let daac = Vec::new();
let prefix = Vec::new();
(merge, daac, prefix)
}
Encoder::WordPiece(enc) => {
let merge = serialize_wordpiece_config(enc);
let daac = enc.matcher().serialize();
let prefix = Vec::new();
(merge, daac, prefix)
}
Encoder::SentencePiece(enc) => {
let merge = serialize_sentencepiece_config(enc);
let daac = Vec::new();
let prefix = Vec::new();
(merge, daac, prefix)
}
Encoder::Unigram(enc) => {
let merge = serialize_unigram_config(enc);
let daac = enc.matcher().serialize();
let prefix = Vec::new();
(merge, daac, prefix)
}
};
let pp_data = post_processor.serialize();
let token_checksum = crc32(&token_data);
let merge_checksum = crc32(&merge_data);
let daac_checksum = crc32(&daac_data);
let prefix_checksum = crc32(&prefix_data);
let pp_checksum = crc32(&pp_data);
let token_offset = HEADER_SIZE as u32;
let merge_offset = token_offset + token_data.len() as u32;
let daac_offset = merge_offset + merge_data.len() as u32;
let prefix_offset = daac_offset + daac_data.len() as u32;
let pp_offset = prefix_offset + prefix_data.len() as u32;
writer.write_all(MAGIC)?;
writer.write_all(&VERSION.to_le_bytes())?;
writer.write_all(&(encoder_type as u32).to_le_bytes())?;
writer.write_all(&(pretokenizer_type as u32).to_le_bytes())?;
writer.write_all(&normalizer.to_u32().to_le_bytes())?;
writer.write_all(&post_processor.type_id().to_le_bytes())?;
writer.write_all(&(decoder.vocab_size() as u32).to_le_bytes())?;
writer.write_all(&((encoder.vocab_size() - encoder.num_base_tokens()) as u32).to_le_bytes())?;
writer.write_all(&(encoder.num_base_tokens() as u32).to_le_bytes())?;
let pad_token_id_raw = self.pad_token_id().unwrap_or(0xFFFF_FFFF);
writer.write_all(&pad_token_id_raw.to_le_bytes())?;
writer.write_all(&token_offset.to_le_bytes())?;
writer.write_all(&token_checksum.to_le_bytes())?;
writer.write_all(&merge_offset.to_le_bytes())?;
writer.write_all(&merge_checksum.to_le_bytes())?;
writer.write_all(&daac_offset.to_le_bytes())?;
writer.write_all(&daac_checksum.to_le_bytes())?;
writer.write_all(&prefix_offset.to_le_bytes())?;
writer.write_all(&prefix_checksum.to_le_bytes())?;
writer.write_all(&pp_offset.to_le_bytes())?;
writer.write_all(&pp_checksum.to_le_bytes())?;
writer.write_all(&0u64.to_le_bytes())?;
writer.write_all(&token_data)?;
writer.write_all(&merge_data)?;
writer.write_all(&daac_data)?;
writer.write_all(&prefix_data)?;
writer.write_all(&pp_data)?;
Ok(())
}
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, SerdeError> {
let file = std::fs::File::open(path)?;
let mut reader = std::io::BufReader::new(file);
Self::load(&mut reader)
}
pub fn load<R: Read>(reader: &mut R) -> Result<Self, SerdeError> {
let mut data = Vec::new();
reader.read_to_end(&mut data)?;
if data.len() < HEADER_SIZE {
return Err(SerdeError::InvalidData("file too small"));
}
if &data[0..4] != MAGIC {
return Err(SerdeError::InvalidMagic);
}
let version = u32::from_le_bytes(data[4..8].try_into().unwrap());
if version != VERSION && version != 10 {
return Err(SerdeError::UnsupportedVersion(version));
}
let encoder_type = u32::from_le_bytes(data[8..12].try_into().unwrap());
let encoder_type = EncoderType::from_u32(encoder_type)
.ok_or(SerdeError::InvalidEncoderType(encoder_type))?;
let pretokenizer_type = u32::from_le_bytes(data[12..16].try_into().unwrap());
let pretokenizer_type = PretokType::from_u32(pretokenizer_type)
.ok_or(SerdeError::InvalidPretokenizer(pretokenizer_type))?;
let normalizer_type = u32::from_le_bytes(data[16..20].try_into().unwrap());
let normalizer = Normalizer::from_u32(normalizer_type)
.ok_or(SerdeError::InvalidNormalizer(normalizer_type))?;
let pp_type = u32::from_le_bytes(data[20..24].try_into().unwrap());
let vocab_size = u32::from_le_bytes(data[24..28].try_into().unwrap()) as usize;
let _num_merges = u32::from_le_bytes(data[28..32].try_into().unwrap()) as usize;
let num_base_tokens = u32::from_le_bytes(data[32..36].try_into().unwrap()) as usize;
let pad_token_id_raw = u32::from_le_bytes(data[36..40].try_into().unwrap());
let pad_token_id = if version >= 11 && pad_token_id_raw != 0xFFFF_FFFF {
Some(pad_token_id_raw)
} else {
None
};
let token_offset = u32::from_le_bytes(data[40..44].try_into().unwrap()) as usize;
let token_checksum = u32::from_le_bytes(data[44..48].try_into().unwrap());
let merge_offset = u32::from_le_bytes(data[48..52].try_into().unwrap()) as usize;
let merge_checksum = u32::from_le_bytes(data[52..56].try_into().unwrap());
let daac_offset = u32::from_le_bytes(data[56..60].try_into().unwrap()) as usize;
let daac_checksum = u32::from_le_bytes(data[60..64].try_into().unwrap());
let prefix_offset = u32::from_le_bytes(data[64..68].try_into().unwrap()) as usize;
let prefix_checksum = u32::from_le_bytes(data[68..72].try_into().unwrap());
let pp_offset = u32::from_le_bytes(data[72..76].try_into().unwrap()) as usize;
let pp_checksum = u32::from_le_bytes(data[76..80].try_into().unwrap());
let token_data = &data[token_offset..merge_offset];
if crc32(token_data) != token_checksum {
return Err(SerdeError::ChecksumMismatch { section: "token_data" });
}
let merge_data = &data[merge_offset..daac_offset];
if crc32(merge_data) != merge_checksum {
return Err(SerdeError::ChecksumMismatch { section: "merge_data" });
}
let daac_data = &data[daac_offset..prefix_offset];
if crc32(daac_data) != daac_checksum {
return Err(SerdeError::ChecksumMismatch { section: "daac_data" });
}
let prefix_data = &data[prefix_offset..pp_offset];
if crc32(prefix_data) != prefix_checksum {
return Err(SerdeError::ChecksumMismatch { section: "prefix_data" });
}
let pp_data = &data[pp_offset..];
if crc32(pp_data) != pp_checksum {
return Err(SerdeError::ChecksumMismatch { section: "pp_data" });
}
let post_processor = PostProcessor::deserialize(pp_type, pp_data)
.ok_or(SerdeError::InvalidPostProcessor(pp_type))?;
let (decoder_offsets, decoder_data) = deserialize_decoder(token_data, vocab_size)?;
let encoder = match encoder_type {
EncoderType::Backtracking => {
let token_bytes: Vec<Vec<u8>> = (0..vocab_size)
.map(|i| {
let start = decoder_offsets[i] as usize;
let end = decoder_offsets[i + 1] as usize;
decoder_data[start..end].to_vec()
})
.collect();
let split_table = deserialize_splits(merge_data)?;
let (daac, _) = DoubleArrayAhoCorasick::deserialize(daac_data)
.ok_or(SerdeError::InvalidData("failed to deserialize DAAC"))?;
let next_prefix_match = deserialize_prefix_match(prefix_data)?;
let pair_lookup = rebuild_pair_lookup(&split_table, num_base_tokens);
let token_lengths: Vec<u8> = (0..vocab_size)
.map(|i| {
let start = decoder_offsets[i] as usize;
let end = decoder_offsets[i + 1] as usize;
(end - start).min(255) as u8
})
.collect();
let enc = BacktrackingBytePairEncoder::from_parts(
split_table,
pair_lookup,
token_lengths,
num_base_tokens,
daac,
next_prefix_match,
&token_bytes,
);
Encoder::Backtracking(enc)
}
EncoderType::Simple => {
let (byte_lut, token_cache, _, _) = build_token_lookups(&decoder_offsets, &decoder_data, vocab_size);
let merges = deserialize_merges(merge_data)?;
let enc = BytePairEncoder::from_parts(
&merges,
byte_lut,
token_cache,
vocab_size,
num_base_tokens,
);
Encoder::Simple(enc)
}
EncoderType::WordPiece => {
let token_bytes: Vec<Vec<u8>> = (0..vocab_size)
.map(|i| {
let start = decoder_offsets[i] as usize;
let end = decoder_offsets[i + 1] as usize;
decoder_data[start..end].to_vec()
})
.collect();
let (unk_token, continuation_prefix, max_input_chars_per_word) = deserialize_wordpiece_config(merge_data)?;
let (daac, _) = DoubleArrayAhoCorasick::deserialize(daac_data)
.ok_or(SerdeError::InvalidData("failed to deserialize DAAC"))?;
let enc = WordPieceEncoder::from_parts(
daac,
unk_token,
continuation_prefix,
vocab_size,
&token_bytes,
max_input_chars_per_word,
);
Encoder::WordPiece(enc)
}
EncoderType::SentencePiece => {
let (mut byte_lut, mut token_cache, token_lengths, byte_tokens) = build_token_lookups(&decoder_offsets, &decoder_data, vocab_size);
let merges = deserialize_merges(merge_data)?;
fix_byte_fallback_collisions(
&mut byte_lut,
&mut token_cache,
&merges,
&byte_tokens,
);
let enc = SentencePieceBPE::from_parts(
&merges,
byte_lut,
token_cache,
token_lengths,
vocab_size,
num_base_tokens,
);
Encoder::SentencePiece(enc)
}
EncoderType::Unigram => {
let token_bytes: Vec<Vec<u8>> = (0..vocab_size)
.map(|i| {
let start = decoder_offsets[i] as usize;
let end = decoder_offsets[i + 1] as usize;
decoder_data[start..end].to_vec()
})
.collect();
let (scores, unk_token, byte_tokens, token_lengths) = deserialize_unigram_config(merge_data)?;
let (daac, _) = DoubleArrayAhoCorasick::deserialize(daac_data)
.ok_or(SerdeError::InvalidData("failed to deserialize DAAC"))?;
let enc = UnigramEncoder::from_parts(
daac,
scores,
unk_token,
byte_tokens,
token_lengths,
&token_bytes,
);
Encoder::Unigram(enc)
}
};
let decoder_type = DecoderType::from_encoder_type(encoder_type);
let decoder = Decoder::from_parts(decoder_data, decoder_offsets, decoder_type);
let mut tokenizer = Tokenizer::new(encoder, decoder, pretokenizer_type, normalizer, post_processor);
if let Some(pad_id) = pad_token_id {
tokenizer.set_pad_token_id(pad_id);
}
Ok(tokenizer)
}
}
fn serialize_vocab_decoder(decoder: &VocabDecoder) -> Vec<u8> {
let (data, offsets) = decoder.as_parts();
let mut buf = Vec::with_capacity(4 + offsets.len() * 4 + data.len());
buf.extend_from_slice(&(offsets.len() as u32).to_le_bytes());
for &offset in offsets {
buf.extend_from_slice(&offset.to_le_bytes());
}
buf.extend_from_slice(data);
buf
}
const MAX_CACHED_TOKEN_LEN: usize = 16;
const FALLBACK_CLUSTER_WINDOW: u32 = 300;
const FALLBACK_CLUSTER_MIN_DENSITY: usize = 200;
fn build_token_lookups(
decoder_offsets: &[u32],
decoder_data: &[u8],
vocab_size: usize,
) -> ([TokenId; 256], FoldHashMap<Vec<u8>, TokenId>, Vec<u16>, [Vec<TokenId>; 256]) {
let mut byte_lut = [u32::MAX; 256];
let mut byte_tokens: [Vec<TokenId>; 256] = std::array::from_fn(|_| Vec::new());
let short_count: usize = (0..vocab_size)
.filter(|&i| {
let len = (decoder_offsets[i + 1] - decoder_offsets[i]) as usize;
len <= MAX_CACHED_TOKEN_LEN
})
.count();
let mut token_cache: FoldHashMap<Vec<u8>, TokenId> =
FoldHashMap::with_capacity_and_hasher(short_count, Default::default());
let mut token_lengths: Vec<u16> = Vec::with_capacity(vocab_size);
for i in 0..vocab_size {
let start = decoder_offsets[i] as usize;
let end = decoder_offsets[i + 1] as usize;
let bytes = &decoder_data[start..end];
let len = bytes.len();
token_lengths.push(len as u16);
if len == 1 {
let byte_val = bytes[0] as usize;
byte_tokens[byte_val].push(i as TokenId);
if byte_lut[byte_val] == u32::MAX {
byte_lut[byte_val] = i as TokenId;
}
token_cache.entry(bytes.to_vec()).or_insert(i as TokenId);
} else if len <= MAX_CACHED_TOKEN_LEN {
token_cache.insert(bytes.to_vec(), i as TokenId);
}
}
(byte_lut, token_cache, token_lengths, byte_tokens)
}
fn fix_byte_fallback_collisions(
byte_lut: &mut [TokenId; 256],
token_cache: &mut FoldHashMap<Vec<u8>, TokenId>,
merges: &[(TokenId, TokenId, TokenId)],
byte_tokens: &[Vec<TokenId>; 256],
) {
if !byte_tokens.iter().any(|ids| ids.len() > 1) {
return;
}
let mut all_single_byte: Vec<(TokenId, u8)> = Vec::new();
for (byte_val, ids) in byte_tokens.iter().enumerate() {
for &id in ids {
all_single_byte.push((id, byte_val as u8));
}
}
all_single_byte.sort_by_key(|(id, _)| *id);
let mut fallback_ids = foldhash::HashSet::default();
if all_single_byte.len() >= 256 {
let mut best_start = 0;
let mut best_density = 0usize;
for start_idx in 0..all_single_byte.len().saturating_sub(FALLBACK_CLUSTER_MIN_DENSITY) {
let start_id = all_single_byte[start_idx].0;
let count = all_single_byte[start_idx..]
.iter()
.take_while(|(id, _)| *id < start_id + FALLBACK_CLUSTER_WINDOW)
.count();
if count > best_density && count >= FALLBACK_CLUSTER_MIN_DENSITY {
best_density = count;
best_start = start_idx;
}
}
if best_density >= FALLBACK_CLUSTER_MIN_DENSITY {
let range_start_id = all_single_byte[best_start].0;
for &(id, _) in &all_single_byte[best_start..] {
if id < range_start_id + FALLBACK_CLUSTER_WINDOW {
fallback_ids.insert(id);
} else {
break;
}
}
}
}
let mut merge_operands = foldhash::HashSet::default();
for &(left, right, _) in merges {
merge_operands.insert(left);
merge_operands.insert(right);
}
for (byte_val, ids) in byte_tokens.iter().enumerate() {
if ids.len() <= 1 {
continue;
}
let mut best = byte_lut[byte_val];
for &id in ids {
if merge_operands.contains(&id) && !merge_operands.contains(&best) {
best = id;
} else if !fallback_ids.contains(&id) && fallback_ids.contains(&best)
&& !merge_operands.contains(&best)
{
best = id;
}
}
if best != byte_lut[byte_val] {
byte_lut[byte_val] = best;
token_cache.insert(vec![byte_val as u8], best);
}
}
}
fn deserialize_decoder(data: &[u8], vocab_size: usize) -> Result<(Vec<u32>, Vec<u8>), SerdeError> {
if data.len() < 4 {
return Err(SerdeError::InvalidData("decoder data too small"));
}
let num_offsets = u32::from_le_bytes(data[0..4].try_into().unwrap()) as usize;
if num_offsets != vocab_size + 1 {
return Err(SerdeError::InvalidData("offset count mismatch"));
}
let offsets_end = 4 + num_offsets * 4;
if data.len() < offsets_end {
return Err(SerdeError::InvalidData("decoder data truncated"));
}
let mut offsets = Vec::with_capacity(num_offsets);
for i in 0..num_offsets {
let start = 4 + i * 4;
offsets.push(u32::from_le_bytes(data[start..start + 4].try_into().unwrap()));
}
let token_data = data[offsets_end..].to_vec();
Ok((offsets, token_data))
}
fn serialize_splits(splits: &[Split]) -> Vec<u8> {
let mut buf = Vec::with_capacity(splits.len() * 8);
for split in splits {
buf.extend_from_slice(&split.left.to_le_bytes());
buf.extend_from_slice(&split.right.to_le_bytes());
}
buf
}
fn deserialize_splits(data: &[u8]) -> Result<Vec<Split>, SerdeError> {
if data.len() % size_of::<Split>() != 0 {
return Err(SerdeError::InvalidData("split data size not aligned"));
}
let num_splits = data.len() / size_of::<Split>();
let mut splits = Vec::with_capacity(num_splits);
for i in 0..num_splits {
let start = i * 8;
let left = u32::from_le_bytes(data[start..start + 4].try_into().unwrap());
let right = u32::from_le_bytes(data[start + 4..start + 8].try_into().unwrap());
splits.push(Split { left, right });
}
Ok(splits)
}
fn serialize_prefix_match(prefixes: &[TokenId]) -> Vec<u8> {
let mut buf = Vec::with_capacity(prefixes.len() * 4);
for &prefix in prefixes {
buf.extend_from_slice(&prefix.to_le_bytes());
}
buf
}
fn deserialize_prefix_match(data: &[u8]) -> Result<Vec<TokenId>, SerdeError> {
if data.len() % 4 != 0 {
return Err(SerdeError::InvalidData("prefix data size not aligned"));
}
let num_prefixes = data.len() / 4;
let mut prefixes = Vec::with_capacity(num_prefixes);
for i in 0..num_prefixes {
let start = i * 4;
prefixes.push(u32::from_le_bytes(data[start..start + 4].try_into().unwrap()));
}
Ok(prefixes)
}
#[inline(always)]
fn pack_pair(left: TokenId, right: TokenId) -> u64 {
((left as u64) << 32) | (right as u64)
}
#[inline(always)]
fn unpack_pair(packed: u64) -> (TokenId, TokenId) {
let left = (packed >> 32) as TokenId;
let right = (packed & 0xFFFF_FFFF) as TokenId;
(left, right)
}
fn serialize_pair_lookup(enc: &BytePairEncoder) -> Vec<u8> {
let pair_lookup = enc.pair_lookup();
let mut merges: Vec<(u32, TokenId, TokenId, TokenId)> = pair_lookup
.iter()
.map(|(&packed, &(merged, rank))| {
let (left, right) = unpack_pair(packed);
(rank, left, right, merged)
})
.collect();
merges.sort_by_key(|(rank, _, _, _)| *rank);
let mut buf = Vec::with_capacity(merges.len() * 12);
for (_, left, right, merged) in merges {
buf.extend_from_slice(&left.to_le_bytes());
buf.extend_from_slice(&right.to_le_bytes());
buf.extend_from_slice(&merged.to_le_bytes());
}
buf
}
fn serialize_sentencepiece_config(enc: &SentencePieceBPE) -> Vec<u8> {
let pair_lookup = enc.pair_lookup();
let mut merges: Vec<(u32, TokenId, TokenId, TokenId)> = pair_lookup
.iter()
.map(|(&packed, &(merged, rank))| {
let (left, right) = unpack_pair(packed);
(rank, left, right, merged)
})
.collect();
merges.sort_by_key(|(rank, _, _, _)| *rank);
let mut buf = Vec::with_capacity(merges.len() * 12);
for (_, left, right, merged) in merges {
buf.extend_from_slice(&left.to_le_bytes());
buf.extend_from_slice(&right.to_le_bytes());
buf.extend_from_slice(&merged.to_le_bytes());
}
buf
}
fn deserialize_merges(data: &[u8]) -> Result<Vec<(TokenId, TokenId, TokenId)>, SerdeError> {
if data.len() % 12 != 0 {
return Err(SerdeError::InvalidData("merge data size not aligned (expected 12 bytes per merge)"));
}
let num_merges = data.len() / 12;
let mut merges = Vec::with_capacity(num_merges);
for i in 0..num_merges {
let start = i * 12;
let left = u32::from_le_bytes(data[start..start + 4].try_into().unwrap());
let right = u32::from_le_bytes(data[start + 4..start + 8].try_into().unwrap());
let merged = u32::from_le_bytes(data[start + 8..start + 12].try_into().unwrap());
merges.push((left, right, merged));
}
Ok(merges)
}
fn rebuild_pair_lookup(
splits: &[Split],
num_base_tokens: usize,
) -> FoldHashMap<u64, TokenId> {
let mut lookup = FoldHashMap::default();
for (id, split) in splits.iter().enumerate().skip(num_base_tokens) {
lookup.insert(pack_pair(split.left, split.right), id as TokenId);
}
lookup
}
fn serialize_wordpiece_config(enc: &WordPieceEncoder) -> Vec<u8> {
let prefix = enc.continuation_prefix();
let mut buf = Vec::with_capacity(12 + prefix.len());
buf.extend_from_slice(&enc.unk_token().to_le_bytes());
buf.extend_from_slice(&(prefix.len() as u32).to_le_bytes());
buf.extend_from_slice(prefix);
buf.extend_from_slice(&(enc.max_input_chars_per_word() as u32).to_le_bytes());
buf
}
fn deserialize_wordpiece_config(data: &[u8]) -> Result<(TokenId, Vec<u8>, usize), SerdeError> {
if data.len() < 8 {
return Err(SerdeError::InvalidData("wordpiece config too small"));
}
let unk_token = u32::from_le_bytes(data[0..4].try_into().unwrap());
let prefix_len = u32::from_le_bytes(data[4..8].try_into().unwrap()) as usize;
if data.len() < 8 + prefix_len {
return Err(SerdeError::InvalidData("wordpiece prefix truncated"));
}
let continuation_prefix = data[8..8 + prefix_len].to_vec();
let max_input_chars_per_word = if data.len() >= 8 + prefix_len + 4 {
u32::from_le_bytes(data[8 + prefix_len..12 + prefix_len].try_into().unwrap()) as usize
} else {
100 };
Ok((unk_token, continuation_prefix, max_input_chars_per_word))
}
fn serialize_unigram_config(enc: &UnigramEncoder) -> Vec<u8> {
let scores = enc.scores();
let byte_tokens = enc.byte_tokens();
let token_lengths = enc.token_lengths();
let vocab_size = enc.vocab_size();
let buf_size = 4 + 4 + (256 * 4) + (vocab_size * 4) + (vocab_size * 2);
let mut buf = Vec::with_capacity(buf_size);
buf.extend_from_slice(&(vocab_size as u32).to_le_bytes());
buf.extend_from_slice(&enc.unk_token().to_le_bytes());
for &bt in byte_tokens.iter() {
buf.extend_from_slice(&bt.to_le_bytes());
}
for &score in scores {
buf.extend_from_slice(&score.to_le_bytes());
}
for &len in token_lengths {
buf.extend_from_slice(&len.to_le_bytes());
}
buf
}
fn deserialize_unigram_config(data: &[u8]) -> Result<(Vec<f32>, TokenId, [TokenId; 256], Vec<u16>), SerdeError> {
if data.len() < 8 + 1024 {
return Err(SerdeError::InvalidData("unigram config too small"));
}
let vocab_size = u32::from_le_bytes(data[0..4].try_into().unwrap()) as usize;
let unk_token = u32::from_le_bytes(data[4..8].try_into().unwrap());
let mut byte_tokens = [0u32; 256];
for i in 0..256 {
let start = 8 + i * 4;
byte_tokens[i] = u32::from_le_bytes(data[start..start + 4].try_into().unwrap());
}
let scores_offset = 8 + 1024;
let expected_len = scores_offset + vocab_size * 4 + vocab_size * 2;
if data.len() < expected_len {
return Err(SerdeError::InvalidData("unigram config truncated"));
}
let mut scores = Vec::with_capacity(vocab_size);
for i in 0..vocab_size {
let start = scores_offset + i * 4;
scores.push(f32::from_le_bytes(data[start..start + 4].try_into().unwrap()));
}
let lengths_offset = scores_offset + vocab_size * 4;
let mut token_lengths = Vec::with_capacity(vocab_size);
for i in 0..vocab_size {
let start = lengths_offset + i * 2;
token_lengths.push(u16::from_le_bytes(data[start..start + 2].try_into().unwrap()));
}
Ok((scores, unk_token, byte_tokens, token_lengths))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TokenId;
#[test]
fn test_crc32() {
assert_eq!(crc32(b""), 0);
assert_eq!(crc32(b"hello"), crc32(b"hello"));
assert_ne!(crc32(b"hello"), crc32(b"world"));
}
#[test]
fn test_pretokenizer_type_roundtrip() {
for typ in [
PretokType::None,
PretokType::Gpt2,
PretokType::Cl100k,
PretokType::O200k,
] {
assert_eq!(PretokType::from_u32(typ as u32), Some(typ));
}
}
fn make_test_tokenizer() -> Tokenizer {
let base_tokens: Vec<Vec<u8>> = (0u8..=255).map(|b| vec![b]).collect();
let merges: Vec<(TokenId, TokenId)> = vec![
(b'a' as u32, b'b' as u32), (b'c' as u32, b'd' as u32), (256, 257), ];
let (encoder, token_bytes) = crate::encoder::BacktrackingBytePairEncoder::from_merges(&merges, &base_tokens);
let decoder = Decoder::new(token_bytes);
Tokenizer::new(Encoder::Backtracking(encoder), decoder, PretokType::Gpt2, Normalizer::None, PostProcessor::None)
}
fn make_simple_test_tokenizer() -> Tokenizer {
let base_tokens: Vec<Vec<u8>> = (0u8..=255).map(|b| vec![b]).collect();
let merges: Vec<(TokenId, TokenId)> = vec![
(b'a' as u32, b'b' as u32), (b'c' as u32, b'd' as u32), (256, 257), ];
let (encoder, token_bytes) = BytePairEncoder::from_merges(&merges, &base_tokens);
let decoder = Decoder::new(token_bytes);
Tokenizer::new(Encoder::Simple(encoder), decoder, PretokType::Gpt2, Normalizer::None, PostProcessor::None)
}
#[test]
fn test_save_load_roundtrip() {
let tokenizer = make_test_tokenizer();
let mut buf = Vec::new();
tokenizer
.save(&mut buf)
.expect("save failed");
let mut cursor = std::io::Cursor::new(&buf);
let loaded = Tokenizer::load(&mut cursor).expect("load failed");
assert_eq!(tokenizer.vocab_size(), loaded.vocab_size());
let test_texts = ["Hello world", "abcd", "test 123", "abcdabcd"];
for text in test_texts {
let original_tokens = tokenizer.encode(text, false).ids;
let loaded_tokens = loaded.encode(text, false).ids;
assert_eq!(
original_tokens, loaded_tokens,
"encoding mismatch for '{}'",
text
);
}
let tokens = tokenizer.encode("Hello world", false).ids;
let original_decoded = tokenizer.decode(&tokens);
let loaded_decoded = loaded.decode(&tokens);
assert_eq!(original_decoded, loaded_decoded);
}
#[test]
fn test_save_load_file() {
let tokenizer = make_test_tokenizer();
let temp_path = std::env::temp_dir().join("tokie_test.bin");
tokenizer
.to_file(&temp_path)
.expect("to_file failed");
let loaded = Tokenizer::from_file(&temp_path).expect("from_file failed");
let text = "Hello world test";
assert_eq!(tokenizer.encode(text, false).ids, loaded.encode(text, false).ids);
std::fs::remove_file(&temp_path).ok();
}
#[test]
fn test_load_invalid_magic() {
let mut bad_data = vec![0u8; HEADER_SIZE + 100];
bad_data[0..4].copy_from_slice(b"BADM");
let mut cursor = std::io::Cursor::new(&bad_data);
let result = Tokenizer::load(&mut cursor);
assert!(matches!(result, Err(SerdeError::InvalidMagic)));
}
#[test]
fn test_load_unsupported_version() {
let mut data = Vec::new();
data.extend_from_slice(MAGIC);
data.extend_from_slice(&99u32.to_le_bytes()); data.resize(HEADER_SIZE + 100, 0);
let mut cursor = std::io::Cursor::new(&data);
let result = Tokenizer::load(&mut cursor);
assert!(matches!(result, Err(SerdeError::UnsupportedVersion(99))));
}
#[test]
fn test_pad_token_id_roundtrip() {
let mut tokenizer = make_test_tokenizer();
tokenizer.set_pad_token_id(42);
let mut buf = Vec::new();
tokenizer.save(&mut buf).expect("save failed");
let mut cursor = std::io::Cursor::new(&buf);
let loaded = Tokenizer::load(&mut cursor).expect("load failed");
assert_eq!(loaded.pad_token_id(), Some(42));
}
#[test]
fn test_pad_token_id_none_roundtrip() {
let tokenizer = make_test_tokenizer();
assert_eq!(tokenizer.pad_token_id(), None);
let mut buf = Vec::new();
tokenizer.save(&mut buf).expect("save failed");
let mut cursor = std::io::Cursor::new(&buf);
let loaded = Tokenizer::load(&mut cursor).expect("load failed");
assert_eq!(loaded.pad_token_id(), None);
}
#[test]
fn test_simple_encoder_save_load_roundtrip() {
let tokenizer = make_simple_test_tokenizer();
assert_eq!(tokenizer.encoder_type(), EncoderType::Simple);
let mut buf = Vec::new();
tokenizer
.save(&mut buf)
.expect("save failed");
let mut cursor = std::io::Cursor::new(&buf);
let loaded = Tokenizer::load(&mut cursor).expect("load failed");
assert_eq!(loaded.encoder_type(), EncoderType::Simple);
assert_eq!(tokenizer.vocab_size(), loaded.vocab_size());
let test_texts = ["Hello world", "abcd", "test 123", "abcdabcd"];
for text in test_texts {
let original_tokens = tokenizer.encode(text, false).ids;
let loaded_tokens = loaded.encode(text, false).ids;
assert_eq!(
original_tokens, loaded_tokens,
"encoding mismatch for '{}'",
text
);
}
let tokens = tokenizer.encode("Hello world", false).ids;
let original_decoded = tokenizer.decode(&tokens);
let loaded_decoded = loaded.decode(&tokens);
assert_eq!(original_decoded, loaded_decoded);
}
}