use std::collections::HashMap;
use crate::error::{LmError, LmResult};
#[derive(Debug, Clone)]
pub struct Vocab {
id_to_bytes: Vec<Vec<u8>>,
bytes_to_id: HashMap<Vec<u8>, u32>,
special_tokens: HashMap<String, u32>,
}
impl Vocab {
pub fn from_tokens(
tokens: Vec<Vec<u8>>,
special_tokens: HashMap<String, u32>,
) -> LmResult<Self> {
let mut bytes_to_id = HashMap::with_capacity(tokens.len());
for (id, bytes) in tokens.iter().enumerate() {
if bytes_to_id.insert(bytes.clone(), id as u32).is_some() {
return Err(LmError::InvalidConfig {
msg: format!("duplicate token bytes at id {id}"),
});
}
}
for (name, &id) in &special_tokens {
if id as usize >= tokens.len() {
return Err(LmError::OutOfVocab { token: id });
}
let _ = name; }
Ok(Self {
id_to_bytes: tokens,
bytes_to_id,
special_tokens,
})
}
pub fn gpt2_byte_vocab() -> Self {
let tokens: Vec<Vec<u8>> = (0u8..=255).map(|b| vec![b]).collect();
let mut bytes_to_id = HashMap::with_capacity(256);
for (id, bytes) in tokens.iter().enumerate() {
bytes_to_id.insert(bytes.clone(), id as u32);
}
Self {
id_to_bytes: tokens,
bytes_to_id,
special_tokens: HashMap::new(),
}
}
pub fn with_extra_tokens(
&self,
extra: Vec<Vec<u8>>,
extra_special: HashMap<String, u32>,
) -> LmResult<Self> {
let mut tokens = self.id_to_bytes.clone();
for bytes in extra {
tokens.push(bytes);
}
let mut special = self.special_tokens.clone();
special.extend(extra_special);
Self::from_tokens(tokens, special)
}
pub fn size(&self) -> usize {
self.id_to_bytes.len()
}
pub fn bytes_to_id(&self, bytes: &[u8]) -> Option<u32> {
self.bytes_to_id.get(bytes).copied()
}
pub fn id_to_bytes(&self, id: u32) -> LmResult<&[u8]> {
self.id_to_bytes
.get(id as usize)
.map(|v| v.as_slice())
.ok_or(LmError::OutOfVocab { token: id })
}
pub fn decode_token(&self, id: u32) -> LmResult<String> {
let bytes = self.id_to_bytes(id)?;
String::from_utf8(bytes.to_vec()).map_err(|_| LmError::Utf8Decode { token: id })
}
pub fn special_id(&self, name: &str) -> Option<u32> {
self.special_tokens.get(name).copied()
}
pub fn special_tokens(&self) -> impl Iterator<Item = (&str, u32)> {
self.special_tokens.iter().map(|(k, &v)| (k.as_str(), v))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn small_vocab() -> Vocab {
let tokens = vec![
vec![b'a'],
vec![b'b'],
vec![b'c'],
vec![b'a', b'b'],
vec![b'a', b'b', b'c'],
];
let special: HashMap<String, u32> = [("<eos>".into(), 2u32)].into_iter().collect();
Vocab::from_tokens(tokens, special).unwrap()
}
#[test]
fn vocab_size() {
assert_eq!(small_vocab().size(), 5);
}
#[test]
fn bytes_to_id_found() {
let v = small_vocab();
assert_eq!(v.bytes_to_id(b"a"), Some(0));
assert_eq!(v.bytes_to_id(b"ab"), Some(3));
}
#[test]
fn bytes_to_id_not_found() {
let v = small_vocab();
assert_eq!(v.bytes_to_id(b"d"), None);
}
#[test]
fn id_to_bytes_ok() {
let v = small_vocab();
assert_eq!(v.id_to_bytes(0).unwrap(), b"a");
assert_eq!(v.id_to_bytes(4).unwrap(), b"abc");
}
#[test]
fn id_to_bytes_out_of_range() {
let v = small_vocab();
assert!(matches!(
v.id_to_bytes(99),
Err(LmError::OutOfVocab { token: 99 })
));
}
#[test]
fn decode_token_ascii() {
let v = small_vocab();
assert_eq!(v.decode_token(0).unwrap(), "a");
assert_eq!(v.decode_token(3).unwrap(), "ab");
}
#[test]
fn decode_token_utf8_error() {
let tokens = vec![vec![0xFF_u8]];
let v = Vocab::from_tokens(tokens, HashMap::new()).unwrap();
assert!(matches!(
v.decode_token(0),
Err(LmError::Utf8Decode { token: 0 })
));
}
#[test]
fn special_id_lookup() {
let v = small_vocab();
assert_eq!(v.special_id("<eos>"), Some(2));
assert_eq!(v.special_id("<unk>"), None);
}
#[test]
fn gpt2_byte_vocab_size() {
let v = Vocab::gpt2_byte_vocab();
assert_eq!(v.size(), 256);
assert_eq!(v.bytes_to_id(&[65_u8]), Some(65));
}
#[test]
fn with_extra_tokens_appends() {
let base = Vocab::gpt2_byte_vocab();
let extra = vec![vec![b'a', b'b']]; let extra_special: HashMap<String, u32> = [("<eos>".into(), 256u32)].into_iter().collect();
let v = base.with_extra_tokens(extra, extra_special).unwrap();
assert_eq!(v.size(), 257);
assert_eq!(v.bytes_to_id(b"ab"), Some(256));
assert_eq!(v.special_id("<eos>"), Some(256));
}
#[test]
fn duplicate_token_errors() {
let tokens = vec![vec![b'a'], vec![b'a']]; assert!(matches!(
Vocab::from_tokens(tokens, HashMap::new()),
Err(LmError::InvalidConfig { .. })
));
}
#[test]
fn special_out_of_range_errors() {
let tokens = vec![vec![b'a']];
let special: HashMap<String, u32> = [("<eos>".into(), 99u32)].into_iter().collect();
assert!(matches!(
Vocab::from_tokens(tokens, special),
Err(LmError::OutOfVocab { token: 99 })
));
}
}