use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UnigramError {
EmptyVocab,
UnkOutOfRange {
unk_id: u32,
vocab_len: usize,
},
DuplicateToken(String),
}
impl std::fmt::Display for UnigramError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::EmptyVocab => write!(f, "unigram vocabulary must not be empty"),
Self::UnkOutOfRange { unk_id, vocab_len } => write!(
f,
"unk_id {unk_id} is out of range for vocabulary of size {vocab_len}"
),
Self::DuplicateToken(tok) => {
write!(f, "duplicate token in unigram vocabulary: {tok:?}")
}
}
}
}
impl std::error::Error for UnigramError {}
#[derive(Debug)]
pub struct UnigramVocab {
entries: Vec<(String, f64)>,
token_to_id: HashMap<String, u32>,
token_scores: Vec<f64>,
unk_id: u32,
max_token_byte_len: usize,
}
impl UnigramVocab {
pub fn new(entries: Vec<(String, f64)>, unk_id: u32) -> Result<Self, UnigramError> {
if entries.is_empty() {
return Err(UnigramError::EmptyVocab);
}
if unk_id as usize >= entries.len() {
return Err(UnigramError::UnkOutOfRange {
unk_id,
vocab_len: entries.len(),
});
}
let mut token_to_id: HashMap<String, u32> = HashMap::with_capacity(entries.len());
let mut token_scores: Vec<f64> = Vec::with_capacity(entries.len());
let mut max_token_byte_len: usize = 1;
for (idx, (token, score)) in entries.iter().enumerate() {
if token_to_id.insert(token.clone(), idx as u32).is_some() {
return Err(UnigramError::DuplicateToken(token.clone()));
}
token_scores.push(*score);
let byte_len = token.len();
if byte_len > max_token_byte_len {
max_token_byte_len = byte_len;
}
}
Ok(Self {
entries,
token_to_id,
token_scores,
unk_id,
max_token_byte_len,
})
}
pub fn encode(&self, text: &str) -> Vec<u32> {
const UNK_PENALTY: f64 = -1e6;
let n = text.len(); if n == 0 {
return Vec::new();
}
let mut best_score: Vec<f64> = vec![f64::NEG_INFINITY; n + 1];
let mut best_back: Vec<Option<(u32, usize)>> = vec![None; n + 1];
best_score[0] = 0.0;
for i in 0..n {
if best_score[i] == f64::NEG_INFINITY {
continue;
}
if !text.is_char_boundary(i) {
if i < n {
let cand = best_score[i] + UNK_PENALTY;
if cand > best_score[i + 1] {
best_score[i + 1] = cand;
best_back[i + 1] = Some((self.unk_id, 1));
}
}
continue;
}
let mut found_any = false;
let max_len = self.max_token_byte_len.min(n - i);
for len in 1..=max_len {
if !text.is_char_boundary(i + len) {
continue;
}
let substr = &text[i..i + len];
if let Some(&tok_id) = self.token_to_id.get(substr) {
let score = self.token_scores[tok_id as usize];
let cand = best_score[i] + score;
if cand > best_score[i + len] {
best_score[i + len] = cand;
best_back[i + len] = Some((tok_id, len));
found_any = true;
}
}
}
if (!found_any || best_score[i + 1] == f64::NEG_INFINITY) && i < n {
let cand = best_score[i] + UNK_PENALTY;
if cand > best_score[i + 1] {
best_score[i + 1] = cand;
best_back[i + 1] = Some((self.unk_id, 1));
}
}
}
let mut tokens: Vec<u32> = Vec::new();
let mut pos = n;
while pos > 0 {
match best_back[pos] {
Some((tok_id, len)) => {
tokens.push(tok_id);
pos -= len;
}
None => {
break;
}
}
}
tokens.reverse();
tokens
}
pub fn token_count(&self) -> usize {
self.entries.len()
}
pub fn decode_token(&self, id: u32) -> Option<&str> {
self.entries.get(id as usize).map(|(s, _)| s.as_str())
}
pub fn decode(&self, ids: &[u32]) -> String {
let mut out = String::new();
for &id in ids {
match self.decode_token(id) {
Some(s) => out.push_str(s),
None => out.push('\u{FFFD}'),
}
}
out
}
pub fn unk_id(&self) -> u32 {
self.unk_id
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_vocab(entries: &[(&str, f64)], unk_id: u32) -> UnigramVocab {
UnigramVocab::new(
entries.iter().map(|(s, p)| (s.to_string(), *p)).collect(),
unk_id,
)
.unwrap()
}
#[test]
fn empty_vocab_errors() {
let err = UnigramVocab::new(vec![], 0).unwrap_err();
assert_eq!(err, UnigramError::EmptyVocab);
}
#[test]
fn unk_out_of_range_errors() {
let entries = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -1.0)];
let err = UnigramVocab::new(entries, 5).unwrap_err();
assert!(matches!(
err,
UnigramError::UnkOutOfRange {
unk_id: 5,
vocab_len: 2
}
));
}
#[test]
fn duplicate_token_errors() {
let entries = vec![
("a".to_string(), -1.0),
("b".to_string(), -1.5),
("a".to_string(), -2.0), ];
let err = UnigramVocab::new(entries, 0).unwrap_err();
match err {
UnigramError::DuplicateToken(tok) => assert_eq!(tok, "a"),
other => panic!("expected DuplicateToken, got {other:?}"),
}
}
#[test]
fn single_token_encodes() {
let vocab = make_vocab(&[("<unk>", 0.0), ("hello", -1.0)], 0);
let ids = vocab.encode("hello");
assert_eq!(ids, vec![1]);
}
#[test]
fn ambiguous_prefers_higher_score() {
let vocab = make_vocab(&[("<unk>", 0.0), ("a", -2.0), ("b", -2.0), ("ab", -1.0)], 0);
let ids = vocab.encode("ab");
assert_eq!(ids, vec![3]);
}
#[test]
fn lower_score_path_loses_to_higher() {
let vocab = make_vocab(&[("<unk>", 0.0), ("a", -0.5), ("b", -0.5), ("ab", -1.5)], 0);
let ids = vocab.encode("ab");
assert_eq!(ids, vec![1, 2]);
}
#[test]
fn multibyte_utf8_boundary_respected() {
let vocab = make_vocab(
&[
("<unk>", 0.0),
("c", -1.0),
("a", -1.0),
("f", -1.0),
("é", -1.0),
],
0,
);
let ids = vocab.encode("café");
assert!(!ids.is_empty());
assert!(ids.contains(&4));
}
#[test]
fn unk_fallback_for_unknown_byte() {
let vocab = make_vocab(&[("<unk>", 0.0), ("a", -1.0)], 0);
let ids = vocab.encode("z");
assert_eq!(ids, vec![0]);
}
#[test]
fn decode_roundtrip() {
let vocab = make_vocab(
&[
("<unk>", 0.0),
("hello", -1.0),
(" ", -0.5),
("world", -1.0),
],
0,
);
let text = "hello world";
let ids = vocab.encode(text);
let decoded = vocab.decode(&ids);
assert_eq!(decoded, text);
}
#[test]
fn empty_string_encodes_to_empty() {
let vocab = make_vocab(&[("<unk>", 0.0), ("a", -1.0)], 0);
let ids = vocab.encode("");
assert!(ids.is_empty());
}
#[test]
fn decode_unknown_id_produces_replacement_char() {
let vocab = make_vocab(&[("<unk>", 0.0), ("a", -1.0)], 0);
let decoded = vocab.decode(&[999]);
assert_eq!(decoded, "\u{FFFD}");
}
#[test]
fn decode_token_out_of_range_returns_none() {
let vocab = make_vocab(&[("<unk>", 0.0), ("a", -1.0)], 0);
assert!(vocab.decode_token(999).is_none());
}
#[test]
fn token_count_matches_entries() {
let vocab = make_vocab(&[("<unk>", 0.0), ("a", -1.0), ("b", -2.0)], 0);
assert_eq!(vocab.token_count(), 3);
}
#[test]
fn error_display_empty_vocab() {
let err = UnigramError::EmptyVocab;
let s = format!("{err}");
assert!(s.contains("empty"));
}
#[test]
fn error_display_unk_out_of_range() {
let err = UnigramError::UnkOutOfRange {
unk_id: 10,
vocab_len: 3,
};
let s = format!("{err}");
assert!(s.contains("10"));
assert!(s.contains("3"));
}
#[test]
fn error_display_duplicate_token() {
let err = UnigramError::DuplicateToken("foo".to_string());
let s = format!("{err}");
assert!(s.contains("foo"));
}
}