use daggrs::{DoubleArrayAhoCorasick, MatchKind, Trie};
use foldhash::HashMap as FoldHashMap;
use smallvec::SmallVec;
use crate::types::TokenId;
#[inline]
fn utf8_char_len(b: u8) -> usize {
match b {
0..=0x7F => 1,
0xC0..=0xDF => 2,
0xE0..=0xEF => 3,
0xF0..=0xFF => 4,
_ => 1,
}
}
const DEFAULT_CHUNK_SIZE: usize = 4 * 1024 * 1024;
#[derive(Clone)]
pub struct UnigramEncoder {
matcher: DoubleArrayAhoCorasick,
scores: Vec<f32>,
unk_token: TokenId,
byte_tokens: [TokenId; 256],
token_lengths: Vec<u16>,
vocab_size: usize,
token_cache: FoldHashMap<Vec<u8>, TokenId>,
has_byte_fallback: bool,
}
const MAX_CACHED_TOKEN_LEN: usize = 16;
impl std::fmt::Debug for UnigramEncoder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnigramEncoder")
.field("vocab_size", &self.vocab_size)
.field("unk_token", &self.unk_token)
.finish()
}
}
impl UnigramEncoder {
pub fn from_vocab_with_scores(
vocab: &[(u32, Vec<u8>, f32)],
unk_token: TokenId,
) -> (Self, Vec<Vec<u8>>) {
let token_bytes: Vec<Vec<u8>> = vocab.iter().map(|(_, bytes, _)| bytes.clone()).collect();
let scores: Vec<f32> = vocab.iter().map(|(_, _, score)| *score).collect();
let mut byte_tokens = [u32::MAX; 256];
for (id, bytes, _) in vocab {
if bytes.len() == 6 && bytes.starts_with(b"<0x") && bytes.ends_with(b">") {
if let Ok(byte_val) = u8::from_str_radix(
std::str::from_utf8(&bytes[3..5]).unwrap_or(""),
16,
) {
if byte_tokens[byte_val as usize] == u32::MAX {
byte_tokens[byte_val as usize] = *id;
}
}
}
}
let mut trie = Trie::new();
for (id, bytes, _) in vocab {
if !bytes.is_empty() {
trie.add(bytes, *id);
}
}
trie.build(MatchKind::Overlapping);
let matcher = trie.compile();
let token_lengths: Vec<u16> = token_bytes.iter().map(|b| b.len() as u16).collect();
let mut token_cache = FoldHashMap::default();
for (id, bytes, _) in vocab {
if bytes.len() <= MAX_CACHED_TOKEN_LEN {
token_cache.insert(bytes.clone(), *id);
}
}
let has_byte_fallback = byte_tokens.iter().any(|&t| t != u32::MAX);
let encoder = Self {
matcher,
scores,
unk_token,
byte_tokens,
token_lengths,
vocab_size: vocab.len(),
token_cache,
has_byte_fallback,
};
(encoder, token_bytes)
}
pub fn from_parts(
matcher: DoubleArrayAhoCorasick,
scores: Vec<f32>,
unk_token: TokenId,
byte_tokens: [TokenId; 256],
token_lengths: Vec<u16>,
token_bytes: &[Vec<u8>],
) -> Self {
let vocab_size = scores.len();
let mut token_cache = FoldHashMap::default();
for (id, bytes) in token_bytes.iter().enumerate() {
if bytes.len() <= MAX_CACHED_TOKEN_LEN {
token_cache.insert(bytes.clone(), id as TokenId);
}
}
let has_byte_fallback = byte_tokens.iter().any(|&t| t != u32::MAX);
Self {
matcher,
scores,
unk_token,
byte_tokens,
token_lengths,
vocab_size,
token_cache,
has_byte_fallback,
}
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn num_base_tokens(&self) -> usize {
self.vocab_size
}
pub fn unk_token(&self) -> TokenId {
self.unk_token
}
pub fn scores(&self) -> &[f32] {
&self.scores
}
pub fn byte_tokens(&self) -> &[TokenId; 256] {
&self.byte_tokens
}
pub fn token_lengths(&self) -> &[u16] {
&self.token_lengths
}
pub fn matcher(&self) -> &DoubleArrayAhoCorasick {
&self.matcher
}
#[inline]
pub fn token_len(&self, token: TokenId) -> usize {
self.token_lengths[token as usize] as usize
}
#[inline]
pub fn is_valid_pair(&self, _token1: TokenId, _token2: TokenId) -> bool {
true
}
pub fn encode(&self, text: &[u8]) -> Vec<TokenId> {
if text.is_empty() {
return Vec::new();
}
if text.len() > DEFAULT_CHUNK_SIZE {
return self.encode_chunked(text, DEFAULT_CHUNK_SIZE);
}
self.encode_single(text)
}
pub fn encode_single(&self, text: &[u8]) -> Vec<TokenId> {
let n = text.len();
let mut best_score = vec![f64::NEG_INFINITY; n + 1];
let mut backptr: Vec<(TokenId, usize)> = vec![(0, 0); n + 1];
best_score[0] = 0.0;
let unk_penalty = if self.has_byte_fallback {
self.scores[self.unk_token as usize] as f64
} else {
-100.0
};
type MatchList = SmallVec<[(usize, TokenId); 8]>;
let mut matches_at: Vec<MatchList> = vec![SmallVec::new(); n];
for m in self.matcher.find_iter(text) {
matches_at[m.start].push((m.end, m.pattern_id));
}
for pos in 0..n {
if best_score[pos] == f64::NEG_INFINITY {
continue;
}
let current_score = best_score[pos];
let has_match = !matches_at[pos].is_empty();
for &(end, token_id) in &matches_at[pos] {
let token_score = self.scores[token_id as usize] as f64;
let new_score = current_score + token_score;
if new_score > best_score[end] {
best_score[end] = new_score;
backptr[end] = (token_id, pos);
}
}
let byte_val = text[pos];
let byte_token = self.byte_tokens[byte_val as usize];
if byte_token != u32::MAX {
let token_score = self.scores[byte_token as usize] as f64;
let new_score = current_score + token_score;
if new_score > best_score[pos + 1] {
best_score[pos + 1] = new_score;
backptr[pos + 1] = (byte_token, pos);
}
} else if !has_match {
let char_len = utf8_char_len(text[pos]);
let end = (pos + char_len).min(n);
let new_score = current_score + unk_penalty;
if new_score > best_score[end] {
best_score[end] = new_score;
backptr[end] = (self.unk_token, pos);
}
}
}
if best_score[n] == f64::NEG_INFINITY {
return self.encode_with_unk_bridging(text, &best_score, &backptr);
}
self.collect_tokens_from_backptr(&backptr, n)
}
#[inline]
fn collect_tokens_from_backptr(&self, backptr: &[(TokenId, usize)], end: usize) -> Vec<TokenId> {
let mut tokens = Vec::new();
let mut pos = end;
while pos > 0 {
let (token_id, start_pos) = backptr[pos];
tokens.push(token_id);
pos = start_pos;
}
tokens.reverse();
if tokens.contains(&self.unk_token) {
tokens.dedup_by(|a, b| *a == self.unk_token && *b == self.unk_token);
}
tokens
}
fn encode_with_unk_bridging(
&self,
text: &[u8],
_best_score: &[f64],
_backptr: &[(TokenId, usize)],
) -> Vec<TokenId> {
let n = text.len();
let mut tokens = Vec::new();
let mut pos = 0;
while pos < n {
let max_len = (n - pos).min(MAX_CACHED_TOKEN_LEN);
let mut best_match: Option<(usize, TokenId)> = None;
for len in (1..=max_len).rev() {
let substr = &text[pos..pos + len];
if let Some(&token_id) = self.token_cache.get(substr) {
best_match = Some((len, token_id));
break;
}
}
let remaining = &text[pos..];
if let Some(m) = self.matcher.find_iter(remaining).next() {
if m.start == 0 && (best_match.is_none() || m.end > best_match.unwrap().0) {
best_match = Some((m.end, m.pattern_id));
}
}
if let Some((len, token_id)) = best_match {
tokens.push(token_id);
pos += len;
} else {
let byte_val = text[pos];
let byte_token = self.byte_tokens[byte_val as usize];
if byte_token != u32::MAX {
tokens.push(byte_token);
} else {
tokens.push(self.unk_token);
}
pos += 1;
}
}
tokens
}
pub fn encode_chunked(&self, text: &[u8], chunk_size: usize) -> Vec<TokenId> {
if text.len() <= chunk_size {
return self.encode_single(text);
}
let mut result = Vec::with_capacity(text.len() / 3);
static METASPACE: [u8; 3] = [0xE2, 0x96, 0x81];
for chunk_bytes in chunk::chunk(text)
.size(chunk_size)
.pattern(&METASPACE)
.prefix()
.consecutive()
.forward_fallback()
{
let chunk_tokens = self.encode_single(chunk_bytes);
result.extend_from_slice(&chunk_tokens);
}
result
}
#[inline]
pub fn encode_chunked_default(&self, text: &[u8]) -> Vec<TokenId> {
self.encode_chunked(text, DEFAULT_CHUNK_SIZE)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_unigram() {
let vocab = vec![
(0, b"h".to_vec(), -1.0),
(1, b"e".to_vec(), -1.0),
(2, b"l".to_vec(), -1.0),
(3, b"o".to_vec(), -1.0),
(4, b"hell".to_vec(), -2.0),
(5, b"hello".to_vec(), -3.0), ];
let (encoder, _) = UnigramEncoder::from_vocab_with_scores(&vocab, 0);
assert_eq!(encoder.encode(b"hello"), vec![5]);
assert_eq!(encoder.encode(b"h"), vec![0]);
}
#[test]
fn test_viterbi_chooses_best_path() {
let vocab = vec![
(0, b"a".to_vec(), -0.1), (1, b"b".to_vec(), -0.1), (2, b"ab".to_vec(), -10.0), ];
let (encoder, _) = UnigramEncoder::from_vocab_with_scores(&vocab, 0);
assert_eq!(encoder.encode(b"ab"), vec![0, 1]);
}
#[test]
fn test_byte_fallback() {
let vocab = vec![
(0, b"<0x00>".to_vec(), -5.0),
(1, b"<0x01>".to_vec(), -5.0),
(2, b"<0xFF>".to_vec(), -5.0),
(3, b"hello".to_vec(), -1.0),
];
let (encoder, _) = UnigramEncoder::from_vocab_with_scores(&vocab, 0);
assert_eq!(encoder.byte_tokens[0x00], 0);
assert_eq!(encoder.byte_tokens[0x01], 1);
assert_eq!(encoder.byte_tokens[0xFF], 2);
}
#[test]
fn test_empty_input() {
let vocab = vec![(0, b"a".to_vec(), -1.0)];
let (encoder, _) = UnigramEncoder::from_vocab_with_scores(&vocab, 0);
let empty: Vec<TokenId> = vec![];
assert_eq!(encoder.encode(b""), empty);
}
#[test]
fn test_vocab_size() {
let vocab = vec![
(0, b"a".to_vec(), -1.0),
(1, b"b".to_vec(), -1.0),
(2, b"c".to_vec(), -1.0),
];
let (encoder, _) = UnigramEncoder::from_vocab_with_scores(&vocab, 0);
assert_eq!(encoder.vocab_size(), 3);
assert_eq!(encoder.num_base_tokens(), 3);
}
}