use std::sync::Arc;
use crate::{TokRxInfo, TokTrie, TokenId};
pub trait TokenizerEnv: Send {
fn tok_trie(&self) -> &TokTrie;
fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId>;
fn tokenize_bytes_marker(&self, s: &[u8]) -> (Vec<TokenId>, usize) {
let mut idx = 0;
let ff = TokTrie::SPECIAL_TOKEN_MARKER;
let mut result = Vec::new();
let trie = self.tok_trie();
let mut num_fixed_tokens = 0;
while idx < s.len() {
let normal_len = s[idx..]
.iter()
.position(|&x| x == ff)
.unwrap_or(s.len() - idx);
if normal_len != 0 {
let new_tokens = self.tokenize_bytes(&s[idx..idx + normal_len]);
for (idx, t) in new_tokens.iter().enumerate() {
if trie.is_special_token(*t) {
num_fixed_tokens = result.len() + idx + 1;
}
}
result.extend_from_slice(&new_tokens);
idx += normal_len;
}
idx += 1; if idx + 2 < s.len() && s[idx] == b'<' {
let spec_len = s[idx..std::cmp::min(s.len(), idx + 100)]
.iter()
.position(|&x| x == b'>');
if let Some(mut spec_len) = spec_len {
spec_len += 1;
let spec_token = &s[idx - 1..idx + spec_len];
if let Some(id) = trie.token_id_at_bytes(spec_token) {
result.push(id);
num_fixed_tokens = result.len();
idx += spec_len;
}
}
} else if idx < s.len() {
if let Some((n_bytes, tok_id)) = parse_numeric_token(&s[idx..]) {
if tok_id < trie.vocab_size() as u32 {
result.push(tok_id);
num_fixed_tokens = result.len();
idx += n_bytes;
}
}
}
}
(result, num_fixed_tokens)
}
fn tokenize(&self, s: &str) -> Vec<TokenId> {
self.tokenize_bytes(s.as_bytes())
}
fn tokenize_special(&self, s: &str) -> Vec<TokenId> {
self.tokenize_bytes_special(s.as_bytes())
}
fn tokenize_bytes_special(&self, s: &[u8]) -> Vec<TokenId> {
self.tokenize_bytes(s)
}
fn eos_token(&self) -> TokenId {
self.tok_trie().eos_token()
}
fn tokenize_is_canonical(&self) -> bool {
true
}
}
pub type TokEnv = Arc<dyn TokenizerEnv + Sync + 'static>;
pub struct TokEnvWithTrie {
base_env: TokEnv,
tok_trie: TokTrie,
}
impl TokEnvWithTrie {
pub fn new(base_env: TokEnv, tok_trie: TokTrie) -> Self {
Self { base_env, tok_trie }
}
}
impl TokenizerEnv for TokEnvWithTrie {
fn tok_trie(&self) -> &TokTrie {
&self.tok_trie
}
fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
self.base_env.tokenize_bytes(s)
}
}
pub fn parse_numeric_token(s: &[u8]) -> Option<(usize, TokenId)> {
let spec_len = s[0..std::cmp::min(s.len(), 20)]
.iter()
.position(|&x| x == b']');
if let Some(spec_len) = spec_len {
if s[0] != b'[' {
return None;
}
let inner_bytes = &s[1..spec_len];
if let Ok(inner_str) = std::str::from_utf8(inner_bytes) {
if let Ok(id) = inner_str.parse::<u32>() {
return Some((spec_len + 1, id as TokenId));
}
}
}
None
}
pub struct ApproximateTokEnv {
trie: TokTrie,
canonical: bool,
}
impl ApproximateTokEnv {
pub fn new(trie: TokTrie) -> Self {
Self {
trie,
canonical: false,
}
}
pub fn single_byte() -> Self {
let mut words = (0..=255).map(|x| vec![x]).collect::<Vec<_>>();
words.push(b"\xFF<|tool|>".to_vec());
words.push(b"\xFF<|/tool|>".to_vec());
words.push(b"\xFF<|user|>".to_vec());
words.push(b"\xFF<|system|>".to_vec());
words.push(b"\xFF<|assistant|>".to_vec());
words.push(b"\xFF<|end|>".to_vec());
let info = TokRxInfo {
vocab_size: words.len() as u32,
tok_eos: words.len() as u32 - 1,
tok_bos: None,
tok_pad: None,
tok_unk: None,
tok_end_of_turn: None,
};
let mut r = ApproximateTokEnv::new(TokTrie::from(&info, &words));
r.canonical = true;
r
}
pub fn single_byte_env() -> TokEnv {
Arc::new(Self::single_byte())
}
}
impl TokenizerEnv for ApproximateTokEnv {
fn tok_trie(&self) -> &TokTrie {
&self.trie
}
fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
self.trie.greedy_tokenize(s)
}
fn tokenize_is_canonical(&self) -> bool {
self.canonical
}
}