use core::str;
use bytemuck_derive::{Pod, Zeroable};
use crate::{bytes::to_hex_string, tokenv::parse_numeric_token, SimpleVob};
pub type TokenId = u32;
#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)]
#[repr(C)]
pub struct BinTokRxInfo {
pub vocab_size: u32,
pub tok_eos: TokenId,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct TokRxInfo {
pub vocab_size: u32,
pub tok_eos: TokenId,
pub tok_bos: Option<TokenId>,
pub tok_pad: Option<TokenId>,
pub tok_unk: Option<TokenId>,
pub tok_end_of_turn: Option<TokenId>,
}
impl TokRxInfo {
pub fn new(vocab_size: u32, tok_eos: TokenId) -> Self {
TokRxInfo {
vocab_size,
tok_eos,
tok_bos: None,
tok_pad: None,
tok_unk: None,
tok_end_of_turn: None,
}
}
pub fn from_bin(info: &BinTokRxInfo) -> Self {
TokRxInfo {
vocab_size: info.vocab_size,
tok_eos: info.tok_eos,
tok_bos: None,
tok_pad: None,
tok_unk: None,
tok_end_of_turn: None,
}
}
pub fn to_bin(&self) -> BinTokRxInfo {
BinTokRxInfo {
vocab_size: self.vocab_size,
tok_eos: self.tok_eos,
}
}
}
pub trait Recognizer {
fn pop_bytes(&mut self, num: usize);
fn collapse(&mut self);
fn byte_allowed(&mut self, byte: u8) -> bool {
if self.try_push_byte(byte) {
self.pop_bytes(1);
true
} else {
false
}
}
fn trie_finished(&mut self);
fn trie_started(&mut self, _dbg_lbl: &str) {}
fn try_push_byte(&mut self, byte: u8) -> bool;
fn get_error(&mut self) -> Option<String> {
None
}
fn save_stats(&mut self, _nodes_walked: usize) {}
}
#[derive(Clone, Copy)]
struct TokDesc {
len: u32,
off: u32,
}
#[derive(Clone)]
pub struct TokTrie {
info: TokRxInfo,
token_offsets: Vec<TokDesc>,
token_data: Vec<u8>,
nodes: Vec<TrieNode>,
max_token_len: usize,
eos_tokens: Vec<TokenId>,
sorted_vocab: Vec<u32>,
}
#[derive(Clone, Copy, Zeroable, Pod)]
#[repr(C)]
pub struct TrieNode {
bits: u32,
bits2: u32,
}
pub const INVALID_TOKEN: TokenId = 0xffff_ffff;
const NO_TOKEN: u32 = 0xffffff;
const PARENT_BITS: u32 = 10;
const PARENT_MASK: u32 = (1 << PARENT_BITS) - 1;
impl TrieNode {
fn new(byte: u8, token_id: u32, num_parents: usize) -> TrieNode {
assert!(num_parents > 0);
assert!(num_parents <= (1 << PARENT_BITS) as usize);
TrieNode {
bits: (token_id << 8) | byte as u32,
bits2: (num_parents - 1) as u32,
}
}
#[inline(always)]
pub fn byte(&self) -> u8 {
(self.bits & 0xff) as u8
}
#[inline(always)]
pub fn subtree_size(&self) -> usize {
(self.bits2 >> PARENT_BITS) as usize
}
fn set_subtree_size(&mut self, size: usize) {
assert!(size < (1 << (32 - PARENT_BITS)));
self.bits2 = (self.bits2 & PARENT_MASK) | ((size as u32) << PARENT_BITS);
}
#[inline(always)]
pub fn num_parents(&self) -> usize {
((self.bits2 & PARENT_MASK) + 1) as usize
}
#[inline(always)]
pub fn token_id(&self) -> Option<u32> {
let r = self.bits >> 8;
if r == NO_TOKEN {
None
} else {
Some(r)
}
}
}
impl TokTrie {
pub const SPECIAL_TOKEN_MARKER: u8 = 0xff;
pub fn from(info: &TokRxInfo, words: &[Vec<u8>]) -> Self {
let mut trie = TrieBuilder::new(0xff, info.vocab_size);
let mut token_offsets = Vec::with_capacity(info.vocab_size as usize);
let total_len = words.iter().map(|w| w.len()).sum();
let mut token_data = Vec::with_capacity(total_len);
assert!(info.vocab_size == words.len() as u32);
let mut max_token_len = 0;
let mut indices: Vec<usize> = (0..words.len()).collect();
indices.sort_by(|&a, &b| words[a].cmp(&words[b]));
for &idx in &indices {
let word = &words[idx];
if !word.is_empty() {
trie.insert(word, idx as u32);
}
}
for word in words.iter() {
if !word.is_empty() {
max_token_len = std::cmp::max(max_token_len, word.len());
}
let desc = TokDesc {
len: word.len().try_into().unwrap(),
off: token_data.len().try_into().unwrap(),
};
token_offsets.push(desc);
token_data.extend_from_slice(word);
}
let mut nodes = Vec::new();
trie.serialize(&mut nodes, 0);
let sorted_vocab: Vec<u32> = indices.into_iter().map(|idx| idx as u32).collect();
let r = TokTrie {
info: *info,
token_offsets,
token_data,
nodes,
max_token_len,
eos_tokens: vec![info.tok_eos],
sorted_vocab,
};
r.validate();
r
}
pub fn filter(&self, filter: &SimpleVob) -> Self {
let mut trie = TrieBuilder::new(0xff, self.info.vocab_size);
let mut token_offsets = Vec::with_capacity(self.vocab_size());
let mut token_data = Vec::with_capacity(self.token_data.len());
let mut max_token_len = 0;
for &n in &self.sorted_vocab {
if filter.is_allowed(n) {
let b = self.token(n);
if !b.is_empty() {
trie.insert(b, n);
}
}
}
for n in 0..(self.vocab_size() as TokenId) {
let b = if filter.is_allowed(n) {
self.token(n)
} else {
&[]
};
if !b.is_empty() {
max_token_len = std::cmp::max(max_token_len, b.len());
}
let desc = TokDesc {
len: b.len().try_into().unwrap(),
off: token_data.len().try_into().unwrap(),
};
token_offsets.push(desc);
token_data.extend_from_slice(b);
}
let mut nodes = Vec::new();
trie.serialize(&mut nodes, 0);
let r = TokTrie {
info: self.info,
token_offsets,
token_data,
nodes,
max_token_len,
eos_tokens: self.eos_tokens.clone(),
sorted_vocab: self.sorted_vocab.clone(),
};
r.validate();
r
}
pub fn with_eos_token(&self, eos_token: TokenId) -> Self {
self.with_eos_tokens(&[eos_token])
}
pub fn with_eos_tokens(&self, eos_tokens: &[TokenId]) -> Self {
assert!(!eos_tokens.is_empty(), "eos_tokens must not be empty");
let vocab = self.vocab_size() as u32;
for &tok in eos_tokens {
assert!(
tok < vocab,
"EOS token ID {tok} is out of range (vocab_size={vocab})"
);
}
let mut r = self.clone();
r.info.tok_eos = eos_tokens[0];
r.eos_tokens = eos_tokens.to_vec();
r
}
pub fn with_info(&self, info: TokRxInfo) -> Self {
let mut r = self.clone();
r.info = info;
r.eos_tokens = vec![info.tok_eos];
r
}
pub fn build_chat_mode_trie(&self) -> Self {
self.with_eos_token(self.info.tok_end_of_turn.unwrap_or(self.info.tok_eos))
}
fn node_offset(&self, n: &TrieNode) -> usize {
let off = (n as *const _ as usize - self.root() as *const _ as usize)
/ std::mem::size_of::<TrieNode>();
assert!(off < self.nodes.len());
off
}
fn next_node(&self, n: &TrieNode) -> usize {
self.node_offset(n) + n.subtree_size()
}
pub fn info(&self) -> &TokRxInfo {
&self.info
}
pub fn eos_token(&self) -> TokenId {
self.info.tok_eos
}
pub fn eos_tokens(&self) -> &[TokenId] {
&self.eos_tokens
}
pub fn vocab_size(&self) -> usize {
self.info.vocab_size as usize
}
pub fn alloc_token_set(&self) -> SimpleVob {
SimpleVob::alloc_with_capacity(self.vocab_size(), self.vocab_size() + 1)
}
pub fn singleton_token_set(&self, tok: TokenId) -> SimpleVob {
let mut r = self.alloc_token_set();
r.allow_token(tok);
r
}
pub fn eos_token_set(&self) -> SimpleVob {
let mut r = self.alloc_token_set();
let vocab = self.vocab_size() as u32;
for &eos in self.eos_tokens() {
if eos != INVALID_TOKEN && eos < vocab {
r.allow_token(eos);
}
}
r
}
pub fn token_set_dbg(&self, ts: &SimpleVob) -> String {
let max_examples = 50;
let ts_neg = ts.negated();
let use_neg = ts_neg.num_set() * 10 < ts.num_set();
let ts1 = if use_neg { &ts_neg } else { ts };
let num_set = ts1.num_set();
let max_tok = std::cmp::min(max_examples, num_set);
let mut token_names = Vec::new();
if self.info.tok_eos != INVALID_TOKEN && ts1.is_allowed(self.info.tok_eos) {
token_names.push("EOS".to_string());
}
for idx in 0..self.vocab_size() {
if idx as TokenId != self.info.tok_eos && ts1.is_allowed(idx as TokenId) {
token_names.push(self.token_dbg(idx as TokenId));
if token_names.len() >= max_tok {
break;
}
}
}
if token_names.len() < num_set {
token_names.push("...".to_string());
}
format!(
"TokenSet: {}/{}; {}{}",
ts.num_set(),
self.vocab_size(),
if use_neg { "ALL EXCEPT " } else { "" },
token_names.join(" ")
)
}
pub fn alloc_logits(&self) -> Vec<f32> {
vec![0.0; self.vocab_size() + 1]
}
pub fn test_trace_tokens(&self, toks: &[u32]) -> String {
self.tokens_dbg_ext(toks, false)
}
pub const MAX_DBG_TOKENS: usize = 200;
pub fn tokens_dbg(&self, toks: &[u32]) -> String {
self.tokens_dbg_ext(toks, true)
}
fn tokens_dbg_ext(&self, toks: &[u32], quote: bool) -> String {
let (limited, toks) = if toks.len() > Self::MAX_DBG_TOKENS {
("…", &toks[toks.len() - Self::MAX_DBG_TOKENS..])
} else {
("", toks)
};
let joined = toks
.iter()
.map(|t| self.token_dbg_ext(*t, false))
.collect::<Vec<_>>()
.join("‧");
if quote {
format!("⟦{limited}{joined}⟧")
} else if limited.is_empty() {
joined
} else {
format!("{limited}{joined}")
}
}
pub fn token_dbg(&self, idx: u32) -> String {
self.token_dbg_ext(idx, true)
}
fn token_dbg_ext(&self, idx: u32, quote: bool) -> String {
if idx == self.info.tok_eos {
"≺EOS≻".to_string()
} else if idx as usize >= self.vocab_size() {
format!("≺OOB[{idx}]≻")
} else {
let bytes = self.token(idx);
if bytes.len() > 1 && bytes[0] == TokTrie::SPECIAL_TOKEN_MARKER {
String::from_utf8_lossy(&bytes[1..]).to_string()
} else {
let s = String::from_utf8_lossy(bytes);
if s.is_empty() {
format!("≺EMPTY[{idx}]≻")
} else if !s.contains('\u{fffd}') {
let mut s = format!("{s:?}").replace("\\\"", "\"");
s.remove(0);
s.pop();
if quote {
format!("⟨{s}⟩")
} else {
s
}
} else {
let bytes = self.token(idx);
format!("≺HEX[{}]≻", to_hex_string(bytes))
}
}
}
}
pub fn token_str(&self, idx: u32) -> String {
String::from_utf8_lossy(self.token(idx)).to_string()
}
pub fn token_len(&self, idx: u32) -> usize {
let t = self.token(idx);
if t.is_empty() || t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
let mut idx = idx;
let mut len = 1;
while idx >= 10 {
idx /= 10;
len += 1;
}
len + 3
} else {
t.len()
}
}
pub fn token(&self, idx: u32) -> &[u8] {
if idx >= self.token_offsets.len() as u32 {
return &[];
}
let desc = self.token_offsets[idx as usize];
let len = desc.len as usize;
let off = desc.off as usize;
&self.token_data[off..(off + len)]
}
pub fn decode(&self, tokens: &[TokenId]) -> Vec<u8> {
self.decode_ext(tokens, true)
}
pub fn decode_ext(&self, tokens: &[TokenId], include_special: bool) -> Vec<u8> {
let mut res = Vec::with_capacity(tokens.len() * 6 + 32); for &tok in tokens {
let t = self.token(tok);
if t.is_empty() {
if include_special {
res.extend_from_slice(format!("<[{tok}]>").as_bytes());
}
} else if t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
if include_special {
res.extend_from_slice(&t[1..]);
}
} else {
res.extend_from_slice(t);
}
}
res
}
pub fn decode_as_special(&self, tok: TokenId) -> Vec<u8> {
let mut res = Vec::with_capacity(9);
res.push(TokTrie::SPECIAL_TOKEN_MARKER);
res.extend_from_slice(format!("[{tok}]").as_bytes());
res
}
pub fn decode_raw(&self, tokens: &[TokenId]) -> Vec<u8> {
let mut res = Vec::with_capacity(tokens.len() * 6 + 32); for &tok in tokens {
let t = self.token(tok);
if t.is_empty() || t[0] == TokTrie::SPECIAL_TOKEN_MARKER {
res.push(TokTrie::SPECIAL_TOKEN_MARKER);
res.extend_from_slice(format!("[{tok}]").as_bytes());
} else {
res.extend_from_slice(t);
}
}
res
}
pub fn decode_str(&self, tokens: &[TokenId]) -> String {
String::from_utf8_lossy(&self.decode(tokens)).to_string()
}
pub fn decode_raw_to_decode(&self, bytes: &[u8]) -> Vec<u8> {
let mut res = Vec::new();
let mut idx = 0;
while idx < bytes.len() {
if bytes[idx] == TokTrie::SPECIAL_TOKEN_MARKER {
if let Some((len, tok)) = parse_numeric_token(&bytes[(idx + 1)..]) {
res.extend_from_slice(&self.decode(&[tok]));
idx += len + 1;
} else {
res.push(bytes[idx]);
idx += 1;
}
} else {
res.push(bytes[idx]);
idx += 1;
}
}
res
}
pub fn is_special_token(&self, tok: TokenId) -> bool {
let bytes = self.token(tok);
!bytes.is_empty() && bytes[0] == TokTrie::SPECIAL_TOKEN_MARKER
}
pub fn get_special_token(&self, name: &str) -> Option<TokenId> {
self.child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER)
.and_then(|n| {
self.child_at_bytes(n, name.as_bytes())
.and_then(|n| n.token_id())
})
}
pub fn get_special_tokens(&self) -> Vec<TokenId> {
let mut res = Vec::new();
let pref_node = self
.child_at_byte(self.root(), TokTrie::SPECIAL_TOKEN_MARKER)
.expect("missing special token prefix");
let mut stack = vec![pref_node];
while let Some(n) = stack.pop() {
for c in self.node_children(n) {
if let Some(tok) = c.token_id() {
res.push(tok);
if res.len() > Self::MAX_DBG_TOKENS + 1 {
break;
}
}
stack.push(c);
}
}
res.remove(0);
res
}
pub fn greedy_tokenize(&self, bytes: &[u8]) -> Vec<TokenId> {
let mut tokens = Vec::new();
let mut i = 0;
while i < bytes.len() {
let mut node = self.root();
let mut last_tok = None;
let mut last_idx = i;
#[allow(clippy::needless_range_loop)]
for j in i..bytes.len() {
if let Some(child) = self.child_at_byte(node, bytes[j]) {
node = child;
if let Some(tok) = node.token_id() {
last_tok = Some(tok);
last_idx = j;
}
} else {
break;
}
}
if let Some(t) = last_tok {
tokens.push(t);
} else {
}
i = last_idx + 1;
}
tokens
}
pub fn tokenize_with_special<F>(&self, s: &str, str_tokenize: F) -> Vec<TokenId>
where
F: Fn(&str) -> Vec<TokenId>,
{
let max_len = 100;
let bytes = s.as_bytes();
let mut out = Vec::new();
let mut last = 0; let mut i = 0;
while i < bytes.len() {
if bytes[i] != b'<' {
i += 1;
continue;
}
let mut valid = true;
let mut j = i + 1;
let mut len_inside = 0;
while j < bytes.len() && len_inside < max_len {
match bytes[j] {
b'<' => {
valid = false;
break;
}
b'>' => break,
_ => {
len_inside += 1;
j += 1;
}
}
}
if !valid || j >= bytes.len() || bytes[j] != b'>' || len_inside == 0 {
i += 1;
continue;
}
let name = &s[i..=j];
if let Some(special_tok) = self.get_special_token(name) {
if last < i {
out.extend(str_tokenize(&s[last..i]));
}
out.push(special_tok);
} else {
out.extend(str_tokenize(&s[last..=j]));
}
i = j + 1;
last = i;
}
if last < bytes.len() {
out.extend(str_tokenize(&s[last..]));
}
out
}
pub fn tokenize_with_greedy_fallback(
&self,
bytes: &[u8],
str_tokenize: impl Fn(&str) -> Vec<TokenId>,
) -> Vec<TokenId> {
match str::from_utf8(bytes) {
Ok(s) => {
str_tokenize(s)
}
Err(_) => {
let mut res = vec![];
for chunk in bytes.utf8_chunks() {
if !chunk.valid().is_empty() {
res.extend(str_tokenize(chunk.valid()));
}
if !chunk.invalid().is_empty() {
res.extend(self.greedy_tokenize(chunk.invalid()));
}
}
res
}
}
}
pub fn has_extensions(&self, bytes: &[u8]) -> bool {
match self.child_at_bytes(self.root(), bytes) {
None => false,
Some(n) => n.subtree_size() > 1,
}
}
pub fn token_id(&self, bytes: &[u8]) -> Option<TokenId> {
let (tok, len) = self.prefix_token_id(bytes);
if len == bytes.len() {
Some(tok)
} else {
None
}
}
pub fn prefix_token_id(&self, bytes: &[u8]) -> (TokenId, usize) {
assert!(!bytes.is_empty());
let mut last = (0, 0);
let mut n = self.root();
for (idx, byte) in bytes.iter().enumerate() {
n = match self.child_at_byte(n, *byte) {
Some(n) => n,
None => break,
};
if let Some(tok) = n.token_id() {
last = (tok, idx + 1);
}
}
last
}
pub fn max_token_len(&self) -> usize {
self.max_token_len
}
fn validate_node(&self, n: &TrieNode, ep: usize, used: &mut [bool]) {
if let Some(tok) = n.token_id() {
assert!(tok < self.info.vocab_size);
assert!(!used[tok as usize]);
used[tok as usize] = true;
}
let endp = self.next_node(n);
assert!(endp <= ep);
for child in self.node_children(n) {
self.validate_node(child, endp, used);
}
}
fn validate(&self) {
self.validate_node(
self.root(),
self.next_node(self.root()),
&mut vec![false; self.info.vocab_size as usize],
);
for idx in 0..self.info.vocab_size {
let _ = self.token(idx);
}
}
pub fn root(&self) -> &TrieNode {
&self.nodes[0]
}
pub fn check_against(&self, tokens: &[Vec<u8>]) {
for (idx, bytes) in tokens.iter().enumerate() {
let tid = idx as TokenId;
assert!(bytes == self.token(tid));
let root = self.root();
if !bytes.is_empty() {
let tid2 = self
.child_at_bytes(root, bytes)
.unwrap()
.token_id()
.unwrap();
if tid != tid2 {
let par = self
.child_at_bytes(root, &bytes[0..bytes.len() - 1])
.unwrap();
let has_it = self.node_children(par).any(|n| {
n.subtree_size() == 1
&& n.byte() == bytes[bytes.len() - 1]
&& n.token_id() == Some(tid)
});
assert!(has_it);
}
}
}
}
pub fn child_at_byte<'a>(&'a self, n: &'a TrieNode, byte: u8) -> Option<&'a TrieNode> {
self.node_children(n).find(|&child| child.byte() == byte)
}
pub fn all_subtokens(&self, bytes: &[u8]) -> Vec<TokenId> {
let mut r = Vec::new();
for i in 0..bytes.len() {
let mut n = self.root();
for &b in &bytes[i..] {
n = match self.child_at_byte(n, b) {
Some(n) => n,
None => break,
};
if let Some(tok) = n.token_id() {
r.push(tok);
}
}
}
r
}
pub fn node_children(&self, n: &TrieNode) -> NodeChildren<'_> {
let off = self.node_offset(n);
NodeChildren {
trie: self,
current_offset: off + 1,
end_offset: off + n.subtree_size(),
}
}
pub fn child_at_bytes<'a>(&'a self, mut n: &'a TrieNode, bytes: &[u8]) -> Option<&'a TrieNode> {
for &byte in bytes {
n = self.child_at_byte(n, byte)?
}
Some(n)
}
pub fn token_id_at_bytes(&self, bytes: &[u8]) -> Option<TokenId> {
self.child_at_bytes(self.root(), bytes)
.and_then(|n| n.token_id())
}
pub fn chop_tokens(&self, r: &mut impl Recognizer, tokens: &[TokenId]) -> (usize, usize) {
let max_token_lookback = 4;
let suff_bytes =
self.decode_raw(&tokens[tokens.len().saturating_sub(max_token_lookback)..]);
let suff_bytes = &suff_bytes[suff_bytes.len().saturating_sub(self.max_token_len())..];
for idx in 0..suff_bytes.len() {
let suff = &suff_bytes[idx..];
if self.has_valid_extensions(r, suff) {
let chop_bytes = suff.len();
assert!(chop_bytes > 0);
let mut curr_len = 0;
for chop_idx in 1..=tokens.len() {
curr_len += self.token_len(tokens[tokens.len() - chop_idx]);
if curr_len >= chop_bytes {
return (chop_idx, curr_len);
}
}
unreachable!();
}
}
(0, 0)
}
#[inline(never)]
pub fn has_valid_extensions(&self, r: &mut impl Recognizer, start: &[u8]) -> bool {
let n = self.child_at_bytes(self.root(), start);
if n.is_none() {
return false;
}
let n = n.unwrap();
r.trie_started("has_valid_extensions");
let off = self.node_offset(n);
let mut p = off + 1;
let endp = off + n.subtree_size();
let mut ok = false;
let mut next_pop = 0;
while p < endp {
r.pop_bytes(next_pop);
let n = &self.nodes[p];
let b = n.byte();
if r.try_push_byte(b) {
if n.token_id().is_some() {
ok = true;
break;
}
next_pop = if n.subtree_size() == 1 {
n.num_parents()
} else {
0
};
p += 1;
} else {
p += n.subtree_size();
next_pop = n.num_parents() - 1;
}
}
r.trie_finished();
ok
}
pub fn all_prefixes(&self, bytes: &[u8]) -> Vec<TokenId> {
let mut r = Vec::new();
let mut n = self.root();
for &b in bytes {
if let Some(c) = self.child_at_byte(n, b) {
n = c;
if let Some(tok) = n.token_id() {
r.push(tok);
}
} else {
break;
}
}
r
}
pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) {
if !start.is_empty() {
let mut fixed = FixedRecognizer::new(start);
self.add_bias(&mut fixed, toks, &[]);
}
let n = self.child_at_bytes(self.root(), start);
if n.is_none() {
return;
}
let n = n.unwrap();
r.trie_started("add_bias");
let (next_pop, nodes_walked) = self.add_bias_inner(r, toks, n);
if start.is_empty() {
r.pop_bytes(next_pop);
}
r.trie_finished();
r.save_stats(nodes_walked);
let defl_tok = self.vocab_size() as u32;
toks.disallow_token(defl_tok);
}
#[inline(never)]
fn add_bias_inner(
&self,
r: &mut impl Recognizer,
toks: &mut SimpleVob,
n: &TrieNode,
) -> (usize, usize) {
let defl_tok = self.vocab_size() as u32;
let off = self.node_offset(n);
let total_nodes = n.subtree_size();
let mut p = off + 1;
let endp = off + total_nodes;
let nodes = &self.nodes[..endp];
let mut next_pop = 0;
let mut num_skip = 0;
while p < endp {
r.pop_bytes(next_pop);
let n = unsafe {
debug_assert!(
p < nodes.len(),
"node index {} out of bounds (len: {})",
p,
nodes.len()
);
nodes.get_unchecked(p)
};
let b = n.byte();
if r.try_push_byte(b) {
let tok = n.token_id().unwrap_or(defl_tok);
debug_assert!(
tok <= self.vocab_size() as u32,
"token {} out of valid range (vocab_size: {})",
tok,
self.vocab_size()
);
unsafe { toks.allow_token_unchecked(tok) };
next_pop = if n.subtree_size() == 1 {
n.num_parents()
} else {
0
};
p += 1;
} else {
let subtree_size = n.subtree_size();
p += subtree_size;
num_skip += subtree_size - 1;
next_pop = n.num_parents() - 1;
}
}
(next_pop, total_nodes - num_skip)
}
pub fn all_tokens(&self) -> Vec<Vec<u8>> {
(0..self.vocab_size())
.map(|idx| self.token(idx as u32).to_vec())
.collect()
}
pub fn sorted_tokens(&self) -> Vec<(u32, Vec<u8>)> {
let mut res = vec![];
let n = self.root();
let off = self.node_offset(n);
let mut p = off + 1;
let endp = off + n.subtree_size();
let mut next_pop = 0;
let mut bytes = vec![];
while p < endp {
bytes.drain(bytes.len() - next_pop..);
let n = &self.nodes[p];
let b = n.byte();
bytes.push(b);
if let Some(t) = n.token_id() {
res.push((t, bytes.clone()));
}
next_pop = if n.subtree_size() == 1 {
n.num_parents()
} else {
0
};
p += 1;
}
res
}
fn count_until_depth(&self, depth: usize) -> (usize, usize) {
let mut count = 0;
let mut num_tokens = 0;
let mut stack = vec![(self.root(), 0)];
while let Some((n, d)) = stack.pop() {
if d == depth {
continue;
} else {
for c in self.node_children(n) {
count += 1;
if c.token_id().is_some() {
num_tokens += 1;
}
stack.push((c, d + 1));
}
}
}
(count, num_tokens)
}
pub fn trie_stats(&self) -> String {
let mut nodes_histogram = vec![0; 256];
let mut token_nodes = 0;
let n = self.root();
let off = self.node_offset(n);
let mut p = off + 1;
let endp = off + n.subtree_size();
while p < endp {
let n = &self.nodes[p];
if n.token_id().is_some() {
token_nodes += 1;
}
let last_ch = self.next_node(n);
let mut ch_p = p + 1;
let mut num_children = 0;
while ch_p < last_ch {
let ch = &self.nodes[ch_p];
ch_p += ch.subtree_size();
num_children += 1;
}
nodes_histogram[std::cmp::min(9, num_children)] += 1;
p += 1;
}
let mut histogram = String::new();
if false {
for (idx, num) in nodes_histogram.iter().enumerate() {
if *num > 0 {
if !histogram.is_empty() {
histogram.push_str(", ");
}
histogram.push_str(&format!("{idx}:{num}"));
}
}
}
if false {
for n in self.node_children(self.root()) {
histogram.push_str(&format!(
"\n{} => {} {}",
n.byte(),
self.node_children(n).count(),
n.subtree_size()
));
}
}
if false {
for depth in 0..30 {
let (count, num_tokens) = self.count_until_depth(depth);
histogram.push_str(&format!(
"\ndepth {depth}: {count} nodes {num_tokens} tokens"
));
}
}
if !histogram.is_empty() {
histogram = format!("\n{histogram}");
}
format!(
"{}{} nodes, {} token nodes, {} token bytes, {} max len",
histogram,
self.nodes.len(),
token_nodes,
self.token_data.len(),
self.max_token_len,
)
}
}
pub struct NodeChildren<'a> {
trie: &'a TokTrie,
current_offset: usize,
end_offset: usize,
}
impl<'a> Iterator for NodeChildren<'a> {
type Item = &'a TrieNode;
fn next(&mut self) -> Option<Self::Item> {
if self.current_offset < self.end_offset {
let node = &self.trie.nodes[self.current_offset];
self.current_offset += node.subtree_size();
Some(node)
} else {
None
}
}
}
const NO_NODE: u32 = 0xffff_ffff;
struct BuilderNode {
token_id: u32,
byte: u8,
first_child: u32,
next_sibling: u32,
last_child: u32,
}
struct TrieBuilder {
nodes: Vec<BuilderNode>,
root_children: [u32; 256],
}
impl TrieBuilder {
fn new(root_byte: u8, vocab_size: u32) -> TrieBuilder {
let estimated_nodes = (vocab_size as usize).saturating_mul(3).max(1024);
let mut builder = TrieBuilder {
nodes: Vec::with_capacity(estimated_nodes),
root_children: [NO_NODE; 256],
};
builder.nodes.push(BuilderNode {
token_id: NO_TOKEN,
byte: root_byte,
first_child: NO_NODE,
next_sibling: NO_NODE,
last_child: NO_NODE,
});
builder
}
fn insert(&mut self, word: &[u8], token_id: u32) {
if word.is_empty() {
assert!(self.nodes[0].token_id == NO_TOKEN);
self.nodes[0].token_id = token_id;
return;
}
let mut curr_node_idx = 0;
for (i, &byte) in word.iter().enumerate() {
let is_last_byte = i == word.len() - 1;
let mut found_existing_path = false;
if curr_node_idx == 0 {
let root_child_idx = self.root_children[byte as usize];
if root_child_idx != NO_NODE {
let child_node = &self.nodes[root_child_idx as usize];
if is_last_byte && child_node.token_id != NO_TOKEN {
} else {
curr_node_idx = root_child_idx as usize;
found_existing_path = true;
}
}
} else {
let mut child_idx = self.nodes[curr_node_idx].first_child;
while child_idx != NO_NODE {
let child_node = &self.nodes[child_idx as usize];
if child_node.byte == byte {
if is_last_byte && child_node.token_id != NO_TOKEN {
} else {
curr_node_idx = child_idx as usize;
found_existing_path = true;
break;
}
}
child_idx = child_node.next_sibling;
}
}
if !found_existing_path {
let new_node_idx = self.nodes.len() as u32;
self.nodes.push(BuilderNode {
token_id: NO_TOKEN,
byte,
first_child: NO_NODE,
next_sibling: NO_NODE,
last_child: NO_NODE,
});
if curr_node_idx == 0 && self.root_children[byte as usize] == NO_NODE {
self.root_children[byte as usize] = new_node_idx;
}
let last_child_idx = self.nodes[curr_node_idx].last_child;
if last_child_idx == NO_NODE {
self.nodes[curr_node_idx].first_child = new_node_idx;
} else {
self.nodes[last_child_idx as usize].next_sibling = new_node_idx;
}
self.nodes[curr_node_idx].last_child = new_node_idx;
curr_node_idx = new_node_idx as usize;
}
}
self.nodes[curr_node_idx].token_id = token_id;
}
fn serialize_node(&self, node_idx: usize, data: &mut Vec<TrieNode>, num_parents: usize) {
let node = &self.nodes[node_idx];
let idx = data.len();
let mut num_ch = 0;
let mut child = node.first_child;
while child != NO_NODE {
num_ch += 1;
child = self.nodes[child as usize].next_sibling;
}
data.push(TrieNode::new(
node.byte,
node.token_id,
if num_parents == 0 { 1 } else { num_parents },
));
let mut child = node.first_child;
while child != NO_NODE {
num_ch -= 1;
self.serialize_node(
child as usize,
data,
if num_ch == 0 { num_parents + 1 } else { 1 },
);
child = self.nodes[child as usize].next_sibling;
}
let subtree_size = data.len() - idx;
data[idx].set_subtree_size(subtree_size);
}
fn serialize(&mut self, data: &mut Vec<TrieNode>, num_parents: usize) {
self.serialize_node(0, data, num_parents);
}
}
struct FixedRecognizer {
bytes: Vec<u8>,
bytes_ptr: usize,
}
impl FixedRecognizer {
fn new(bytes: &[u8]) -> FixedRecognizer {
FixedRecognizer {
bytes: bytes.to_vec(),
bytes_ptr: 0,
}
}
}
impl Recognizer for FixedRecognizer {
fn collapse(&mut self) {}
fn trie_finished(&mut self) {}
fn pop_bytes(&mut self, num: usize) {
self.bytes_ptr -= num;
}
fn try_push_byte(&mut self, byte: u8) -> bool {
if self.bytes_ptr < self.bytes.len() && self.bytes[self.bytes_ptr] == byte {
self.bytes_ptr += 1;
true
} else {
false
}
}
}
pub struct AnythingGoes;
impl Recognizer for AnythingGoes {
fn collapse(&mut self) {}
fn trie_finished(&mut self) {}
fn pop_bytes(&mut self, _num: usize) {}
fn try_push_byte(&mut self, _byte: u8) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_trie(eos: TokenId) -> TokTrie {
let info = TokRxInfo::new(4, eos);
let words = vec![b"a".to_vec(), b"b".to_vec(), b"c".to_vec(), b"d".to_vec()];
TokTrie::from(&info, &words)
}
#[test]
fn test_default_single_eos() {
let trie = make_test_trie(2);
assert_eq!(trie.eos_token(), 2);
assert_eq!(trie.eos_tokens(), &[2]);
}
#[test]
fn test_with_eos_tokens_multiple() {
let trie = make_test_trie(0).with_eos_tokens(&[1, 3]);
assert_eq!(trie.eos_token(), 1);
assert_eq!(trie.eos_tokens(), &[1, 3]);
assert_eq!(trie.info().tok_eos, 1);
}
#[test]
fn test_with_eos_token_backwards_compat() {
let trie = make_test_trie(0).with_eos_token(2);
assert_eq!(trie.eos_token(), 2);
assert_eq!(trie.eos_tokens(), &[2]);
}
#[test]
fn test_with_info_resets_eos_tokens() {
let trie = make_test_trie(0).with_eos_tokens(&[1, 2]);
let trie2 = trie.with_info(TokRxInfo::new(4, 3));
assert_eq!(trie2.eos_token(), 3);
assert_eq!(trie2.eos_tokens(), &[3]);
}
#[test]
fn test_filter_preserves_eos_tokens() {
let trie = make_test_trie(0).with_eos_tokens(&[1, 2]);
let mut filter = trie.alloc_token_set();
for i in 0..4 {
filter.allow_token(i);
}
let filtered = trie.filter(&filter);
assert_eq!(filtered.eos_tokens(), &[1, 2]);
}
#[test]
#[should_panic(expected = "eos_tokens must not be empty")]
fn test_with_eos_tokens_empty_panics() {
make_test_trie(0).with_eos_tokens(&[]);
}
#[test]
fn test_eos_token_set_single() {
let trie = make_test_trie(2);
let set = trie.eos_token_set();
assert!(set.is_allowed(2));
assert!(!set.is_allowed(0));
assert!(!set.is_allowed(1));
assert_eq!(set.num_set(), 1);
}
#[test]
fn test_eos_token_set_multiple() {
let trie = make_test_trie(0).with_eos_tokens(&[1, 3]);
let set = trie.eos_token_set();
assert!(set.is_allowed(1));
assert!(set.is_allowed(3));
assert!(!set.is_allowed(0));
assert!(!set.is_allowed(2));
assert_eq!(set.num_set(), 2);
}
}