use aho_corasick::AhoCorasick;
use lru::LruCache;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
use regexr::{Regex as RegexrRegex, RegexBuilder};
use rustc_hash::FxHashMap;
use rustc_hash::FxHasher;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::sync::Mutex;
use thiserror::Error;
#[cfg(feature = "pcre2")]
use pcre2::bytes::Regex as Pcre2Regex;
use super::bpe::byte_pair_encode;
use super::byte_level::{byte_level_decode_bytes, byte_level_encode};
use super::vocab::{build_decoder, load_tiktoken_bpe, load_tiktoken_bpe_file, VocabError};
#[derive(Error, Debug)]
pub enum TokenizerError {
#[error("Regex compilation error (regexr): {0}")]
RegexrError(#[from] regexr::Error),
#[cfg(feature = "pcre2")]
#[error("Regex compilation error (PCRE2): {0}")]
Pcre2Error(#[from] pcre2::Error),
#[error("Vocabulary error: {0}")]
VocabError(#[from] VocabError),
#[error("Decoding error: invalid UTF-8")]
Utf8Error,
#[error("Aho-Corasick build error: {0}")]
AhoCorasickError(#[from] aho_corasick::BuildError),
#[error("PCRE2 feature not enabled. Compile with --features pcre2")]
Pcre2NotEnabled,
#[error("Unknown pretrained model: {0}")]
UnknownPretrained(String),
}
pub const CL100K_BASE_PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
pub const O200K_BASE_PATTERN: &str = r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
pub const LLAMA3_PATTERN: &str = O200K_BASE_PATTERN;
pub const SENTENCEPIECE_PATTERN: &str = r"[^\s]+|\s+";
pub const MISTRAL_V3_PATTERN: &str = r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+";
pub mod cl100k_agent_tokens {
pub const SYSTEM: u32 = 100277;
pub const USER: u32 = 100278;
pub const ASSISTANT: u32 = 100279;
pub const IM_START: u32 = 100280;
pub const IM_END: u32 = 100281;
pub const THINK: u32 = 100282;
pub const THINK_END: u32 = 100283;
pub const PLAN: u32 = 100284;
pub const PLAN_END: u32 = 100285;
pub const STEP: u32 = 100286;
pub const STEP_END: u32 = 100287;
pub const ACT: u32 = 100288;
pub const ACT_END: u32 = 100289;
pub const OBSERVE: u32 = 100290;
pub const OBSERVE_END: u32 = 100291;
pub const FUNCTION: u32 = 100292;
pub const FUNCTION_END: u32 = 100293;
pub const RESULT: u32 = 100294;
pub const RESULT_END: u32 = 100295;
pub const ERROR: u32 = 100296;
pub const ERROR_END: u32 = 100297;
pub const CODE: u32 = 100298;
pub const CODE_END: u32 = 100299;
pub const OUTPUT: u32 = 100300;
pub const OUTPUT_END: u32 = 100301;
pub const LANG: u32 = 100302;
pub const LANG_END: u32 = 100303;
pub const CONTEXT: u32 = 100304;
pub const CONTEXT_END: u32 = 100305;
pub const QUOTE: u32 = 100306;
pub const QUOTE_END: u32 = 100307;
pub const CITE: u32 = 100308;
pub const CITE_END: u32 = 100309;
pub const SOURCE: u32 = 100310;
pub const SOURCE_END: u32 = 100311;
pub const MEMORY: u32 = 100312;
pub const MEMORY_END: u32 = 100313;
pub const RECALL: u32 = 100314;
pub const RECALL_END: u32 = 100315;
pub const PAD: u32 = 100316;
pub const STOP: u32 = 100317;
pub const SEP: u32 = 100318;
pub const IMAGE: u32 = 100319;
pub const IMAGE_END: u32 = 100320;
pub const AUDIO: u32 = 100321;
pub const AUDIO_END: u32 = 100322;
pub const VIDEO: u32 = 100323;
pub const VIDEO_END: u32 = 100324;
pub const TITLE: u32 = 100325;
pub const TITLE_END: u32 = 100326;
pub const SECTION: u32 = 100327;
pub const SECTION_END: u32 = 100328;
pub const SUMMARY: u32 = 100329;
pub const SUMMARY_END: u32 = 100330;
}
pub mod o200k_agent_tokens {
pub const SYSTEM: u32 = 200019;
pub const USER: u32 = 200020;
pub const ASSISTANT: u32 = 200021;
pub const IM_START: u32 = 200022;
pub const IM_END: u32 = 200023;
pub const THINK: u32 = 200024;
pub const THINK_END: u32 = 200025;
pub const PLAN: u32 = 200026;
pub const PLAN_END: u32 = 200027;
pub const STEP: u32 = 200028;
pub const STEP_END: u32 = 200029;
pub const ACT: u32 = 200030;
pub const ACT_END: u32 = 200031;
pub const OBSERVE: u32 = 200032;
pub const OBSERVE_END: u32 = 200033;
pub const FUNCTION: u32 = 200034;
pub const FUNCTION_END: u32 = 200035;
pub const RESULT: u32 = 200036;
pub const RESULT_END: u32 = 200037;
pub const ERROR: u32 = 200038;
pub const ERROR_END: u32 = 200039;
pub const CODE: u32 = 200040;
pub const CODE_END: u32 = 200041;
pub const OUTPUT: u32 = 200042;
pub const OUTPUT_END: u32 = 200043;
pub const LANG: u32 = 200044;
pub const LANG_END: u32 = 200045;
pub const CONTEXT: u32 = 200046;
pub const CONTEXT_END: u32 = 200047;
pub const QUOTE: u32 = 200048;
pub const QUOTE_END: u32 = 200049;
pub const CITE: u32 = 200050;
pub const CITE_END: u32 = 200051;
pub const SOURCE: u32 = 200052;
pub const SOURCE_END: u32 = 200053;
pub const MEMORY: u32 = 200054;
pub const MEMORY_END: u32 = 200055;
pub const RECALL: u32 = 200056;
pub const RECALL_END: u32 = 200057;
pub const PAD: u32 = 200058;
pub const STOP: u32 = 200059;
pub const SEP: u32 = 200060;
pub const IMAGE: u32 = 200061;
pub const IMAGE_END: u32 = 200062;
pub const AUDIO: u32 = 200063;
pub const AUDIO_END: u32 = 200064;
pub const VIDEO: u32 = 200065;
pub const VIDEO_END: u32 = 200066;
pub const TITLE: u32 = 200067;
pub const TITLE_END: u32 = 200068;
pub const SECTION: u32 = 200069;
pub const SECTION_END: u32 = 200070;
pub const SUMMARY: u32 = 200071;
pub const SUMMARY_END: u32 = 200072;
}
const DEFAULT_CACHE_SIZE: usize = 4096;
enum RegexBackend {
Regexr(Box<RegexrRegex>),
#[cfg(feature = "pcre2")]
Pcre2(Pcre2Regex),
}
impl RegexBackend {
fn find_iter<'a>(&'a self, text: &'a str) -> Vec<(usize, usize)> {
match self {
RegexBackend::Regexr(regex) => regex
.find_iter(text)
.map(|m| (m.start(), m.end()))
.collect(),
#[cfg(feature = "pcre2")]
RegexBackend::Pcre2(regex) => regex
.find_iter(text.as_bytes())
.filter_map(|m| m.ok())
.map(|m| (m.start(), m.end()))
.collect(),
}
}
}
pub struct Tokenizer {
encoder: FxHashMap<Vec<u8>, u32>,
decoder: FxHashMap<u32, Vec<u8>>,
special_tokens: FxHashMap<String, u32>,
special_tokens_decoder: FxHashMap<u32, String>,
special_token_strings: Vec<String>,
regex: RegexBackend,
pattern: String,
special_matcher: Option<AhoCorasick>,
chunk_cache: Mutex<LruCache<u64, Vec<u32>>>,
use_byte_level: bool,
use_sentencepiece: bool,
cache_size: usize,
use_jit: bool,
use_pcre2: bool,
}
impl Tokenizer {
pub fn new(
encoder: FxHashMap<Vec<u8>, u32>,
special_tokens: FxHashMap<String, u32>,
pattern: &str,
) -> Result<Self, TokenizerError> {
Self::with_options(encoder, special_tokens, pattern, DEFAULT_CACHE_SIZE, false)
}
pub fn new_byte_level(
encoder: FxHashMap<Vec<u8>, u32>,
special_tokens: FxHashMap<String, u32>,
pattern: &str,
) -> Result<Self, TokenizerError> {
Self::with_options(encoder, special_tokens, pattern, DEFAULT_CACHE_SIZE, true)
}
pub fn new_sentencepiece(
encoder: FxHashMap<Vec<u8>, u32>,
special_tokens: FxHashMap<String, u32>,
pattern: &str,
) -> Result<Self, TokenizerError> {
Self::with_full_options(
encoder,
special_tokens,
pattern,
DEFAULT_CACHE_SIZE,
false,
true,
)
}
pub fn with_cache_size(
encoder: FxHashMap<Vec<u8>, u32>,
special_tokens: FxHashMap<String, u32>,
pattern: &str,
cache_size: usize,
) -> Result<Self, TokenizerError> {
Self::with_options(encoder, special_tokens, pattern, cache_size, false)
}
pub fn with_options(
encoder: FxHashMap<Vec<u8>, u32>,
special_tokens: FxHashMap<String, u32>,
pattern: &str,
cache_size: usize,
use_byte_level: bool,
) -> Result<Self, TokenizerError> {
Self::with_full_options(
encoder,
special_tokens,
pattern,
cache_size,
use_byte_level,
false,
)
}
pub fn with_full_options(
encoder: FxHashMap<Vec<u8>, u32>,
special_tokens: FxHashMap<String, u32>,
pattern: &str,
cache_size: usize,
use_byte_level: bool,
use_sentencepiece: bool,
) -> Result<Self, TokenizerError> {
let decoder = build_decoder(&encoder);
let special_tokens_decoder: FxHashMap<u32, String> = special_tokens
.iter()
.map(|(k, v)| (*v, k.clone()))
.collect();
let regex = RegexBuilder::new(pattern).jit(true).build()?;
let special_token_strings: Vec<String> = special_tokens.keys().cloned().collect();
let special_matcher = if special_token_strings.is_empty() {
None
} else {
Some(AhoCorasick::new(&special_token_strings)?)
};
let cache_size_nz = NonZeroUsize::new(cache_size.max(1)).unwrap();
let chunk_cache = Mutex::new(LruCache::new(cache_size_nz));
Ok(Self {
encoder,
decoder,
special_tokens,
special_tokens_decoder,
special_token_strings,
regex: RegexBackend::Regexr(Box::new(regex)),
pattern: pattern.to_string(),
special_matcher,
chunk_cache,
use_byte_level,
use_sentencepiece,
cache_size,
use_jit: true,
use_pcre2: false,
})
}
#[cfg(feature = "pcre2")]
pub fn pcre2(mut self, use_pcre2: bool) -> Result<Self, TokenizerError> {
self.use_pcre2 = use_pcre2;
if use_pcre2 {
let mut regex_builder = pcre2::bytes::RegexBuilder::new();
if self.use_jit {
regex_builder.jit_if_available(true);
}
regex_builder.utf(true);
regex_builder.ucp(true);
let regex = regex_builder.build(&self.pattern)?;
self.regex = RegexBackend::Pcre2(regex);
} else {
let regex = RegexBuilder::new(&self.pattern).jit(self.use_jit).build()?;
self.regex = RegexBackend::Regexr(Box::new(regex));
}
Ok(self)
}
#[cfg(not(feature = "pcre2"))]
pub fn pcre2(self, use_pcre2: bool) -> Result<Self, TokenizerError> {
if use_pcre2 {
Err(TokenizerError::Pcre2NotEnabled)
} else {
Ok(self)
}
}
#[cfg(feature = "pcre2")]
pub fn jit(mut self, use_jit: bool) -> Result<Self, TokenizerError> {
self.use_jit = use_jit;
if self.use_pcre2 {
let mut regex_builder = pcre2::bytes::RegexBuilder::new();
if use_jit {
regex_builder.jit_if_available(true);
}
regex_builder.utf(true);
regex_builder.ucp(true);
let regex = regex_builder.build(&self.pattern)?;
self.regex = RegexBackend::Pcre2(regex);
} else {
let regex = RegexBuilder::new(&self.pattern).jit(use_jit).build()?;
self.regex = RegexBackend::Regexr(Box::new(regex));
}
Ok(self)
}
#[cfg(not(feature = "pcre2"))]
pub fn jit(mut self, use_jit: bool) -> Result<Self, TokenizerError> {
self.use_jit = use_jit;
let regex = RegexBuilder::new(&self.pattern).jit(use_jit).build()?;
self.regex = RegexBackend::Regexr(Box::new(regex));
Ok(self)
}
pub fn from_file(
vocab_path: &str,
pattern: &str,
special_tokens: FxHashMap<String, u32>,
) -> Result<Self, TokenizerError> {
let encoder = load_tiktoken_bpe_file(vocab_path)?;
Self::new(encoder, special_tokens, pattern)
}
pub fn from_bytes(
vocab_data: &[u8],
pattern: &str,
special_tokens: FxHashMap<String, u32>,
) -> Result<Self, TokenizerError> {
let encoder = load_tiktoken_bpe(vocab_data)?;
Self::new(encoder, special_tokens, pattern)
}
pub fn from_bytes_byte_level(
vocab_data: &[u8],
pattern: &str,
special_tokens: FxHashMap<String, u32>,
) -> Result<Self, TokenizerError> {
let encoder = load_tiktoken_bpe(vocab_data)?;
Self::new_byte_level(encoder, special_tokens, pattern)
}
pub fn from_bytes_sentencepiece(
vocab_data: &[u8],
pattern: &str,
special_tokens: FxHashMap<String, u32>,
) -> Result<Self, TokenizerError> {
let encoder = load_tiktoken_bpe(vocab_data)?;
Self::new_sentencepiece(encoder, special_tokens, pattern)
}
pub fn from_bytes_sentencepiece_with_decoder(
vocab_data: &[u8],
pattern: &str,
special_tokens: FxHashMap<String, u32>,
) -> Result<Self, TokenizerError> {
use crate::core::vocab::load_tiktoken_bpe_with_decoder;
let (encoder, mut decoder) = load_tiktoken_bpe_with_decoder(vocab_data)?;
for (token_str, id) in &special_tokens {
decoder.insert(*id, token_str.as_bytes().to_vec());
}
let special_tokens_decoder: FxHashMap<u32, String> = special_tokens
.iter()
.map(|(k, v)| (*v, k.clone()))
.collect();
let regex = RegexBuilder::new(pattern).jit(true).build()?;
let special_token_strings: Vec<String> = special_tokens.keys().cloned().collect();
let special_matcher = if special_token_strings.is_empty() {
None
} else {
Some(AhoCorasick::new(&special_token_strings)?)
};
let cache_size_nz = NonZeroUsize::new(DEFAULT_CACHE_SIZE.max(1)).unwrap();
let chunk_cache = Mutex::new(LruCache::new(cache_size_nz));
Ok(Self {
encoder,
decoder,
special_tokens,
special_tokens_decoder,
special_token_strings,
regex: RegexBackend::Regexr(Box::new(regex)),
pattern: pattern.to_string(),
special_matcher,
chunk_cache,
use_byte_level: false,
use_sentencepiece: true,
cache_size: DEFAULT_CACHE_SIZE,
use_jit: true,
use_pcre2: false,
})
}
#[inline]
fn hash_slice(slice: &[u8]) -> u64 {
let mut hasher = FxHasher::default();
slice.hash(&mut hasher);
hasher.finish()
}
#[allow(dead_code)]
fn encode_chunk_sentencepiece(&self, slice: &[u8], add_prefix: bool) -> Vec<u32> {
let bytes_to_encode: std::borrow::Cow<[u8]> = if add_prefix {
let mut with_prefix = Vec::with_capacity(slice.len() + 3);
with_prefix.extend_from_slice("▁".as_bytes()); with_prefix.extend_from_slice(slice);
std::borrow::Cow::Owned(with_prefix)
} else {
std::borrow::Cow::Borrowed(slice)
};
self.encode_bytes_with_cache(bytes_to_encode.as_ref())
}
fn encode_bytes_with_cache(&self, bytes: &[u8]) -> Vec<u32> {
if let Some(&rank) = self.encoder.get(bytes) {
return vec![rank];
}
let hash = Self::hash_slice(bytes);
if let Ok(mut cache) = self.chunk_cache.lock() {
if let Some(cached) = cache.get(&hash) {
return cached.clone();
}
}
let result = byte_pair_encode(bytes, &self.encoder);
if let Ok(mut cache) = self.chunk_cache.lock() {
cache.put(hash, result.clone());
}
result
}
fn encode_chunk_with_position(&self, slice: &[u8], _position: usize) -> Vec<u32> {
let bytes_to_encode: std::borrow::Cow<[u8]> = if self.use_byte_level {
let byte_level_str = byte_level_encode(slice);
std::borrow::Cow::Owned(byte_level_str.into_bytes())
} else {
std::borrow::Cow::Borrowed(slice)
};
if let Some(&rank) = self.encoder.get(bytes_to_encode.as_ref()) {
return vec![rank];
}
let hash = Self::hash_slice(bytes_to_encode.as_ref());
if let Ok(mut cache) = self.chunk_cache.lock() {
if let Some(cached) = cache.get(&hash) {
return cached.clone();
}
}
let result = byte_pair_encode(bytes_to_encode.as_ref(), &self.encoder);
if let Ok(mut cache) = self.chunk_cache.lock() {
cache.put(hash, result.clone());
}
result
}
pub fn encode(&self, text: &str) -> Vec<u32> {
let text_bytes = text.as_bytes();
let chunks = self.regex.find_iter(text);
if chunks.is_empty() {
return vec![];
}
if self.use_sentencepiece {
let mut results = Vec::new();
let mut pending_underscores = 0usize;
for &(start, end) in chunks.iter() {
let slice = &text_bytes[start..end];
if slice.is_empty() {
continue;
}
if slice[0].is_ascii_whitespace() {
for &b in slice {
if b == b' ' {
pending_underscores += 1;
} else {
if pending_underscores > 0 {
let underscores = "▁".repeat(pending_underscores);
results
.extend(self.encode_bytes_with_cache(underscores.as_bytes()));
pending_underscores = 0;
}
results.extend(self.encode_bytes_with_cache(&[b]));
}
}
} else {
if pending_underscores > 0 {
let mut with_prefix =
Vec::with_capacity(pending_underscores * 3 + slice.len());
for _ in 0..pending_underscores {
with_prefix.extend_from_slice("▁".as_bytes());
}
with_prefix.extend_from_slice(slice);
results.extend(self.encode_bytes_with_cache(&with_prefix));
pending_underscores = 0;
} else {
results.extend(self.encode_bytes_with_cache(slice));
}
}
}
if pending_underscores > 0 {
let underscores = "▁".repeat(pending_underscores);
results.extend(self.encode_bytes_with_cache(underscores.as_bytes()));
}
results
} else {
let results: Vec<Vec<u32>> = chunks
.iter()
.map(|&(start, end)| {
let slice = &text_bytes[start..end];
self.encode_chunk_with_position(slice, start)
})
.collect();
results.into_iter().flatten().collect()
}
}
pub fn encode_rayon(&self, text: &str) -> Vec<u32> {
if self.use_sentencepiece {
return self.encode(text);
}
let text_bytes = text.as_bytes();
let chunks = self.regex.find_iter(text);
if chunks.is_empty() {
return vec![];
}
#[cfg(feature = "rayon")]
let results: Vec<Vec<u32>> = chunks
.par_iter()
.map(|&(start, end)| {
let slice = &text_bytes[start..end];
self.encode_chunk_with_position(slice, start)
})
.collect();
#[cfg(not(feature = "rayon"))]
let results: Vec<Vec<u32>> = chunks
.iter()
.map(|&(start, end)| {
let slice = &text_bytes[start..end];
self.encode_chunk_with_position(slice, start)
})
.collect();
results.into_iter().flatten().collect()
}
pub fn encode_with_special(&self, text: &str) -> Vec<u32> {
let Some(ref special_matcher) = self.special_matcher else {
return self.encode(text);
};
let text_bytes = text.as_bytes();
let mut result = Vec::new();
let mut last_end = 0;
for m in special_matcher.find_iter(text_bytes) {
let start = m.start();
let end = m.end();
if start > last_end {
let slice = &text[last_end..start];
result.extend(self.encode(slice));
}
let pattern_idx = m.pattern().as_usize();
let token_str = &self.special_token_strings[pattern_idx];
if let Some(&rank) = self.special_tokens.get(token_str) {
result.push(rank);
}
last_end = end;
}
if last_end < text.len() {
result.extend(self.encode(&text[last_end..]));
}
result
}
pub fn decode_bytes(&self, tokens: &[u32]) -> Vec<u8> {
let mut result = Vec::with_capacity(tokens.len() * 4);
for &token in tokens {
if let Some(bytes) = self.decoder.get(&token) {
if self.use_byte_level {
if let Some(decoded) = byte_level_decode_bytes(bytes) {
result.extend_from_slice(&decoded);
} else {
result.extend_from_slice(bytes);
}
} else {
result.extend_from_slice(bytes);
}
} else if let Some(special) = self.special_tokens_decoder.get(&token) {
result.extend_from_slice(special.as_bytes());
}
}
result
}
pub fn decode(&self, tokens: &[u32]) -> Result<String, TokenizerError> {
let bytes = self.decode_bytes(tokens);
let text = String::from_utf8(bytes).map_err(|_| TokenizerError::Utf8Error)?;
Ok(self.postprocess_decode(text))
}
pub fn decode_lossy(&self, tokens: &[u32]) -> String {
let bytes = self.decode_bytes(tokens);
let text = String::from_utf8_lossy(&bytes).into_owned();
self.postprocess_decode(text)
}
#[inline]
fn postprocess_decode(&self, text: String) -> String {
if self.use_sentencepiece {
text.replace('\u{2581}', " ")
} else {
text
}
}
pub fn encode_batch(&self, texts: &[String]) -> Vec<Vec<u32>> {
#[cfg(feature = "rayon")]
{
texts.par_iter().map(|text| self.encode(text)).collect()
}
#[cfg(not(feature = "rayon"))]
{
texts.iter().map(|text| self.encode(text)).collect()
}
}
pub fn encode_batch_with_special(&self, texts: &[String]) -> Vec<Vec<u32>> {
#[cfg(feature = "rayon")]
{
texts
.par_iter()
.map(|text| self.encode_with_special(text))
.collect()
}
#[cfg(not(feature = "rayon"))]
{
texts
.iter()
.map(|text| self.encode_with_special(text))
.collect()
}
}
pub fn decode_batch(&self, token_lists: &[Vec<u32>]) -> Result<Vec<String>, TokenizerError> {
#[cfg(feature = "rayon")]
{
token_lists
.par_iter()
.map(|tokens| self.decode(tokens))
.collect()
}
#[cfg(not(feature = "rayon"))]
{
token_lists
.iter()
.map(|tokens| self.decode(tokens))
.collect()
}
}
pub fn decode_batch_lossy(&self, token_lists: &[Vec<u32>]) -> Vec<String> {
#[cfg(feature = "rayon")]
{
token_lists
.par_iter()
.map(|tokens| self.decode_lossy(tokens))
.collect()
}
#[cfg(not(feature = "rayon"))]
{
token_lists
.iter()
.map(|tokens| self.decode_lossy(tokens))
.collect()
}
}
pub fn vocab_size(&self) -> usize {
let max_decoder_id = self.decoder.keys().max().copied().unwrap_or(0);
let max_special_id = self.special_tokens.values().max().copied().unwrap_or(0);
let max_id = max_decoder_id.max(max_special_id);
(max_id + 1) as usize
}
pub fn encoder(&self) -> &FxHashMap<Vec<u8>, u32> {
&self.encoder
}
pub fn decoder(&self) -> &FxHashMap<u32, Vec<u8>> {
&self.decoder
}
pub fn special_tokens(&self) -> &FxHashMap<String, u32> {
&self.special_tokens
}
pub fn special_tokens_decoder(&self) -> &FxHashMap<u32, String> {
&self.special_tokens_decoder
}
pub fn clear_cache(&self) {
if let Ok(mut cache) = self.chunk_cache.lock() {
cache.clear();
}
}
pub fn cache_len(&self) -> usize {
self.chunk_cache.lock().map(|c| c.len()).unwrap_or(0)
}
}
impl Clone for Tokenizer {
fn clone(&self) -> Self {
let regex = match &self.regex {
RegexBackend::Regexr(_) => {
let regex = RegexBuilder::new(&self.pattern)
.jit(self.use_jit)
.build()
.unwrap();
RegexBackend::Regexr(Box::new(regex))
}
#[cfg(feature = "pcre2")]
RegexBackend::Pcre2(_) => {
let mut regex_builder = pcre2::bytes::RegexBuilder::new();
if self.use_jit {
regex_builder.jit_if_available(true);
}
regex_builder.utf(true);
regex_builder.ucp(true);
let regex = regex_builder.build(&self.pattern).unwrap();
RegexBackend::Pcre2(regex)
}
};
let cache_size_nz = NonZeroUsize::new(self.cache_size.max(1)).unwrap();
let chunk_cache = Mutex::new(LruCache::new(cache_size_nz));
let special_matcher = if self.special_token_strings.is_empty() {
None
} else {
Some(AhoCorasick::new(&self.special_token_strings).unwrap())
};
Self {
encoder: self.encoder.clone(),
decoder: self.decoder.clone(),
special_tokens: self.special_tokens.clone(),
special_tokens_decoder: self.special_tokens_decoder.clone(),
special_token_strings: self.special_token_strings.clone(),
regex,
pattern: self.pattern.clone(),
special_matcher,
chunk_cache,
use_byte_level: self.use_byte_level,
use_sentencepiece: self.use_sentencepiece,
cache_size: self.cache_size,
use_jit: self.use_jit,
use_pcre2: self.use_pcre2,
}
}
}
impl super::tokenize::Tokenize for Tokenizer {
fn encode(&self, text: &str) -> Vec<u32> {
self.encode(text)
}
fn decode(&self, ids: &[u32]) -> Result<String, super::tokenize::TokenizeError> {
self.decode(ids).map_err(|e| match e {
TokenizerError::Utf8Error => super::tokenize::TokenizeError::Utf8Error,
other => super::tokenize::TokenizeError::Other(other.to_string()),
})
}
fn vocab_size(&self) -> usize {
self.vocab_size()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_tokenizer() -> Tokenizer {
let mut encoder = FxHashMap::default();
for b in 32u8..=126 {
encoder.insert(vec![b], b as u32);
}
encoder.insert(b"Hello".to_vec(), 200);
encoder.insert(b"World".to_vec(), 201);
encoder.insert(b" World".to_vec(), 202);
let mut special_tokens = FxHashMap::default();
special_tokens.insert("<|endoftext|>".to_string(), 50256);
let pattern = r"\S+|\s+";
Tokenizer::new(encoder, special_tokens, pattern).unwrap()
}
#[test]
fn test_encode_decode() {
let tokenizer = make_test_tokenizer();
let text = "Hello World";
let tokens = tokenizer.encode(text);
let decoded = tokenizer.decode(&tokens).unwrap();
assert_eq!(decoded, text);
}
#[test]
fn test_encode_with_special() {
let tokenizer = make_test_tokenizer();
let text = "Hello<|endoftext|>World";
let tokens = tokenizer.encode_with_special(text);
assert!(tokens.contains(&50256));
}
#[test]
fn test_batch_encode() {
let tokenizer = make_test_tokenizer();
let texts = vec!["Hello".to_string(), "World".to_string()];
let batch_tokens = tokenizer.encode_batch(&texts);
assert_eq!(batch_tokens.len(), 2);
}
#[test]
fn test_vocab_size() {
let tokenizer = make_test_tokenizer();
assert!(tokenizer.vocab_size() > 0);
}
#[test]
fn test_cache_works() {
let tokenizer = make_test_tokenizer();
let text = "HelloWorld";
let tokens1 = tokenizer.encode(text);
let tokens2 = tokenizer.encode(text);
assert_eq!(tokens1, tokens2);
assert!(tokenizer.cache_len() > 0);
}
#[test]
fn test_clear_cache() {
let tokenizer = make_test_tokenizer();
tokenizer.encode("HelloWorld");
assert!(tokenizer.cache_len() > 0);
tokenizer.clear_cache();
assert_eq!(tokenizer.cache_len(), 0);
}
#[cfg(feature = "pcre2")]
#[test]
fn test_pcre2_backend() {
let tokenizer = make_test_tokenizer().pcre2(true).unwrap();
let text = "Hello World";
let tokens = tokenizer.encode(text);
let decoded = tokenizer.decode(&tokens).unwrap();
assert_eq!(decoded, text);
}
#[cfg(not(feature = "pcre2"))]
#[test]
fn test_pcre2_not_enabled() {
let tokenizer = make_test_tokenizer();
let result = tokenizer.pcre2(true);
assert!(result.is_err());
}
#[test]
fn test_jit_disable() {
let tokenizer = make_test_tokenizer().jit(false).unwrap();
let text = "Hello World";
let tokens = tokenizer.encode(text);
let decoded = tokenizer.decode(&tokens).unwrap();
assert_eq!(decoded, text);
}
#[test]
fn test_jit_enable() {
let tokenizer = make_test_tokenizer().jit(true).unwrap();
let text = "Hello World";
let tokens = tokenizer.encode(text);
let decoded = tokenizer.decode(&tokens).unwrap();
assert_eq!(decoded, text);
}
#[cfg(feature = "pcre2")]
#[test]
fn test_pcre2_switch_back_to_regexr() {
let tokenizer = make_test_tokenizer()
.pcre2(true)
.unwrap()
.pcre2(false)
.unwrap();
let text = "Hello World";
let tokens = tokenizer.encode(text);
let decoded = tokenizer.decode(&tokens).unwrap();
assert_eq!(decoded, text);
}
#[cfg(feature = "pcre2")]
#[test]
fn test_pcre2_with_jit_disabled() {
let tokenizer = make_test_tokenizer()
.jit(false)
.unwrap()
.pcre2(true)
.unwrap();
let text = "Hello World";
let tokens = tokenizer.encode(text);
let decoded = tokenizer.decode(&tokens).unwrap();
assert_eq!(decoded, text);
}
const _: () = {
assert!(super::cl100k_agent_tokens::SYSTEM > 100276);
assert!(super::cl100k_agent_tokens::SUMMARY_END == 100330);
assert!(super::o200k_agent_tokens::SYSTEM > 200018);
assert!(super::o200k_agent_tokens::SUMMARY_END == 200072);
assert!(super::cl100k_agent_tokens::USER == super::cl100k_agent_tokens::SYSTEM + 1);
assert!(super::o200k_agent_tokens::USER == super::o200k_agent_tokens::SYSTEM + 1);
};
}