use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const UNK_TOKEN: &str = "<unk>";
pub const PAD_TOKEN: &str = "<pad>";
pub const BOS_TOKEN: &str = "<bos>";
pub const EOS_TOKEN: &str = "<eos>";
pub const MASK_TOKEN: &str = "<mask>";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Vocab {
token_to_idx: HashMap<String, usize>,
idx_to_token: Vec<String>,
unk_token: Option<String>,
pad_token: Option<String>,
bos_token: Option<String>,
eos_token: Option<String>,
}
impl Vocab {
#[must_use]
pub fn new() -> Self {
Self {
token_to_idx: HashMap::new(),
idx_to_token: Vec::new(),
unk_token: None,
pad_token: None,
bos_token: None,
eos_token: None,
}
}
#[must_use]
pub fn with_special_tokens() -> Self {
let mut vocab = Self::new();
vocab.add_special_tokens(&[PAD_TOKEN, UNK_TOKEN, BOS_TOKEN, EOS_TOKEN]);
vocab.unk_token = Some(UNK_TOKEN.to_string());
vocab.pad_token = Some(PAD_TOKEN.to_string());
vocab.bos_token = Some(BOS_TOKEN.to_string());
vocab.eos_token = Some(EOS_TOKEN.to_string());
vocab
}
#[must_use]
pub fn from_tokens(tokens: &[&str]) -> Self {
let mut vocab = Self::new();
if !tokens.contains(&UNK_TOKEN) {
vocab.add_token(UNK_TOKEN);
vocab.unk_token = Some(UNK_TOKEN.to_string());
}
if !tokens.contains(&PAD_TOKEN) {
vocab.add_token(PAD_TOKEN);
vocab.pad_token = Some(PAD_TOKEN.to_string());
}
for token in tokens {
vocab.add_token(token);
if *token == UNK_TOKEN {
vocab.unk_token = Some(UNK_TOKEN.to_string());
}
if *token == PAD_TOKEN {
vocab.pad_token = Some(PAD_TOKEN.to_string());
}
}
vocab
}
#[must_use]
pub fn from_text(text: &str, min_freq: usize) -> Self {
let mut freq: HashMap<String, usize> = HashMap::new();
for word in text.split_whitespace() {
*freq.entry(word.to_string()).or_insert(0) += 1;
}
let mut vocab = Self::with_special_tokens();
let mut tokens: Vec<_> = freq
.into_iter()
.filter(|(_, count)| *count >= min_freq)
.collect();
tokens.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
for (token, _) in tokens {
vocab.add_token(&token);
}
vocab
}
pub fn add_token(&mut self, token: &str) -> usize {
if let Some(&idx) = self.token_to_idx.get(token) {
return idx;
}
let idx = self.idx_to_token.len();
self.token_to_idx.insert(token.to_string(), idx);
self.idx_to_token.push(token.to_string());
idx
}
pub fn add_special_tokens(&mut self, tokens: &[&str]) {
for token in tokens {
self.add_token(token);
}
}
#[must_use]
pub fn token_to_index(&self, token: &str) -> usize {
if let Some(&idx) = self.token_to_idx.get(token) {
return idx;
}
if let Some(ref unk) = self.unk_token {
if let Some(&idx) = self.token_to_idx.get(unk) {
return idx;
}
}
0 }
#[must_use]
pub fn index_to_token(&self, idx: usize) -> Option<&str> {
self.idx_to_token.get(idx).map(std::string::String::as_str)
}
#[must_use]
pub fn len(&self) -> usize {
self.idx_to_token.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.idx_to_token.is_empty()
}
#[must_use]
pub fn contains(&self, token: &str) -> bool {
self.token_to_idx.contains_key(token)
}
#[must_use]
pub fn unk_index(&self) -> Option<usize> {
self.unk_token
.as_ref()
.and_then(|t| self.token_to_idx.get(t).copied())
}
#[must_use]
pub fn pad_index(&self) -> Option<usize> {
self.pad_token
.as_ref()
.and_then(|t| self.token_to_idx.get(t).copied())
}
#[must_use]
pub fn bos_index(&self) -> Option<usize> {
self.bos_token
.as_ref()
.and_then(|t| self.token_to_idx.get(t).copied())
}
#[must_use]
pub fn eos_index(&self) -> Option<usize> {
self.eos_token
.as_ref()
.and_then(|t| self.token_to_idx.get(t).copied())
}
#[must_use]
pub fn encode(&self, tokens: &[&str]) -> Vec<usize> {
tokens.iter().map(|t| self.token_to_index(t)).collect()
}
#[must_use]
pub fn decode(&self, indices: &[usize]) -> Vec<String> {
indices
.iter()
.filter_map(|&idx| {
self.index_to_token(idx)
.map(std::string::ToString::to_string)
})
.collect()
}
pub fn set_unk_token(&mut self, token: &str) {
self.add_token(token);
self.unk_token = Some(token.to_string());
}
pub fn set_pad_token(&mut self, token: &str) {
self.add_token(token);
self.pad_token = Some(token.to_string());
}
#[must_use]
pub fn tokens(&self) -> &[String] {
&self.idx_to_token
}
}
impl Vocab {
pub fn save(&self, path: &std::path::Path) -> std::io::Result<()> {
let json = serde_json::to_string_pretty(self)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
std::fs::write(path, json)
}
pub fn load(path: &std::path::Path) -> std::io::Result<Self> {
let json = std::fs::read_to_string(path)?;
serde_json::from_str(&json)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
}
impl Default for Vocab {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vocab_new() {
let vocab = Vocab::new();
assert!(vocab.is_empty());
assert_eq!(vocab.len(), 0);
}
#[test]
fn test_vocab_add_token() {
let mut vocab = Vocab::new();
let idx1 = vocab.add_token("hello");
let idx2 = vocab.add_token("world");
let idx3 = vocab.add_token("hello");
assert_eq!(idx1, 0);
assert_eq!(idx2, 1);
assert_eq!(idx3, 0); assert_eq!(vocab.len(), 2);
}
#[test]
fn test_vocab_token_to_index() {
let mut vocab = Vocab::new();
vocab.add_token("hello");
vocab.add_token("world");
assert_eq!(vocab.token_to_index("hello"), 0);
assert_eq!(vocab.token_to_index("world"), 1);
}
#[test]
fn test_vocab_index_to_token() {
let mut vocab = Vocab::new();
vocab.add_token("hello");
vocab.add_token("world");
assert_eq!(vocab.index_to_token(0), Some("hello"));
assert_eq!(vocab.index_to_token(1), Some("world"));
assert_eq!(vocab.index_to_token(2), None);
}
#[test]
fn test_vocab_with_special_tokens() {
let vocab = Vocab::with_special_tokens();
assert!(vocab.contains(PAD_TOKEN));
assert!(vocab.contains(UNK_TOKEN));
assert!(vocab.contains(BOS_TOKEN));
assert!(vocab.contains(EOS_TOKEN));
assert!(vocab.pad_index().is_some());
assert!(vocab.unk_index().is_some());
assert!(vocab.bos_index().is_some());
assert!(vocab.eos_index().is_some());
}
#[test]
fn test_vocab_unknown_token() {
let vocab = Vocab::with_special_tokens();
let unk_idx = vocab.unk_index().unwrap();
assert_eq!(vocab.token_to_index("nonexistent"), unk_idx);
}
#[test]
fn test_vocab_encode_decode() {
let mut vocab = Vocab::with_special_tokens();
vocab.add_token("hello");
vocab.add_token("world");
let tokens = vec!["hello", "world", "hello"];
let encoded = vocab.encode(&tokens);
let decoded = vocab.decode(&encoded);
assert_eq!(decoded, vec!["hello", "world", "hello"]);
}
#[test]
fn test_vocab_from_tokens() {
let vocab = Vocab::from_tokens(&["a", "b", "c"]);
assert_eq!(vocab.len(), 5);
assert!(vocab.unk_index().is_some());
assert!(vocab.pad_index().is_some());
assert!(vocab.contains("a"));
assert!(vocab.contains("b"));
assert!(vocab.contains("c"));
assert_eq!(vocab.token_to_index("unknown"), vocab.unk_index().unwrap());
}
#[test]
fn test_vocab_from_text() {
let text = "the quick brown fox jumps over the lazy dog the";
let vocab = Vocab::from_text(text, 1);
assert!(vocab.contains("the"));
assert!(vocab.contains("quick"));
assert!(vocab.contains("fox"));
}
#[test]
fn test_vocab_from_text_min_freq() {
let text = "the the the quick quick brown";
let vocab = Vocab::from_text(text, 2);
assert!(vocab.contains("the"));
assert!(vocab.contains("quick"));
assert!(!vocab.contains("brown"));
}
#[test]
fn test_vocab_contains() {
let mut vocab = Vocab::new();
vocab.add_token("hello");
assert!(vocab.contains("hello"));
assert!(!vocab.contains("world"));
}
}