use std::collections::HashMap;
use std::fs;
use std::path::Path;
use crate::error::{LLMError, LLMResult};
use crate::hf_loader::HFLoader;
pub struct HFTokenizer {
vocab: HashMap<String, u32>,
id_to_token: HashMap<u32, String>,
merges: Vec<(String, String)>,
special_tokens: SpecialTokens,
added_tokens: HashMap<String, u32>,
}
#[derive(Debug, Clone, Default)]
pub struct SpecialTokens {
pub bos_token: Option<String>,
pub eos_token: Option<String>,
pub unk_token: Option<String>,
pub pad_token: Option<String>,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub unk_token_id: Option<u32>,
pub pad_token_id: Option<u32>,
}
impl HFTokenizer {
pub fn from_pretrained(model_id: &str) -> LLMResult<Self> {
let loader = HFLoader::new(model_id)?;
let cache_dir = loader.cache_dir();
Self::download_tokenizer_files(&loader)?;
Self::from_directory(cache_dir)
}
pub fn from_directory<P: AsRef<Path>>(path: P) -> LLMResult<Self> {
let path = path.as_ref();
let tokenizer_json = path.join("tokenizer.json");
if tokenizer_json.exists() {
return Self::load_tokenizer_json(&tokenizer_json);
}
let vocab_json = path.join("vocab.json");
let merges_txt = path.join("merges.txt");
if vocab_json.exists() {
return Self::load_legacy_format(&vocab_json, &merges_txt, path);
}
Err(LLMError::ModelNotFound(
"No tokenizer.json or vocab.json found".to_string(),
))
}
fn download_tokenizer_files(loader: &HFLoader) -> LLMResult<()> {
if loader.download_file_if_exists("tokenizer.json")? {
let _ = loader.download_file_if_exists("tokenizer_config.json");
return Ok(());
}
loader.download_file("vocab.json")?;
let _ = loader.download_file_if_exists("merges.txt");
let _ = loader.download_file_if_exists("special_tokens_map.json");
Ok(())
}
fn load_tokenizer_json(path: &Path) -> LLMResult<Self> {
let content = fs::read_to_string(path).map_err(|e| LLMError::IoError(e.to_string()))?;
let json: serde_json::Value =
serde_json::from_str(&content).map_err(|e| LLMError::ParseError(e.to_string()))?;
let mut vocab = HashMap::new();
let mut id_to_token = HashMap::new();
if let Some(model) = json.get("model") {
if let Some(v) = model.get("vocab").and_then(|v| v.as_object()) {
for (token, id) in v {
if let Some(id) = id.as_u64() {
vocab.insert(token.clone(), id as u32);
id_to_token.insert(id as u32, token.clone());
}
}
}
}
let mut merges = Vec::new();
if let Some(model) = json.get("model") {
if let Some(m) = model.get("merges").and_then(|m| m.as_array()) {
for merge in m {
if let Some(s) = merge.as_str() {
let parts: Vec<&str> = s.split(' ').collect();
if parts.len() == 2 {
merges.push((parts[0].to_string(), parts[1].to_string()));
}
}
}
}
}
let mut added_tokens = HashMap::new();
if let Some(tokens) = json.get("added_tokens").and_then(|t| t.as_array()) {
for token in tokens {
if let (Some(content), Some(id)) = (
token.get("content").and_then(|c| c.as_str()),
token.get("id").and_then(|i| i.as_u64()),
) {
added_tokens.insert(content.to_string(), id as u32);
id_to_token.insert(id as u32, content.to_string());
}
}
}
let special_tokens = Self::extract_special_tokens(&json, &vocab, &added_tokens);
Ok(Self {
vocab,
id_to_token,
merges,
special_tokens,
added_tokens,
})
}
fn load_legacy_format(vocab_path: &Path, merges_path: &Path, dir: &Path) -> LLMResult<Self> {
let vocab_content =
fs::read_to_string(vocab_path).map_err(|e| LLMError::IoError(e.to_string()))?;
let vocab_json: HashMap<String, u32> = serde_json::from_str(&vocab_content)
.map_err(|e| LLMError::ParseError(e.to_string()))?;
let vocab = vocab_json;
let id_to_token: HashMap<u32, String> =
vocab.iter().map(|(k, v)| (*v, k.clone())).collect();
let mut merges = Vec::new();
if merges_path.exists() {
let merges_content =
fs::read_to_string(merges_path).map_err(|e| LLMError::IoError(e.to_string()))?;
for line in merges_content.lines().skip(1) {
let parts: Vec<&str> = line.split(' ').collect();
if parts.len() == 2 {
merges.push((parts[0].to_string(), parts[1].to_string()));
}
}
}
let mut special_tokens = SpecialTokens::default();
let special_path = dir.join("special_tokens_map.json");
if special_path.exists() {
if let Ok(content) = fs::read_to_string(&special_path) {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
if let Some(bos) = json.get("bos_token").and_then(|t| t.as_str()) {
special_tokens.bos_token = Some(bos.to_string());
special_tokens.bos_token_id = vocab.get(bos).copied();
}
if let Some(eos) = json.get("eos_token").and_then(|t| t.as_str()) {
special_tokens.eos_token = Some(eos.to_string());
special_tokens.eos_token_id = vocab.get(eos).copied();
}
if let Some(unk) = json.get("unk_token").and_then(|t| t.as_str()) {
special_tokens.unk_token = Some(unk.to_string());
special_tokens.unk_token_id = vocab.get(unk).copied();
}
if let Some(pad) = json.get("pad_token").and_then(|t| t.as_str()) {
special_tokens.pad_token = Some(pad.to_string());
special_tokens.pad_token_id = vocab.get(pad).copied();
}
}
}
}
Ok(Self {
vocab,
id_to_token,
merges,
special_tokens,
added_tokens: HashMap::new(),
})
}
fn extract_special_tokens(
json: &serde_json::Value,
_vocab: &HashMap<String, u32>,
_added_tokens: &HashMap<String, u32>,
) -> SpecialTokens {
let mut special = SpecialTokens::default();
if let Some(tokens) = json.get("added_tokens").and_then(|t| t.as_array()) {
for token in tokens {
let content = token.get("content").and_then(|c| c.as_str());
let id = token.get("id").and_then(|i| i.as_u64()).map(|i| i as u32);
let special_flag = token
.get("special")
.and_then(|s| s.as_bool())
.unwrap_or(false);
if let (Some(content), Some(id)) = (content, id) {
if special_flag {
let lower = content.to_lowercase();
if lower.contains("bos") || lower == "<s>" {
special.bos_token = Some(content.to_string());
special.bos_token_id = Some(id);
} else if lower.contains("eos") || lower == "</s>" {
special.eos_token = Some(content.to_string());
special.eos_token_id = Some(id);
} else if lower.contains("unk") {
special.unk_token = Some(content.to_string());
special.unk_token_id = Some(id);
} else if lower.contains("pad") {
special.pad_token = Some(content.to_string());
special.pad_token_id = Some(id);
}
}
}
}
}
special
}
pub fn encode(&self, text: &str) -> LLMResult<Vec<u32>> {
let mut tokens = Vec::new();
if let Some(bos_id) = self.special_tokens.bos_token_id {
tokens.push(bos_id);
}
let text_tokens = self.bpe_encode(text)?;
tokens.extend(text_tokens);
Ok(tokens)
}
pub fn encode_with_options(
&self,
text: &str,
add_bos: bool,
add_eos: bool,
) -> LLMResult<Vec<u32>> {
let mut tokens = Vec::new();
if add_bos {
if let Some(bos_id) = self.special_tokens.bos_token_id {
tokens.push(bos_id);
}
}
let text_tokens = self.bpe_encode(text)?;
tokens.extend(text_tokens);
if add_eos {
if let Some(eos_id) = self.special_tokens.eos_token_id {
tokens.push(eos_id);
}
}
Ok(tokens)
}
fn bpe_encode(&self, text: &str) -> LLMResult<Vec<u32>> {
let mut tokens = Vec::new();
for word in text.split_inclusive(|c: char| c.is_whitespace() || c.is_ascii_punctuation()) {
if word.is_empty() {
continue;
}
if let Some(&id) = self.vocab.get(word) {
tokens.push(id);
continue;
}
if let Some(&id) = self.added_tokens.get(word) {
tokens.push(id);
continue;
}
let word_tokens = self.bpe_tokenize_word(word)?;
tokens.extend(word_tokens);
}
Ok(tokens)
}
fn bpe_tokenize_word(&self, word: &str) -> LLMResult<Vec<u32>> {
if word.is_empty() {
return Ok(vec![]);
}
let mut parts: Vec<String> = word.chars().map(|c| c.to_string()).collect();
for (a, b) in &self.merges {
let mut i = 0;
while i < parts.len().saturating_sub(1) {
if &parts[i] == a && &parts[i + 1] == b {
let merged = format!("{}{}", a, b);
parts[i] = merged;
parts.remove(i + 1);
} else {
i += 1;
}
}
}
let mut ids = Vec::new();
for part in parts {
if let Some(&id) = self.vocab.get(&part) {
ids.push(id);
} else if let Some(unk_id) = self.special_tokens.unk_token_id {
ids.push(unk_id);
} else {
for byte in part.as_bytes() {
let byte_token = format!("<0x{:02X}>", byte);
if let Some(&id) = self.vocab.get(&byte_token) {
ids.push(id);
}
}
}
}
Ok(ids)
}
pub fn decode(&self, ids: &[u32]) -> LLMResult<String> {
self.decode_with_options(ids, true)
}
pub fn decode_with_options(&self, ids: &[u32], skip_special: bool) -> LLMResult<String> {
let mut text = String::new();
for &id in ids {
if skip_special
&& (Some(id) == self.special_tokens.bos_token_id
|| Some(id) == self.special_tokens.eos_token_id
|| Some(id) == self.special_tokens.pad_token_id)
{
continue;
}
if let Some(token) = self.id_to_token.get(&id) {
if token.starts_with("<0x") && token.ends_with('>') {
if let Ok(byte) = u8::from_str_radix(&token[3..5], 16) {
text.push(byte as char);
continue;
}
}
text.push_str(token);
}
}
let text = text.replace("Ġ", " "); let text = text.replace("▁", " ");
Ok(text)
}
pub fn vocab_size(&self) -> usize {
self.vocab.len() + self.added_tokens.len()
}
pub fn special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
pub fn bos_token_id(&self) -> Option<u32> {
self.special_tokens.bos_token_id
}
pub fn eos_token_id(&self) -> Option<u32> {
self.special_tokens.eos_token_id
}
pub fn id_to_token(&self, id: u32) -> Option<&str> {
self.id_to_token.get(&id).map(|s| s.as_str())
}
pub fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab
.get(token)
.copied()
.or_else(|| self.added_tokens.get(token).copied())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_special_tokens_default() {
let tokens = SpecialTokens::default();
assert!(tokens.bos_token.is_none());
assert!(tokens.eos_token.is_none());
}
#[test]
fn test_bpe_simple() {
let mut vocab = HashMap::new();
vocab.insert("hello".to_string(), 0);
vocab.insert(" ".to_string(), 1);
vocab.insert("world".to_string(), 2);
let tokenizer = HFTokenizer {
vocab: vocab.clone(),
id_to_token: vocab.iter().map(|(k, v)| (*v, k.clone())).collect(),
merges: vec![],
special_tokens: SpecialTokens::default(),
added_tokens: HashMap::new(),
};
let text = tokenizer.decode(&[0, 1, 2]).unwrap();
assert_eq!(text, "hello world");
}
#[test]
fn test_vocab_size() {
let mut vocab = HashMap::new();
vocab.insert("a".to_string(), 0);
vocab.insert("b".to_string(), 1);
let mut added = HashMap::new();
added.insert("<special>".to_string(), 2);
let tokenizer = HFTokenizer {
vocab,
id_to_token: HashMap::new(),
merges: vec![],
special_tokens: SpecialTokens::default(),
added_tokens: added,
};
assert_eq!(tokenizer.vocab_size(), 3);
}
}