use memmap2::{Mmap, MmapOptions};
use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
use std::slice;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::traits::{TokenizedInput, Tokenizer};
#[repr(C, packed)]
#[derive(Debug, Clone, Copy)]
pub struct ZeroCopyHeader {
pub magic: [u8; 4],
pub version: u32,
pub header_size: u32,
pub vocab_offset: u64,
pub vocab_size: u64,
pub metadata_offset: u64,
pub metadata_size: u64,
pub special_tokens_offset: u64,
pub special_tokens_size: u64,
pub checksum: u64,
pub padding: [u8; 8],
}
impl ZeroCopyHeader {
const MAGIC: [u8; 4] = *b"TFZC"; const VERSION: u32 = 1;
const SIZE: usize = std::mem::size_of::<Self>();
pub fn new(
vocab_offset: u64,
vocab_size: u64,
metadata_offset: u64,
metadata_size: u64,
special_tokens_offset: u64,
special_tokens_size: u64,
checksum: u64,
) -> Self {
Self {
magic: Self::MAGIC,
version: Self::VERSION,
header_size: Self::SIZE as u32,
vocab_offset,
vocab_size,
metadata_offset,
metadata_size,
special_tokens_offset,
special_tokens_size,
checksum,
padding: [0; 8],
}
}
pub fn validate(&self) -> Result<()> {
if self.magic != Self::MAGIC {
return Err(TrustformersError::serialization_error(
"Invalid magic bytes in zero-copy header".to_string(),
));
}
let version = self.version;
if version != Self::VERSION {
return Err(TrustformersError::serialization_error(format!(
"Unsupported version: {}, expected: {}",
version,
Self::VERSION
)));
}
let header_size = self.header_size;
if header_size != Self::SIZE as u32 {
return Err(TrustformersError::serialization_error(
"Invalid header size".to_string(),
));
}
Ok(())
}
}
#[repr(C, packed)]
#[derive(Debug, Clone, Copy)]
pub struct ZeroCopyVocabEntry {
pub id: u32,
pub token_offset: u64,
pub token_length: u32,
pub frequency: f32,
pub flags: u32,
pub padding: [u8; 4],
}
impl ZeroCopyVocabEntry {
pub fn is_special(&self) -> bool {
(self.flags & 0x01) != 0
}
pub fn set_special(&mut self, is_special: bool) {
if is_special {
self.flags |= 0x01;
} else {
self.flags &= !0x01;
}
}
}
pub struct ZeroCopyTokenizer {
mmap: Mmap,
header: ZeroCopyHeader,
vocab_entries: &'static [ZeroCopyVocabEntry],
token_to_id: HashMap<String, u32>,
id_to_token: HashMap<u32, String>,
}
impl ZeroCopyTokenizer {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = File::open(path)?;
let mmap = unsafe { MmapOptions::new().map(&file)? };
if mmap.len() < ZeroCopyHeader::SIZE {
return Err(TrustformersError::serialization_error(
"File too small to contain header".to_string(),
));
}
let header_bytes = &mmap[0..ZeroCopyHeader::SIZE];
let header: ZeroCopyHeader =
unsafe { std::ptr::read(header_bytes.as_ptr() as *const ZeroCopyHeader) };
header.validate()?;
let vocab_start = header.vocab_offset as usize;
let vocab_end = vocab_start + header.vocab_size as usize;
if vocab_end > mmap.len() {
return Err(TrustformersError::serialization_error(
"Vocabulary section extends beyond file".to_string(),
));
}
let entry_size = std::mem::size_of::<ZeroCopyVocabEntry>();
let num_entries = header.vocab_size as usize / entry_size;
let vocab_entries = unsafe {
slice::from_raw_parts(
mmap[vocab_start..].as_ptr() as *const ZeroCopyVocabEntry,
num_entries,
)
};
let mut token_to_id = HashMap::new();
let mut id_to_token = HashMap::new();
for entry in vocab_entries {
let token_start = entry.token_offset as usize;
let token_end = token_start + entry.token_length as usize;
if token_end > mmap.len() {
return Err(TrustformersError::serialization_error(
"Token string extends beyond file".to_string(),
));
}
let token_bytes = &mmap[token_start..token_end];
let token = String::from_utf8(token_bytes.to_vec()).map_err(|e| {
TrustformersError::serialization_error(format!("Invalid UTF-8 in token: {}", e))
})?;
token_to_id.insert(token.clone(), entry.id);
id_to_token.insert(entry.id, token);
}
Ok(Self {
mmap,
header,
vocab_entries,
token_to_id,
id_to_token,
})
}
pub fn header(&self) -> &ZeroCopyHeader {
&self.header
}
pub fn vocab_size(&self) -> usize {
self.vocab_entries.len()
}
pub fn get_token_unchecked(&self, id: u32) -> Option<&str> {
self.vocab_entries.iter().find(|entry| entry.id == id).and_then(|entry| {
let token_start = entry.token_offset as usize;
let token_end = token_start + entry.token_length as usize;
if token_end <= self.mmap.len() {
std::str::from_utf8(&self.mmap[token_start..token_end]).ok()
} else {
None
}
})
}
pub fn get_id_unchecked(&self, token: &str) -> Option<u32> {
self.token_to_id.get(token).copied()
}
pub fn get_vocab_entry(&self, index: usize) -> Option<&ZeroCopyVocabEntry> {
self.vocab_entries.get(index)
}
pub fn vocab_entries(&self) -> impl Iterator<Item = &ZeroCopyVocabEntry> {
self.vocab_entries.iter()
}
pub fn metadata_bytes(&self) -> &[u8] {
let start = self.header.metadata_offset as usize;
let end = start + self.header.metadata_size as usize;
&self.mmap[start..end]
}
pub fn special_tokens_bytes(&self) -> &[u8] {
let start = self.header.special_tokens_offset as usize;
let end = start + self.header.special_tokens_size as usize;
&self.mmap[start..end]
}
pub fn memory_stats(&self) -> ZeroCopyMemoryStats {
ZeroCopyMemoryStats {
file_size: self.mmap.len(),
header_size: ZeroCopyHeader::SIZE,
vocab_size: self.header.vocab_size as usize,
metadata_size: self.header.metadata_size as usize,
special_tokens_size: self.header.special_tokens_size as usize,
lookup_table_size: self.token_to_id.len()
* (std::mem::size_of::<String>() + std::mem::size_of::<u32>())
+ self.id_to_token.len()
* (std::mem::size_of::<u32>() + std::mem::size_of::<String>()),
}
}
pub fn verify_integrity(&self) -> Result<bool> {
let mut hasher = crc32fast::Hasher::new();
let vocab_start = self.header.vocab_offset as usize;
let vocab_end = vocab_start + self.header.vocab_size as usize;
hasher.update(&self.mmap[vocab_start..vocab_end]);
let metadata_start = self.header.metadata_offset as usize;
let metadata_end = metadata_start + self.header.metadata_size as usize;
hasher.update(&self.mmap[metadata_start..metadata_end]);
let special_start = self.header.special_tokens_offset as usize;
let special_end = special_start + self.header.special_tokens_size as usize;
hasher.update(&self.mmap[special_start..special_end]);
let calculated_checksum = hasher.finalize() as u64;
Ok(calculated_checksum == self.header.checksum)
}
}
impl Tokenizer for ZeroCopyTokenizer {
fn encode(&self, text: &str) -> Result<TokenizedInput> {
let tokens: Vec<&str> = text.split_whitespace().collect();
let mut input_ids = Vec::new();
let mut tokens_out = Vec::new();
for token in tokens {
if let Some(id) = self.get_id_unchecked(token) {
input_ids.push(id);
tokens_out.push(token.to_string());
}
}
Ok(TokenizedInput {
input_ids,
attention_mask: vec![1; tokens_out.len()],
token_type_ids: None,
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
})
}
fn decode(&self, token_ids: &[u32]) -> Result<String> {
let tokens: std::result::Result<Vec<&str>, _> = token_ids
.iter()
.map(|&id| {
self.get_token_unchecked(id)
.ok_or_else(|| TrustformersError::other(format!("Unknown token ID: {}", id)))
})
.collect();
Ok(tokens?.join(" "))
}
fn get_vocab(&self) -> HashMap<String, u32> {
self.token_to_id.clone()
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.get_id_unchecked(token)
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.get_token_unchecked(id).map(|s| s.to_string())
}
fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
let combined = format!("{} {}", text_a, text_b);
self.encode(&combined)
}
fn vocab_size(&self) -> usize {
self.token_to_id.len()
}
}
#[derive(Debug, Clone)]
pub struct ZeroCopyMemoryStats {
pub file_size: usize,
pub header_size: usize,
pub vocab_size: usize,
pub metadata_size: usize,
pub special_tokens_size: usize,
pub lookup_table_size: usize,
}
impl ZeroCopyMemoryStats {
pub fn total_memory(&self) -> usize {
self.lookup_table_size }
pub fn efficiency_ratio(&self) -> f64 {
if self.file_size == 0 {
0.0
} else {
self.total_memory() as f64 / self.file_size as f64
}
}
}
pub struct ZeroCopyBuilder {
vocabulary: Vec<(String, u32, f32, bool)>, metadata: Vec<u8>,
special_tokens: Vec<u8>,
}
impl ZeroCopyBuilder {
pub fn new() -> Self {
Self {
vocabulary: Vec::new(),
metadata: Vec::new(),
special_tokens: Vec::new(),
}
}
pub fn add_token(
&mut self,
token: String,
id: u32,
frequency: f32,
is_special: bool,
) -> &mut Self {
self.vocabulary.push((token, id, frequency, is_special));
self
}
pub fn add_tokens_from_map(&mut self, vocab: &HashMap<String, u32>) -> &mut Self {
for (token, &id) in vocab {
self.add_token(token.clone(), id, 1.0, false);
}
self
}
pub fn set_metadata(&mut self, metadata: Vec<u8>) -> &mut Self {
self.metadata = metadata;
self
}
pub fn set_special_tokens(&mut self, special_tokens: Vec<u8>) -> &mut Self {
self.special_tokens = special_tokens;
self
}
pub fn build_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
use std::fs::OpenOptions;
use std::io::Write;
let mut file = OpenOptions::new().create(true).write(true).truncate(true).open(path)?;
let header_size = ZeroCopyHeader::SIZE as u64;
let vocab_entry_size = std::mem::size_of::<ZeroCopyVocabEntry>() as u64;
let vocab_entries_size = self.vocabulary.len() as u64 * vocab_entry_size;
let string_data_size: u64 =
self.vocabulary.iter().map(|(token, _, _, _)| token.len() as u64).sum();
let vocab_offset = header_size;
let vocab_size = vocab_entries_size + string_data_size;
let metadata_offset = vocab_offset + vocab_size;
let metadata_size = self.metadata.len() as u64;
let special_tokens_offset = metadata_offset + metadata_size;
let special_tokens_size = self.special_tokens.len() as u64;
let mut vocab_entries = Vec::new();
let mut string_data = Vec::new();
let mut current_string_offset = vocab_offset + vocab_entries_size;
for (token, id, frequency, is_special) in &self.vocabulary {
let token_bytes = token.as_bytes();
let mut entry = ZeroCopyVocabEntry {
id: *id,
token_offset: current_string_offset,
token_length: token_bytes.len() as u32,
frequency: *frequency,
flags: 0,
padding: [0; 4],
};
entry.set_special(*is_special);
vocab_entries.push(entry);
string_data.extend_from_slice(token_bytes);
current_string_offset += token_bytes.len() as u64;
}
let mut hasher = crc32fast::Hasher::new();
let entries_bytes = unsafe {
slice::from_raw_parts(
vocab_entries.as_ptr() as *const u8,
vocab_entries.len() * vocab_entry_size as usize,
)
};
hasher.update(entries_bytes);
hasher.update(&string_data);
hasher.update(&self.metadata);
hasher.update(&self.special_tokens);
let checksum = hasher.finalize() as u64;
let header = ZeroCopyHeader::new(
vocab_offset,
vocab_size,
metadata_offset,
metadata_size,
special_tokens_offset,
special_tokens_size,
checksum,
);
let header_bytes = unsafe {
slice::from_raw_parts(
&header as *const ZeroCopyHeader as *const u8,
ZeroCopyHeader::SIZE,
)
};
file.write_all(header_bytes)?;
file.write_all(entries_bytes)?;
file.write_all(&string_data)?;
file.write_all(&self.metadata)?;
file.write_all(&self.special_tokens)?;
file.flush()?;
Ok(())
}
}
impl Default for ZeroCopyBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct ZeroCopyUtils;
impl ZeroCopyUtils {
pub fn convert_tokenizer_to_zero_copy<T: Tokenizer, P: AsRef<Path>>(
tokenizer: &T,
path: P,
metadata: Option<&[u8]>,
special_tokens: Option<&[u8]>,
) -> Result<()> {
let mut builder = ZeroCopyBuilder::new();
let vocab = tokenizer.get_vocab();
builder.add_tokens_from_map(&vocab);
if let Some(meta) = metadata {
builder.set_metadata(meta.to_vec());
}
if let Some(special) = special_tokens {
builder.set_special_tokens(special.to_vec());
}
builder.build_to_file(path)
}
pub fn validate_file<P: AsRef<Path>>(path: P) -> Result<bool> {
let tokenizer = ZeroCopyTokenizer::from_file(path)?;
tokenizer.verify_integrity()
}
pub fn get_file_info<P: AsRef<Path>>(path: P) -> Result<HashMap<String, String>> {
let file = File::open(path)?;
let mmap = unsafe { MmapOptions::new().map(&file)? };
if mmap.len() < ZeroCopyHeader::SIZE {
return Err(TrustformersError::serialization_error(
"File too small to contain header".to_string(),
));
}
let header_bytes = &mmap[0..ZeroCopyHeader::SIZE];
let header: ZeroCopyHeader =
unsafe { std::ptr::read(header_bytes.as_ptr() as *const ZeroCopyHeader) };
header.validate()?;
let mut info = HashMap::new();
let version = header.version;
let vocab_size = header.vocab_size;
let metadata_size = header.metadata_size;
let special_tokens_size = header.special_tokens_size;
info.insert("format".to_string(), "ZeroCopy".to_string());
info.insert("version".to_string(), version.to_string());
info.insert("file_size".to_string(), mmap.len().to_string());
info.insert(
"vocab_size".to_string(),
(vocab_size / std::mem::size_of::<ZeroCopyVocabEntry>() as u64).to_string(),
);
info.insert("metadata_size".to_string(), metadata_size.to_string());
info.insert(
"special_tokens_size".to_string(),
special_tokens_size.to_string(),
);
Ok(info)
}
pub fn compare_files<P1: AsRef<Path>, P2: AsRef<Path>>(
path1: P1,
path2: P2,
) -> Result<HashMap<String, String>> {
let info1 = Self::get_file_info(path1)?;
let info2 = Self::get_file_info(path2)?;
let mut comparison = HashMap::new();
let default_value = "0".to_string();
for key in &[
"file_size",
"vocab_size",
"metadata_size",
"special_tokens_size",
] {
let val1 = info1.get(*key).unwrap_or(&default_value);
let val2 = info2.get(*key).unwrap_or(&default_value);
comparison.insert(format!("{}_1", key), val1.clone());
comparison.insert(format!("{}_2", key), val2.clone());
comparison.insert(format!("{}_equal", key), (val1 == val2).to_string());
}
Ok(comparison)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap as TestHashMap;
use tempfile::tempdir;
#[test]
fn test_zero_copy_header() {
let header = ZeroCopyHeader::new(100, 200, 300, 50, 350, 25, 0x12345678);
assert_eq!(header.magic, ZeroCopyHeader::MAGIC);
let version = header.version;
let vocab_offset = header.vocab_offset;
let vocab_size = header.vocab_size;
let checksum = header.checksum;
assert_eq!(version, ZeroCopyHeader::VERSION);
assert_eq!(vocab_offset, 100);
assert_eq!(vocab_size, 200);
assert_eq!(checksum, 0x12345678);
assert!(header.validate().is_ok());
}
#[test]
fn test_vocab_entry_flags() {
let mut entry = ZeroCopyVocabEntry {
id: 1,
token_offset: 0,
token_length: 5,
frequency: 1.0,
flags: 0,
padding: [0; 4],
};
assert!(!entry.is_special());
entry.set_special(true);
assert!(entry.is_special());
entry.set_special(false);
assert!(!entry.is_special());
}
#[test]
fn test_zero_copy_builder() {
let temp_dir = tempdir().expect("Operation failed in test");
let file_path = temp_dir.path().join("test_tokenizer.zc");
let mut builder = ZeroCopyBuilder::new();
builder
.add_token("hello".to_string(), 1, 1.0, false)
.add_token("world".to_string(), 2, 1.0, false)
.add_token("<pad>".to_string(), 0, 1.0, true);
assert!(builder.build_to_file(&file_path).is_ok());
assert!(file_path.exists());
}
#[test]
fn test_zero_copy_tokenizer_loading() {
let temp_dir = tempdir().expect("Operation failed in test");
let file_path = temp_dir.path().join("test_load.zc");
let mut builder = ZeroCopyBuilder::new();
builder
.add_token("test".to_string(), 1, 1.0, false)
.add_token("token".to_string(), 2, 1.0, false)
.add_token("[CLS]".to_string(), 0, 1.0, true);
builder.build_to_file(&file_path).expect("Operation failed in test");
let tokenizer = ZeroCopyTokenizer::from_file(&file_path).expect("Operation failed in test");
assert_eq!(tokenizer.vocab_size(), 3);
assert_eq!(tokenizer.get_id_unchecked("test"), Some(1));
assert_eq!(tokenizer.get_id_unchecked("token"), Some(2));
assert_eq!(tokenizer.get_id_unchecked("[CLS]"), Some(0));
assert_eq!(tokenizer.get_token_unchecked(1), Some("test"));
assert_eq!(tokenizer.get_token_unchecked(2), Some("token"));
assert_eq!(tokenizer.get_token_unchecked(0), Some("[CLS]"));
}
#[test]
fn test_tokenizer_interface() {
let temp_dir = tempdir().expect("Operation failed in test");
let file_path = temp_dir.path().join("test_interface.zc");
let mut builder = ZeroCopyBuilder::new();
builder.add_token("hello".to_string(), 1, 1.0, false).add_token(
"world".to_string(),
2,
1.0,
false,
);
builder.build_to_file(&file_path).expect("Operation failed in test");
let tokenizer = ZeroCopyTokenizer::from_file(&file_path).expect("Operation failed in test");
let encoded = tokenizer.encode("hello world").expect("Encoding failed");
assert_eq!(encoded.input_ids, vec![1, 2]);
let tokens: Vec<String> = encoded
.input_ids
.iter()
.map(|&id| tokenizer.id_to_token(id).expect("Operation failed in test"))
.collect();
assert_eq!(tokens, vec!["hello", "world"]);
let decoded = tokenizer.decode(&[1, 2]).expect("Decoding failed");
assert_eq!(decoded, "hello world");
let vocab = tokenizer.get_vocab();
assert_eq!(vocab.len(), 2);
assert_eq!(vocab.get("hello"), Some(&1));
assert_eq!(vocab.get("world"), Some(&2));
}
#[test]
fn test_memory_stats() {
let temp_dir = tempdir().expect("Operation failed in test");
let file_path = temp_dir.path().join("test_stats.zc");
let mut builder = ZeroCopyBuilder::new();
builder
.add_token("test".to_string(), 1, 1.0, false)
.add_token("memory".to_string(), 2, 1.0, false)
.set_metadata(b"test metadata".to_vec())
.set_special_tokens(b"special".to_vec());
builder.build_to_file(&file_path).expect("Operation failed in test");
let tokenizer = ZeroCopyTokenizer::from_file(&file_path).expect("Operation failed in test");
let stats = tokenizer.memory_stats();
assert!(stats.file_size > 0);
assert!(stats.vocab_size > 0);
assert_eq!(stats.metadata_size, 13); assert_eq!(stats.special_tokens_size, 7); assert!(stats.lookup_table_size > 0);
}
#[test]
fn test_integrity_verification() {
let temp_dir = tempdir().expect("Operation failed in test");
let file_path = temp_dir.path().join("test_integrity.zc");
let mut builder = ZeroCopyBuilder::new();
builder.add_token("integrity".to_string(), 1, 1.0, false).add_token(
"check".to_string(),
2,
1.0,
false,
);
builder.build_to_file(&file_path).expect("Operation failed in test");
let tokenizer = ZeroCopyTokenizer::from_file(&file_path).expect("Operation failed in test");
assert!(tokenizer.verify_integrity().expect("Operation failed in test"));
}
#[test]
fn test_utils_file_info() {
let temp_dir = tempdir().expect("Operation failed in test");
let file_path = temp_dir.path().join("test_info.zc");
let mut builder = ZeroCopyBuilder::new();
builder.add_token("info".to_string(), 1, 1.0, false).add_token(
"test".to_string(),
2,
1.0,
false,
);
builder.build_to_file(&file_path).expect("Operation failed in test");
let info = ZeroCopyUtils::get_file_info(&file_path).expect("Operation failed in test");
assert_eq!(info.get("format").expect("Key not found"), "ZeroCopy");
assert_eq!(info.get("vocab_size").expect("Key not found"), "2");
assert!(info.contains_key("file_size"));
}
#[test]
fn test_utils_validation() {
let temp_dir = tempdir().expect("Operation failed in test");
let file_path = temp_dir.path().join("test_validation.zc");
let mut builder = ZeroCopyBuilder::new();
builder.add_token("validate".to_string(), 1, 1.0, false);
builder.build_to_file(&file_path).expect("Operation failed in test");
assert!(ZeroCopyUtils::validate_file(&file_path).expect("Operation failed in test"));
}
#[test]
fn test_builder_from_map() {
let temp_dir = tempdir().expect("Operation failed in test");
let file_path = temp_dir.path().join("test_from_map.zc");
let mut vocab = TestHashMap::new();
vocab.insert("from".to_string(), 1);
vocab.insert("map".to_string(), 2);
vocab.insert("test".to_string(), 3);
let mut builder = ZeroCopyBuilder::new();
builder.add_tokens_from_map(&vocab);
builder.build_to_file(&file_path).expect("Operation failed in test");
let tokenizer = ZeroCopyTokenizer::from_file(&file_path).expect("Operation failed in test");
assert_eq!(tokenizer.vocab_size(), 3);
for (token, &expected_id) in &vocab {
assert_eq!(tokenizer.get_id_unchecked(token), Some(expected_id));
}
}
}