use std::collections::HashMap;
use crate::error::{Result, RullamaError};
use crate::gguf::GgufReader;
pub const TOKEN_TYPE_NORMAL: u32 = 1;
pub const TOKEN_TYPE_UNKNOWN: u32 = 2;
pub const TOKEN_TYPE_CONTROL: u32 = 3;
pub const TOKEN_TYPE_USER_DEFINED: u32 = 4;
pub const TOKEN_TYPE_UNUSED: u32 = 5;
pub const TOKEN_TYPE_BYTE: u32 = 6;
pub const SPM_SPACE: char = '▁';
pub struct BpeTokenizer {
vocab: HashMap<String, u32>,
rev_vocab: Vec<String>,
merges: HashMap<(String, String), u32>,
specials: Vec<(String, u32)>,
byte_fallback: [Option<u32>; 256],
}
impl BpeTokenizer {
pub fn from_gguf(r: &GgufReader) -> Result<Self> {
let tokens = r.get("tokenizer.ggml.tokens")?.as_string_array()?.to_vec();
let types = r.get("tokenizer.ggml.token_type")?.as_u32_array()?;
if types.len() != tokens.len() {
return Err(RullamaError::Tokenizer(format!(
"token_type len {} != tokens len {}",
types.len(),
tokens.len()
)));
}
let mut vocab: HashMap<String, u32> = HashMap::with_capacity(tokens.len());
for (i, t) in tokens.iter().enumerate() {
vocab.insert(t.clone(), i as u32);
}
let mut specials: Vec<(String, u32)> = tokens
.iter()
.enumerate()
.filter(|(i, _)| {
types[*i] == TOKEN_TYPE_CONTROL || types[*i] == TOKEN_TYPE_USER_DEFINED
})
.map(|(i, s)| (s.clone(), i as u32))
.filter(|(s, _)| !s.is_empty())
.collect();
specials.sort_by(|a, b| b.0.len().cmp(&a.0.len()).then_with(|| a.1.cmp(&b.1)));
let merge_strs = r.get("tokenizer.ggml.merges")?.as_string_array()?;
let mut merges: HashMap<(String, String), u32> = HashMap::with_capacity(merge_strs.len());
for (rank, m) in merge_strs.iter().enumerate() {
if let Some(sp) = m.find(' ') {
let left = m[..sp].to_string();
let right = m[sp + 1..].to_string();
merges.insert((left, right), rank as u32);
}
}
let mut byte_fallback = [None; 256];
for b in 0u32..256 {
let key = format!("<0x{:02X}>", b);
if let Some(&id) = vocab.get(&key) {
byte_fallback[b as usize] = Some(id);
}
}
Ok(Self {
vocab,
rev_vocab: tokens,
merges,
specials,
byte_fallback,
})
}
pub fn vocab_size(&self) -> u32 {
self.rev_vocab.len() as u32
}
pub fn id_to_str(&self, id: u32) -> Option<&str> {
self.rev_vocab.get(id as usize).map(|s| s.as_str())
}
pub fn str_to_id(&self, s: &str) -> Option<u32> {
self.rev_vocab.iter().position(|t| t == s).map(|i| i as u32)
}
pub fn encode(&self, s: &str) -> Vec<u32> {
let s_owned: String;
let s_ref: &str = if s.is_empty()
|| s.starts_with(' ') || s.starts_with(SPM_SPACE)
|| self.specials.iter().any(|(sp, _)| s.starts_with(sp.as_str()))
{
s
} else {
s_owned = format!("{SPM_SPACE}{s}");
&s_owned
};
let mut frags: Vec<Frag> = vec![Frag::Text(s_ref.to_string())];
for (special, sid) in &self.specials {
let mut next: Vec<Frag> = Vec::new();
for f in frags.into_iter() {
match f {
Frag::Special(_) => next.push(f),
Frag::Text(t) => split_around(&t, special, *sid, &mut next),
}
}
frags = next;
}
let mut out = Vec::new();
for f in frags {
match f {
Frag::Special(id) => out.push(id),
Frag::Text(t) => self.encode_text(&t, &mut out),
}
}
out
}
fn encode_text(&self, raw: &str, out: &mut Vec<u32>) {
if raw.is_empty() {
return;
}
let normalized: String = raw
.chars()
.map(|c| if c == ' ' { SPM_SPACE } else { c })
.collect();
if let Some(&id) = self.vocab.get(&normalized) {
out.push(id);
return;
}
let mut toks: Vec<String> = normalized.chars().map(|c| c.to_string()).collect();
loop {
let mut best_rank = u32::MAX;
let mut best_idx: i32 = -1;
for i in 0..toks.len().saturating_sub(1) {
if let Some(&rank) = self.merges.get(&(toks[i].clone(), toks[i + 1].clone()))
&& rank < best_rank
{
best_rank = rank;
best_idx = i as i32;
}
}
if best_idx < 0 {
break;
}
let i = best_idx as usize;
let merged = format!("{}{}", toks[i], toks[i + 1]);
if !self.vocab.contains_key(&merged) {
break;
}
toks[i] = merged;
toks.remove(i + 1);
}
for tok in toks {
if let Some(&id) = self.vocab.get(&tok) {
out.push(id);
} else {
for b in tok.as_bytes() {
if let Some(id) = self.byte_fallback[*b as usize] {
out.push(id);
} else {
log::debug!("unknown byte token: 0x{:02X}", b);
}
}
}
}
}
}
enum Frag {
Text(String),
Special(u32),
}
fn split_around(text: &str, special: &str, sid: u32, out: &mut Vec<Frag>) {
if text.is_empty() {
return;
}
if special.is_empty() {
out.push(Frag::Text(text.to_string()));
return;
}
let mut start = 0usize;
while let Some(pos) = text[start..].find(special) {
let abs = start + pos;
if abs > start {
out.push(Frag::Text(text[start..abs].to_string()));
}
out.push(Frag::Special(sid));
start = abs + special.len();
}
if start < text.len() {
out.push(Frag::Text(text[start..].to_string()));
}
}