use std::collections::HashMap;
use crate::error::{Result, TextError};
const SPIECE_UNDERLINE: char = '▁';
const NUM_BYTE_TOKENS: usize = 256;
const BYTE_TOKEN_OFFSET: u32 = 3;
const ALPHA_TOKEN_OFFSET: u32 = BYTE_TOKEN_OFFSET + NUM_BYTE_TOKENS as u32;
#[derive(Debug, Clone)]
pub struct LlamaTokenizerConfig {
pub unk_id: u32,
pub bos_id: u32,
pub eos_id: u32,
pub pad_id: i64,
pub add_bos: bool,
pub add_eos: bool,
pub normalize_whitespace: bool,
}
impl Default for LlamaTokenizerConfig {
fn default() -> Self {
LlamaTokenizerConfig {
unk_id: 0,
bos_id: 1,
eos_id: 2,
pad_id: -1,
add_bos: true,
add_eos: false,
normalize_whitespace: true,
}
}
}
#[derive(Debug, Clone)]
pub struct LlamaTokenizer {
pub vocab: HashMap<String, u32>,
id_to_piece: HashMap<u32, String>,
pub merges: Vec<(u32, u32)>,
pub byte_tokens: [u32; NUM_BYTE_TOKENS],
pub config: LlamaTokenizerConfig,
}
impl LlamaTokenizer {
pub fn new_minimal(vocab_size: usize, config: LlamaTokenizerConfig) -> Self {
let min_size = 312usize; let effective_size = vocab_size.max(min_size);
let mut vocab: HashMap<String, u32> = HashMap::with_capacity(effective_size);
let mut id_to_piece: HashMap<u32, String> = HashMap::with_capacity(effective_size);
let specials = [
(config.unk_id, "<unk>"),
(config.bos_id, "<s>"),
(config.eos_id, "</s>"),
];
for (id, tok) in &specials {
vocab.insert(tok.to_string(), *id);
id_to_piece.insert(*id, tok.to_string());
}
let mut byte_tokens = [0u32; NUM_BYTE_TOKENS];
for b in 0u32..256 {
let piece = format!("<0x{:02X}>", b);
let id = BYTE_TOKEN_OFFSET + b;
vocab.insert(piece.clone(), id);
id_to_piece.insert(id, piece);
byte_tokens[b as usize] = id;
}
let mut alpha_id = ALPHA_TOKEN_OFFSET;
for ch in ('a'..='z').chain('A'..='Z') {
let piece = ch.to_string();
vocab.insert(piece.clone(), alpha_id);
id_to_piece.insert(alpha_id, piece);
alpha_id += 1;
}
let spiece_id = alpha_id; let spiece_str = SPIECE_UNDERLINE.to_string();
vocab.insert(spiece_str.clone(), spiece_id);
id_to_piece.insert(spiece_id, spiece_str.clone());
let mut next_id = spiece_id + 1;
let common_digrams: &[(&str, &str)] = &[
("h", "e"), ("l", "l"), ("he", "ll"), ("o", "w"), ("w", "o"), ("▁", "h"), ("▁", "w"), ("r", "l"), ("l", "d"), ("o", "r"), ];
let mut merges: Vec<(u32, u32)> = Vec::new();
for (left_str, right_str) in common_digrams {
if next_id as usize >= effective_size {
break;
}
let left_id = vocab.get(*left_str).copied();
let right_id = vocab.get(*right_str).copied();
let merged = format!("{}{}", left_str, right_str);
if let (Some(lid), Some(rid)) = (left_id, right_id) {
if !vocab.contains_key(&merged) {
vocab.insert(merged.clone(), next_id);
id_to_piece.insert(next_id, merged);
merges.push((lid, rid));
next_id += 1;
} else {
merges.push((lid, rid));
}
}
}
LlamaTokenizer {
vocab,
id_to_piece,
merges,
byte_tokens,
config,
}
}
pub fn from_vocab_and_merges(
vocab: HashMap<String, u32>,
merges: Vec<(String, String)>,
config: LlamaTokenizerConfig,
) -> Result<Self> {
let id_to_piece: HashMap<u32, String> =
vocab.iter().map(|(k, v)| (*v, k.clone())).collect();
let merges_ids: Vec<(u32, u32)> = merges
.iter()
.filter_map(|(l, r)| {
let lid = vocab.get(l.as_str()).copied()?;
let rid = vocab.get(r.as_str()).copied()?;
Some((lid, rid))
})
.collect();
let mut byte_tokens = [0u32; NUM_BYTE_TOKENS];
for b in 0usize..NUM_BYTE_TOKENS {
let piece = format!("<0x{:02X}>", b);
byte_tokens[b] = vocab.get(&piece).copied().ok_or_else(|| {
TextError::VocabularyError(format!(
"LlamaTokenizer: byte token '{}' missing from vocab",
piece
))
})?;
}
Ok(LlamaTokenizer {
vocab,
id_to_piece,
merges: merges_ids,
byte_tokens,
config,
})
}
pub fn encode(&self, text: &str) -> Vec<u32> {
if text.is_empty() {
return if self.config.add_bos {
vec![self.config.bos_id]
} else {
vec![]
};
}
let normalised: String = if self.config.normalize_whitespace {
let replaced = text.replace(' ', &SPIECE_UNDERLINE.to_string());
format!("{}{}", SPIECE_UNDERLINE, replaced)
} else {
text.to_string()
};
let mut token_ids: Vec<u32> = self.text_to_initial_ids(&normalised);
if !self.merges.is_empty() {
token_ids = self.apply_bpe_merges(token_ids);
}
let mut result = Vec::with_capacity(token_ids.len() + 2);
if self.config.add_bos {
result.push(self.config.bos_id);
}
result.extend(token_ids);
if self.config.add_eos {
result.push(self.config.eos_id);
}
result
}
pub fn decode(&self, ids: &[u32]) -> String {
let mut byte_buf: Vec<u8> = Vec::new();
for &id in ids {
if id == self.config.bos_id || id == self.config.eos_id {
continue;
}
if let Some(piece) = self.id_to_piece.get(&id) {
if let Some(b) = parse_byte_token(piece) {
byte_buf.push(b);
} else {
let text_piece = piece.replace(SPIECE_UNDERLINE, " ");
byte_buf.extend_from_slice(text_piece.as_bytes());
}
}
}
let decoded = String::from_utf8_lossy(&byte_buf).into_owned();
if decoded.starts_with(' ') {
decoded[1..].to_string()
} else {
decoded
}
}
pub fn vocab_size(&self) -> usize {
self.vocab.len()
}
pub fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
pub fn id_to_token(&self, id: u32) -> Option<&str> {
self.id_to_piece.get(&id).map(|s| s.as_str())
}
pub fn to_hf_json(&self) -> String {
use super::hf_json::{HfAddedToken, HfModel, HfTokenizerJson};
use std::collections::HashMap;
let merges_strs: Vec<String> = self
.merges
.iter()
.filter_map(|(lid, rid)| {
let l = self.id_to_piece.get(lid)?;
let r = self.id_to_piece.get(rid)?;
Some(format!("{} {}", l, r))
})
.collect();
let model = HfModel {
model_type: "BPE".to_string(),
vocab: self.vocab.clone(),
merges: Some(merges_strs),
unk_token: Some("<unk>".to_string()),
continuing_subword_prefix: None,
max_input_chars_per_word: None,
};
let mut added_tokens = Vec::new();
for (id, name) in [
(self.config.unk_id, "<unk>"),
(self.config.bos_id, "<s>"),
(self.config.eos_id, "</s>"),
] {
added_tokens.push(HfAddedToken::special(id, name));
}
let hf = HfTokenizerJson {
version: "1.0".to_string(),
model,
added_tokens,
normalizer_json: None,
pre_tokenizer_json: None,
post_processor_json: None,
decoder_json: None,
};
hf.to_json_string()
}
fn text_to_initial_ids(&self, text: &str) -> Vec<u32> {
let mut ids = Vec::new();
let chars: Vec<char> = text.chars().collect();
let mut i = 0;
while i < chars.len() {
let piece = chars[i].to_string();
if let Some(&id) = self.vocab.get(&piece) {
ids.push(id);
i += 1;
} else {
let mut buf = [0u8; 4];
let encoded = chars[i].encode_utf8(&mut buf);
for &b in encoded.as_bytes() {
ids.push(self.byte_tokens[b as usize]);
}
i += 1;
}
}
ids
}
fn apply_bpe_merges(&self, mut ids: Vec<u32>) -> Vec<u32> {
if ids.len() <= 1 {
return ids;
}
let merge_lookup: HashMap<(u32, u32), (usize, u32)> = self
.merges
.iter()
.enumerate()
.filter_map(|(rank, (lid, rid))| {
let left_piece = self.id_to_piece.get(lid)?;
let right_piece = self.id_to_piece.get(rid)?;
let merged_piece = format!("{}{}", left_piece, right_piece);
let merged_id = self.vocab.get(&merged_piece).copied()?;
Some(((*lid, *rid), (rank, merged_id)))
})
.collect();
loop {
if ids.len() < 2 {
break;
}
let best = ids
.windows(2)
.enumerate()
.filter_map(|(i, w)| {
merge_lookup
.get(&(w[0], w[1]))
.map(|&(rank, mid)| (i, rank, mid))
})
.min_by_key(|&(_, rank, _)| rank);
match best {
None => break,
Some((pos, _, merged_id)) => {
let mut new_ids = Vec::with_capacity(ids.len() - 1);
new_ids.extend_from_slice(&ids[..pos]);
new_ids.push(merged_id);
new_ids.extend_from_slice(&ids[pos + 2..]);
ids = new_ids;
}
}
}
ids
}
}
fn parse_byte_token(piece: &str) -> Option<u8> {
if piece.len() != 6 {
return None;
}
let bytes = piece.as_bytes();
if bytes[0] != b'<' || bytes[1] != b'0' || bytes[2] != b'x' || bytes[5] != b'>' {
return None;
}
let hex = &piece[3..5];
u8::from_str_radix(hex, 16).ok()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tokenizer() -> LlamaTokenizer {
LlamaTokenizer::new_minimal(512, LlamaTokenizerConfig::default())
}
#[test]
fn new_minimal_special_tokens() {
let tok = make_tokenizer();
assert_eq!(tok.token_to_id("<unk>"), Some(0));
assert_eq!(tok.token_to_id("<s>"), Some(1));
assert_eq!(tok.token_to_id("</s>"), Some(2));
}
#[test]
fn new_minimal_byte_tokens_at_correct_ids() {
let tok = make_tokenizer();
assert_eq!(tok.token_to_id("<0x00>"), Some(BYTE_TOKEN_OFFSET));
assert_eq!(tok.token_to_id("<0xFF>"), Some(BYTE_TOKEN_OFFSET + 255));
}
#[test]
fn encode_returns_nonempty() {
let tok = make_tokenizer();
let ids = tok.encode("hello");
assert!(!ids.is_empty());
}
#[test]
fn encode_prepends_bos_when_configured() {
let tok = make_tokenizer();
let ids = tok.encode("hello");
assert_eq!(ids[0], 1, "first token should be BOS=1");
}
#[test]
fn encode_empty_string_with_bos() {
let tok = make_tokenizer();
let ids = tok.encode("");
assert_eq!(ids, vec![1], "empty + add_bos → [BOS]");
}
#[test]
fn encode_empty_string_without_bos() {
let config = LlamaTokenizerConfig {
add_bos: false,
..Default::default()
};
let tok = LlamaTokenizer::new_minimal(512, config);
let ids = tok.encode("");
assert!(ids.is_empty());
}
#[test]
fn encode_no_bos_no_bos_prefix() {
let config = LlamaTokenizerConfig {
add_bos: false,
..Default::default()
};
let tok = LlamaTokenizer::new_minimal(512, config);
let ids = tok.encode("hello");
assert_ne!(ids[0], 1);
}
#[test]
fn ascii_roundtrip() {
let tok = make_tokenizer();
let text = "hello world";
let ids = tok.encode(text);
let decoded = tok.decode(&ids);
assert_eq!(decoded, text, "roundtrip failed: got '{}'", decoded);
}
#[test]
fn cjk_byte_fallback_no_panic() {
let tok = make_tokenizer();
let ids = tok.encode("こんにちは");
assert!(!ids.is_empty(), "encoding CJK should produce tokens");
}
#[test]
fn vocab_size_gt_3() {
let tok = make_tokenizer();
assert!(tok.vocab_size() > 3);
}
#[test]
fn unk_token_lookup() {
let tok = make_tokenizer();
assert_eq!(tok.token_to_id("<unk>"), Some(0));
}
#[test]
fn id_to_token_bos() {
let tok = make_tokenizer();
let s = tok.id_to_token(1);
assert_eq!(s, Some("<s>"));
}
#[test]
fn all_256_bytes_have_ids() {
let tok = make_tokenizer();
for b in 0usize..256 {
let piece = format!("<0x{:02X}>", b);
assert!(
tok.token_to_id(&piece).is_some(),
"byte token {} not found",
piece
);
}
}
#[test]
fn decode_byte_tokens_reconstructs_bytes() {
let tok = make_tokenizer();
let byte_id = tok.token_to_id("<0x41>").expect("byte token missing");
let decoded = tok.decode(&[byte_id]);
assert_eq!(decoded, "A", "expected 'A', got '{}'", decoded);
}
#[test]
fn to_hf_json_contains_bpe() {
let tok = make_tokenizer();
let json = tok.to_hf_json();
assert!(json.contains("BPE"), "HF JSON must contain BPE model type");
}
#[test]
fn config_defaults() {
let cfg = LlamaTokenizerConfig::default();
assert_eq!(cfg.unk_id, 0);
assert_eq!(cfg.bos_id, 1);
assert_eq!(cfg.eos_id, 2);
assert_eq!(cfg.pad_id, -1);
assert!(cfg.add_bos);
assert!(!cfg.add_eos);
assert!(cfg.normalize_whitespace);
}
}