use daggrs::{DoubleArrayAhoCorasick, Trie};
use foldhash::HashMap as FoldHashMap;
use crate::types::TokenId;
pub const DEFAULT_CONTINUATION_PREFIX: &[u8] = b"##";
#[derive(Clone)]
pub struct WordPieceEncoder {
matcher: DoubleArrayAhoCorasick,
unk_token: TokenId,
continuation_prefix: Vec<u8>,
vocab_size: usize,
byte_lut: [TokenId; 256],
token_cache: FoldHashMap<Vec<u8>, TokenId>,
max_input_chars_per_word: usize,
}
impl WordPieceEncoder {
pub fn from_vocab(
vocab: &[(Vec<u8>, TokenId)],
unk_token: TokenId,
continuation_prefix: &[u8],
max_input_chars_per_word: usize,
) -> Self {
let mut trie = Trie::new();
for (bytes, token_id) in vocab {
trie.add(bytes, *token_id);
}
trie.build_wordpiece(continuation_prefix);
let matcher = trie.compile();
let mut byte_lut = [unk_token; 256];
for (bytes, token_id) in vocab {
if bytes.len() == 1 {
byte_lut[bytes[0] as usize] = *token_id;
}
}
let mut token_cache = FoldHashMap::default();
for (bytes, token_id) in vocab {
if bytes.len() <= 16 {
token_cache.insert(bytes.clone(), *token_id);
}
}
Self {
matcher,
unk_token,
continuation_prefix: continuation_prefix.to_vec(),
vocab_size: vocab.len(),
byte_lut,
token_cache,
max_input_chars_per_word,
}
}
pub fn from_vocab_default(vocab: &[(Vec<u8>, TokenId)], unk_token: TokenId) -> Self {
Self::from_vocab(vocab, unk_token, DEFAULT_CONTINUATION_PREFIX, 100)
}
pub fn from_parts(
matcher: DoubleArrayAhoCorasick,
unk_token: TokenId,
continuation_prefix: Vec<u8>,
vocab_size: usize,
token_bytes: &[Vec<u8>],
max_input_chars_per_word: usize,
) -> Self {
let mut byte_lut = [unk_token; 256];
for (token_id, bytes) in token_bytes.iter().enumerate() {
if bytes.len() == 1 {
byte_lut[bytes[0] as usize] = token_id as TokenId;
}
}
let mut token_cache = FoldHashMap::default();
for (token_id, bytes) in token_bytes.iter().enumerate() {
if bytes.len() <= 16 {
token_cache.insert(bytes.clone(), token_id as TokenId);
}
}
Self {
matcher,
unk_token,
continuation_prefix,
vocab_size,
byte_lut,
token_cache,
max_input_chars_per_word,
}
}
pub fn encode(&self, word: &[u8]) -> Vec<TokenId> {
if word.is_empty() {
return Vec::new();
}
if word.len() > self.max_input_chars_per_word {
let char_count = std::str::from_utf8(word)
.map(|s| s.chars().count())
.unwrap_or(word.len());
if char_count > self.max_input_chars_per_word {
return vec![self.unk_token];
}
}
if word.len() == 1 {
return vec![self.byte_lut[word[0] as usize]];
}
if let Some(&token_id) = self.token_cache.get(word) {
return vec![token_id];
}
let anchor = match self.matcher.anchor {
Some(a) => a,
None => return vec![self.unk_token],
};
let mut result = Vec::new();
let mut pos = 0usize;
let mut state = self.matcher.start_state();
let mut last_match: Option<(usize, TokenId)> = None;
loop {
while pos < word.len() {
if let Some(next_state) = self.try_transition(state, word[pos]) {
state = next_state;
pos += 1;
if let Some(output) = self.matcher.outputs(state).next() {
last_match = Some((pos, output.pattern_id));
}
} else {
if let Some((end_pos, token_id)) = last_match.take() {
result.push(token_id);
pos = end_pos;
state = anchor;
} else {
return vec![self.unk_token];
}
}
}
if let Some((end_pos, token_id)) = last_match.take() {
result.push(token_id);
if end_pos < word.len() {
pos = end_pos;
state = anchor;
continue;
}
}
break;
}
if result.is_empty() {
vec![self.unk_token]
} else {
result
}
}
#[inline]
fn try_transition(&self, state: u32, byte: u8) -> Option<u32> {
let states = &self.matcher.states;
let current = &states[state as usize];
let child = current.base ^ (byte as u32);
if (child as usize) < states.len() && states[child as usize].check == state {
Some(child)
} else {
None
}
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn unk_token(&self) -> TokenId {
self.unk_token
}
pub fn continuation_prefix(&self) -> &[u8] {
&self.continuation_prefix
}
pub fn matcher(&self) -> &DoubleArrayAhoCorasick {
&self.matcher
}
pub fn max_input_chars_per_word(&self) -> usize {
self.max_input_chars_per_word
}
pub fn is_valid_pair(&self, _token1: TokenId, _token2: TokenId) -> bool {
true
}
pub fn num_base_tokens(&self) -> usize {
self.vocab_size
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_vocab() -> Vec<(Vec<u8>, TokenId)> {
vec![
(b"[UNK]".to_vec(), 0),
(b"un".to_vec(), 1),
(b"break".to_vec(), 2),
(b"##break".to_vec(), 3),
(b"##able".to_vec(), 4),
(b"##ing".to_vec(), 5),
]
}
#[test]
fn test_wordpiece_single_token() {
let vocab = make_test_vocab();
let encoder = WordPieceEncoder::from_vocab_default(&vocab, 0);
let tokens = encoder.encode(b"un");
assert_eq!(tokens, vec![1]);
}
#[test]
fn test_wordpiece_continuation() {
let vocab = make_test_vocab();
let encoder = WordPieceEncoder::from_vocab_default(&vocab, 0);
let tokens = encoder.encode(b"unbreakable");
assert_eq!(tokens, vec![1, 3, 4]);
}
#[test]
fn test_wordpiece_unknown() {
let vocab = make_test_vocab();
let encoder = WordPieceEncoder::from_vocab_default(&vocab, 0);
let tokens = encoder.encode(b"xyz");
assert_eq!(tokens, vec![0]);
}
#[test]
fn test_wordpiece_empty() {
let vocab = make_test_vocab();
let encoder = WordPieceEncoder::from_vocab_default(&vocab, 0);
let tokens = encoder.encode(b"");
assert!(tokens.is_empty());
}
#[test]
fn test_wordpiece_partial_unknown() {
let vocab = make_test_vocab();
let encoder = WordPieceEncoder::from_vocab_default(&vocab, 0);
let tokens = encoder.encode(b"unxyz");
assert_eq!(tokens, vec![0]); }
}