use std::collections::HashMap;
pub type WordId = u32;
pub const EOS_WORD_ID: WordId = 0;
pub const UNK_WORD_ID: WordId = 1;
pub const EOS_TOKEN: &str = "</s>";
pub const UNK_TOKEN: &str = "<unk>";
#[derive(Clone, Debug)]
pub struct WordVocabulary {
word_to_id: HashMap<String, WordId>,
id_to_word: Vec<String>,
}
impl WordVocabulary {
pub fn new() -> Self {
let mut vocab = Self {
word_to_id: HashMap::new(),
id_to_word: Vec::new(),
};
vocab.add_word(EOS_TOKEN);
vocab.add_word(UNK_TOKEN);
vocab
}
pub fn with_capacity(capacity: usize) -> Self {
let mut vocab = Self {
word_to_id: HashMap::with_capacity(capacity + 2), id_to_word: Vec::with_capacity(capacity + 2),
};
vocab.add_word(EOS_TOKEN);
vocab.add_word(UNK_TOKEN);
vocab
}
pub fn add_word(&mut self, word: &str) -> WordId {
if let Some(&id) = self.word_to_id.get(word) {
return id;
}
let id = self.id_to_word.len() as WordId;
self.id_to_word.push(word.to_string());
self.word_to_id.insert(word.to_string(), id);
id
}
#[inline]
pub fn get_id(&self, word: &str) -> Option<WordId> {
self.word_to_id.get(word).copied()
}
#[inline]
pub fn get_id_or_unk(&self, word: &str) -> WordId {
self.word_to_id.get(word).copied().unwrap_or(UNK_WORD_ID)
}
#[inline]
pub fn get_word(&self, id: WordId) -> Option<&str> {
self.id_to_word.get(id as usize).map(|s| s.as_str())
}
#[inline]
pub fn len(&self) -> usize {
self.id_to_word.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.id_to_word.len() <= 2 }
#[inline]
pub fn contains(&self, word: &str) -> bool {
self.word_to_id.contains_key(word)
}
pub fn iter(&self) -> impl Iterator<Item = (&str, WordId)> + '_ {
self.id_to_word
.iter()
.enumerate()
.map(|(id, word)| (word.as_str(), id as WordId))
}
pub fn word_ids(&self) -> impl Iterator<Item = WordId> + '_ {
2..self.id_to_word.len() as WordId
}
#[inline]
pub const fn eos_id(&self) -> WordId {
EOS_WORD_ID
}
#[inline]
pub const fn unk_id(&self) -> WordId {
UNK_WORD_ID
}
pub fn encode(&self, words: &[&str]) -> Vec<WordId> {
words.iter().map(|w| self.get_id_or_unk(w)).collect()
}
pub fn decode(&self, ids: &[WordId]) -> Option<Vec<&str>> {
ids.iter().map(|&id| self.get_word(id)).collect()
}
}
impl Default for WordVocabulary {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_vocabulary() {
let vocab = WordVocabulary::new();
assert_eq!(vocab.len(), 2); assert_eq!(vocab.get_id(EOS_TOKEN), Some(EOS_WORD_ID));
assert_eq!(vocab.get_id(UNK_TOKEN), Some(UNK_WORD_ID));
}
#[test]
fn test_add_word() {
let mut vocab = WordVocabulary::new();
let id1 = vocab.add_word("hello");
let id2 = vocab.add_word("world");
let id3 = vocab.add_word("hello");
assert_eq!(id1, 2); assert_eq!(id2, 3);
assert_eq!(id3, id1); assert_eq!(vocab.len(), 4);
}
#[test]
fn test_get_word_and_id() {
let mut vocab = WordVocabulary::new();
let id = vocab.add_word("test");
assert_eq!(vocab.get_id("test"), Some(id));
assert_eq!(vocab.get_word(id), Some("test"));
assert_eq!(vocab.get_id("nonexistent"), None);
assert_eq!(vocab.get_word(999), None);
}
#[test]
fn test_get_id_or_unk() {
let mut vocab = WordVocabulary::new();
vocab.add_word("known");
assert_eq!(vocab.get_id_or_unk("known"), 2);
assert_eq!(vocab.get_id_or_unk("unknown"), UNK_WORD_ID);
}
#[test]
fn test_encode_decode() {
let mut vocab = WordVocabulary::new();
vocab.add_word("the");
vocab.add_word("quick");
vocab.add_word("fox");
let words = ["the", "quick", "fox"];
let ids = vocab.encode(&words);
let decoded = vocab.decode(&ids).expect("decode failed");
assert_eq!(decoded, words);
}
#[test]
fn test_iter() {
let mut vocab = WordVocabulary::new();
vocab.add_word("a");
vocab.add_word("b");
let pairs: Vec<_> = vocab.iter().collect();
assert_eq!(pairs.len(), 4);
assert_eq!(pairs[0], (EOS_TOKEN, EOS_WORD_ID));
assert_eq!(pairs[1], (UNK_TOKEN, UNK_WORD_ID));
assert_eq!(pairs[2], ("a", 2));
assert_eq!(pairs[3], ("b", 3));
}
}