use std::collections::HashMap;
pub const WORDPIECE_CONTINUATION_PREFIX: &str = "##";
#[derive(Debug, Clone, PartialEq)]
pub enum WordPieceError {
EmptyVocab,
UnkOutOfRange {
unk_id: u32,
vocab_len: usize,
},
DuplicateToken(String),
}
impl std::fmt::Display for WordPieceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::EmptyVocab => write!(f, "vocabulary is empty"),
Self::UnkOutOfRange { unk_id, vocab_len } => {
write!(
f,
"unk_id {unk_id} is out of range for vocabulary of size {vocab_len}"
)
}
Self::DuplicateToken(t) => write!(f, "duplicate token in WordPiece vocabulary: {t:?}"),
}
}
}
impl std::error::Error for WordPieceError {}
#[derive(Debug, Clone)]
pub struct WordPieceVocab {
token_to_id: HashMap<String, u32>,
id_to_token: Vec<String>,
unk_id: u32,
max_input_chars_per_word: usize,
}
impl WordPieceVocab {
pub fn new(tokens: Vec<String>, unk_id: u32) -> Result<Self, WordPieceError> {
if tokens.is_empty() {
return Err(WordPieceError::EmptyVocab);
}
if unk_id as usize >= tokens.len() {
return Err(WordPieceError::UnkOutOfRange {
unk_id,
vocab_len: tokens.len(),
});
}
let mut token_to_id: HashMap<String, u32> = HashMap::with_capacity(tokens.len());
for (i, token) in tokens.iter().enumerate() {
if token_to_id.insert(token.clone(), i as u32).is_some() {
return Err(WordPieceError::DuplicateToken(token.clone()));
}
}
Ok(Self {
id_to_token: tokens,
token_to_id,
unk_id,
max_input_chars_per_word: 100,
})
}
pub fn with_max_input_chars(mut self, max: usize) -> Self {
self.max_input_chars_per_word = max;
self
}
pub fn encode(&self, text: &str) -> Vec<u32> {
let mut result = Vec::new();
for word in text.split_whitespace() {
self.tokenize_word_into(word, &mut result);
}
result
}
fn tokenize_word_into(&self, word: &str, out: &mut Vec<u32>) {
let char_count = word.chars().count();
if char_count > self.max_input_chars_per_word {
out.push(self.unk_id);
return;
}
let char_boundaries: Vec<usize> = word
.char_indices()
.map(|(byte_idx, _)| byte_idx)
.chain(std::iter::once(word.len()))
.collect();
let n_chars = char_boundaries.len() - 1; let mut start_char: usize = 0;
let mut is_bad = false;
let checkpoint = out.len();
'outer: while start_char < n_chars {
let byte_start = char_boundaries[start_char];
let mut end_char = n_chars;
loop {
let byte_end = char_boundaries[end_char];
let candidate: String = if start_char == 0 {
word[byte_start..byte_end].to_owned()
} else {
format!(
"{}{}",
WORDPIECE_CONTINUATION_PREFIX,
&word[byte_start..byte_end]
)
};
if let Some(&id) = self.token_to_id.get(&candidate) {
out.push(id);
start_char = end_char;
continue 'outer;
}
if end_char == start_char + 1 {
is_bad = true;
break 'outer;
}
end_char -= 1;
}
}
if is_bad {
out.truncate(checkpoint);
out.push(self.unk_id);
}
}
pub fn decode(&self, ids: &[u32]) -> String {
let mut result = String::new();
for &id in ids {
let token = match self.id_to_token.get(id as usize) {
Some(t) => t.as_str(),
None => continue,
};
if let Some(cont) = token.strip_prefix(WORDPIECE_CONTINUATION_PREFIX) {
result.push_str(cont);
} else {
if !result.is_empty() {
result.push(' ');
}
result.push_str(token);
}
}
result
}
pub fn decode_token(&self, id: u32) -> Option<&str> {
self.id_to_token.get(id as usize).map(String::as_str)
}
pub fn vocab_size(&self) -> usize {
self.id_to_token.len()
}
pub fn unk_id(&self) -> u32 {
self.unk_id
}
pub fn max_input_chars_per_word(&self) -> usize {
self.max_input_chars_per_word
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_vocab() -> WordPieceVocab {
let tokens: Vec<String> = vec![
"[PAD]".into(),
"[UNK]".into(),
"[CLS]".into(),
"[SEP]".into(),
"hello".into(),
"world".into(),
"play".into(),
"##ing".into(),
"##s".into(),
"foo".into(),
"##bar".into(),
];
WordPieceVocab::new(tokens, 1).expect("make_vocab must succeed")
}
#[test]
fn error_empty_vocab() {
let err = WordPieceVocab::new(vec![], 0).unwrap_err();
assert_eq!(err, WordPieceError::EmptyVocab);
}
#[test]
fn error_unk_out_of_range() {
let err = WordPieceVocab::new(vec!["a".into()], 5).unwrap_err();
assert!(
matches!(
err,
WordPieceError::UnkOutOfRange {
unk_id: 5,
vocab_len: 1
}
),
"unexpected variant: {err:?}"
);
}
#[test]
fn error_duplicate_token() {
let err = WordPieceVocab::new(vec!["a".into(), "a".into()], 0).unwrap_err();
assert!(matches!(err, WordPieceError::DuplicateToken(ref t) if t == "a"));
}
#[test]
fn vocab_size_matches() {
let vocab = make_vocab();
assert_eq!(vocab.vocab_size(), 11);
}
#[test]
fn unk_id_accessor() {
let vocab = make_vocab();
assert_eq!(vocab.unk_id(), 1);
}
#[test]
fn max_input_chars_default() {
let vocab = make_vocab();
assert_eq!(vocab.max_input_chars_per_word(), 100);
}
#[test]
fn max_input_chars_builder() {
let vocab = make_vocab().with_max_input_chars(42);
assert_eq!(vocab.max_input_chars_per_word(), 42);
}
#[test]
fn encode_empty_string() {
let vocab = make_vocab();
assert_eq!(vocab.encode(""), Vec::<u32>::new());
}
#[test]
fn encode_known_word() {
let vocab = make_vocab();
assert_eq!(vocab.encode("hello"), vec![4]);
}
#[test]
fn encode_word_with_continuation() {
let vocab = make_vocab();
assert_eq!(vocab.encode("playing"), vec![6, 7]);
}
#[test]
fn encode_unknown_word_becomes_unk() {
let vocab = make_vocab();
assert_eq!(vocab.encode("xyz"), vec![1]);
}
#[test]
fn encode_multi_word() {
let vocab = make_vocab();
assert_eq!(vocab.encode("hello world"), vec![4, 5]);
}
#[test]
fn encode_word_too_long_becomes_unk() {
let vocab = WordPieceVocab::new(vec!["[UNK]".into(), "a".into()], 0)
.expect("vocab ok")
.with_max_input_chars(3);
assert_eq!(vocab.encode("aaaa"), vec![0]);
}
#[test]
fn encode_at_exact_char_limit_is_not_unk() {
let vocab = WordPieceVocab::new(vec!["[UNK]".into(), "aaa".into()], 0)
.expect("vocab ok")
.with_max_input_chars(3);
assert_eq!(vocab.encode("aaa"), vec![1]);
}
#[test]
fn foobar_segmentation() {
let vocab = make_vocab();
assert_eq!(vocab.encode("foobar"), vec![9, 10]);
}
#[test]
fn partial_bad_word_is_fully_replaced() {
let vocab = make_vocab();
assert_eq!(vocab.encode("fooxyz"), vec![1]);
}
#[test]
fn encode_multibyte_unicode_word() {
let tokens: Vec<String> = vec!["[UNK]".into(), "caf".into(), "##é".into()];
let vocab = WordPieceVocab::new(tokens, 0).expect("vocab ok");
let ids = vocab.encode("café");
assert_eq!(ids, vec![1, 2]);
}
#[test]
fn decode_strips_continuation_prefix() {
let vocab = make_vocab();
assert_eq!(vocab.decode(&[6, 7]), "playing");
}
#[test]
fn decode_multi_word() {
let vocab = make_vocab();
assert_eq!(vocab.decode(&[4, 5]), "hello world");
}
#[test]
fn decode_empty_slice() {
let vocab = make_vocab();
assert_eq!(vocab.decode(&[] as &[u32]), "");
}
#[test]
fn decode_unknown_ids_silently_ignored() {
let vocab = make_vocab();
assert_eq!(vocab.decode(&[4, 999, 5]), "hello world");
}
#[test]
fn decode_token_known_id() {
let vocab = make_vocab();
assert_eq!(vocab.decode_token(4), Some("hello"));
assert_eq!(vocab.decode_token(7), Some("##ing"));
}
#[test]
fn decode_token_out_of_range() {
let vocab = make_vocab();
assert_eq!(vocab.decode_token(999), None);
}
#[test]
fn display_empty_vocab_error() {
let s = format!("{}", WordPieceError::EmptyVocab);
assert!(s.contains("empty"));
}
#[test]
fn display_unk_out_of_range_error() {
let s = format!(
"{}",
WordPieceError::UnkOutOfRange {
unk_id: 7,
vocab_len: 3
}
);
assert!(s.contains("7"));
assert!(s.contains("3"));
}
#[test]
fn display_duplicate_token_error() {
let s = format!("{}", WordPieceError::DuplicateToken("hello".to_string()));
assert!(s.contains("hello"));
}
}