#[allow(clippy::wildcard_imports)]
use super::*;
use crate::text::Tokenizer;
use crate::AprenderError;
use std::collections::HashMap;
impl UnigramTokenizer {
pub fn train(corpus: &[&str], vocab_size: usize) -> Result<Self, AprenderError> {
if vocab_size < 10 {
return Err(AprenderError::InvalidHyperparameter {
param: "vocab_size".to_string(),
value: vocab_size.to_string(),
constraint: "must be at least 10".to_string(),
});
}
let unk_token = "<unk>".to_string();
let bos_token = "<s>".to_string();
let eos_token = "</s>".to_string();
let mut token_counts: HashMap<String, usize> = HashMap::new();
token_counts.insert(unk_token.clone(), 1_000_000);
token_counts.insert(bos_token.clone(), 1_000_000);
token_counts.insert(eos_token.clone(), 1_000_000);
let word_boundary = "▁".to_string(); token_counts.insert(word_boundary.clone(), 1_000_000);
let max_ngram = 16;
for doc in corpus {
let mut processed = String::new();
for w in doc.split_whitespace() {
processed.push_str(&word_boundary);
processed.push_str(w);
}
let chars: Vec<char> = processed.chars().collect();
for start in 0..chars.len() {
for end in start + 1..=std::cmp::min(start + max_ngram, chars.len()) {
let ngram: String = chars[start..end].iter().collect();
*token_counts.entry(ngram).or_insert(0) += 1;
}
}
}
let mut vocab_items: Vec<(String, usize)> = token_counts.into_iter().collect();
vocab_items.sort_by(|a, b| b.1.cmp(&a.1));
if vocab_items.len() > vocab_size {
vocab_items.truncate(vocab_size);
}
let total: f64 = vocab_items.iter().map(|(_, c)| *c as f64).sum();
let mut vocab: HashMap<String, (u32, f64)> = HashMap::new();
let mut inverse_vocab: HashMap<u32, String> = HashMap::new();
for (id, (token, count)) in vocab_items.iter().enumerate() {
let log_prob = ((*count as f64) / total).ln();
vocab.insert(token.clone(), (id as u32, log_prob));
inverse_vocab.insert(id as u32, token.clone());
}
let num_tokens = vocab.len() as u32;
if !vocab.contains_key(&unk_token) {
vocab.insert(unk_token.clone(), (num_tokens, -10.0));
inverse_vocab.insert(num_tokens, unk_token.clone());
}
Ok(Self {
vocab,
inverse_vocab,
unk_token,
bos_token,
eos_token,
})
}
#[must_use]
pub fn from_vocab(vocab: HashMap<String, (u32, f64)>) -> Self {
let inverse_vocab: HashMap<u32, String> =
vocab.iter().map(|(k, (id, _))| (*id, k.clone())).collect();
Self {
vocab,
inverse_vocab,
unk_token: "<unk>".to_string(),
bos_token: "<s>".to_string(),
eos_token: "</s>".to_string(),
}
}
pub fn encode(&self, text: &str) -> Result<Vec<u32>, AprenderError> {
let word_boundary = "▁";
let mut processed = String::new();
for w in text.split_whitespace() {
processed.push_str(word_boundary);
processed.push_str(w);
}
if processed.is_empty() {
return Ok(Vec::new());
}
let chars: Vec<char> = processed.chars().collect();
let n = chars.len();
let mut best: Vec<(f64, usize, String)> =
vec![(f64::NEG_INFINITY, 0, String::new()); n + 1];
best[0] = (0.0, 0, String::new());
for i in 0..n {
if best[i].0 == f64::NEG_INFINITY {
continue;
}
for j in i + 1..=std::cmp::min(i + 16, n) {
let substr: String = chars[i..j].iter().collect();
if let Some(&(_, log_prob)) = self.vocab.get(&substr) {
let score = best[i].0 + log_prob;
if score > best[j].0 {
best[j] = (score, i, substr);
}
}
}
if best[i + 1].0 == f64::NEG_INFINITY {
let char_str = chars[i].to_string();
let log_prob = self.vocab.get(&char_str).map_or(-100.0, |(_, p)| *p);
best[i + 1] = (best[i].0 + log_prob, i, char_str);
}
}
let mut tokens = Vec::new();
let mut pos = n;
while pos > 0 {
let token = &best[pos].2;
let prev = best[pos].1;
if let Some(&(id, _)) = self.vocab.get(token) {
tokens.push(id);
} else {
if let Some(&(id, _)) = self.vocab.get(&self.unk_token) {
tokens.push(id);
}
}
pos = prev;
}
tokens.reverse();
Ok(tokens)
}
pub fn decode(&self, ids: &[u32]) -> Result<String, AprenderError> {
let word_boundary = '▁';
let mut result = String::new();
for &id in ids {
let token = self.inverse_vocab.get(&id).map_or(&self.unk_token, |t| t);
if token == &self.unk_token || token == &self.bos_token || token == &self.eos_token {
continue;
}
for c in token.chars() {
if c == word_boundary {
if !result.is_empty() {
result.push(' ');
}
} else {
result.push(c);
}
}
}
Ok(result)
}
#[must_use]
pub fn vocab_size(&self) -> usize {
self.vocab.len()
}
#[must_use]
pub fn vocab_ids(&self) -> HashMap<String, u32> {
self.vocab
.iter()
.map(|(k, (id, _))| (k.clone(), *id))
.collect()
}
#[must_use]
pub fn log_prob(&self, token: &str) -> Option<f64> {
self.vocab.get(token).map(|(_, p)| *p)
}
}
impl Tokenizer for UnigramTokenizer {
fn tokenize(&self, text: &str) -> Result<Vec<String>, AprenderError> {
let ids = self.encode(text)?;
let tokens: Vec<String> = ids
.iter()
.filter_map(|id| self.inverse_vocab.get(id).cloned())
.collect();
Ok(tokens)
}
}
#[cfg(test)]
#[path = "tests.rs"]
mod tests;