use std::collections::HashMap;
pub enum VocabSource {
JsonFile(String),
Bytes(Vec<u8>),
BuiltIn(BuiltInVocab),
}
pub enum BuiltInVocab {
Cl100kBase,
P50kBase,
Llama,
}
pub struct Vocabulary {
encoder: HashMap<Vec<u8>, u32>,
decoder: HashMap<u32, Vec<u8>>,
size: usize,
}
impl Vocabulary {
pub fn load(source: VocabSource) -> Result<Self, VocabError> {
match source {
VocabSource::JsonFile(path) => Self::from_json_file(&path),
VocabSource::Bytes(bytes) => Self::from_bytes(&bytes),
VocabSource::BuiltIn(builtin) => Self::built_in(builtin),
}
}
fn built_in(vocab: BuiltInVocab) -> Result<Self, VocabError> {
let mut encoder = HashMap::new();
let mut decoder = HashMap::new();
for i in 0..=255u8 {
encoder.insert(vec![i], i as u32);
decoder.insert(i as u32, vec![i]);
}
let avg_chars_per_token = match vocab {
BuiltInVocab::Cl100kBase => 4.0, BuiltInVocab::P50kBase => 3.5, BuiltInVocab::Llama => 3.8, };
let ratio_bytes = (avg_chars_per_token * 100.0) as u32;
encoder.insert(b"__ratio__".to_vec(), ratio_bytes);
Ok(Self {
size: encoder.len(),
encoder,
decoder,
})
}
fn from_json_file(path: &str) -> Result<Self, VocabError> {
let content = std::fs::read_to_string(path)
.map_err(|e| VocabError::LoadFailed(e.to_string()))?;
let map: HashMap<String, u32> = serde_json::from_str(&content)
.map_err(|e| VocabError::ParseFailed(e.to_string()))?;
let mut encoder = HashMap::new();
let mut decoder = HashMap::new();
for (token, rank) in &map {
let bytes = token.as_bytes().to_vec();
encoder.insert(bytes.clone(), *rank);
decoder.insert(*rank, bytes);
}
Ok(Self { size: encoder.len(), encoder, decoder })
}
fn from_bytes(data: &[u8]) -> Result<Self, VocabError> {
let map: HashMap<String, u32> = serde_json::from_slice(data)
.map_err(|e| VocabError::ParseFailed(e.to_string()))?;
let mut encoder = HashMap::new();
let mut decoder = HashMap::new();
for (token, rank) in &map {
let bytes = token.as_bytes().to_vec();
encoder.insert(bytes.clone(), *rank);
decoder.insert(*rank, bytes);
}
Ok(Self { size: encoder.len(), encoder, decoder })
}
pub fn encode_token(&self, bytes: &[u8]) -> Option<u32> {
self.encoder.get(bytes).copied()
}
pub fn decode_token(&self, id: u32) -> Option<&[u8]> {
self.decoder.get(&id).map(|v| v.as_slice())
}
pub fn size(&self) -> usize { self.size }
pub fn chars_per_token_ratio(&self) -> f64 {
self.encoder.get(b"__ratio__".as_slice())
.map(|r| *r as f64 / 100.0)
.unwrap_or(4.0)
}
}
#[derive(Debug)]
pub enum VocabError {
LoadFailed(String),
ParseFailed(String),
}
impl std::fmt::Display for VocabError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::LoadFailed(e) => write!(f, "vocab load failed: {}", e),
Self::ParseFailed(e) => write!(f, "vocab parse failed: {}", e),
}
}
}
impl std::error::Error for VocabError {}