use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::error::{TokenizerError, TokenizerResult};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Vocabulary {
token_to_id: HashMap<String, u32>,
id_to_token: HashMap<u32, String>,
special_tokens: HashMap<String, u32>,
}
impl Vocabulary {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, token: &str, id: u32) {
self.token_to_id.insert(token.to_owned(), id);
self.id_to_token.insert(id, token.to_owned());
}
pub fn add_special(&mut self, token: &str, id: u32) {
self.special_tokens.insert(token.to_owned(), id);
self.insert(token, id);
}
pub fn get_id(&self, token: &str) -> Option<u32> {
self.token_to_id.get(token).copied()
}
pub fn get_token(&self, id: u32) -> Option<&str> {
self.id_to_token.get(&id).map(|s| s.as_str())
}
pub fn size(&self) -> usize {
self.token_to_id.len()
}
pub fn is_empty(&self) -> bool {
self.token_to_id.is_empty()
}
pub fn is_special_token(&self, token: &str) -> bool {
self.special_tokens.contains_key(token)
}
pub fn is_special_id(&self, id: u32) -> bool {
self.special_tokens.values().any(|&v| v == id)
}
pub fn iter(&self) -> impl Iterator<Item = (&str, u32)> {
self.token_to_id.iter().map(|(k, &v)| (k.as_str(), v))
}
pub fn from_json(json: &str) -> TokenizerResult<Self> {
let raw: HashMap<String, u32> =
serde_json::from_str(json).map_err(|e| TokenizerError::InvalidJson(e.to_string()))?;
if raw.is_empty() {
return Err(TokenizerError::InvalidVocab(
"vocabulary JSON must not be empty".into(),
));
}
let mut vocab = Self::new();
for (token, id) in raw {
if token.starts_with('<') && token.ends_with('>') {
vocab.add_special(&token, id);
} else {
vocab.insert(&token, id);
}
}
Ok(vocab)
}
pub fn to_json(&self) -> String {
let mut entries: Vec<(&str, u32)> = self.iter().collect();
entries.sort_by(|a, b| a.0.cmp(b.0));
let mut out = String::from('{');
for (i, (token, id)) in entries.iter().enumerate() {
if i > 0 {
out.push(',');
}
out.push('"');
for ch in token.chars() {
match ch {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
c => out.push(c),
}
}
out.push_str("\":");
out.push_str(&id.to_string());
}
out.push('}');
out
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn insert_and_lookup() {
let mut v = Vocabulary::new();
v.insert("hello", 1);
v.insert("world", 2);
assert_eq!(v.get_id("hello"), Some(1));
assert_eq!(v.get_id("world"), Some(2));
assert_eq!(v.get_token(1), Some("hello"));
assert_eq!(v.get_token(99), None);
}
#[test]
fn special_tokens_are_found_in_main_maps() {
let mut v = Vocabulary::new();
v.add_special("<bos>", 0);
assert_eq!(v.get_id("<bos>"), Some(0));
assert!(v.is_special_token("<bos>"));
assert!(v.is_special_id(0));
}
#[test]
fn json_roundtrip() {
let mut v = Vocabulary::new();
v.insert("a", 3);
v.insert("b", 4);
v.add_special("<unk>", 0);
let json = v.to_json();
let v2 = Vocabulary::from_json(&json).expect("parse should succeed");
assert_eq!(v2.get_id("a"), Some(3));
assert_eq!(v2.get_id("b"), Some(4));
assert_eq!(v2.get_id("<unk>"), Some(0));
}
#[test]
fn empty_json_fails() {
assert!(Vocabulary::from_json("{}").is_err());
}
#[test]
fn invalid_json_fails() {
assert!(Vocabulary::from_json("not json").is_err());
}
}