use std::cell::RefCell;
use std::collections::HashMap;
use regex::Regex;
use crate::byte_encoder::{encode_byte_level_chars, METASPACE};
use crate::map::TokenizerMap;
pub trait ITokenizer: Send {
fn id(&self) -> &str;
fn encode(&self, text: &str) -> Vec<u32>;
}
pub struct BPETokenizer {
id: String,
vocab: HashMap<String, u32>,
merge_ranks: HashMap<String, u32>,
pre_tok_regex: Option<Regex>,
pre_tok_program: Option<crate::pretok_program::PreTokProgram>,
encoder: String,
byte_fallback_start: i64,
cache: RefCell<HashMap<String, Vec<u32>>>,
special_ids: HashMap<String, u32>,
special_regex: Option<Regex>,
}
impl BPETokenizer {
pub fn supports(map: &TokenizerMap) -> bool {
let has_vocab = map.vocab.as_ref().is_some_and(|v| !v.is_empty());
let has_merges = map.merges.as_ref().is_some_and(|v| !v.is_empty());
let enc_ok = matches!(map.encoder.as_deref(), Some("byte_level") | Some("metaspace"));
has_vocab && has_merges && enc_ok
}
pub fn new(map: &TokenizerMap) -> Result<Self, String> {
if !Self::supports(map) {
return Err(format!(
"BPETokenizer: map \"{}\" lacks vocab/merges/encoder. \
Use BPETokenizer::supports(map) to check first, or call \
Tokenize::pick(map) which falls back to LongestMatchTokenizer.",
map.id
));
}
let vocab = map.vocab.as_ref().expect("supports() checked").clone();
let merges = map.merges.as_ref().expect("supports() checked");
let encoder = map.encoder.as_ref().expect("supports() checked").clone();
let id = map.id.clone();
let byte_fallback_start = map.byte_fallback_start.unwrap_or(-1);
let mut merge_ranks: HashMap<String, u32> = HashMap::with_capacity(merges.len());
for (i, m) in merges.iter().enumerate() {
merge_ranks.insert(m.clone(), i as u32);
}
let (pre_tok_regex, pre_tok_program) = if encoder == "byte_level" {
if let Some(prog) = map.pre_tokenizer_program.as_ref() {
if prog.ops.is_empty() {
return Err(format!(
"BPETokenizer: byte_level map \"{}\" has empty pre_tokenizer_program.",
map.id
));
}
(None, Some(prog.clone()))
} else if let Some(pat) = map.pre_tokenizer_pattern.as_ref() {
let re = Regex::new(pat)
.map_err(|e| format!("BPETokenizer: invalid pre_tokenizer_pattern: {e}"))?;
(Some(re), None)
} else {
return Err(format!(
"BPETokenizer: byte_level map \"{}\" missing both pre_tokenizer_program and pre_tokenizer_pattern.",
map.id
));
}
} else {
(None, None)
};
let mut special_ids: HashMap<String, u32> = HashMap::new();
if let Some(specials) = map.special_tokens.as_ref() {
for (name, id) in specials.iter() {
special_ids.insert(name.clone(), *id);
}
}
for (tok, id) in vocab.iter() {
if special_ids.contains_key(tok) {
continue;
}
if is_delimiter_shape(tok) {
special_ids.insert(tok.clone(), *id);
}
}
let special_regex = if special_ids.is_empty() {
None
} else {
let mut keys: Vec<&String> = special_ids.keys().collect();
keys.sort_by_key(|k| std::cmp::Reverse(k.len()));
let alt = keys
.iter()
.map(|k| regex::escape(k))
.collect::<Vec<_>>()
.join("|");
Some(
Regex::new(&alt)
.map_err(|e| format!("BPETokenizer: bad special-token regex: {e}"))?,
)
};
Ok(Self {
id,
vocab,
merge_ranks,
pre_tok_regex,
pre_tok_program,
encoder,
byte_fallback_start,
cache: RefCell::new(HashMap::new()),
special_ids,
special_regex,
})
}
pub fn encode(&self, text: &str) -> Vec<u32> {
if text.is_empty() {
return Vec::new();
}
if let Some(re) = self.special_regex.as_ref() {
let mut ids: Vec<u32> = Vec::new();
let mut cursor = 0usize;
for m in re.find_iter(text) {
if m.start() > cursor {
self.encode_chunk(&text[cursor..m.start()], &mut ids);
}
ids.push(self.special_ids[m.as_str()]);
cursor = m.end();
}
if cursor < text.len() {
self.encode_chunk(&text[cursor..], &mut ids);
}
return ids;
}
let mut ids: Vec<u32> = Vec::new();
self.encode_chunk(text, &mut ids);
ids
}
fn encode_chunk(&self, text: &str, out: &mut Vec<u32>) {
if text.is_empty() {
return;
}
let pieces = self.pre_tokenize(text);
for piece in pieces {
if let Ok(cache) = self.cache.try_borrow() {
if let Some(cached) = cache.get(&piece) {
out.extend_from_slice(cached);
continue;
}
}
let encoded = self.encode_piece_to_vocab_space(&piece);
let merged = self.apply_bpe(encoded);
let piece_ids = self.lookup(&merged);
if let Ok(mut cache) = self.cache.try_borrow_mut() {
cache.insert(piece.clone(), piece_ids.clone());
}
out.extend_from_slice(&piece_ids);
}
}
fn pre_tokenize(&self, text: &str) -> Vec<String> {
if self.encoder == "byte_level" {
if let Some(prog) = self.pre_tok_program.as_ref() {
return crate::pretok_program::run_pretok_program(prog, text);
}
let re = self.pre_tok_regex.as_ref().expect("byte_level requires regex or program");
return re.find_iter(text).map(|m| m.as_str().to_string()).collect();
}
let collapsed = collapse_spaces_and_tabs(text);
let parts = split_keep_whitespace(&collapsed);
let mut pieces: Vec<String> = Vec::new();
for p in parts {
if p == " " {
continue;
}
let mut s = String::with_capacity(p.len() + 3);
s.push(METASPACE);
s.push_str(&p);
pieces.push(s);
}
pieces
}
fn encode_piece_to_vocab_space(&self, piece: &str) -> Vec<String> {
if self.encoder == "byte_level" {
let bytes = piece.as_bytes();
let encoded = encode_byte_level_chars(bytes);
return codepoints(&encoded);
}
codepoints(piece)
}
fn apply_bpe(&self, tokens: Vec<String>) -> Vec<String> {
if tokens.len() < 2 {
return tokens;
}
let mut parts = tokens;
loop {
let mut best_idx: Option<usize> = None;
let mut best_rank: u32 = u32::MAX;
for i in 0..parts.len() - 1 {
let mut key = String::with_capacity(parts[i].len() + 1 + parts[i + 1].len());
key.push_str(&parts[i]);
key.push(' ');
key.push_str(&parts[i + 1]);
if let Some(&r) = self.merge_ranks.get(&key) {
if r < best_rank {
best_rank = r;
best_idx = Some(i);
}
}
}
let Some(_idx) = best_idx else {
break;
};
let left = parts[best_idx.unwrap()].clone();
let right = parts[best_idx.unwrap() + 1].clone();
let merged = format!("{left}{right}");
let mut next: Vec<String> = Vec::with_capacity(parts.len());
let mut j = 0;
while j < parts.len() {
if j + 1 < parts.len() && parts[j] == left && parts[j + 1] == right {
next.push(merged.clone());
j += 2;
} else {
next.push(parts[j].clone());
j += 1;
}
}
parts = next;
}
parts
}
fn lookup(&self, tokens: &[String]) -> Vec<u32> {
let mut ids: Vec<u32> = Vec::with_capacity(tokens.len());
for tok in tokens {
if let Some(&id) = self.vocab.get(tok) {
ids.push(id);
continue;
}
if self.byte_fallback_start >= 0 {
for &b in tok.as_bytes() {
ids.push((self.byte_fallback_start + b as i64) as u32);
}
}
}
ids
}
}
impl ITokenizer for BPETokenizer {
fn id(&self) -> &str {
&self.id
}
fn encode(&self, text: &str) -> Vec<u32> {
BPETokenizer::encode(self, text)
}
}
fn codepoints(s: &str) -> Vec<String> {
s.chars().map(|c| c.to_string()).collect()
}
fn collapse_spaces_and_tabs(s: &str) -> String {
let mut out = String::with_capacity(s.len());
let mut prev_space = false;
for c in s.chars() {
if c == ' ' || c == '\t' {
if !prev_space {
out.push(' ');
prev_space = true;
}
} else {
out.push(c);
prev_space = false;
}
}
out
}
fn is_delimiter_shape(tok: &str) -> bool {
if tok.len() <= 4 {
return false;
}
let bytes = tok.as_bytes();
if !(bytes.starts_with(b"<|") && bytes.ends_with(b"|>")) {
return false;
}
let body = &tok[2..tok.len() - 2];
!body.is_empty()
&& body
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
}
fn split_keep_whitespace(s: &str) -> Vec<String> {
let mut parts: Vec<String> = Vec::new();
let mut buf = String::new();
for c in s.chars() {
if c.is_whitespace() {
if !buf.is_empty() {
parts.push(std::mem::take(&mut buf));
}
parts.push(c.to_string());
} else {
buf.push(c);
}
}
if !buf.is_empty() {
parts.push(buf);
}
parts
}