use base64::{engine::general_purpose::STANDARD, Engine};
use rustc_hash::FxHashMap;
use thiserror::Error;
pub type EncoderDecoderPair = (FxHashMap<Vec<u8>, u32>, FxHashMap<u32, Vec<u8>>);
#[derive(Error, Debug)]
pub enum VocabError {
#[error("Invalid base64 encoding: {0}")]
Base64Error(#[from] base64::DecodeError),
#[error("Invalid line format: {0}")]
ParseError(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
}
pub fn load_tiktoken_bpe(data: &[u8]) -> Result<FxHashMap<Vec<u8>, u32>, VocabError> {
let mut encoder = FxHashMap::default();
for line in data.split(|&b| b == b'\n') {
if line.is_empty() {
continue;
}
let space_pos = line
.iter()
.rposition(|&b| b == b' ')
.ok_or_else(|| VocabError::ParseError("Missing space separator".to_string()))?;
let token_b64 = &line[..space_pos];
let rank_str = &line[space_pos + 1..];
let token = STANDARD.decode(token_b64)?;
let rank_str = std::str::from_utf8(rank_str)
.map_err(|_| VocabError::ParseError("Invalid UTF-8 in rank".to_string()))?;
let rank: u32 = rank_str
.trim()
.parse()
.map_err(|_| VocabError::ParseError(format!("Invalid rank: {}", rank_str)))?;
encoder.insert(token, rank);
}
Ok(encoder)
}
pub fn load_tiktoken_bpe_file(path: &str) -> Result<FxHashMap<Vec<u8>, u32>, VocabError> {
let data = std::fs::read(path)?;
load_tiktoken_bpe(&data)
}
pub fn load_tiktoken_bpe_with_decoder(data: &[u8]) -> Result<EncoderDecoderPair, VocabError> {
let mut encoder = FxHashMap::default();
let mut decoder = FxHashMap::default();
for line in data.split(|&b| b == b'\n') {
if line.is_empty() {
continue;
}
let space_pos = line
.iter()
.rposition(|&b| b == b' ')
.ok_or_else(|| VocabError::ParseError("Missing space separator".to_string()))?;
let token_b64 = &line[..space_pos];
let rank_str = &line[space_pos + 1..];
let token = STANDARD.decode(token_b64)?;
let rank_str = std::str::from_utf8(rank_str)
.map_err(|_| VocabError::ParseError("Invalid UTF-8 in rank".to_string()))?;
let rank: u32 = rank_str
.trim()
.parse()
.map_err(|_| VocabError::ParseError(format!("Invalid rank: {}", rank_str)))?;
decoder.insert(rank, token.clone());
encoder.entry(token).or_insert(rank);
}
Ok((encoder, decoder))
}
pub fn build_decoder(encoder: &FxHashMap<Vec<u8>, u32>) -> FxHashMap<u32, Vec<u8>> {
encoder.iter().map(|(k, v)| (*v, k.clone())).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_tiktoken_bpe() {
let data = b"SGVsbG8= 0\nV29ybGQ= 1\n";
let encoder = load_tiktoken_bpe(data).unwrap();
assert_eq!(encoder.get(b"Hello".as_slice()), Some(&0));
assert_eq!(encoder.get(b"World".as_slice()), Some(&1));
assert_eq!(encoder.len(), 2);
}
#[test]
fn test_build_decoder() {
let mut encoder = FxHashMap::default();
encoder.insert(b"Hello".to_vec(), 0);
encoder.insert(b"World".to_vec(), 1);
let decoder = build_decoder(&encoder);
assert_eq!(decoder.get(&0), Some(&b"Hello".to_vec()));
assert_eq!(decoder.get(&1), Some(&b"World".to_vec()));
}
}