use super::tokenize::{Tokenize, TokenizeError};
use std::collections::HashMap;
pub struct WordPieceTokenizer {
token_to_id: HashMap<String, u32>,
id_to_token: Vec<String>,
unk_token_id: u32,
max_word_len: usize,
do_lower_case: bool,
has_continuation_prefix: bool,
cls_token_id: Option<u32>,
sep_token_id: Option<u32>,
pad_token_id: Option<u32>,
}
impl WordPieceTokenizer {
pub fn new(
vocab: Vec<String>,
unk_token_id: u32,
max_word_len: usize,
do_lower_case: bool,
) -> Self {
let mut token_to_id = HashMap::with_capacity(vocab.len());
for (id, token) in vocab.iter().enumerate() {
token_to_id.insert(token.clone(), id as u32);
}
let has_continuation_prefix = token_to_id.keys().any(|k| k.starts_with("##"));
let cls_token_id = token_to_id.get("[CLS]").copied();
let sep_token_id = token_to_id.get("[SEP]").copied();
let pad_token_id = token_to_id.get("[PAD]").copied();
Self {
token_to_id,
id_to_token: vocab,
unk_token_id,
max_word_len,
do_lower_case,
has_continuation_prefix,
cls_token_id,
sep_token_id,
pad_token_id,
}
}
pub fn cls_token_id(&self) -> Option<u32> {
self.cls_token_id
}
pub fn sep_token_id(&self) -> Option<u32> {
self.sep_token_id
}
pub fn pad_token_id(&self) -> Option<u32> {
self.pad_token_id
}
pub fn unk_token_id(&self) -> u32 {
self.unk_token_id
}
fn basic_tokenize(&self, text: &str) -> Vec<String> {
let text = if self.do_lower_case {
let lowered = text.to_lowercase();
strip_accents(&lowered)
} else {
text.to_string()
};
let mut tokens = Vec::new();
for word in text.split_whitespace() {
split_on_punctuation(word, &mut tokens);
}
tokens
}
fn wordpiece_tokenize(&self, word: &str) -> Vec<u32> {
let chars: Vec<char> = word.chars().collect();
if chars.len() > self.max_word_len {
return vec![self.unk_token_id];
}
let mut ids = Vec::new();
let mut start = 0;
while start < chars.len() {
let mut end = chars.len();
let mut found = false;
while start < end {
let raw: String = chars[start..end].iter().collect();
let lookup = if start == 0 || !self.has_continuation_prefix {
raw
} else {
format!("##{}", raw)
};
if let Some(&id) = self.token_to_id.get(&lookup) {
ids.push(id);
found = true;
start = end;
break;
}
end -= 1;
}
if !found {
ids.push(self.unk_token_id);
start += 1;
}
}
ids
}
}
impl Tokenize for WordPieceTokenizer {
fn encode(&self, text: &str) -> Vec<u32> {
let words = self.basic_tokenize(text);
let mut ids = Vec::new();
for word in &words {
let word_ids = self.wordpiece_tokenize(word);
ids.extend(word_ids);
}
ids
}
fn decode(&self, ids: &[u32]) -> Result<String, TokenizeError> {
if self.has_continuation_prefix {
self.decode_with_prefix(ids)
} else {
self.decode_without_prefix(ids)
}
}
fn vocab_size(&self) -> usize {
self.id_to_token.len()
}
}
impl WordPieceTokenizer {
fn decode_with_prefix(&self, ids: &[u32]) -> Result<String, TokenizeError> {
let mut pieces = Vec::with_capacity(ids.len());
for &id in ids {
let token = self
.id_to_token
.get(id as usize)
.ok_or(TokenizeError::InvalidTokenId(id))?;
if is_special_token(token) {
continue;
}
if let Some(stripped) = token.strip_prefix("##") {
pieces.push(stripped.to_string());
} else {
if !pieces.is_empty() {
pieces.push(" ".to_string());
}
pieces.push(token.to_string());
}
}
Ok(pieces.join(""))
}
fn decode_without_prefix(&self, ids: &[u32]) -> Result<String, TokenizeError> {
let mut parts = Vec::with_capacity(ids.len());
for &id in ids {
let token = self
.id_to_token
.get(id as usize)
.ok_or(TokenizeError::InvalidTokenId(id))?;
if is_special_token(token) {
continue;
}
parts.push(token.as_str());
}
Ok(parts.join(" "))
}
}
fn is_special_token(token: &str) -> bool {
matches!(token, "[CLS]" | "[SEP]" | "[PAD]" | "[UNK]" | "[MASK]")
|| (token.starts_with("[unused") && token.ends_with(']'))
}
fn strip_accents(text: &str) -> String {
use unicode_normalization::UnicodeNormalization;
text.nfd()
.filter(|c| !unicode_normalization::char::is_combining_mark(*c))
.collect()
}
fn split_on_punctuation(word: &str, out: &mut Vec<String>) {
let mut current = String::new();
for c in word.chars() {
if is_punctuation(c) {
if !current.is_empty() {
out.push(std::mem::take(&mut current));
}
out.push(c.to_string());
} else {
current.push(c);
}
}
if !current.is_empty() {
out.push(current);
}
}
fn is_punctuation(c: char) -> bool {
matches!(c, '\x21'..='\x2F' | '\x3A'..='\x40' | '\x5B'..='\x60' | '\x7B'..='\x7E')
|| c.is_ascii_punctuation()
|| {
let cat = unicode_general_category::get_general_category(c);
matches!(
cat,
unicode_general_category::GeneralCategory::ConnectorPunctuation
| unicode_general_category::GeneralCategory::DashPunctuation
| unicode_general_category::GeneralCategory::ClosePunctuation
| unicode_general_category::GeneralCategory::FinalPunctuation
| unicode_general_category::GeneralCategory::InitialPunctuation
| unicode_general_category::GeneralCategory::OtherPunctuation
| unicode_general_category::GeneralCategory::OpenPunctuation
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tokenizer() -> WordPieceTokenizer {
let vocab = vec![
"[PAD]".to_string(), "[UNK]".to_string(), "[CLS]".to_string(), "[SEP]".to_string(), "hello".to_string(), "world".to_string(), "##ing".to_string(), "##s".to_string(), "un".to_string(), "##know".to_string(), "##n".to_string(), ",".to_string(), "the".to_string(), "a".to_string(), ];
WordPieceTokenizer::new(vocab, 1, 200, true)
}
#[test]
fn test_encode_basic() {
let tok = make_tokenizer();
let ids = tok.encode("hello world");
assert_eq!(ids, vec![4, 5]);
}
#[test]
fn test_encode_subwords() {
let tok = make_tokenizer();
let ids = tok.encode("unknown");
assert_eq!(ids, vec![8, 9, 10]);
}
#[test]
fn test_encode_punctuation() {
let tok = make_tokenizer();
let ids = tok.encode("hello, world");
assert_eq!(ids, vec![4, 11, 5]);
}
#[test]
fn test_decode_basic() {
let tok = make_tokenizer();
let text = tok.decode(&[4, 5]).unwrap();
assert_eq!(text, "hello world");
}
#[test]
fn test_decode_subwords() {
let tok = make_tokenizer();
let text = tok.decode(&[8, 9, 10]).unwrap();
assert_eq!(text, "unknown");
}
#[test]
fn test_decode_skips_special() {
let tok = make_tokenizer();
let text = tok.decode(&[2, 4, 5, 3]).unwrap();
assert_eq!(text, "hello world");
}
#[test]
fn test_vocab_size() {
let tok = make_tokenizer();
assert_eq!(tok.vocab_size(), 14);
}
#[test]
fn test_special_token_ids() {
let tok = make_tokenizer();
assert_eq!(tok.cls_token_id(), Some(2));
assert_eq!(tok.sep_token_id(), Some(3));
assert_eq!(tok.pad_token_id(), Some(0));
assert_eq!(tok.unk_token_id(), 1);
}
#[test]
fn test_unknown_word() {
let tok = make_tokenizer();
let ids = tok.encode("xyz");
assert!(ids.iter().all(|&id| id == 1));
}
#[test]
fn test_lowercase() {
let tok = make_tokenizer();
let ids = tok.encode("Hello WORLD");
assert_eq!(ids, vec![4, 5]);
}
#[test]
fn test_case_sensitive() {
let vocab = vec![
"[UNK]".to_string(), "Hello".to_string(), "hello".to_string(), ];
let tok = WordPieceTokenizer::new(vocab, 0, 200, false);
let ids = tok.encode("Hello");
assert_eq!(ids, vec![1]);
let ids = tok.encode("hello");
assert_eq!(ids, vec![2]);
}
#[test]
fn test_decode_invalid_id() {
let tok = make_tokenizer();
let result = tok.decode(&[999]);
assert!(result.is_err());
}
}